public class SupervisedGraphWiseModelConfig extends GraphWiseModelConfig
SupervisedGraphWiseModel
. See SupervisedGraphWiseModel
for a description of the hyperparameters.Modifier and Type | Class and Description |
---|---|
static class |
oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction
Deprecated.
|
GraphWiseModelConfig.Backend, GraphWiseModelConfig.GraphConvModelVariant
Modifier and Type | Field and Description |
---|---|
static BatchGenerator |
DEFAULT_BATCH_GENERATOR
|
static java.util.Map<?,java.lang.Float> |
DEFAULT_CLASS_MAP
null
|
static oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction |
DEFAULT_LOSS_FUNCTION
Deprecated.
|
static LossFunction |
DEFAULT_LOSS_FUNCTION_CLASS
|
static boolean |
DEFAULT_NORMALIZE
true
|
static GraphWisePredictionLayerConfig[] |
DEFAULT_PREDICTION_LAYER_CONFIGS
one default initialized config (See
GraphWisePredictionLayerConfig ) |
static java.util.EnumSet<PropertyType> |
SUPPORTED_LABEL_TYPES
INTEGER, STRING, BOOLEAN, LONG
|
DEFAULT_BACKEND, DEFAULT_BATCH_SIZE, DEFAULT_CONV_LAYER_CONFIGS, DEFAULT_EMBEDDING_DIM, DEFAULT_LEARNING_RATE, DEFAULT_MODE, DEFAULT_NUM_EPOCHS, DEFAULT_SEED, DEFAULT_SHUFFLE, DEFAULT_STANDARDIZE, DEFAULT_WEIGHT_DECAY, SUPPORTED_INPUT_TYPES
Constructor and Description |
---|
SupervisedGraphWiseModelConfig() |
SupervisedGraphWiseModelConfig(int batchSize, int numEpochs, double learningRate, double weightDecay, int embeddingDim, java.lang.Integer seed, GraphWiseConvLayerConfig[] convLayerConfigs, boolean standardize, boolean shuffle, java.util.List<java.lang.String> vertexInputPropertyNames, java.util.List<java.lang.String> edgeInputPropertyNames, java.util.List<java.util.Set<java.lang.String>> targetVertexLabelSets, boolean fitted, double trainingLoss, int inputFeatureDim, int edgeInputFeatureDim, oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction lossFunction, LossFunction lossFunctionClass, BatchGenerator batchGenerator, GraphWisePredictionLayerConfig[] predictionLayerConfigs, boolean normalize, java.lang.String vertexTargetPropertyName, LabelMaps labelMaps, GraphWiseModelConfig.Backend backend, GraphWiseModelConfig.GraphConvModelVariant variant) |
SupervisedGraphWiseModelConfig(SupervisedGraphWiseModelConfig source) |
Modifier and Type | Method and Description |
---|---|
BatchGenerator |
getBatchGenerator() |
java.util.Map<?,java.lang.Integer> |
getClassMap() |
java.util.Map<?,java.lang.Float> |
getClassWeights() |
LabelMaps |
getLabelMaps() |
PropertyType |
getLabelType() |
oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction |
getLossFunction()
Deprecated.
|
LossFunction |
getLossFunctionClass() |
int |
getNumClasses() |
GraphWisePredictionLayerConfig[] |
getPredictionLayerConfigs() |
java.lang.String |
getVertexTargetPropertyName() |
boolean |
isNormalize() |
void |
setBatchGenerator(BatchGenerator batchGenerator) |
void |
setClassMap(java.util.Map<?,java.lang.Integer> classMap) |
void |
setClassWeights(java.util.Map<?,java.lang.Float> classWeights) |
void |
setLabelMaps(LabelMaps labelMaps) |
void |
setLabelType(PropertyType labelType) |
void |
setLossFunction(oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction lossFunction)
Deprecated.
|
void |
setLossFunctionClass(LossFunction lossFunction) |
void |
setNormalize(boolean normalize) |
void |
setPredictionLayerConfigs(GraphWisePredictionLayerConfig... predictionLayerConfigs) |
void |
setVertexTargetPropertyName(java.lang.String vertexTargetPropertyName) |
getBackend, getBatchSize, getConvLayerConfigs, getEdgeInputFeatureDim, getEdgeInputPropertyNames, getEmbeddingDim, getInputFeatureDim, getLearningRate, getNumEpochs, getSeed, getTargetVertexLabelSets, getTrainingLoss, getVariant, getVertexInputPropertyNames, getWeightDecay, isFitted, isShuffle, isStandardize, setBatchSize, setConvLayerConfigs, setEdgeInputFeatureDim, setEdgeInputPropertyNames, setEmbeddingDim, setFitted, setInputFeatureDim, setLearningRate, setNumEpochs, setSeed, setShuffle, setStandardize, setTargetVertexLabels, setTargetVertexLabelSets, setTrainingLoss, setVariant, setVertexInputPropertyNames, setWeightDecay
public static final BatchGenerator DEFAULT_BATCH_GENERATOR
public static final java.util.Map<?,java.lang.Float> DEFAULT_CLASS_MAP
@Deprecated public static final oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction DEFAULT_LOSS_FUNCTION
SupervisedGraphWiseModelConfig.LossFunction.SOFTMAX_CROSS_ENTROPY
public static final LossFunction DEFAULT_LOSS_FUNCTION_CLASS
public static final boolean DEFAULT_NORMALIZE
public static final GraphWisePredictionLayerConfig[] DEFAULT_PREDICTION_LAYER_CONFIGS
GraphWisePredictionLayerConfig
)public static final java.util.EnumSet<PropertyType> SUPPORTED_LABEL_TYPES
public SupervisedGraphWiseModelConfig()
public SupervisedGraphWiseModelConfig(int batchSize, int numEpochs, double learningRate, double weightDecay, int embeddingDim, java.lang.Integer seed, GraphWiseConvLayerConfig[] convLayerConfigs, boolean standardize, boolean shuffle, java.util.List<java.lang.String> vertexInputPropertyNames, java.util.List<java.lang.String> edgeInputPropertyNames, java.util.List<java.util.Set<java.lang.String>> targetVertexLabelSets, boolean fitted, double trainingLoss, int inputFeatureDim, int edgeInputFeatureDim, oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction lossFunction, LossFunction lossFunctionClass, BatchGenerator batchGenerator, GraphWisePredictionLayerConfig[] predictionLayerConfigs, boolean normalize, java.lang.String vertexTargetPropertyName, LabelMaps labelMaps, GraphWiseModelConfig.Backend backend, GraphWiseModelConfig.GraphConvModelVariant variant)
public SupervisedGraphWiseModelConfig(SupervisedGraphWiseModelConfig source)
public BatchGenerator getBatchGenerator()
public java.util.Map<?,java.lang.Integer> getClassMap()
public java.util.Map<?,java.lang.Float> getClassWeights()
public LabelMaps getLabelMaps()
public PropertyType getLabelType()
@Deprecated public oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction getLossFunction()
public LossFunction getLossFunctionClass()
public int getNumClasses()
public GraphWisePredictionLayerConfig[] getPredictionLayerConfigs()
public java.lang.String getVertexTargetPropertyName()
public boolean isNormalize()
public final void setBatchGenerator(BatchGenerator batchGenerator)
public final void setClassMap(java.util.Map<?,java.lang.Integer> classMap)
public final void setClassWeights(java.util.Map<?,java.lang.Float> classWeights)
public final void setLabelMaps(LabelMaps labelMaps)
public final void setLabelType(PropertyType labelType)
@Deprecated public final void setLossFunction(oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig.LossFunction lossFunction)
public final void setLossFunctionClass(LossFunction lossFunction)
public final void setNormalize(boolean normalize)
public final void setPredictionLayerConfigs(GraphWisePredictionLayerConfig... predictionLayerConfigs)
public final void setVertexTargetPropertyName(java.lang.String vertexTargetPropertyName)