17.2.6 Classification Versus Regression Models on Supervised GraphWise Models

When predicting a property, the loss function defines if the model will perform classification tasks or regression tasks.

For classification tasks, the Supervised GraphWise model will infer labels. Even if the property is a number, the model will assign one label for each value found and classify on it. The possible losses for classification tasks are softmax cross entropy, sigmoid cross entropy, and DevNet loss.

For regression tasks, the Supervised GraphWise model will infer values for the property. The loss for regression tasks is the MSE loss.

opg4j> import oracle.pgx.config.mllib.loss.LossFunctions
opg4j> var model = analyst.supervisedGraphWiseModelBuilder().
     setVertexInputPropertyNames("vertex_features").
     setEdgeInputPropertyNames("edge_features").
     setVertexTargetPropertyName("scores").
     setConvLayerConfigs(convLayerConfig).
     setPredictionLayerConfigs(predictionLayerConfig).
     setLossFunction(LossFunctions.MSELoss()).   
     setBatchGenerator(BatchGenerators.STRATIFIED_OVERSAMPLING).
     build()
import oracle.pgx.config.mllib.loss.LossFunctions;

SupervisedGraphWiseModel model = analyst.supervisedGraphWiseModelBuilder()
    .setVertexInputPropertyNames("vertex_features")
    .setEdgeInputPropertyNames("edge_features")
    .setVertexTargetPropertyName("scores")
    .setConvLayerConfigs(convLayerConfig)
    .setPredictionLayerConfigs(predictionLayerConfig)
    .setLossFunction(LossFunctions.MSELoss())   
    .setBatchGenerator(BatchGenerators.STRATIFIED_OVERSAMPLING)
    .build();
from pypgx.api.mllib import MSELoss

params = dict(edge_target_property_name="scores",
              conv_layer_config=[conv_layer],
              pred_layer_config=[pred_layer],
              vertex_input_property_names=["vertex_features"],
              edge_input_property_names=["edge_features"],
              batch_gen='Stratified_Oversampling',
              loss_fn=MSELoss())

model = analyst.supervised_graphwise_builder(**params)