17.4.14 Unsupervised GraphWiseモデルの予測の説明
Unsupervised GraphWiseモデルの予測に重要な特徴と頂点を理解するために、YingなどによるGNNExplainer
と同様の手法を使用してUnsupervisedGnnExplanation
を生成できます。
説明には、次に関連する情報が保持されます。
- グラフ構造: 各頂点の重要度スコア
- 特徴: 各グラフ・プロパティの重要度スコア
ノート:
説明されている頂点には、常に重要度1が割り当てられます。さらに、特徴の重要度は、最も重要な特徴の重要度が1となるようにスケーリングされます。また、UnsupervisedGnnExplanation
には、推測された埋込みが含まれます。モデルの予測の説明は、UnsupervisedGnnExplainer
オブジェクトを使用して取得できます。オブジェクトは、gnnExplainer
メソッドを使用して取得できます。UnsupervisedGnnExplainer
オブジェクトを取得した後、inferAndExplain
メソッドを使用して頂点の説明をリクエストできます。
explainerのパラメータは、explainerの作成中または作成後に、関連するsetter関数を使用して構成できます。UnsupervisedGnnExplainer
の構成可能なパラメータは次のとおりです。
numOptimizationSteps:
explainerが使用する最適化ステップの数。learningRate:
explainerの学習率。marginalize:
explainerの損失が特徴よりも周辺化されているかどうかを判断します。これは、ゼロに近い値を取る重要な特徴がある場合に役立ちます。周辺化なしでは、explainerは、たとえ重要であってもそのような特徴をマスクすることを学習できます。周辺化では、推定入力分布からの偏差に対するマスクを学習することによってこれを解決します。numClusters:
explainerの損失で使用するクラスタの数。教師なしexplainerは、k-meansクラスタリングを使用して、最適化されたexplainerの損失を計算します。グラフ内のコンポーネントのおおよその数がわかっている場合は、クラスタの数をこの数に設定することをお薦めします。numSamples:
explainerを最適化するために使用する頂点サンプルの数。パフォーマンス向上のため、explainerでは、ランダムにサンプリングされたこの数の頂点について損失を計算します。使用するサンプルが多いほど正確になりますが、所要時間が長くなり、使用するリソースが増えます。
最適な結果を得るには、特徴は0を中心とする必要があることに注意してください。
たとえば、k
個の密に接続されたコンポーネントを含む(つまり、同じコンポーネントの頂点の間にエッジが多数あり、任意の2つのコンポーネントの間にエッジが少数あります)、単純なグラフcomponentGraph
があるとします。このグラフでUnsupervised GraphWiseモデルをトレーニングすると、密に接続されたコンポーネントの頂点に対して同様の埋込みを生成するモデルになると予想できます。
次の例では、推論componentGraph
で説明を生成する方法を示します。同じコンポーネントの頂点は、異なるコンポーネントの頂点よりも重要度が高くなると予想されます。なお、この例では、特徴の重要度は関連ありません。
opg4j> var componentGraph = session.readGraphByName("<graph>",GraphSource.PG_PGQL)
// explain prediction of vertex 0
opg4j> var feat1Property = componentGraph.getVertexProperty("feat1")
opg4j> var feat2Property = componentGraph.getVertexProperty("feat2")
// build and train an Unsupervised GraphWise model as explained in Advanced Hyperparameter Customization
// obtain and configure the explainer
// setting the numClusters argument to the expected number of clusters may improve
// explanation results as the explainer optimization will try to cluster samples into
// this number of clusters
opg4j> var explainer = model.gnnExplainer().numClusters(50)
// set the number of samples to compute the loss over during explainer optimization
opg4j> explainer.numSamples(10000)
// explain prediction of vertex 0
opg4j> var explanation = explainer.inferAndExplain(componentGraph, componentGraph.getVertex(0), 10)
// retrieve computation graph with importance
opg4j> var importanceGraph = explanation.getImportanceGraph()
// retrieve importance of vertices
// vertex 1 is in the same densely connected component as vertex 0
// vertex 2 is in a different component
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0) // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1) // high importance
opg4j> var importanceVertex2 = importanceProperty.get(2) // low importance
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important
// optionally retrieve feature importance
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceFeat1Prop = featureImportances.get(feat1Property)
opg4j> var importanceFeat2Prop = featureImportances.get(feat2Property)
PgxGraph componentGraph = session.readGraphByName("<graph>",GraphSource.PG_PGQL); // load graph
VertexProperty<Integer, Float> feat1Property = componentGraph.getVertexProperty("feat1");
VertexProperty<Integer, Float> feat2Property = componentGraph.getVertexProperty("feat2");
// build and train an Unsupervised GraphWise model as explained in Advanced Hyperparameter Customization
// obtain and configure the explainer
// setting the numClusters argument to the expected number of clusters may improve
// explanation results as the explainer optimization will try to cluster samples into
// this number of clusters
UnsupervisedGnnExplainer explainer = model.gnnExplainer().numClusters(50);
// set the number of samples to compute the loss over during explainer optimization
explainer.numSamples(10000);
// explain prediction of vertex 0
UnsupervisedGnnExplanation<Integer> explanation = explainer.inferAndExplain(componentGraph, componentGraph.getVertex(0));
// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();
// retrieve importance of vertices
// vertex 1 is in the same densely connected component as vertex 0
// vertex 2 is in a different component
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0); // has importance 1
float importanceVertex1 = importanceProperty.get(1); // high importance
float importanceVertex2 = importanceProperty.get(2); // low importance
// retrieve feature importance (not relevant for this example)
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceFeat1Prop = featureImportances.get(feat1Property);
float importanceFeat2Prop = featureImportances.get(feat2Property);
# load 'component_graph' with vertex features 'feat1' and 'feat2'
feat1_property = component_graph.get_vertex_property("feat1")
feat2_property = component_graph.get_vertex_property("feat2")
# build and train an Unsupervised GraphWise model as explained in Advanced Hyperparameter Customization
# obtain and configure the explainer
# setting the num_clusters argument to the expected number of clusters may improve
# explanation results as the explainer optimization will try to cluster samples into
# this number of clusters
explainer = model.gnn_explainer(num_clusters=50)
# set the number of samples to compute the loss over during explainer optimization
explainer.num_samples = 10000
# explain prediction of vertex 0
explanation = explainer.infer_and_explain(
graph=component_graph,
vertex=component_graph.get_vertex(0)
)
# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()
# retrieve importance of vertices
# vertex 1 is in the same densely connected component as vertex 0
# vertex 2 is in a different component
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0] # has importance 1
importance_vertex_1 = importance_property[1] # high importance
importance_vertex_2 = importance_property[2] # low importance
# retrieve feature importance (not relevant for this example)
feature_importances = explanation.get_vertex_feature_importance()
importance_feat1_prop = feature_importances[feat1_property]
importance_feat2_prop = feature_importances[feat2_property]