PGX 21.1.1



SupervisedGraphWise is an inductive vertex representation learning algorithm which is able to leverage vertex feature information. It can be applied to a wide variety of tasks, including vertex classification and link prediction.

SupervisedGraphWise is based on GraphSAGE by Hamilton et al.

Model Structure

A SupervisedGraphWise model consists of two graph convolutional layers followed by several prediction layers.

The forward pass through a convolutional layer for a vertex proceeds as follows:

  1. A set of neighbors of the vertex is sampled.

  2. The previous layer representations of the neighbors are mean-aggregated, and the aggregated features are concatenated with the previous layer representation of the vertex.

  3. This concatenated vector is multiplied with weights, and a bias vector is added.

  4. The result is normalized to such that the layer output has unit norm.

The prediction layers are standard neural network layers.


We describe here the usage of the main functionalities of our implementation of GraphSAGE in PGX using the Cora graph as an example.

Loading a Graph

First, we create a session and an analyst:

import oracle.pgx.config.mllib.ActivationFunction;
import oracle.pgx.config.mllib.WeightInitScheme;

PgxSession session = Pgx.createSession("my-session");
Analyst analyst = session.createAnalyst();
import oracle.pgx.api.*;
import oracle.pgx.api.mllib.SupervisedGraphWiseModel;
import oracle.pgx.api.frames.*;
import oracle.pgx.config.mllib.ActivationFunction;
import oracle.pgx.config.mllib.GraphWiseConvLayerConfig;
import oracle.pgx.config.mllib.GraphWisePredictionLayerConfig;
import oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig;
import oracle.pgx.config.mllib.WeightInitScheme;
pgx> var fullGraph = session.readGraphWithProperties("<path>/cora_full.json")
pgx> var trainGraph = session.readGraphWithProperties("<path>/cora_train.json")
pgx> var testVertices = fullGraph.getVertices().
         filter(v -> !trainGraph.hasVertex(v.getId())).
PgxGraph fullGraph = session.readGraphWithProperties("<path>/cora_full.json");
PgxGraph trainGraph = session.readGraphWithProperties("<path>/cora_train.json");
List<PgxVertex> testVertices = fullGraph.getVertices()
    .filter(v -> !trainGraph.hasVertex(v.getId()))

Building a GraphWise Model (minimal)

We build a GraphWise model using the minimal configuration and default hyper-parameters. Note that even though only one feature property is specified in this example, you can specify arbitrarily many.

pgx> var model = analyst.supervisedGraphWiseModelBuilder().
SupervisedGraphWiseModel model = analyst.supervisedGraphWiseModelBuilder()

Advanced Hyperparameter Customization

The implementation allows for very rich hyperparameter customization. This is done through two sub-config classes: GraphWiseConvLayerConfig and GraphWisePredictionLayerConfig. In the following, we build such configurations and use them in a model.

pgx> var weightProperty = analyst.pagerank(trainGraph).getName()
pgx> var convLayerConfig = analyst.graphWiseConvLayerConfigBuilder().
pgx> var predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder().
pgx> var model = analyst.supervisedGraphWiseModelBuilder().
String weightProperty = analyst.pagerank(trainGraph).getName()
GraphWiseConvLayerConfig convLayerConfig = analyst.graphWiseConvLayerConfigBuilder()

GraphWisePredictionLayerConfig predictionLayerConfig = analyst.graphWisePredictionLayerConfigBuilder()

SupervisedGraphWiseModel model = analyst.supervisedGraphWiseModelBuilder()

For a full description of all available hyperparameters and their default values, see the SupervisedGraphWiseModelBuilder, GraphWiseConvLayerConfigBuilder and GraphWisePredictionLayerConfigBuilder javadocs.

Training the SupervisedGraphWiseModel

We can train a SupervisedGraphWiseModel model on a graph:


Getting Loss Value

We can fetch the training loss value:

pgx> var loss = model.getTrainingLoss()
double loss = model.getTrainingLoss();

