17.2.16 Supervised GraphWiseモデルの予測の説明

Supervised GraphWiseモデルの予測に重要な特徴と頂点を理解するために、YingなどによってGNNExplainerと同様の手法を使用してSupervisedGnnExplanationを生成できます。

説明には、次に関連する情報が保持されます。

  • グラフ構造: 各頂点の重要度スコア
  • 特徴: 各グラフ・プロパティの重要度スコア

ノート:

説明されている頂点には、常に重要度1が割り当てられます。さらに、特徴の重要度は、最も重要な特徴の重要度が1となるようにスケーリングされます。

また、SupervisedGnnExplanationには、推測された埋込み、ロジットおよびラベルが含まれます。モデルの予測の説明は、SupervisedGnnExplainerオブジェクトを使用して取得できます。オブジェクトは、gnnExplainerメソッドを使用して取得できます。SupervisedGnnExplainerオブジェクトを取得した後、inferAndExplainメソッドを使用して頂点の説明をリクエストできます。

explainerのパラメータは、explainerの作成中または作成後に、関連するsetter関数を使用して構成できます。SupervisedGnnExplainerの構成可能なパラメータは次のとおりです。

  • numOptimizationSteps: explainerが使用する最適化ステップの数。
  • learningRate: explainerの学習率。
  • marginalize: explainerの損失が特徴よりも周辺化されているかどうかを判断します。これは、ゼロに近い値を取る重要な特徴がある場合に役立ちます。周辺化なしでは、explainerは、たとえ重要であってもそのような特徴をマスクすることを学習できます。周辺化では、推定入力分布からの偏差に対するマスクを学習することによってこれを解決します。

最適な結果を得るには、特徴は0を中心とする必要があることに注意してください。

たとえば、ラベルと相関する特徴と、そうでない別の特徴が含まれる単純なグラフであるとします。したがって、特徴の重要度は大幅に異なること(ラベルと相関する特徴がより重要)が予想される一方、構造上の重要度は大きな役割を果たしません。この場合は、次のように説明を生成できます。

opg4j> var simpleGraph = session.createGraphBuilder().
                          addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5).
                          setProperty("label", true).
                          addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5).
                          setProperty("label", false).
                          addEdge(0, 1).build()

// build and train a Supervised GraphWise model as explained in Advanced Hyperparameter Customization

// obtain and configure GnnExplainer
var explainer = model.gnnExplainer().learningRate(0.05)
explainer.numOptimizationSteps(200)

// explain prediction of vertex 0
opg4j> var explanation = explainer.inferAndExplain(simpleGraph, simpleGraph.getVertex(0))
// if you used the devNet loss, you can add the decision threshold as an extra parameter:
// var explanation = explainer.inferAndExplain(simpleGraph, simpleGraph.getVertex(0), 6f)

opg4j> var constProperty = simpleGraph.getVertexProperty("const_feature")
opg4j> var labelProperty = simpleGraph.getVertexProperty("label_feature")

// retrieve feature importances
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important

// retrieve computation graph with importances
opg4j> var importanceGraph = explanation.getImportanceGraph()

// retrieve importance of vertices
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0) // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1) // available if vertex 1 part of computation
    
PgxGraph simpleGraph = session.createGraphBuilder()
    .addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5)
    .setProperty("label", true)
    .addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5)
    .setProperty("label", false)
    .addEdge(0, 1).build();

// build and train a Supervised GraphWise model as explained in Advanced Hyperparameter Customization

// obtain and configure the explainer
SupervisedGnnExplainerexplainer=model.gnnExplainer().learningRate(0.05);
explainer.numOptimizationSteps(200);

// explain prediction of vertex 0
SupervisedGnnExplanation<Integer> explanation = explainer.inferAndExplain(simpleGraph,
    simpleGraph.getVertex(0));

// if we used the devNet loss, we can add the decision threshold as an extra parameter:
// SupervisedGnnExplanation<Integer> explanation = explainer.inferAndExplain(simpleGraph, simpleGraph.getVertex(0), 6f);

VertexProperty<Integer, Float> constProperty = simpleGraph.getVertexProperty("const_feature");
VertexProperty<Integer, Float> labelProperty = simpleGraph.getVertexProperty("label_feature");

// retrieve feature importances
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceConstProp = featureImportances.get(constProperty); // small as unimportant
float importanceLabelProp = featureImportances.get(labelProperty); // large (1) as important

// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();

// retrieve importance of vertices
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0); // has importance 1
float importanceVertex1 = importanceProperty.get(1); // available if vertex 1 part of computation
simple_graph = session.create_graph_builder()
    .add_vertex(0).set_property("label_feature", 0.5).set_property("const_feature", 0.5)
    .set_property("label", true)
    .add_vertex(1).set_property("label_feature", -0.5).set_property("const_feature", 0.5)
    .set_property("label", false)
    .add_edge(0, 1).build()

# build and train a Supervised GraphWise model as explained in Advanced Hyperparameter Customization

# obtain the explainer
explainer = model.gnn_explainer(learning_rate=0.05)
explainer.num_optimization_steps=200

# explain prediction of vertex 0
explanation = explainer.inferAndExplain(simple_graph,simple_graph.get_vertex(0))
# if we used the devNet loss, we can add the decision threshold as an extra parameter:
# explanation = explainer.inferAndExplain(simple_graph, simple_graph.get_vertex(0), 6)

const_property = simple_graph.get_vertex_property("const_feature")
label_property = simple_graph.get_vertex_property("label_feature")

# retrieve feature importances
feature_importances = explanation.get_vertex_feature_importance()
importance_const_prop = feature_importances[const_property]
importance_label_prop = feature_importances[label_property]

# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()

# retrieve importance of vertices
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0]
importance_vertex_1 = importance_property[1]