8.2.13 Explaining a Prediction of a Supervised GraphWise Model

In order to understand which features and vertices are important for a prediction of the Supervised GraphWise model, you can generate a SupervisedGnnExplanation using a technique similar to the GNNExplainer by Ying et al.

The explanation holds information related to:

  • graph structure: an importance score for each vertex
  • features: an importance score for each graph property

Note:

The vertex being explained is always assigned importance 1. Further, the feature importances are scaled such that the most important feature has importance 1.

Additionally, a SupervisedGnnExplanation contains the inferred embedding, logits, and label. The inferAndGetExplanation method can be used on all fitted SupervisedGraphWiseModel models that do not rely on edge features. In order to achieve best results, the features should be centered around 0.

For example, assume a simple graph that contains a feature that correlates with the label and another feature that does not. It is therefore expected that the importance of the features to differ significantly (with the feature correlating with the label being more important), while structural importance does not play a big role. In this case, you can generate an explanation as shown:

opg4j> var simpleGraph = session.createGraphBuilder().
                          addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5).
                          setProperty("label", true).
                          addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5).
                          setProperty("label", false).
                          addEdge(0, 1).build()

// build and train a Supervised GraphWise model

// explain prediction of vertex 0
opg4j> var explanation = model.inferAndGetExplanation(simpleGraph, simpleGraph.getVertex(0))
opg4j> var constProperty = simpleGraph.getVertexProperty("const_feature")
opg4j> var labelProperty = simpleGraph.getVertexProperty("label_feature")

// retrieve feature importances
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important

// retrieve computation graph with importances
opg4j> var importanceGraph = explanation.getImportanceGraph()

// retrieve importance of vertices
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0) // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1) // available if vertex 1 part of computation
    
PgxGraph simpleGraph = session.createGraphBuilder()
    .addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5)
    .setProperty("label", true)
    .addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5)
    .setProperty("label", false)
    .addEdge(0, 1).build();

// build and train a Supervised GraphWise model

// explain prediction of vertex 0
SupervisedGnnExplanation<Integer> explanation = model.inferAndGetExplanation(simpleGraph,
    simpleGraph.getVertex(0));

VertexProperty<Integer, Float> constProperty = simpleGraph.getVertexProperty("const_feature");
VertexProperty<Integer, Float> labelProperty = simpleGraph.getVertexProperty("label_feature");

// retrieve feature importances
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceConstProp = featureImportances.get(constProperty); // small as unimportant
float importanceLabelProp = featureImportances.get(labelProperty); // large (1) as important

// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();

// retrieve importance of vertices
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0); // has importance 1
float importanceVertex1 = importanceProperty.get(1); // available if vertex 1 part of computation
simple_graph = session.create_graph_builder()
    .add_vertex(0).set_property("label_feature", 0.5).set_property("const_feature", 0.5)
    .set_property("label", true)
    .add_vertex(1).set_property("label_feature", -0.5).set_property("const_feature", 0.5)
    .set_property("label", false)
    .add_edge(0, 1).build()

# build and train a Supervised GraphWise model

# explain prediction of vertex 0
explanation = model.infer_and_get_explanation(simple_graph, simple_graph.get_vertex(0))

const_property = simple_graph.get_vertex_property("const_feature")
label_property = simple_graph.get_vertex_property("label_feature")

# retrieve feature importances
feature_importances = explanation.get_vertex_feature_importance()
importance_const_prop = feature_importances[const_property]
importance_label_prop = feature_importances[label_property]

# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()

# retrieve importance of vertices
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0]
importance_vertex_1 = importance_property[1]