PyPGX MLlib

Graph machine learning tools for use with PGX.

class pypgx.api.mllib.CorruptionFunction(java_corruption_function)

Bases: object

Abstract Corruption Function which generate the corrupted subgraph for DGI

class pypgx.api.mllib.DeepWalkModel(java_deepwalk_model)

Bases: pypgx.api._pgx_context_manager.PgxContextManager

DeepWalk model object.

close()

Call destroy

compute_similars(v, k)

Compute the top-k similar vertices for a given vertex.

Parameters
  • v – id of the vertex or list of vetex ids for which to compute the similar vertices

  • k – number of similars to return

destroy()

Destroy this model object.

export()

Return a ModelStore object which can be used to save the model.

Returns

ModelStore object

fit(graph)

Fit the model on a graph.

Parameters

graph – Graph to fit on

store(path, key, overwrite=False)

Store the model in a file.

Parameters
  • path – Path where to store the model

  • key – Encryption key

  • overwrite – Whether or not to overwrite pre-existing file

property trained_vectors

Get the trained vertex vectors for the current DeepWalk model.

Returns

PgxFrame object with the trained vertex vectors

class pypgx.api.mllib.GraphWiseConvLayerConfig(java_config, params)

Bases: object

GraphWise conv layer configuration.

class pypgx.api.mllib.GraphWiseDgiLayerConfig(java_config, params)

Bases: object

GraphWise dgi layer configuration.

get_corruption_function()

Return the corruption function

get_discriminator()

Return the discriminator

get_readout_function()

Return the readout function

set_corruption_function(corruption_function)

Set the corruption function :param corruption_function(CorruptionFunction): the corruption function

Supported currently: PermutationCorruption

set_discriminator(discriminator)

Set the discriminator :param discriminator(str): The discriminator function

Supported currently: ‘BILINEAR’

set_readout_function(readout_function)

Set the readout function :param readout_function(str): The readout function

Supported currently: ‘MEAN’

class pypgx.api.mllib.GraphWiseModelConfig(java_graphwise_model_config)

Bases: object

Graphwise Model Configuration class

get_conv_layer_configs()

Return a list of conv layer configs

set_batch_size(batch_size)

Set the batch size :param batch_size (int)

set_edge_input_feature_dim(edge_input_feature_dim)

Set the edge input feature dimension :param edge_input_feature_dim (int)

set_embedding_dim(embedding_dim)

Set the embedding dimension :param embedding_dim (int)

set_fitted(fitted)

Set the fitted flag :param fitted (boolean)

set_input_feature_dim(input_feature_dim)

Set the input feature dimension :param input_feature_dim (int)

set_learning_rate(learning_rate)

Set the learning rate :param learning rate (int)

set_num_epochs(num_epochs)

Set the number of epochs :param num_epochs (int)

set_seed(seed)

Set the seed :param seed (int)

set_shuffle(shuffle)

Set the shuffling flag :param shuffle (boolean)

set_standarize(standardize)

Set the standardize flag :param standardize (boolean)

set_training_loss(training_loss)

Set the training loss :param training_loss (float)

class pypgx.api.mllib.GraphWisePredictionLayerConfig(java_config, params)

Bases: object

GraphWise prediction layer configuration.

class pypgx.api.mllib.PermutationCorruption(java_permutation_corruption)

Bases: pypgx.api.mllib._corruption_function.CorruptionFunction

Permutation Function which shuffle the nodes to generate the corrupted subgraph for DGI

class pypgx.api.mllib.Pg2vecModel(java_pg2vec_model)

Bases: pypgx.api._pgx_context_manager.PgxContextManager

Pg2Vec model object.

close()

Call destroy

compute_similars(graphlet_id, k)

Compute the top-k similar graphlets for a list of input graphlets.

Parameters
  • graphlet_id – graphletIds or iterable of graphletIds

  • k – number of similars to return

destroy()

Destroy this model object.

export()

Return a ModelStore object which can be used to save the model.

Returns

ModelStore object

fit(graph)

Fit the model on a graph.

Parameters

graph – Graph to fit on

infer_graphlet_vector(graph)
Parameters

graph – graphlet for which to infer a vector

infer_graphlet_vector_batched(graph)
Parameters

graph – graphlets (as a single graph but different graphlet-id) for which to infer vectors

store(path, key, overwrite=False)

Store the model in a file.

Parameters
  • path – Path where to store the model

  • key – Encryption key

  • overwrite – Whether or not to overwrite pre-existing file

property trained_graphlet_vectors

Get the trained graphlet vectors for the current pg2vec model.

Returns

PgxFrame containing the trained graphlet vectors

class pypgx.api.mllib.SupervisedGraphWiseModel(java_graphwise_model, params={})

Bases: pypgx.api.mllib._graphwise_model.GraphWiseModel

SupervisedGraphWise model object.

evaluate_labels(graph, vertices)

Evaluate (macro averaged) classification performance statistics for the specified vertices.

Parameters
  • graph – the graph

  • vertices – the vertices to evaluate on

Returns

PgxFrame containing the metrics

export()

Return a ModelStore object which can be used to save the model.

Returns

ModelStore object

fit(graph)

Fit the model on a graph.

Parameters

graph – Graph to fit on

infer_and_get_explanation(graph, vertex)

Perform inference on the specified vertex and generate an explanation that contains scores of how important each property and each vertex in the computation graph is for the prediction.

Parameters
  • graph – the graph

  • vertex – the vertex

Returns

explanation containing feature importance and vertex importance.

infer_embeddings(graph, vertices)

Infer the embeddings for the specified vertices

Parameters
  • graph – the graph

  • vertices – the vertices to infer embeddings for

Returns

PgxFrame containing the embeddings for each vertex

infer_labels(graph, vertices)

Infer the labels for the specified vertices

Parameters
  • graph – the graph

  • vertices – the vertices to infer labels for

Returns

PgxFrame containing the labels for each vertex

infer_logits(graph, vertices)

Infer the prediction logits for the specified vertices

Parameters
  • graph – the graph

  • vertices – the vertices to infer logits for

Returns

PgxFrame containing the logits for each vertex

store(path, key, overwrite=False)

Store the model in a file.

Parameters
  • path – Path where to store the model

  • key – Encryption key

  • overwrite – Whether or not to overwrite pre-existing file

class pypgx.api.mllib.UnsupervisedGraphWiseModel(java_graphwise_model, params={})

Bases: pypgx.api.mllib._graphwise_model.GraphWiseModel

UnsupervisedGraphWise model object.

export()

Return a ModelStore object which can be used to save the model.

Returns

ModelStore object

fit(graph)

Fit the model on a graph.

Parameters

graph – Graph to fit on

store(path, key, overwrite=False)

Store the model in a file.

Parameters
  • path – Path where to store the model

  • key – Encryption key

  • overwrite – Whether or not to overwrite pre-existing file