PyPGX MLlib¶
Graph machine learning tools for use with PGX.
- class pypgx.api.mllib.CorruptionFunction(java_corruption_function)¶
Bases:
object
Abstract Corruption Function which generate the corrupted subgraph for DGI
- class pypgx.api.mllib.DeepWalkModel(java_deepwalk_model)¶
Bases:
pypgx.api._pgx_context_manager.PgxContextManager
DeepWalk model object.
- close()¶
Call destroy
- compute_similars(v, k)¶
Compute the top-k similar vertices for a given vertex.
- Parameters
v – id of the vertex or list of vetex ids for which to compute the similar vertices
k – number of similars to return
- destroy()¶
Destroy this model object.
- export()¶
Return a ModelStore object which can be used to save the model.
- Returns
ModelStore object
- fit(graph)¶
Fit the model on a graph.
- Parameters
graph – Graph to fit on
- store(path, key, overwrite=False)¶
Store the model in a file.
- Parameters
path – Path where to store the model
key – Encryption key
overwrite – Whether or not to overwrite pre-existing file
- property trained_vectors¶
Get the trained vertex vectors for the current DeepWalk model.
- Returns
PgxFrame object with the trained vertex vectors
- class pypgx.api.mllib.GraphWiseConvLayerConfig(java_config, params)¶
Bases:
object
GraphWise conv layer configuration.
- class pypgx.api.mllib.GraphWiseDgiLayerConfig(java_config, params)¶
Bases:
object
GraphWise dgi layer configuration.
- get_corruption_function()¶
Return the corruption function
- get_discriminator()¶
Return the discriminator
- get_readout_function()¶
Return the readout function
- set_corruption_function(corruption_function)¶
Set the corruption function :param corruption_function(CorruptionFunction): the corruption function
Supported currently: PermutationCorruption
- set_discriminator(discriminator)¶
Set the discriminator :param discriminator(str): The discriminator function
Supported currently: ‘BILINEAR’
- set_readout_function(readout_function)¶
Set the readout function :param readout_function(str): The readout function
Supported currently: ‘MEAN’
- class pypgx.api.mllib.GraphWiseModelConfig(java_graphwise_model_config)¶
Bases:
object
Graphwise Model Configuration class
- get_conv_layer_configs()¶
Return a list of conv layer configs
- set_batch_size(batch_size)¶
Set the batch size :param batch_size (int)
- set_edge_input_feature_dim(edge_input_feature_dim)¶
Set the edge input feature dimension :param edge_input_feature_dim (int)
- set_embedding_dim(embedding_dim)¶
Set the embedding dimension :param embedding_dim (int)
- set_fitted(fitted)¶
Set the fitted flag :param fitted (boolean)
- set_input_feature_dim(input_feature_dim)¶
Set the input feature dimension :param input_feature_dim (int)
- set_learning_rate(learning_rate)¶
Set the learning rate :param learning rate (int)
- set_num_epochs(num_epochs)¶
Set the number of epochs :param num_epochs (int)
- set_seed(seed)¶
Set the seed :param seed (int)
- set_shuffle(shuffle)¶
Set the shuffling flag :param shuffle (boolean)
- set_standarize(standardize)¶
Set the standardize flag :param standardize (boolean)
- set_training_loss(training_loss)¶
Set the training loss :param training_loss (float)
- class pypgx.api.mllib.GraphWisePredictionLayerConfig(java_config, params)¶
Bases:
object
GraphWise prediction layer configuration.
- class pypgx.api.mllib.PermutationCorruption(java_permutation_corruption)¶
Bases:
pypgx.api.mllib._corruption_function.CorruptionFunction
Permutation Function which shuffle the nodes to generate the corrupted subgraph for DGI
- class pypgx.api.mllib.Pg2vecModel(java_pg2vec_model)¶
Bases:
pypgx.api._pgx_context_manager.PgxContextManager
Pg2Vec model object.
- close()¶
Call destroy
- compute_similars(graphlet_id, k)¶
Compute the top-k similar graphlets for a list of input graphlets.
- Parameters
graphlet_id – graphletIds or iterable of graphletIds
k – number of similars to return
- destroy()¶
Destroy this model object.
- export()¶
Return a ModelStore object which can be used to save the model.
- Returns
ModelStore object
- fit(graph)¶
Fit the model on a graph.
- Parameters
graph – Graph to fit on
- infer_graphlet_vector(graph)¶
- Parameters
graph – graphlet for which to infer a vector
- infer_graphlet_vector_batched(graph)¶
- Parameters
graph – graphlets (as a single graph but different graphlet-id) for which to infer vectors
- store(path, key, overwrite=False)¶
Store the model in a file.
- Parameters
path – Path where to store the model
key – Encryption key
overwrite – Whether or not to overwrite pre-existing file
- property trained_graphlet_vectors¶
Get the trained graphlet vectors for the current pg2vec model.
- Returns
PgxFrame containing the trained graphlet vectors
- class pypgx.api.mllib.SupervisedGraphWiseModel(java_graphwise_model, params={})¶
Bases:
pypgx.api.mllib._graphwise_model.GraphWiseModel
SupervisedGraphWise model object.
- evaluate_labels(graph, vertices)¶
Evaluate (macro averaged) classification performance statistics for the specified vertices.
- Parameters
graph – the graph
vertices – the vertices to evaluate on
- Returns
PgxFrame containing the metrics
- export()¶
Return a ModelStore object which can be used to save the model.
- Returns
ModelStore object
- fit(graph)¶
Fit the model on a graph.
- Parameters
graph – Graph to fit on
- infer_and_get_explanation(graph, vertex)¶
Perform inference on the specified vertex and generate an explanation that contains scores of how important each property and each vertex in the computation graph is for the prediction.
- Parameters
graph – the graph
vertex – the vertex
- Returns
explanation containing feature importance and vertex importance.
- infer_embeddings(graph, vertices)¶
Infer the embeddings for the specified vertices
- Parameters
graph – the graph
vertices – the vertices to infer embeddings for
- Returns
PgxFrame containing the embeddings for each vertex
- infer_labels(graph, vertices)¶
Infer the labels for the specified vertices
- Parameters
graph – the graph
vertices – the vertices to infer labels for
- Returns
PgxFrame containing the labels for each vertex
- infer_logits(graph, vertices)¶
Infer the prediction logits for the specified vertices
- Parameters
graph – the graph
vertices – the vertices to infer logits for
- Returns
PgxFrame containing the logits for each vertex
- store(path, key, overwrite=False)¶
Store the model in a file.
- Parameters
path – Path where to store the model
key – Encryption key
overwrite – Whether or not to overwrite pre-existing file
- class pypgx.api.mllib.UnsupervisedGraphWiseModel(java_graphwise_model, params={})¶
Bases:
pypgx.api.mllib._graphwise_model.GraphWiseModel
UnsupervisedGraphWise model object.
- export()¶
Return a ModelStore object which can be used to save the model.
- Returns
ModelStore object
- fit(graph)¶
Fit the model on a graph.
- Parameters
graph – Graph to fit on
- store(path, key, overwrite=False)¶
Store the model in a file.
- Parameters
path – Path where to store the model
key – Encryption key
overwrite – Whether or not to overwrite pre-existing file