Class EdgeWiseModel<Config extends EdgeWiseModelConfig,​Metadata extends oracle.pgx.api.internal.mllib.EdgeWiseModelMetadata<Config>,​ModelType extends EdgeWiseModel<Config,​Metadata,​ModelType>>

  • All Implemented Interfaces:
    java.lang.AutoCloseable
    Direct Known Subclasses:
    SupervisedEdgeWiseModel, UnsupervisedEdgeWiseModel

    public abstract class EdgeWiseModel<Config extends EdgeWiseModelConfig,​Metadata extends oracle.pgx.api.internal.mllib.EdgeWiseModelMetadata<Config>,​ModelType extends EdgeWiseModel<Config,​Metadata,​ModelType>>
    extends Model<ModelType>
    Base class for EdgeWiseModels
    Since:
    23.1
    • Constructor Detail

      • EdgeWiseModel

        public EdgeWiseModel​(PgxSession session,
                             oracle.pgx.api.internal.Core core,
                             java.util.function.Supplier<java.lang.String> keystorePathSupplier,
                             java.util.function.Supplier<char[]> keystorePasswordSupplier,
                             Metadata modelMetadata,
                             java.util.function.BiFunction<PgxSession,​oracle.pgx.api.internal.Graph,​PgxGraph> graphConstructor)
        This constructor should never be used to get a model. Use SupervisedEdgeWiseModelBuilder instead.
        Parameters:
        session - PgxSession to which the model is connected
        core - Core to which the model is connected
        modelMetadata - Metadata concerning the different hyper-parameters of the EdgeWise Model
        Since:
        23.1
    • Method Detail

      • destroy

        public void destroy()
                     throws java.util.concurrent.ExecutionException,
                            java.lang.InterruptedException
        Blocking version of destroyAsync(). Calls destroyAsync() and waits for the returned PgxFuture to complete.
        Throws:
        java.lang.InterruptedException - if the caller thread gets interrupted while waiting for completion.
        java.util.concurrent.ExecutionException - if any exception occurred during asynchronous execution. The actual exception will be nested.
        Since:
        23.1
      • getNumEpochs

        public int getNumEpochs()
        Gets the number of epochs to train the model
        Returns:
        number of epochs to train the model
        Since:
        23.1
      • getLearningRate

        public double getLearningRate()
        Gets the initial learning rate
        Returns:
        initial learning rate
        Since:
        23.1
      • getBatchSize

        public int getBatchSize()
        Gets the batch size
        Returns:
        batch size
        Since:
        23.1
      • getEmbeddingDim

        public int getEmbeddingDim()
        Gets the dimension of the embeddings
        Returns:
        embedding dimension
        Since:
        23.1
      • getSeed

        public java.lang.Integer getSeed()
        Gets the random seed
        Returns:
        random seed
        Since:
        23.1
      • getConvLayerConfigs

        public GraphWiseBaseConvLayerConfig[] getConvLayerConfigs()
        Gets the configuration objects for the convolutional layers
        Returns:
        configurations
        Since:
        23.1
      • getVertexInputPropertyNames

        public java.util.List<java.lang.String> getVertexInputPropertyNames()
        Gets the vertices input feature names
        Returns:
        vertices input feature names
        Since:
        23.1
      • getEdgeInputPropertyNames

        public java.util.List<java.lang.String> getEdgeInputPropertyNames()
        Gets the edges input feature names
        Returns:
        edges input feature names
        Since:
        23.1
      • isFitted

        public boolean isFitted()
        Checks if the model is fitted
        Returns:
        true if the model is fitted
        Since:
        23.1
      • getTrainingLoss

        public double getTrainingLoss()
        Gets the final training loss
        Returns:
        training loss
        Since:
        23.1
      • getInputFeatureDim

        public int getInputFeatureDim()
        Gets the input feature dimension, that is, the dimension of all the input vertex properties when concatenated
        Returns:
        input feature dimension
        Since:
        23.1
      • getEdgeInputFeatureDim

        public int getEdgeInputFeatureDim()
        Gets the edges input feature dimension, that is, the dimension of all the input edge properties when concatenated
        Returns:
        edges input feature dimension
        Since:
        23.1
      • getEdgeCombinationMethod

        public EdgeCombinationMethod getEdgeCombinationMethod()
        Gets the edge embedding method used to produce the embedding
        Returns:
        edge embedding method
        Since:
        23.1
      • getConfig

        public Config getConfig()
        Gets the model configuration object
        Returns:
        model configuration
        Since:
        23.1
      • getTrainingLogAsync

        public abstract PgxFuture<PgxFrame> getTrainingLogAsync()
        Get the training log that has evaluation results from validation. It is available only after the model was trained with validation.
        Returns:
        training log
        Since:
        24.2
      • getTrainingLog

        public PgxFrame getTrainingLog()
                                throws java.util.concurrent.ExecutionException,
                                       java.lang.InterruptedException
        Blocking version of getTrainingLogAsync(). Calls getTrainingLogAsync() and waits for the returned PgxFuture to complete.
        Returns:
        training log
        Throws:
        java.lang.InterruptedException - if the caller thread gets interrupted while waiting for completion.
        java.util.concurrent.ExecutionException - if any exception occurred during asynchronous execution. The actual exception will be nested.
        Since:
        24.2
      • fitAsync

        public abstract PgxFuture<java.lang.Double> fitAsync​(PgxGraph graph)
        Trains the EdgeWise model on the input graph.
        Parameters:
        graph - input graph to fit on.
        Since:
        23.1
      • fitAsync

        public abstract PgxFuture<java.lang.Double> fitAsync​(PgxGraph trainGraph,
                                                             PgxGraph valGraph)
        Trains the EdgeWise model on the input trainGraph and evaluate on the input valGraph.
        Parameters:
        trainGraph - input graph to fit on.
        valGraph - input graph to evaluate on for validation.
        Since:
        24.2
      • fit

        public double fit​(PgxGraph graph)
                   throws java.util.concurrent.ExecutionException,
                          java.lang.InterruptedException
        Blocking version of fitAsync(PgxGraph). Calls fitAsync(PgxGraph) and waits for the returned PgxFuture to complete.
        Parameters:
        graph - input graph to fit on.
        Returns:
        the training loss of the last batch
        Throws:
        java.lang.InterruptedException - if the caller thread gets interrupted while waiting for completion.
        java.util.concurrent.ExecutionException - if any exception occurred during asynchronous execution. The actual exception will be nested.
        Since:
        23.1
      • fit

        public double fit​(PgxGraph trainGraph,
                          PgxGraph valGraph)
                   throws java.util.concurrent.ExecutionException,
                          java.lang.InterruptedException
        Blocking version of fitAsync(PgxGraph, PgxGraph). Calls fitAsync(PgxGraph, PgxGraph) and waits for the returned PgxFuture to complete.
        Parameters:
        trainGraph - input graph to fit on.
        valGraph - input graph to evaluate on for validation.
        Returns:
        the training loss of the last batch
        Throws:
        java.lang.InterruptedException - if the caller thread gets interrupted while waiting for completion.
        java.util.concurrent.ExecutionException - if any exception occurred during asynchronous execution. The actual exception will be nested.
        Since:
        24.2
      • inferEmbeddingsAsync

        public abstract PgxFuture<PgxFrame> inferEmbeddingsAsync​(PgxGraph graph,
                                                                 java.lang.Iterable<PgxEdge> edges)
        Infers the embeddings for the specified edges.
        Returns:
        PgxFrame containing the embeddings for each edge.
        Since:
        23.1