8.2.13 Supervised GraphWiseモデルの予測の説明
Supervised GraphWiseモデルの予測に重要な特徴と頂点を理解するために、YingなどによってGNNExplainer
と同様の手法を使用してSupervisedGnnExplanation
を生成できます。
説明には、次に関連する情報が保持されます。
- グラフ構造: 各頂点の重要度スコア
- 特徴: 各グラフ・プロパティの重要度スコア
ノート:
説明されている頂点には、常に重要度1が割り当てられます。さらに、特徴の重要度は、最も重要な特徴の重要度が1となるようにスケーリングされます。また、SupervisedGnnExplanation
には、推測された埋込み、ロジットおよびラベルが含まれます。inferAndGetExplanation
メソッドは、エッジの特徴に依存しないすべての適合SupervisedGraphWiseModel
モデルで使用できます。最適な結果を得るには、特徴は0を中心とする必要があります。
たとえば、ラベルと相関する特徴と、そうでない別の特徴が含まれる単純なグラフであるとします。したがって、特徴の重要度は大幅に異なること(ラベルと相関する特徴がより重要)が予想される一方、構造上の重要度は大きな役割を果たしません。この場合は、次のように説明を生成できます。
opg4j> var simpleGraph = session.createGraphBuilder().
addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5).
setProperty("label", true).
addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5).
setProperty("label", false).
addEdge(0, 1).build()
// build and train a Supervised GraphWise model
// explain prediction of vertex 0
opg4j> var explanation = model.inferAndGetExplanation(simpleGraph, simpleGraph.getVertex(0))
opg4j> var constProperty = simpleGraph.getVertexProperty("const_feature")
opg4j> var labelProperty = simpleGraph.getVertexProperty("label_feature")
// retrieve feature importances
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important
// retrieve computation graph with importances
opg4j> var importanceGraph = explanation.getImportanceGraph()
// retrieve importance of vertices
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0) // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1) // available if vertex 1 part of computation
PgxGraph simpleGraph = session.createGraphBuilder()
.addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5)
.setProperty("label", true)
.addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5)
.setProperty("label", false)
.addEdge(0, 1).build();
// build and train a Supervised GraphWise model
// explain prediction of vertex 0
SupervisedGnnExplanation<Integer> explanation = model.inferAndGetExplanation(simpleGraph,
simpleGraph.getVertex(0));
VertexProperty<Integer, Float> constProperty = simpleGraph.getVertexProperty("const_feature");
VertexProperty<Integer, Float> labelProperty = simpleGraph.getVertexProperty("label_feature");
// retrieve feature importances
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceConstProp = featureImportances.get(constProperty); // small as unimportant
float importanceLabelProp = featureImportances.get(labelProperty); // large (1) as important
// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();
// retrieve importance of vertices
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0); // has importance 1
float importanceVertex1 = importanceProperty.get(1); // available if vertex 1 part of computation
simple_graph = session.create_graph_builder()
.add_vertex(0).set_property("label_feature", 0.5).set_property("const_feature", 0.5)
.set_property("label", true)
.add_vertex(1).set_property("label_feature", -0.5).set_property("const_feature", 0.5)
.set_property("label", false)
.add_edge(0, 1).build()
# build and train a Supervised GraphWise model
# explain prediction of vertex 0
explanation = model.infer_and_get_explanation(simple_graph, simple_graph.get_vertex(0))
const_property = simple_graph.get_vertex_property("const_feature")
label_property = simple_graph.get_vertex_property("label_feature")
# retrieve feature importances
feature_importances = explanation.get_vertex_feature_importance()
importance_const_prop = feature_importances[const_property]
importance_label_prop = feature_importances[label_property]
# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()
# retrieve importance of vertices
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0]
importance_vertex_1 = importance_property[1]