8.2.4 カスタム損失関数およびバッチ・ジェネレータの設定(異常検出用)
LossFunction
オブジェクトを指定することで教師ありモデルに対して様々な損失関数を選択でき、BatchGenerator
オブジェクトを指定することで様々なバッチ・ジェネレータを選択できます。これは、標準の教師ありフレームワークにキャストできるが様々な損失関数およびバッチ・ジェネレータが必要となる、異常検出などのアプリケーションに役立ちます。
SupervisedGraphWiseモデルでは、DevNetLoss
およびStratifiedOversamplingBatchGenerator
を使用できます。DevNetLoss
は、信頼度マージン、およびターゲット・プロパティで異常によってもたらされる値という2つのパラメータを受け入れます。
次の例では、convLayerConfig
がすでに定義されているとします。
opg4j> import oracle.pgx.config.mllib.loss.LossFunctions
opg4j> import oracle.pgx.config.mllib.batchgenerator.BatchGenerators
opg4j> var predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder().
setHiddenDimension(32).
setActivationFunction(ActivationFunction.LINEAR).
build()
opg4j> var model = analyst.supervisedGraphWiseModelBuilder().
setVertexInputPropertyNames("vertex_features").
setEdgeInputPropertyNames("edge_features").
setVertexTargetPropertyName("labels").
setConvLayerConfigs(convLayerConfig).
setPredictionLayerConfigs(predictionLayerConfig).
setLossFunction(LossFunctions.devNetLoss(5.0, true)).
setBatchGenerator(BatchGenerators.STRATIFIED_OVERSAMPLING).
build()
import oracle.pgx.config.mllib.loss.LossFunctions;
import oracle.pgx.config.mllib.batchgenerator.BatchGenerators;
GraphWisePredictionLayerConfig predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder()
.setHiddenDimension(32)
.setActivationFunction(ActivationFunction.LINEAR)
.build();
SupervisedGraphWiseModel model = analyst.supervisedGraphWiseModelBuilder()
.setVertexInputPropertyNames("vertex_features")
.setEdgeInputPropertyNames("edge_features")
.setVertexTargetPropertyName("labels")
.setConvLayerConfigs(convLayerConfig)
.setPredictionLayerConfigs(predictionLayerConfig)
.setLossFunction(LossFunctions.devNetLoss(5.0, true))
.setBatchGenerator(BatchGenerators.STRATIFIED_OVERSAMPLING)
.build();
from pypgx.api.mllib import DevNetLoss
pred_layer_config = dict(hidden_dim=32,
activation_fn='LINEAR')
pred_layer = analyst.graphwise_pred_layer_config(**pred_layer_config)
params = dict(vertex_target_property_name="labels",
conv_layer_config=[conv_layer],
pred_layer_config=[pred_layer],
vertex_input_property_names=["vertex_features"],
edge_input_property_names=["edge_features"],
loss_fn=DevNetLoss(5.0, True),
batch_gen='Stratified_Oversampling',
seed=17)
model = analyst.supervised_graphwise_builder(**params)