17.2.7 Setting a Custom Loss Function and Batch Generator (for Anomaly Detection)
It is possible to select different loss functions for the supervised model by providing a
LossFunction
object, and different batch generators by providing a
BatchGenerator
object. This is useful for applications such as
Anomaly Detection, which can be cast into the standard supervised framework but require
different loss functions and batch generators.
SupervisedGraphWise model can use the DevNetLoss
and the
StratifiedOversamplingBatchGenerator
. The
DevNetLoss
takes confidence margin and the value the anomaly takes in the
target property as the two parameters.
The following example assumes that the convLayerConfig
has already been
defined:
opg4j> import oracle.pgx.config.mllib.loss.LossFunctions
opg4j> import oracle.pgx.config.mllib.batchgenerator.BatchGenerators
opg4j> var predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder().
setHiddenDimension(32).
setActivationFunction(ActivationFunction.LINEAR).
build()
opg4j> var model = analyst.supervisedGraphWiseModelBuilder().
setVertexInputPropertyNames("vertex_features").
setEdgeInputPropertyNames("edge_features").
setVertexTargetPropertyName("labels").
setConvLayerConfigs(convLayerConfig).
setPredictionLayerConfigs(predictionLayerConfig).
setLossFunction(LossFunctions.devNetLoss(5.0, true)).
setBatchGenerator(BatchGenerators.STRATIFIED_OVERSAMPLING).
build()
import oracle.pgx.config.mllib.loss.LossFunctions;
import oracle.pgx.config.mllib.batchgenerator.BatchGenerators;
GraphWisePredictionLayerConfig predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder()
.setHiddenDimension(32)
.setActivationFunction(ActivationFunction.LINEAR)
.build();
SupervisedGraphWiseModel model = analyst.supervisedGraphWiseModelBuilder()
.setVertexInputPropertyNames("vertex_features")
.setEdgeInputPropertyNames("edge_features")
.setVertexTargetPropertyName("labels")
.setConvLayerConfigs(convLayerConfig)
.setPredictionLayerConfigs(predictionLayerConfig)
.setLossFunction(LossFunctions.devNetLoss(5.0, true))
.setBatchGenerator(BatchGenerators.STRATIFIED_OVERSAMPLING)
.build();
from pypgx.api.mllib import DevNetLoss
pred_layer_config = dict(hidden_dim=32,
activation_fn='LINEAR')
pred_layer = analyst.graphwise_pred_layer_config(**pred_layer_config)
params = dict(vertex_target_property_name="labels",
conv_layer_config=[conv_layer],
pred_layer_config=[pred_layer],
vertex_input_property_names=["vertex_features"],
edge_input_property_names=["edge_features"],
loss_fn=DevNetLoss(5.0, True),
batch_gen='Stratified_Oversampling',
seed=17)
model = analyst.supervised_graphwise_builder(**params)