8.3.10 Explaining a Prediction for an Unsupervised GraphWise Model
In order to understand which features and vertices are
important for a prediction of the Unsupervised GraphWise model, you
can generate an UnsupervisedGnnExplanation
using a
technique similar to the GNNExplainer
by Ying et
al.
The explanation holds information related to:
- graph structure: an importance score for each vertex
- features: an importance score for each graph property
Note:
The vertex being explained is always assigned importance 1. Further, the feature importances are scaled such that the most important feature has importance 1.Additionally, an UnsupervisedGnnExplanation
contains the inferred embedding. The
inferAndGetExplanation
method can be used
on all fitted UnsupervisedGraphWiseModel
models. In
order to achieve best results, the features should be centered
around 0.
For example, assume a simple graph,
componentGraph
which contains
k
densely connect components, that
is, there are many edges between vertices of the same component and
few edges between any two components. By training an Unsupervised
GraphWise model on this graph, you can expect a model that produces
similar embeddings for vertices in a densely connected
component.
The following example shows how to generate an explanation on an
inference componentGraph
. It is expected that
vertices from the same component to have a higher importance than
vertices from a different component. Note that the feature
importances are not relevant in this example.
opg4j> var componentGraph = session.readGraphWithProperties("<path_to_component_graph.json>")
// explain prediction of vertex 0
opg4j> var feat1Property = componentGraph.getVertexProperty("feat1")
opg4j> var feat2Property = componentGraph.getVertexProperty("feat2")
// build and train an Unsupervised GraphWise model
// explain prediction of vertex 0
// setting the numClusters argument to the expected number of clusters may improve
// explanation results
opg4j> var explanation = model.inferAndGetExplanation(componentGraph, componentGraph.getVertex(0), 10)
// retrieve computation graph with importance
opg4j> var importanceGraph = explanation.getImportanceGraph()
// retrieve importance of vertices
// vertex 1 is in the same densely connected component as vertex 0
// vertex 2 is in a different component
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0) // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1) // high importance
opg4j> var importanceVertex2 = importanceProperty.get(2) // low importance
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important
// optionally retrieve feature importance
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceFeat1Prop = featureImportances.get(feat1Property)
opg4j> var importanceFeat2Prop = featureImportances.get(feat2Property)
PgxGraph componentGraph = session.readGraphWithProperties("<path_to_component_graph.json>") // load component graph
VertexProperty<Integer, Float> feat1Property = componentGraph.getVertexProperty("feat1");
VertexProperty<Integer, Float> feat2Property = componentGraph.getVertexProperty("feat2");
// build and train an Unsupervised GraphWise model
// explain prediction of vertex 0
// setting the numClusters argument to the expected number of clusters may improve
// explanation results
UnsupervisedGnnExplanation<Integer> explanation = model.inferAndGetExplanation(componentGraph, componentGraph.getVertex(0), 10);
// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();
// retrieve importance of vertices
// vertex 1 is in the same densely connected component as vertex 0
// vertex 2 is in a different component
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0); // has importance 1
float importanceVertex1 = importanceProperty.get(1); // high importance
float importanceVertex2 = importanceProperty.get(2); // low importance
// retrieve feature importance (not relevant for this example)
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceFeat1Prop = featureImportances.get(feat1Property);
float importanceFeat2Prop = featureImportances.get(feat2Property);
# load 'component_graph' with vertex features 'feat1' and 'feat2'
feat1_property = component_graph.get_vertex_property("feat1")
feat2_property = component_graph.get_vertex_property("feat2")
# build and train an Unsupervised GraphWise model
# explain prediction of vertex 0
# setting the num_clusters argument to the expected number of clusters may improve
# explanation results
explanation = model.infer_and_get_explanation(
graph=component_graph,
vertex=component_graph.get_vertex(0),
num_clusters=10,
)
# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()
# retrieve importance of vertices
# vertex 1 is in the same densely connected component as vertex 0
# vertex 2 is in a different component
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0] # has importance 1
importance_vertex_1 = importance_property[1] # high importance
importance_vertex_2 = importance_property[2] # low importance
# retrieve feature importance (not relevant for this example)
feature_importances = explanation.get_vertex_feature_importance()
importance_feat1_prop = feature_importances[feat1_property]
importance_feat2_prop = feature_importances[feat2_property]
Parent topic: Using the Unsupervised GraphWise Algorithm