public class SupervisedGraphWiseModel extends GraphWiseModel<SupervisedGraphWiseModelConfig,oracle.pgx.api.internal.mllib.SupervisedGraphWiseModelMetadata,SupervisedGraphWiseModel>
SupervisedGraphWiseModelBuilder
for documentation of the hyperparameters.Modifier and Type | Class and Description |
---|---|
static class |
SupervisedGraphWiseModel.SupervisedGraphWiseInferenceType |
Modifier and Type | Field and Description |
---|---|
static java.lang.String |
ALGORITHM_NAME |
Constructor and Description |
---|
SupervisedGraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor, oracle.pgx.api.internal.mllib.ModelMetadata modelMetadata) |
SupervisedGraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor, oracle.pgx.api.internal.mllib.SupervisedGraphWiseModelMetadata modelMetadata)
This constructor should never be used to get a model.
|
Modifier and Type | Method and Description |
---|---|
<ID> PgxFrame |
evaluateLabels(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Blocking version of
evaluateLabelsAsync(PgxGraph, Iterable) . |
<ID> PgxFrame |
evaluateLabels(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices, float threshold)
Blocking version of
evaluateLabelsAsync(PgxGraph, Iterable, float) . |
<ID> PgxFuture<PgxFrame> |
evaluateLabelsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Evaluates (macro averaged) classification performance statistics for the specified vertices.
|
<ID> PgxFuture<PgxFrame> |
evaluateLabelsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices, float threshold)
Evaluates (macro averaged) classification performance statistics for the specified vertices.
|
PgxFuture<java.lang.Double> |
fitAsync(PgxGraph graph)
Trains the GraphWise model on the input graph.
|
java.util.Map<?,java.lang.Float> |
getClassWeights()
Gets the class weights
|
LossFunction |
getLossFunctionClass()
Gets the loss function
|
GraphWisePredictionLayerConfig[] |
getPredictionLayerConfigs()
Gets the configuration objects for the prediction layers
|
java.util.List<java.util.Set<java.lang.String>> |
getTargetVertexLabels()
Gets the target vertex labels
|
java.lang.String |
getVertexTargetPropertyName()
Gets the target property name
|
SupervisedGnnExplainer |
gnnExplainer()
Get a GnnExplainer object that can explain this model's predictions.
|
<ID> PgxFuture<PgxFrame> |
inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Infers the embeddings for the specified vertices.
|
<ID> PgxFrame |
inferLabels(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Blocking version of
inferLabelsAsync(PgxGraph, Iterable) . |
<ID> PgxFrame |
inferLabels(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices, float threshold)
Blocking version of
inferLabelsAsync(PgxGraph, Iterable, float) . |
<ID> PgxFuture<PgxFrame> |
inferLabelsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Infers the labels for the specified vertices.
|
<ID> PgxFuture<PgxFrame> |
inferLabelsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices, float threshold)
Infers the labels for the specified vertices.
|
<ID> PgxFrame |
inferLogits(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Blocking version of
inferLogitsAsync(PgxGraph, Iterable) . |
<ID> PgxFuture<PgxFrame> |
inferLogitsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
Infers the prediction logits for the specified vertices.
|
void |
store(java.lang.String path, java.lang.String key)
Blocking version of
storeAsync(String, String) . |
void |
store(java.lang.String path, java.lang.String key, boolean overwrite)
Blocking version of
storeAsync(String, String) . |
PgxFuture<java.lang.Void> |
storeAsync(java.lang.String path, java.lang.String key)
Stores the GraphWise model in the specified path, with encryption.
|
PgxFuture<java.lang.Void> |
storeAsync(java.lang.String path, java.lang.String key, boolean overwrite)
Stores the GraphWise model in the specified path, with encryption.
|
destroy, destroyAsync, fit, getBatchSize, getConfig, getConvLayerConfigs, getEdgeInputFeatureDim, getEdgeInputPropertyNames, getEmbeddingDim, getInputFeatureDim, getLearningRate, getNumEpochs, getSeed, getTrainingLoss, getVertexInputPropertyNames, inferEmbeddings, isFitted
public static final java.lang.String ALGORITHM_NAME
public SupervisedGraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor, oracle.pgx.api.internal.mllib.ModelMetadata modelMetadata)
public SupervisedGraphWiseModel(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor, oracle.pgx.api.internal.mllib.SupervisedGraphWiseModelMetadata modelMetadata)
SupervisedGraphWiseModelBuilder
instead.session
- PgxSession to which the model is connectedcore
- Core to which the model is connectedgraphConstructor
- Constructor for a PgxGraphmodelMetadata
- Metadata concerning the different hyper-parameters of the GraphWise Modelpublic <ID> PgxFrame evaluateLabels(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
evaluateLabelsAsync(PgxGraph, Iterable)
. Evaluates (macro averaged) classification performance statistics for the specified vertices.graph
- the input graphvertices
- the vertices to evaluate the model onpublic <ID> PgxFrame evaluateLabels(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices, float threshold)
evaluateLabelsAsync(PgxGraph, Iterable, float)
. Evaluates (macro averaged) classification performance statistics for the specified vertices.graph
- the input graphvertices
- the vertices to evaluate the model onthreshold
- decision thresholdpublic <ID> PgxFuture<PgxFrame> evaluateLabelsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
graph
- the input graphvertices
- the vertices to evaluate the model onpublic <ID> PgxFuture<PgxFrame> evaluateLabelsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices, float threshold)
graph
- the input graphvertices
- the vertices to evaluate the model onthreshold
- the decision thresholdpublic PgxFuture<java.lang.Double> fitAsync(PgxGraph graph)
fitAsync
in class GraphWiseModel<SupervisedGraphWiseModelConfig,oracle.pgx.api.internal.mllib.SupervisedGraphWiseModelMetadata,SupervisedGraphWiseModel>
graph
- input graph to fit on.public java.util.Map<?,java.lang.Float> getClassWeights()
public LossFunction getLossFunctionClass()
public GraphWisePredictionLayerConfig[] getPredictionLayerConfigs()
public java.util.List<java.util.Set<java.lang.String>> getTargetVertexLabels()
public java.lang.String getVertexTargetPropertyName()
public SupervisedGnnExplainer gnnExplainer()
public <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
inferEmbeddingsAsync
in class GraphWiseModel<SupervisedGraphWiseModelConfig,oracle.pgx.api.internal.mllib.SupervisedGraphWiseModelMetadata,SupervisedGraphWiseModel>
graph
- the input graphvertices
- the vertices to produce embeddings topublic <ID> PgxFrame inferLabels(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
inferLabelsAsync(PgxGraph, Iterable)
. Infers the labels for the specified vertices.graph
- the input graphvertices
- the vertices to produce labels topublic <ID> PgxFrame inferLabels(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices, float threshold)
inferLabelsAsync(PgxGraph, Iterable, float)
. Infers the labels for the specified vertices.graph
- the input graphvertices
- the vertices to produce labels topublic <ID> PgxFuture<PgxFrame> inferLabelsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
graph
- the input graphvertices
- the vertices to produce labels topublic <ID> PgxFuture<PgxFrame> inferLabelsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices, float threshold)
graph
- the input graphvertices
- the vertices to produce labels tothreshold
- the decision thresholdpublic <ID> PgxFrame inferLogits(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
inferLogitsAsync(PgxGraph, Iterable)
. Infers the prediction logits for the specified vertices.graph
- the input graphvertices
- the vertices to produce logits topublic <ID> PgxFuture<PgxFrame> inferLogitsAsync(PgxGraph graph, java.lang.Iterable<PgxVertex<ID>> vertices)
graph
- the input graphvertices
- the vertices to produce logits topublic void store(java.lang.String path, java.lang.String key) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
storeAsync(String, String)
. Calls storeAsync(String, String)
and waits for the returned PgxFuture
to complete.java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will be nested.public void store(java.lang.String path, java.lang.String key, boolean overwrite) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
storeAsync(String, String)
. Calls storeAsync(String, String)
and waits for the returned PgxFuture
to complete.java.lang.InterruptedException
- if the caller thread gets interrupted while waiting for completion.java.util.concurrent.ExecutionException
- if any exception occurred during asynchronous execution. The actual exception will be nested.public PgxFuture<java.lang.Void> storeAsync(java.lang.String path, java.lang.String key) throws java.util.concurrent.ExecutionException, java.lang.InterruptedException
path
- path to store the modelkey
- the encryption key, or null if no encryption should be used.java.util.concurrent.ExecutionException
java.lang.InterruptedException
public PgxFuture<java.lang.Void> storeAsync(java.lang.String path, java.lang.String key, boolean overwrite)
path
- path to store the modelkey
- the encryption key, or null if no encryption should be used.