Package oracle.pgx.api.mllib
Class SupervisedGraphWiseModelBuilder
- java.lang.Object
-
- oracle.pgx.api.mllib.WiseModelBuilder<Config,Self>
-
- oracle.pgx.api.mllib.GraphWiseModelBuilder<SupervisedGraphWiseModel,SupervisedGraphWiseModelConfig,SupervisedGraphWiseModelBuilder>
-
- oracle.pgx.api.mllib.SupervisedGraphWiseModelBuilder
-
public class SupervisedGraphWiseModelBuilder extends GraphWiseModelBuilder<SupervisedGraphWiseModel,SupervisedGraphWiseModelConfig,SupervisedGraphWiseModelBuilder>
Builder forSupervisedGraphWiseModel
The builder can be used to set the configuration of the model and create the model object.- Since:
- 19.4
-
-
Constructor Summary
Constructors Constructor Description SupervisedGraphWiseModelBuilder(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description SupervisedGraphWiseModel
build()
Builds the SupervisedGraphWise model with the specified parameters.SupervisedGraphWiseModelBuilder
setBatchGenerator(BatchGenerator batchGenerator)
Sets the batch generator.SupervisedGraphWiseModelBuilder
setClassWeights(java.util.Map<?,java.lang.Float> classWeights)
Set the class weights to be used in the loss function.SupervisedGraphWiseModelBuilder
setLossFunction(LossFunction lossFunction)
Sets the loss function for the algorithm.SupervisedGraphWiseModelBuilder
setPredictionLayerConfigs(GraphWisePredictionLayerConfig... layerConfigs)
Set the prediction layer configurations (SeeGraphWisePredictionLayerConfig
.SupervisedGraphWiseModelBuilder
setVertexTargetPropertyName(java.lang.String propertyName)
Set the target (labels) for the algorithm in the form of a property name of the graph.-
Methods inherited from class oracle.pgx.api.mllib.GraphWiseModelBuilder
setTargetVertexLabels, setTargetVertexLabels
-
Methods inherited from class oracle.pgx.api.mllib.WiseModelBuilder
setBatchSize, setConvLayerConfigs, setEdgeInputPropertyConfigs, setEdgeInputPropertyNames, setEdgeInputPropertyNames, setEmbeddingDim, setEnableAccelerator, setLearningRate, setNormalize, setNumEpochs, setSeed, setShuffle, setStandardize, setVertexInputPropertyConfigs, setVertexInputPropertyNames, setVertexInputPropertyNames, setWeightDecay
-
-
-
-
Constructor Detail
-
SupervisedGraphWiseModelBuilder
public SupervisedGraphWiseModelBuilder(PgxSession session, oracle.pgx.api.internal.Core core, java.util.function.Supplier<java.lang.String> keystorePathSupplier, java.util.function.Supplier<char[]> keystorePasswordSupplier, java.util.function.BiFunction<PgxSession,oracle.pgx.api.internal.Graph,PgxGraph> graphConstructor)
-
-
Method Detail
-
setVertexTargetPropertyName
public SupervisedGraphWiseModelBuilder setVertexTargetPropertyName(java.lang.String propertyName)
Set the target (labels) for the algorithm in the form of a property name of the graph.Supported property types are given by
SupervisedGraphWiseModelConfig.SUPPORTED_LABEL_TYPES
.- Parameters:
propertyName
- property name- Returns:
- this
- Since:
- 19.4
-
setPredictionLayerConfigs
public SupervisedGraphWiseModelBuilder setPredictionLayerConfigs(GraphWisePredictionLayerConfig... layerConfigs)
Set the prediction layer configurations (SeeGraphWisePredictionLayerConfig
. You must pass at least one prediction layer config (or leave it as default).Note that an additional layer will be inserted in the end for which:
- the activation function will be replaced with the activation function of the loss function, e.g. softmax or sigmoid
- the hidden dimension will be equal to the number of classes
- the weight initialization scheme will be copied from the previous layer
default:
SupervisedGraphWiseModelConfig.DEFAULT_PREDICTION_LAYER_CONFIGS
- Parameters:
layerConfigs
- layer configs- Returns:
- this
- Since:
- 19.4
-
setClassWeights
public SupervisedGraphWiseModelBuilder setClassWeights(java.util.Map<?,java.lang.Float> classWeights)
Set the class weights to be used in the loss function. The loss for the corresponding class will be multiplied by the factor given in this map. If null, uniform class weights will be used.- Parameters:
classWeights
- map from classes to weights- Returns:
- this
- Since:
- 19.4
-
setLossFunction
public SupervisedGraphWiseModelBuilder setLossFunction(LossFunction lossFunction)
Sets the loss function for the algorithm. SeeLossFunction
default:
SoftmaxCrossEntropyLoss
- Parameters:
lossFunction
- loss function- Returns:
- this
- Since:
- 21.3
-
setBatchGenerator
public SupervisedGraphWiseModelBuilder setBatchGenerator(BatchGenerator batchGenerator)
Sets the batch generator. SeeBatchGenerator
default:
StandardBatchGenerator
- Parameters:
batchGenerator
- batch generator- Returns:
- this
- Since:
- 21.3
-
build
public SupervisedGraphWiseModel build() throws java.lang.InterruptedException, java.util.concurrent.ExecutionException
Builds the SupervisedGraphWise model with the specified parameters.- Specified by:
build
in classGraphWiseModelBuilder<SupervisedGraphWiseModel,SupervisedGraphWiseModelConfig,SupervisedGraphWiseModelBuilder>
- Returns:
- SupervisedGraphWise model
- Throws:
java.lang.InterruptedException
java.util.concurrent.ExecutionException
- Since:
- 19.4
-
-