Apply Machine Learning on a Graph

You can use machine learning on your property graph data in Graph Studio using the PGX machine learning library.

The following are a few of the supported machine learning algorithms:

  • DeepWalk
  • Supervised GraphWise
  • Unsupervised GraphWise
  • Pg2vec
See Using the Machine Learning Library (PgxML) for Graphs in Oracle Database Graph Developer's Guide for Property Graph for more information.
Running machine learning algorithms is supported in a notebook paragraph using the following interpreters:
For example, the following steps describe the usage of the DeepWalk model on a graph in a notebook paragraph.
  1. Load the required graph into memory and reference the graph in the notebook.
    See Reference Graphs in Notebook Paragraphs for more information.
  2. Build a DeepWalk model using customized hyper-parameters.
    import oracle.pgx.api.mllib.DeepWalkModel
    var model = session.createAnalyst().deepWalkModelBuilder().
                    setMinWordFrequency(1).
                    setBatchSize(512).
                    setNumEpochs(1).
                    setLayerSize(100).
                    setLearningRate(0.05).
                    setMinLearningRate(0.0001).
                    setWindowSize(3).
                    setWalksPerVertex(6).
                    setWalkLength(4).
                    setSampleRate(0.00001).
                    setNegativeSample(2).
                    setValidationFraction(0.01).
                    build()
    model = analyst.deepwalk_builder(min_word_frequency= 1,
                                    batch_size= 512,
                                    num_epochs= 1,
                                    layer_size= 100,
                                    learning_rate= 0.05,
                                    min_learning_rate= 0.0001,
                                    window_size= 3,
                                    walks_per_vertex= 6,
                                    walk_length= 4,
                                    sample_rate= 0.00001,
                                    negative_sample= 2,
                                    validation_fraction= 0.01)
  3. Train the DeepWalk model on the graph data.
    model.fit(g)
    model.fit(g)

You can now perform one or more of the following functionalities on the DeepWalk model:

  1. Compute the loss value on the data.
    var loss = model.getLoss()
    loss = model.loss
  2. Fetch similar vertices for a list of input vertices.
    import oracle.pgx.api.frames.*
    List<java.lang.Object> vertices = Arrays.asList("3244407212344026742", "371586706748522153")
    var batchSimilars = model.computeSimilars(vertices, 2 )
    batchSimilars.print(out,10,0)
    vertices = ["3244407212344026742", "371586706748522153"]
    batch_similars = model.compute_similars(vertices, 2)
    batch_similars.print()
    The output results in the following format:
    +----------------------------------------------------------------+
    | srcVertex           | dstVertex           | similarity         |
    +----------------------------------------------------------------+
    | 3244407212344026742 | 3244407212344026742 | 1.0                |
    | 3244407212344026742 | 3510061098087750671 | 0.2863036096096039 |
    | 371586706748522153  | 371586706748522153  | 1.0                |
    | 371586706748522153  | 2128822953047004384 | 0.3220503330230713 |
    +----------------------------------------------------------------+
  3. Retrieve and store all trained vertex vectors to the database.
    var vertexVectors = model.getTrainedVertexVectors().flattenAll()
    vertexVectors.write().db().name("deepwalkframe").tablename("vertexVectors").overwrite(true).store()
    
    vertex_vectors = model.trained_vectors.flatten_all()
    vertex_vectors.write().db().table_name("vertex_vectors").overwrite(True).store()
    
    If you are using an Always Free Autonomous Database instance (that is, one with only 1 OCPU and 20GB of storage), then you must also specify that only one connection must be used when writing the PgxFrame to the table in a Java (PGX) notebook paragraph. For example, you must invoke write() as shown:
    vertexVectors.write().db().name("deepwalkframe").tablename("vertexVectors").overwrite(true).connections(1).store()
    The columns in the database table for the flattened vectors will appear as:
    +----------------------------------------------------+---------------------+
    | vertexid                     | embedding_0         | embedding_1         |
    +----------------------------------------------------+---------------------+
  4. Store the trained model to the database.
    model.export().db().modelstore("bank_model").modelname("model").description("DeepWalk Model for Bank data").store()
    model.export().db(model_store="bank_model",
                      model_name="model", model_description="DeepWalk Model for Bank data", overwrite=True)
    The model gets stored as a row in the model store table.
  5. Load a pre-trained model from the database.
    var model = session.createAnalyst().loadDeepWalkModel().db().modelstore("bank_model").modelname("model").load()
    model = analyst.get_deepwalk_model_loader().db(model_store="bank_model",
                                           model_name="model")
  6. Destroy a DeepWalk model.
    model.destroy()
    model.destroy()