public abstract class GraphWiseModel<Config extends GraphWiseModelConfig,Metadata extends oracle.pgx.api.internal.mllib.GraphWiseModelMetadata<Config>,ModelType extends GraphWiseModel<Config,Metadata,ModelType>> extends Model<ModelType>
Constructor and Description |
---|
GraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
This constructor should never be used to get a model.
|
Modifier and Type | Method and Description |
---|---|
void |
destroy()
Blocking version of
destroyAsync() . |
PgxFuture<java.lang.Void> |
destroyAsync()
Destroys a GraphWise model
|
double |
fit(PgxGraph graph)
Blocking version of
fitAsync(PgxGraph) . |
abstract PgxFuture<java.lang.Double> |
fitAsync(PgxGraph graph)
Trains the GraphWise model on the input graph.
|
int |
getBatchSize()
Gets the batch size
|
Config |
getConfig()
Gets the model configuration object
|
GraphWiseConvLayerConfig[] |
getConvLayerConfigs()
Gets the configuration objects for the convolutional layers
|
int |
getEdgeInputFeatureDim()
Gets the edges input feature dimension, that is, the dimension of all the input vertex properties when concatenated
|
java.util.List<java.lang.String> |
getEdgeInputPropertyNames()
Gets the edges input feature names
|
int |
getEmbeddingDim()
Gets the dimension of the embeddings
|
int |
getInputFeatureDim()
Gets the input feature dimension, that is, the dimension of all the input vertex properties when concatenated
|
double |
getLearningRate()
Gets the initial learning rate
|
int |
getNumEpochs()
Gets the number of epochs to train the model
|
int |
getSeed()
Gets the random seed
|
double |
getTrainingLoss()
Gets the final training loss
|
java.util.List<java.lang.String> |
getVertexInputPropertyNames()
Gets the vertices input feature names
|
<ID> PgxFrame |
inferEmbeddings(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Blocking version of
inferEmbeddingsAsync(PgxGraph, Iterable) . |
abstract <ID> PgxFuture<PgxFrame> |
inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Infers the embeddings for the specified vertices.
|
boolean |
isFitted()
Checks if the model is fitted
|
public GraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
SupervisedGraphWiseModelBuilder
instead.session
- PgxSession to which the model is connectedcore
- Core to which the model is connectedmodelMetadata
- Metadata concerning the different hyper-parameters of the GraphWise Modelpublic void destroy() throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
destroyAsync()
. Calls destroyAsync()
and waits for the returned PgxFuture
to complete.java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will be nested.public PgxFuture<java.lang.Void> destroyAsync()
destroyAsync
in class Model<ModelType extends GraphWiseModel<Config,Metadata,ModelType>>
public double fit(PgxGraph graph) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
fitAsync(PgxGraph)
. Calls fitAsync(PgxGraph)
and waits for the returned PgxFuture
to complete.graph
- input graph to fit on.java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will be nested.public abstract PgxFuture<java.lang.Double> fitAsync(PgxGraph graph)
graph
- input graph to fit on.public int getBatchSize()
public Config getConfig()
public GraphWiseConvLayerConfig[] getConvLayerConfigs()
public int getEdgeInputFeatureDim()
public java.util.List<java.lang.String> getEdgeInputPropertyNames()
public int getEmbeddingDim()
public int getInputFeatureDim()
public double getLearningRate()
public int getNumEpochs()
public int getSeed()
public double getTrainingLoss()
public java.util.List<java.lang.String> getVertexInputPropertyNames()
public <ID> PgxFrame inferEmbeddings(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
inferEmbeddingsAsync(PgxGraph, Iterable)
. Infers the embeddings for the specified vertices.public abstract <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
public boolean isFitted()