Package oracle.pgx.api.mllib
Class GraphWiseModel<Config extends GraphWiseModelConfig,Metadata extends oracle.pgx.api.internal.mllib.GraphWiseModelMetadata<Config>,ModelType extends GraphWiseModel<Config,Metadata,ModelType>>
- java.lang.Object
-
- oracle.pgx.api.mllib.Model<ModelType>
-
- oracle.pgx.api.mllib.GraphWiseModel<Config,Metadata,ModelType>
-
- All Implemented Interfaces:
java.lang.AutoCloseable
- Direct Known Subclasses:
SupervisedGraphWiseModel,UnsupervisedGraphWiseModel
public abstract class GraphWiseModel<Config extends GraphWiseModelConfig,Metadata extends oracle.pgx.api.internal.mllib.GraphWiseModelMetadata<Config>,ModelType extends GraphWiseModel<Config,Metadata,ModelType>> extends Model<ModelType>
Base class for GraphWiseModels- Since:
- 19.4
-
-
Constructor Summary
Constructors Constructor Description GraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)This constructor should never be used to get a model.
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description voiddestroy()Blocking version ofdestroyAsync().PgxFuture<java.lang.Void>destroyAsync()Destroys a GraphWise modeldoublefit(PgxGraph graph)Blocking version offitAsync(PgxGraph).doublefit(PgxGraph trainGraph, PgxGraph valGraph)Blocking version offitAsync(PgxGraph, PgxGraph).abstract PgxFuture<java.lang.Double>fitAsync(PgxGraph graph)Trains the GraphWise model on the input graph.abstract PgxFuture<java.lang.Double>fitAsync(PgxGraph trainGraph, PgxGraph valGraph)Trains the GraphWise model on the input trainGraph and evaluate on the input valGraph.intgetBatchSize()Gets the batch sizeConfiggetConfig()Gets the model configuration objectGraphWiseBaseConvLayerConfig[]getConvLayerConfigs()Gets the configuration objects for the convolutional layersintgetEdgeInputFeatureDim()Gets the edges input feature dimension, that is, the dimension of all the input vertex properties when concatenatedjava.util.List<java.lang.String>getEdgeInputPropertyNames()Gets the edges input feature namesintgetEmbeddingDim()Gets the dimension of the embeddingsintgetInputFeatureDim()Gets the input feature dimension, that is, the dimension of all the input vertex properties when concatenateddoublegetLearningRate()Gets the initial learning rateintgetNumEpochs()Gets the number of epochs to train the modeljava.lang.IntegergetSeed()Gets the random seedPgxFramegetTrainingLog()Blocking version ofgetTrainingLogAsync().abstract PgxFuture<PgxFrame>getTrainingLogAsync()Gets the training log that has evaluation results from validation.doublegetTrainingLoss()Gets the final training lossjava.util.List<java.lang.String>getVertexInputPropertyNames()Gets the vertices input feature names<ID> PgxFrameinferEmbeddings(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)Blocking version ofinferEmbeddingsAsync(PgxGraph, Iterable).abstract <ID> PgxFuture<PgxFrame>inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)Infers the embeddings for the specified vertices.booleanisFitted()Checks if the model is fitted
-
-
-
Constructor Detail
-
GraphWiseModel
public GraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
This constructor should never be used to get a model. UseSupervisedGraphWiseModelBuilderinstead.- Parameters:
session- PgxSession to which the model is connectedcore- Core to which the model is connectedmodelMetadata- Metadata concerning the different hyper-parameters of the GraphWise Model- Since:
- 19.4
-
-
Method Detail
-
destroyAsync
public PgxFuture<java.lang.Void> destroyAsync()
Destroys a GraphWise model- Specified by:
destroyAsyncin classModel<ModelType extends GraphWiseModel<Config,Metadata,ModelType>>- Returns:
- a future which will be completed once the destruction request finishes.
- Since:
- 19.4
-
destroy
public void destroy() throws java.util.concurrent.ExecutionException, java.lang.InterruptedExceptionBlocking version ofdestroyAsync(). CallsdestroyAsync()and waits for the returnedPgxFutureto complete.- Throws:
java.lang.InterruptedException- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException- if any exception occurred during asynchronous execution. The actual exception will be nested.- Since:
- 19.4
-
getNumEpochs
public int getNumEpochs()
Gets the number of epochs to train the model- Returns:
- number of epochs to train the model
- Since:
- 19.4
-
getLearningRate
public double getLearningRate()
Gets the initial learning rate- Returns:
- initial learning rate
- Since:
- 19.4
-
getBatchSize
public int getBatchSize()
Gets the batch size- Returns:
- batch size
- Since:
- 19.4
-
getEmbeddingDim
public int getEmbeddingDim()
Gets the dimension of the embeddings- Returns:
- embedding dimension
- Since:
- 19.4
-
getSeed
public java.lang.Integer getSeed()
Gets the random seed- Returns:
- random seed
- Since:
- 19.4
-
getConvLayerConfigs
public GraphWiseBaseConvLayerConfig[] getConvLayerConfigs()
Gets the configuration objects for the convolutional layers- Returns:
- configurations
- Since:
- 19.4
-
getVertexInputPropertyNames
public java.util.List<java.lang.String> getVertexInputPropertyNames()
Gets the vertices input feature names- Returns:
- vertices input feature names
- Since:
- 19.4
-
getEdgeInputPropertyNames
public java.util.List<java.lang.String> getEdgeInputPropertyNames()
Gets the edges input feature names- Returns:
- edges input feature names
- Since:
- 21.2
-
isFitted
public boolean isFitted()
Checks if the model is fitted- Returns:
- true if the model is fitted
- Since:
- 19.4
-
getTrainingLoss
public double getTrainingLoss()
Gets the final training loss- Returns:
- training loss
- Since:
- 19.4
-
getInputFeatureDim
public int getInputFeatureDim()
Gets the input feature dimension, that is, the dimension of all the input vertex properties when concatenated- Returns:
- input feature dimension
- Since:
- 19.4
-
getEdgeInputFeatureDim
public int getEdgeInputFeatureDim()
Gets the edges input feature dimension, that is, the dimension of all the input vertex properties when concatenated- Returns:
- edges input feature dimension
- Since:
- 21.2
-
getConfig
public Config getConfig()
Gets the model configuration object- Returns:
- model configuration
- Since:
- 19.4
-
getTrainingLogAsync
public abstract PgxFuture<PgxFrame> getTrainingLogAsync()
Gets the training log that has evaluation results from validation. It is available only after the model was trained with validation.- Returns:
- training log
- Since:
- 24.2
-
getTrainingLog
public PgxFrame getTrainingLog() throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
Blocking version ofgetTrainingLogAsync(). CallsgetTrainingLogAsync()and waits for the returnedPgxFutureto complete.- Returns:
- training log
- Throws:
java.lang.InterruptedException- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException- if any exception occurred during asynchronous execution. The actual exception will be nested.- Since:
- 24.2
-
fitAsync
public abstract PgxFuture<java.lang.Double> fitAsync(PgxGraph graph)
Trains the GraphWise model on the input graph.- Parameters:
graph- input graph to fit on.- Since:
- 19.4
-
fitAsync
public abstract PgxFuture<java.lang.Double> fitAsync(PgxGraph trainGraph, PgxGraph valGraph)
Trains the GraphWise model on the input trainGraph and evaluate on the input valGraph.- Parameters:
trainGraph- input graph to fit on.valGraph- input graph to evaluate on for validation.- Since:
- 24.2
-
fit
public double fit(PgxGraph graph) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
Blocking version offitAsync(PgxGraph). CallsfitAsync(PgxGraph)and waits for the returnedPgxFutureto complete.- Parameters:
graph- input graph to fit on.- Returns:
- the training loss of the last batch
- Throws:
java.lang.InterruptedException- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException- if any exception occurred during asynchronous execution. The actual exception will be nested.- Since:
- 19.4
-
fit
public double fit(PgxGraph trainGraph, PgxGraph valGraph) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
Blocking version offitAsync(PgxGraph, PgxGraph). CallsfitAsync(PgxGraph, PgxGraph)and waits for the returnedPgxFutureto complete.- Parameters:
trainGraph- input graph to fit on.valGraph- input graph to evaluate on for validation.- Returns:
- the training loss of the last batch
- Throws:
java.lang.InterruptedException- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException- if any exception occurred during asynchronous execution. The actual exception will be nested.- Since:
- 24.2
-
inferEmbeddingsAsync
public abstract <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Infers the embeddings for the specified vertices.- Returns:
- PgxFrame containing the embeddings for each vertex.
- Since:
- 19.4
-
inferEmbeddings
public <ID> PgxFrame inferEmbeddings(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Blocking version ofinferEmbeddingsAsync(PgxGraph, Iterable). Infers the embeddings for the specified vertices.- Returns:
- PgxFrame containing the embeddings for each vertex.
- Since:
- 19.4
-
-