Inferring Vertex Labels

We can infer the labels for vertices on any graph (including vertices or graphs that were not seen during training):

pgx> var labels = model.inferLabels(fullGraph, testVertices)
pgx> labels.head().print()
PgxFrame labels = model.inferLabels(fullGraph, testVertices);

The output will be similar to the following example output:

| vertexId | label                 |
| 2        | Neural Networks       |
| 6        | Theory                |
| 7        | Case Based            |
| 22       | Rule Learning         |
| 30       | Theory                |
| 34       | Neural Networks       |
| 47       | Case Based            |
| 48       | Probabalistic Methods |
| 50       | Theory                |
| 52       | Theory                |

In a similar fashion, you can get the model confidence for each class by inferring the prediction logits:

pgx> var logits = model.inferLogits(fullGraph, testVertices)
pgx> logits.head().print()
PgxFrame logits = model.inferLogits(fullGraph, testVertices);

Evaluating Model Performance

evaluateLabels is a convenience method to evaluate various classification metrics for the model:

pgx> model.evaluateLabels(fullGraph, testVertices).print()
model.evaluateLabels(fullGraph, testVertices).print();

The output will be similar to the following example output:

| Accuracy | Precision | Recall | F1-Score |
| 0.8488   | 0.8523    | 0.831  | 0.8367   |

Inferring Embeddings

We can use a trained model to infer embeddings for unseen nodes and store in a CSV file:

pgx> var vertexVectors = model.inferEmbeddings(fullGraph, fullGraph.getVertices()).flattenAll()
pgx> vertexVectors.write().
PgxFrame vertexVectors = model.inferEmbeddings(fullGraph, fullGraph.getVertices()).flattenAll();

The schema for the vertexVectors would be as follows without flattening (flattenAll splits the vector column into separate double-valued columns):

| vertexId                                | embedding           |

Storing a Trained Model

Models can be stored either to the server file system, or to a database.

The following shows how to store a trained SupervisedGraphWise model to a specified file path:

pgx> model.export().file().path("<path>/<modelName>").store()

When storing models in database, they are stored as a row inside a model store table. The following shows how to store a trained SupervisedGraphWise model in database in a specific model store table:

pgx> model.export().db(). //
       username("user"). // DB user to use for storing the model
       password("password"). // password of the DB user
       jdbcUrl("jdbcUrl"). // jdbc url to the DB
       modelstore("modelstoretablename"). // name of the model store table
       modelname("model"). // name to give to the model (primary key of model store table)
       description("a model description"). // description to store alongside the model
model.export().db() //
       .username("user") // DB user to use for storing the model
       .password("password") // password of the DB user
       .jdbcUrl("jdbcUrl") // jdbc url to the DB
       .modelstore("modelstoretablename") // name of the model store table
       .modelname("model") // name to give to the model (primary key of model store table)
       .description("a model description") // description to store alongside the model

Loading a Pre-trained Model

Similarly to storing, models can be loaded from a file in the server file system, or from a database.

We can load a pre-trained SupervisedGraphWise model from a specified file path as follows:

pgx> var model = analyst.loadSupervisedGraphWiseModel().file().path("<path>/<modelName>").load()
SupervisedGraphWiseModel model = analyst.loadSupervisedGraphWiseModel().file().path("<path>/<modelName>").load();

We can load a pre-trained SupervisedGraphWise model from a model store table in database as follows:

pgx> var model = analyst.loadPg2vecModel().db().
     username("user"). //
     password("password"). //
     jdbcUrl("jdbcUrl"). //
     modelstore("modeltablename"). //
     modelname("model"). //
DeepWalkModel model = analyst.loadPg2vecModel().db()
    .username("user") //
    .password("password") //
    .jdbcUrl("jdbcUrl") //
    .modelstore("modeltablename") //
    .modelname("model") //

Destroying a Model

We can destroy a model as follows:

pgx> model.destroy()