8.3.10 Explaining a Prediction for an Unsupervised GraphWise Model

In order to understand which features and vertices are important for a prediction of the Unsupervised GraphWise model, you can generate an UnsupervisedGnnExplanation using a technique similar to the GNNExplainer by Ying et al.

The explanation holds information related to:

  • graph structure: an importance score for each vertex
  • features: an importance score for each graph property

Note:

The vertex being explained is always assigned importance 1. Further, the feature importances are scaled such that the most important feature has importance 1.

Additionally, an UnsupervisedGnnExplanation contains the inferred embedding. The inferAndGetExplanation method can be used on all fitted UnsupervisedGraphWiseModel models. In order to achieve best results, the features should be centered around 0.

For example, assume a simple graph, componentGraph which contains k densely connect components, that is, there are many edges between vertices of the same component and few edges between any two components. By training an Unsupervised GraphWise model on this graph, you can expect a model that produces similar embeddings for vertices in a densely connected component.

The following example shows how to generate an explanation on an inference componentGraph. It is expected that vertices from the same component to have a higher importance than vertices from a different component. Note that the feature importances are not relevant in this example.

opg4j> var componentGraph = session.readGraphWithProperties("<path_to_component_graph.json>")
// explain prediction of vertex 0
opg4j> var feat1Property = componentGraph.getVertexProperty("feat1")
opg4j> var feat2Property = componentGraph.getVertexProperty("feat2")

// build and train an Unsupervised GraphWise model 

// explain prediction of vertex 0
// setting the numClusters argument to the expected number of clusters may improve
// explanation results
opg4j> var explanation = model.inferAndGetExplanation(componentGraph, componentGraph.getVertex(0), 10)

// retrieve computation graph with importance
opg4j> var importanceGraph = explanation.getImportanceGraph()

// retrieve importance of vertices
// vertex 1 is in the same densely connected component as vertex 0
// vertex 2 is in a different component
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0)  // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1)  // high importance
opg4j> var importanceVertex2 = importanceProperty.get(2)  // low importance

opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important

// optionally retrieve feature importance
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceFeat1Prop = featureImportances.get(feat1Property)
opg4j> var importanceFeat2Prop = featureImportances.get(feat2Property)
    
PgxGraph componentGraph = session.readGraphWithProperties("<path_to_component_graph.json>") // load component graph
VertexProperty<Integer, Float> feat1Property = componentGraph.getVertexProperty("feat1");
VertexProperty<Integer, Float> feat2Property = componentGraph.getVertexProperty("feat2");

// build and train an Unsupervised GraphWise model

// explain prediction of vertex 0
// setting the numClusters argument to the expected number of clusters may improve
// explanation results
UnsupervisedGnnExplanation<Integer> explanation = model.inferAndGetExplanation(componentGraph, componentGraph.getVertex(0), 10);

// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();

// retrieve importance of vertices
// vertex 1 is in the same densely connected component as vertex 0
// vertex 2 is in a different component
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0);  // has importance 1
float importanceVertex1 = importanceProperty.get(1);  // high importance
float importanceVertex2 = importanceProperty.get(2);  // low importance

// retrieve feature importance (not relevant for this example)
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceFeat1Prop = featureImportances.get(feat1Property);
float importanceFeat2Prop = featureImportances.get(feat2Property);
# load 'component_graph' with vertex features 'feat1' and 'feat2'
feat1_property = component_graph.get_vertex_property("feat1")
feat2_property = component_graph.get_vertex_property("feat2")

# build and train an Unsupervised GraphWise model

# explain prediction of vertex 0
# setting the num_clusters argument to the expected number of clusters may improve 
# explanation results
explanation = model.infer_and_get_explanation(
    graph=component_graph,
    vertex=component_graph.get_vertex(0),
    num_clusters=10,
)

# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()

# retrieve importance of vertices
# vertex 1 is in the same densely connected component as vertex 0
# vertex 2 is in a different component
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0]  # has importance 1
importance_vertex_1 = importance_property[1]  # high importance
importance_vertex_2 = importance_property[2]  # low importance

# retrieve feature importance (not relevant for this example)
feature_importances = explanation.get_vertex_feature_importance()
importance_feat1_prop = feature_importances[feat1_property]
importance_feat2_prop = feature_importances[feat2_property]