8.2.13 Supervised GraphWiseモデルの予測の説明

Supervised GraphWiseモデルの予測に重要な特徴と頂点を理解するために、YingなどによってGNNExplainerと同様の手法を使用してSupervisedGnnExplanationを生成できます。

説明には、次に関連する情報が保持されます。

  • グラフ構造: 各頂点の重要度スコア
  • 特徴: 各グラフ・プロパティの重要度スコア

ノート:

説明されている頂点には、常に重要度1が割り当てられます。さらに、特徴の重要度は、最も重要な特徴の重要度が1となるようにスケーリングされます。

また、SupervisedGnnExplanationには、推測された埋込み、ロジットおよびラベルが含まれます。inferAndGetExplanationメソッドは、エッジの特徴に依存しないすべての適合SupervisedGraphWiseModelモデルで使用できます。最適な結果を得るには、特徴は0を中心とする必要があります。

たとえば、ラベルと相関する特徴と、そうでない別の特徴が含まれる単純なグラフであるとします。したがって、特徴の重要度は大幅に異なること(ラベルと相関する特徴がより重要)が予想される一方、構造上の重要度は大きな役割を果たしません。この場合は、次のように説明を生成できます。

opg4j> var simpleGraph = session.createGraphBuilder().
                          addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5).
                          setProperty("label", true).
                          addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5).
                          setProperty("label", false).
                          addEdge(0, 1).build()

// build and train a Supervised GraphWise model

// explain prediction of vertex 0
opg4j> var explanation = model.inferAndGetExplanation(simpleGraph, simpleGraph.getVertex(0))
opg4j> var constProperty = simpleGraph.getVertexProperty("const_feature")
opg4j> var labelProperty = simpleGraph.getVertexProperty("label_feature")

// retrieve feature importances
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important

// retrieve computation graph with importances
opg4j> var importanceGraph = explanation.getImportanceGraph()

// retrieve importance of vertices
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0) // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1) // available if vertex 1 part of computation
    
PgxGraph simpleGraph = session.createGraphBuilder()
    .addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5)
    .setProperty("label", true)
    .addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5)
    .setProperty("label", false)
    .addEdge(0, 1).build();

// build and train a Supervised GraphWise model

// explain prediction of vertex 0
SupervisedGnnExplanation<Integer> explanation = model.inferAndGetExplanation(simpleGraph,
    simpleGraph.getVertex(0));

VertexProperty<Integer, Float> constProperty = simpleGraph.getVertexProperty("const_feature");
VertexProperty<Integer, Float> labelProperty = simpleGraph.getVertexProperty("label_feature");

// retrieve feature importances
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceConstProp = featureImportances.get(constProperty); // small as unimportant
float importanceLabelProp = featureImportances.get(labelProperty); // large (1) as important

// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();

// retrieve importance of vertices
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0); // has importance 1
float importanceVertex1 = importanceProperty.get(1); // available if vertex 1 part of computation
simple_graph = session.create_graph_builder()
    .add_vertex(0).set_property("label_feature", 0.5).set_property("const_feature", 0.5)
    .set_property("label", true)
    .add_vertex(1).set_property("label_feature", -0.5).set_property("const_feature", 0.5)
    .set_property("label", false)
    .add_edge(0, 1).build()

# build and train a Supervised GraphWise model

# explain prediction of vertex 0
explanation = model.infer_and_get_explanation(simple_graph, simple_graph.get_vertex(0))

const_property = simple_graph.get_vertex_property("const_feature")
label_property = simple_graph.get_vertex_property("label_feature")

# retrieve feature importances
feature_importances = explanation.get_vertex_feature_importance()
importance_const_prop = feature_importances[const_property]
importance_label_prop = feature_importances[label_property]

# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()

# retrieve importance of vertices
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0]
importance_vertex_1 = importance_property[1]