PGX 20.2.2



SupervisedGraphWise is an inductive vertex representation learning algorithm which is able to leverage vertex feature information. It can be applied to a wide variety of tasks, including vertex classification and link prediction.

SupervisedGraphWise is based on GraphSAGE by Hamilton et al.

Model Structure

A SupervisedGraphWise model consists of two graph convolutional layers followed by several prediction layers.

The forward pass through a convolutional layer for a vertex proceeds as follows:

  1. A set of neighbors of the vertex is sampled.

  2. The previous layer representations of the neighbors are mean-aggregated, and the aggregated features are concatenated with the previous layer representation of the vertex.

  3. This concatenated vector is multiplied with weights, and a bias vector is added.

  4. The result is normalized to such that the layer output has unit norm.

The prediction layers are standard neural network layers.


We describe here the usage of the main functionalities of our implementation of GraphSAGE in PGX using the Cora graph as an example.

Loading a Graph

First, we create a session and an analyst:

import oracle.pgx.config.mllib.ActivationFunction;
import oracle.pgx.config.mllib.WeightInitScheme;

PgxSession session = Pgx.createSession("my-session");
Analyst analyst = session.createAnalyst();
import oracle.pgx.api.*;
import oracle.pgx.api.beta.mllib.SupervisedGraphWiseModel;
import oracle.pgx.api.beta.frames.*;
import oracle.pgx.config.mllib.ActivationFunction;
import oracle.pgx.config.mllib.GraphWiseConvLayerConfig;
import oracle.pgx.config.mllib.GraphWisePredictionLayerConfig;
import oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig;
import oracle.pgx.config.mllib.WeightInitScheme;
pgx> var fullGraph = session.readGraphWithProperties("<path>/cora_full.json")
pgx> var trainGraph = session.readGraphWithProperties("<path>/cora_train.json")
pgx> var testVertices = fullGraph.getVertices().
         filter(v -> !trainGraph.hasVertex(v.getId())).
PgxGraph fullGraph = session.readGraphWithProperties("<path>/cora_full.json");
PgxGraph trainGraph = session.readGraphWithProperties("<path>/cora_train.json");
List<PgxVertex> testVertices = fullGraph.getVertices()
    .filter(v -> !trainGraph.hasVertex(v.getId()))

Building a GraphWise Model (minimal)

We build a GraphWise model using the minimal configuration and default hyper-parameters. Note that even though only one feature property is specified in this example, you can specify arbitrarily many.

pgx> var model = analyst.supervisedGraphWiseModelBuilder().
SupervisedGraphWiseModel model = analyst.supervisedGraphWiseModelBuilder()

Advanced Hyperparameter Customization

The implementation allows for very rich hyperparameter customization. This is done through two sub-config classes: GraphWiseConvLayerConfig and GraphWisePredictionLayerConfig. In the following, we build such configurations and use them in a model.

pgx> var weightProperty = analyst.pagerank(trainGraph).getName()
pgx> var convLayerConfig = analyst.graphWiseConvLayerConfigBuilder().
pgx> var predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder().
pgx> var model = analyst.supervisedGraphWiseModelBuilder().
String weightProperty = analyst.pagerank(trainGraph).getName()
GraphWiseConvLayerConfig convLayerConfig = analyst.graphWiseConvLayerConfigBuilder()

GraphWisePredictionLayerConfig predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder()

SupervisedGraphWiseModel model = analyst.supervisedGraphWiseModelBuilder()

For a full description of all available hyperparameters and their default values, see the SupervisedGraphWiseModelBuilder, GraphWiseConvLayerConfigBuilder and GraphWisePredictionLayerConfigBuilder javadocs.

Training the SupervisedGraphWiseModel

We can train a SupervisedGraphWiseModel model on a graph:


Getting Loss Value

We can fetch the training loss value:

pgx> var loss = model.getTrainingLoss()
double loss = model.getTrainingLoss();

Inferring Vertex Labels

We can infer the labels for vertices on any graph (including vertices or graphs that were not seen during training):

pgx> var labels = model.inferLabels(fullGraph, testVertices)
pgx> labels.head().print()
PgxFrame labels = model.inferLabels(fullGraph, testVertices);

The output will be similar to the following example output:

| vertexId | label                 |
| 2        | Neural Networks       |
| 6        | Theory                |
| 7        | Case Based            |
| 22       | Rule Learning         |
| 30       | Theory                |
| 34       | Neural Networks       |
| 47       | Case Based            |
| 48       | Probabalistic Methods |
| 50       | Theory                |
| 52       | Theory                |

In a similar fashion, you can get the model confidence for each class by inferring the prediction logits:

pgx> var logits = model.inferLogits(fullGraph, testVertices)
pgx> logits.head().print()
PgxFrame logits = model.inferLogits(fullGraph, testVertices);

Evaluating Model Performance

evaluateLabels is a convenience method to evaluate various classification metrics for the model:

pgx> model.evaluateLabels(fullGraph, testVertices).print()
model.evaluateLabels(fullGraph, testVertices).print();

The output will be similar to the following example output:

| Accuracy | Precision | Recall | F1-Score |
| 0.8488   | 0.8523    | 0.831  | 0.8367   |

Inferring Embeddings

We can use a trained model to infer embeddings for unseen nodes and store in a CSV file:

pgx> var vertexVectors = model.inferEmbeddings(fullGraph, fullGraph.getVertices()).flattenAll()
pgx> vertexVectors.write().
PgxFrame vertexVectors = model.inferEmbeddings(fullGraph, fullGraph.getVertices()).flattenAll();

The schema for the vertexVectors would be as follows without flattening (flattenAll splits the vector column into separate double-valued columns):

| vertexId                                | embedding           |

Storing a Trained Model

Store a trained SupervisedGraphWise encrypted model to a specified path:


Loading a Pre-trained Model

We can load a pre-trained SupervisedGraphWise encrypted model from a specified path:

pgx> var model = analyst.loadSupervisedGraphWiseModel("<path>/<modelName>","encryption_key")
SupervisedGraphWiseModel model = analyst.loadSupervisedGraphWiseModel("<path>/<modelName>","encryption_key");

Destroying a Model

We can destroy a model as follows:

pgx> model.destroy()