8.2.14 Explaining a Prediction of a Supervised GraphWise Model
In order to understand which features and vertices are important for a prediction
of the Supervised GraphWise model, you can generate a
SupervisedGnnExplanation
using a technique similar
to the GNNExplainer
by Ying et al.
The explanation holds information related to:
- Graph structure: An importance score for each vertex
- Features: An importance score for each graph property
Note:
The vertex being explained is always assigned importance 1. Further, the feature importances are scaled such that the most important feature has importance 1.Additionally, an SupervisedGnnExplanation
contains
the inferred embeddings, logits, and label. You can get explanations for a
model's predictions by using the SupervisedGnnExplainer
object. The object can be obtained using the gnnExplainer
method. After obtaining the SupervisedGnnExplainer
object,
you can use the inferAndExplain
method to request an
explanation for a vertex.
The parameters of the explainer can be configured while the explainer
is being created or afterwards using the relevant setter functions. The
configurable parameters for the SupervisedGnnExplainer
are
as follows:
numOptimizationSteps:
Number of optimization steps used by the explainer.learningRate:
Learning rate of the explainer.marginalize:
Determines if the explainer loss is marginalized over features. This can help in cases where there are important features that take values close to zero. Without marginalization the explainer can learn to mask such features out even if they are important. Marginalization solves this by learning a mask for the deviation from the estimated input distribution.
Note that, in order to achieve best results, the features should be centered around 0.
For example, assume a simple graph that contains a feature that correlates with the label and another feature that does not. It is therefore expected that the importance of the features to differ significantly (with the feature correlating with the label being more important), while structural importance does not play a big role. In this case, you can generate an explanation as shown:
opg4j> var simpleGraph = session.createGraphBuilder().
addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5).
setProperty("label", true).
addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5).
setProperty("label", false).
addEdge(0, 1).build()
// build and train a Supervised GraphWise model as explained in Advanced Hyperparameter Customization
// obtain and configure GnnExplainer
var explainer = model.gnnExplainer().learningRate(0.05)
explainer.numOptimizationSteps(200)
// explain prediction of vertex 0
opg4j> var explanation = explainer.inferAndExplain(simpleGraph, simpleGraph.getVertex(0))
// if you used the devNet loss, you can add the decision threshold as an extra parameter:
// var explanation = explainer.inferAndExplain(simpleGraph, simpleGraph.getVertex(0), 6f)
opg4j> var constProperty = simpleGraph.getVertexProperty("const_feature")
opg4j> var labelProperty = simpleGraph.getVertexProperty("label_feature")
// retrieve feature importances
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important
// retrieve computation graph with importances
opg4j> var importanceGraph = explanation.getImportanceGraph()
// retrieve importance of vertices
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0) // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1) // available if vertex 1 part of computation
PgxGraph simpleGraph = session.createGraphBuilder()
.addVertex(0).setProperty("label_feature", 0.5).setProperty("const_feature", 0.5)
.setProperty("label", true)
.addVertex(1).setProperty("label_feature", -0.5).setProperty("const_feature", 0.5)
.setProperty("label", false)
.addEdge(0, 1).build();
// build and train a Supervised GraphWise model as explained in Advanced Hyperparameter Customization
// obtain and configure the explainer
SupervisedGnnExplainerexplainer=model.gnnExplainer().learningRate(0.05);
explainer.numOptimizationSteps(200);
// explain prediction of vertex 0
SupervisedGnnExplanation<Integer> explanation = explainer.inferAndExplain(simpleGraph,
simpleGraph.getVertex(0));
// if we used the devNet loss, we can add the decision threshold as an extra parameter:
// SupervisedGnnExplanation<Integer> explanation = explainer.inferAndExplain(simpleGraph, simpleGraph.getVertex(0), 6f);
VertexProperty<Integer, Float> constProperty = simpleGraph.getVertexProperty("const_feature");
VertexProperty<Integer, Float> labelProperty = simpleGraph.getVertexProperty("label_feature");
// retrieve feature importances
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceConstProp = featureImportances.get(constProperty); // small as unimportant
float importanceLabelProp = featureImportances.get(labelProperty); // large (1) as important
// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();
// retrieve importance of vertices
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0); // has importance 1
float importanceVertex1 = importanceProperty.get(1); // available if vertex 1 part of computation
simple_graph = session.create_graph_builder()
.add_vertex(0).set_property("label_feature", 0.5).set_property("const_feature", 0.5)
.set_property("label", true)
.add_vertex(1).set_property("label_feature", -0.5).set_property("const_feature", 0.5)
.set_property("label", false)
.add_edge(0, 1).build()
# build and train a Supervised GraphWise model as explained in Advanced Hyperparameter Customization
# obtain the explainer
explainer = model.gnn_explainer(learning_rate=0.05)
explainer.num_optimization_steps=200
# explain prediction of vertex 0
explanation = explainer.inferAndExplain(simple_graph,simple_graph.get_vertex(0))
# if we used the devNet loss, we can add the decision threshold as an extra parameter:
# explanation = explainer.inferAndExplain(simple_graph, simple_graph.get_vertex(0), 6)
const_property = simple_graph.get_vertex_property("const_feature")
label_property = simple_graph.get_vertex_property("label_feature")
# retrieve feature importances
feature_importances = explanation.get_vertex_feature_importance()
importance_const_prop = feature_importances[const_property]
importance_label_prop = feature_importances[label_property]
# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()
# retrieve importance of vertices
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0]
importance_vertex_1 = importance_property[1]
Parent topic: Using the Supervised GraphWise Algorithm