8.2.4 カスタム損失関数およびバッチ・ジェネレータの設定(異常検出用)

LossFunctionオブジェクトを指定することで教師ありモデルに対して様々な損失関数を選択でき、BatchGeneratorオブジェクトを指定することで様々なバッチ・ジェネレータを選択できます。これは、標準の教師ありフレームワークにキャストできるが様々な損失関数およびバッチ・ジェネレータが必要となる、異常検出などのアプリケーションに役立ちます。

SupervisedGraphWiseモデルでは、DevNetLossおよびStratifiedOversamplingBatchGeneratorを使用できます。DevNetLossは、信頼度マージン、およびターゲット・プロパティで異常によってもたらされる値という2つのパラメータを受け入れます。

次の例では、convLayerConfigがすでに定義されているとします。

opg4j> import oracle.pgx.config.mllib.loss.LossFunctions
opg4j> import oracle.pgx.config.mllib.batchgenerator.BatchGenerators
opg4j> var predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder().
         setHiddenDimension(32).
         setActivationFunction(ActivationFunction.LINEAR).
         build()
opg4j> var model = analyst.supervisedGraphWiseModelBuilder().
         setVertexInputPropertyNames("vertex_features").
         setEdgeInputPropertyNames("edge_features").
         setVertexTargetPropertyName("labels").
         setConvLayerConfigs(convLayerConfig).
         setPredictionLayerConfigs(predictionLayerConfig).
         setLossFunction(LossFunctions.devNetLoss(5.0, true)).
         setBatchGenerator(BatchGenerators.STRATIFIED_OVERSAMPLING).
         build()
import oracle.pgx.config.mllib.loss.LossFunctions;
import oracle.pgx.config.mllib.batchgenerator.BatchGenerators;

GraphWisePredictionLayerConfig predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder()
    .setHiddenDimension(32)
    .setActivationFunction(ActivationFunction.LINEAR)
    .build();

SupervisedGraphWiseModel model = analyst.supervisedGraphWiseModelBuilder()
    .setVertexInputPropertyNames("vertex_features")
    .setEdgeInputPropertyNames("edge_features")
    .setVertexTargetPropertyName("labels")
    .setConvLayerConfigs(convLayerConfig)
    .setPredictionLayerConfigs(predictionLayerConfig)
    .setLossFunction(LossFunctions.devNetLoss(5.0, true))
    .setBatchGenerator(BatchGenerators.STRATIFIED_OVERSAMPLING)
    .build();
from pypgx.api.mllib import DevNetLoss

pred_layer_config = dict(hidden_dim=32,
                         activation_fn='LINEAR')

pred_layer = analyst.graphwise_pred_layer_config(**pred_layer_config)

params = dict(vertex_target_property_name="labels",
              conv_layer_config=[conv_layer],
              pred_layer_config=[pred_layer],
              vertex_input_property_names=["vertex_features"],
              edge_input_property_names=["edge_features"],
              loss_fn=DevNetLoss(5.0, True),
              batch_gen='Stratified_Oversampling',
              seed=17)

model = analyst.supervised_graphwise_builder(**params)