17.4.13 Explaining a Prediction for an Unsupervised GraphWise Model
In order to understand which features and vertices are
important for a prediction of the Unsupervised GraphWise model, you
can generate an UnsupervisedGnnExplanation
using a
technique similar to the GNNExplainer
by Ying et
al.
The explanation holds information related to:
- Graph structure: An importance score for each vertex
- Features: An importance score for each graph property
Note:
The vertex being explained is always assigned importance 1. Further, the feature importances are scaled such that the most important feature has importance 1.Additionally, an UnsupervisedGnnExplanation
contains the inferred embedding. You can get explanations for a
model's predictions by using the
UnsupervisedGnnExplainer
object. The
object can be obtained using the gnnExplainer
method. After obtaining the
UnsupervisedGnnExplainer
object, you can
use the inferAndExplain
method to request an
explanation for a vertex.
The parameters of the explainer can be configured while the explainer is
being created or afterwards using the relevant setter functions. The
configurable parameters for the
UnsupervisedGnnExplainer
are as
follows:
numOptimizationSteps:
Number of optimization steps used by the explainer.learningRate:
Learning rate of the explainer.marginalize:
Determines if the explainer loss is marginalized over features. This can help in cases where there are important features that take values close to zero. Without marginalization the explainer can learn to mask such features out even if they are important. Marginalization solves this by learning a mask for the deviation from the estimated input distribution.numClusters:
Number of clusters to use in the explainer loss. The unsupervised explainer uses k-means clustering to compute the explainer loss that is optimized. If the approximate number of components in the graph is known, it is a good idea to set the number of clusters to this number.numSamples:
Number of vertex samples to use to optimize the explainer. For the sake of performance, the explainer computes the loss on this number of randomly sampled vertices. Using more samples will be more accurate but will take longer and use more resources.
Note that, in order to achieve best results, the features should be centered around 0.
For example, assume a simple graph,
componentGraph
which contains
k
densely connect components, that
is, there are many edges between vertices of the same component and
few edges between any two components. By training an Unsupervised
GraphWise model on this graph, you can expect a model that produces
similar embeddings for vertices in a densely connected
component.
The following example shows how to generate an explanation on an
inference componentGraph
. It is expected that
vertices from the same component to have a higher importance than
vertices from a different component. Note that the feature
importances are not relevant in this example.
opg4j> var componentGraph = session.readGraphByName("<graph>",GraphSource.PG_PGQL)
// explain prediction of vertex 0
opg4j> var feat1Property = componentGraph.getVertexProperty("feat1")
opg4j> var feat2Property = componentGraph.getVertexProperty("feat2")
// build and train an Unsupervised GraphWise model as explained in Advanced Hyperparameter Customization
// obtain and configure the explainer
// setting the numClusters argument to the expected number of clusters may improve
// explanation results as the explainer optimization will try to cluster samples into
// this number of clusters
opg4j> var explainer = model.gnnExplainer().numClusters(50)
// set the number of samples to compute the loss over during explainer optimization
opg4j> explainer.numSamples(10000)
// explain prediction of vertex 0
opg4j> var explanation = explainer.inferAndExplain(componentGraph, componentGraph.getVertex(0), 10)
// retrieve computation graph with importance
opg4j> var importanceGraph = explanation.getImportanceGraph()
// retrieve importance of vertices
// vertex 1 is in the same densely connected component as vertex 0
// vertex 2 is in a different component
opg4j> var importanceProperty = explanation.getVertexImportanceProperty()
opg4j> var importanceVertex0 = importanceProperty.get(0) // has importance 1
opg4j> var importanceVertex1 = importanceProperty.get(1) // high importance
opg4j> var importanceVertex2 = importanceProperty.get(2) // low importance
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceConstProp = featureImportances.get(constProperty) // small as unimportant
opg4j> var importanceLabelProp = featureImportances.get(labelProperty) // large (1) as important
// optionally retrieve feature importance
opg4j> var featureImportances = explanation.getVertexFeatureImportance()
opg4j> var importanceFeat1Prop = featureImportances.get(feat1Property)
opg4j> var importanceFeat2Prop = featureImportances.get(feat2Property)
PgxGraph componentGraph = session.readGraphByName("<graph>",GraphSource.PG_PGQL); // load graph
VertexProperty<Integer, Float> feat1Property = componentGraph.getVertexProperty("feat1");
VertexProperty<Integer, Float> feat2Property = componentGraph.getVertexProperty("feat2");
// build and train an Unsupervised GraphWise model as explained in Advanced Hyperparameter Customization
// obtain and configure the explainer
// setting the numClusters argument to the expected number of clusters may improve
// explanation results as the explainer optimization will try to cluster samples into
// this number of clusters
UnsupervisedGnnExplainer explainer = model.gnnExplainer().numClusters(50);
// set the number of samples to compute the loss over during explainer optimization
explainer.numSamples(10000);
// explain prediction of vertex 0
UnsupervisedGnnExplanation<Integer> explanation = explainer.inferAndExplain(componentGraph, componentGraph.getVertex(0));
// retrieve computation graph with importances
PgxGraph importanceGraph = explanation.getImportanceGraph();
// retrieve importance of vertices
// vertex 1 is in the same densely connected component as vertex 0
// vertex 2 is in a different component
VertexProperty<Integer, Float> importanceProperty = explanation.getVertexImportanceProperty();
float importanceVertex0 = importanceProperty.get(0); // has importance 1
float importanceVertex1 = importanceProperty.get(1); // high importance
float importanceVertex2 = importanceProperty.get(2); // low importance
// retrieve feature importance (not relevant for this example)
Map<VertexProperty<Integer, ?>, Float> featureImportances = explanation.getVertexFeatureImportance();
float importanceFeat1Prop = featureImportances.get(feat1Property);
float importanceFeat2Prop = featureImportances.get(feat2Property);
# load 'component_graph' with vertex features 'feat1' and 'feat2'
feat1_property = component_graph.get_vertex_property("feat1")
feat2_property = component_graph.get_vertex_property("feat2")
# build and train an Unsupervised GraphWise model as explained in Advanced Hyperparameter Customization
# obtain and configure the explainer
# setting the num_clusters argument to the expected number of clusters may improve
# explanation results as the explainer optimization will try to cluster samples into
# this number of clusters
explainer = model.gnn_explainer(num_clusters=50)
# set the number of samples to compute the loss over during explainer optimization
explainer.num_samples = 10000
# explain prediction of vertex 0
explanation = explainer.infer_and_explain(
graph=component_graph,
vertex=component_graph.get_vertex(0)
)
# retrieve computation graph with importances
importance_graph = explanation.get_importance_graph()
# retrieve importance of vertices
# vertex 1 is in the same densely connected component as vertex 0
# vertex 2 is in a different component
importance_property = explanation.get_vertex_importance_property()
importance_vertex_0 = importance_property[0] # has importance 1
importance_vertex_1 = importance_property[1] # high importance
importance_vertex_2 = importance_property[2] # low importance
# retrieve feature importance (not relevant for this example)
feature_importances = explanation.get_vertex_feature_importance()
importance_feat1_prop = feature_importances[feat1_property]
importance_feat2_prop = feature_importances[feat2_property]