public abstract class EdgeWiseModel<Config extends EdgeWiseModelConfig,Metadata extends oracle.pgx.api.internal.mllib.EdgeWiseModelMetadata<Config>,ModelType extends EdgeWiseModel<Config,Metadata,ModelType>> extends Model<ModelType>
Constructor and Description |
---|
EdgeWiseModel(PgxSession session,
oracle.pgx.api.internal.Core core,
java.util.function.Supplier<java.lang.String> keystorePathSupplier,
java.util.function.Supplier<char[]> keystorePasswordSupplier,
Metadata modelMetadata,
java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
This constructor should never be used to get a model.
|
Modifier and Type | Method and Description |
---|---|
void |
destroy()
Blocking version of
destroyAsync() . |
PgxFuture<java.lang.Void> |
destroyAsync()
Destroys an EdgeWise model
|
double |
fit(PgxGraph graph)
Blocking version of
fitAsync(PgxGraph) . |
abstract PgxFuture<java.lang.Double> |
fitAsync(PgxGraph graph)
Trains the EdgeWise model on the input graph.
|
int |
getBatchSize()
Gets the batch size
|
Config |
getConfig()
Gets the model configuration object
|
GraphWiseBaseConvLayerConfig[] |
getConvLayerConfigs()
Gets the configuration objects for the convolutional layers
|
EdgeCombinationMethod |
getEdgeCombinationMethod()
Gets the edge embedding method used to produce the embedding
|
int |
getEdgeInputFeatureDim()
Gets the edges input feature dimension, that is, the dimension of all the input edge properties when
concatenated
|
java.util.List<java.lang.String> |
getEdgeInputPropertyNames()
Gets the edges input feature names
|
int |
getEmbeddingDim()
Gets the dimension of the embeddings
|
int |
getInputFeatureDim()
Gets the input feature dimension, that is, the dimension of all the input vertex properties when concatenated
|
double |
getLearningRate()
Gets the initial learning rate
|
int |
getNumEpochs()
Gets the number of epochs to train the model
|
java.lang.Integer |
getSeed()
Gets the random seed
|
double |
getTrainingLoss()
Gets the final training loss
|
java.util.List<java.lang.String> |
getVertexInputPropertyNames()
Gets the vertices input feature names
|
PgxFrame |
inferEmbeddings(PgxGraph graph,
java.lang.Iterable<PgxEdge> edges)
Blocking version of
inferEmbeddingsAsync(PgxGraph, Iterable) . |
abstract PgxFuture<PgxFrame> |
inferEmbeddingsAsync(PgxGraph graph,
java.lang.Iterable<PgxEdge> edges)
Infers the embeddings for the specified edges.
|
boolean |
isFitted()
Checks if the model is fitted
|
public EdgeWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
SupervisedEdgeWiseModelBuilder
instead.session
- PgxSession to which the model is connectedcore
- Core to which the model is connectedmodelMetadata
- Metadata concerning the different hyper-parameters of the EdgeWise Modelpublic void destroy() throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
destroyAsync()
.
Calls destroyAsync()
and waits for the returned PgxFuture
to complete.java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will
be nested.public PgxFuture<java.lang.Void> destroyAsync()
destroyAsync
in class Model<ModelType extends EdgeWiseModel<Config,Metadata,ModelType>>
public double fit(PgxGraph graph) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
fitAsync(PgxGraph)
.
Calls fitAsync(PgxGraph)
and waits for the returned PgxFuture
to complete.graph
- input graph to fit on.java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will
be nested.public abstract PgxFuture<java.lang.Double> fitAsync(PgxGraph graph)
graph
- input graph to fit on.public int getBatchSize()
public Config getConfig()
public GraphWiseBaseConvLayerConfig[] getConvLayerConfigs()
public EdgeCombinationMethod getEdgeCombinationMethod()
public int getEdgeInputFeatureDim()
public java.util.List<java.lang.String> getEdgeInputPropertyNames()
public int getEmbeddingDim()
public int getInputFeatureDim()
public double getLearningRate()
public int getNumEpochs()
public java.lang.Integer getSeed()
public double getTrainingLoss()
public java.util.List<java.lang.String> getVertexInputPropertyNames()
public PgxFrame inferEmbeddings(PgxGraph graph, java.lang.Iterable<PgxEdge> edges)
inferEmbeddingsAsync(PgxGraph, Iterable)
.
Infers the embeddings for the specified edges.public abstract PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxEdge> edges)
public boolean isFitted()