17.3.18 例: Movielensデータセットの評価の予測

この項では、Movielensグラフを例として使用し、グラフサーバー(PGX)でのSupervisedEdgeWiseの使用方法について説明します。

このデータ・セットは、1682本の映画に対する943人のユーザーからの100,000の評価(1-5)で構成され、ユーザー(年齢、性別、職業)および映画(年、増悪、ジャンル)に関する単純な統計情報があります。ユーザーおよび映画は頂点ですが、映画に対するユーザーの評価は評価特徴を持つエッジです。

次の例では、SupervisedEdgeWiseモデルを使用して評価を予測します。まずモデルが構築された後、trainGraphに適合されます。

opg4j> import oracle.pgx.config.mllib.loss.LossFunctions
opg4j> var convLayer = analyst.graphWiseConvLayerConfigBuilder().
        setNumSampledNeighbors(10).
        build()
opg4j> var predictionLayer = analyst.graphWisePredictionLayerConfigBuilder().
        setHiddenDimension(16).
        build()
opg4j> var model = analyst.supervisedEdgeWiseModelBuilder().
        setVertexInputPropertyNames("movie_year", "avg_rating", "movie_genres", // Movies features
            "user_occupation_label", "user_gender", "raw_user_age"). // Users features
        setEdgeTargetPropertyName("user_rating").
        setConvLayerConfigs(convLayer).
        setPredictionLayerConfigs(predictionLayer).
        setNumEpochs(10).
        setEmbeddingDim(32).
        setLearningRate(0.003).
        setStandardize(true).
        setNormalize(true).
        setSeed(0).
        setLossFunction(LossFunctions.MSE_LOSS).
        build()
opg4j> model.fit(trainGraph)
import oracle.pgx.config.mllib.loss.LossFunctions;
GraphWiseConvLayerConfig convLayer = analyst.graphWiseConvLayerConfigBuilder()
        .setNumSampledNeighbors(10)
        .build();

GraphWisePredictionLayerConfig predictionLayer = analyst.graphWisePredictionLayerConfigBuilder()
      .setHiddenDimension(16)
      .build();

SupervisedEdgeWiseModel model = analyst.supervisedEdgeWiseModelBuilder()
        .setVertexInputPropertyNames("movie_year", "avg_rating", "movie_genres", // Movies features
            "user_occupation_label", "user_gender", "raw_user_age") // Users features
        .setEdgeTargetPropertyName("user_rating")
        .setConvLayerConfigs(convLayer)
        .setPredictionLayerConfigs(predictionLayer)
        .setNumEpochs(10)
        .setEmbeddingDim(32)
        .setLearningRate(0.003)
        .setStandardize(true)
        .setNormalize(true)
        .setSeed(0)
        .setLossFunction(LossFunctions.MSE_LOSS)
        .build();

model.fit(trainGraph);
from pypgx.api.mllib import MSELoss
conv_layer_config = dict(num_sampled_neighbors=10)

conv_layer = analyst.graphwise_conv_layer_config(**conv_layer_config)

pred_layer_config = dict(hidden_dim=16)

pred_layer = analyst.graphwise_pred_layer_config(**pred_layer_config)

params = dict(edge_target_property_name="labels",
              conv_layer_config=[conv_layer],
              pred_layer_config=[pred_layer],
              vertex_input_property_names=["movie_year", "avg_rating", "movie_genres",
                "user_occupation_label", "user_gender", "raw_user_age"],
              edge_input_property_names=["user_rating"],
              num_epochs=10,
              layer_size=32,
              learning_rate=0.003,
              normalize=true,
              loss_fn=MSELoss(),
              seed=0)

model = analyst.supervised_edgewise_builder(**params)

model.fit(train_graph)

EdgeWiseは帰納的であるため、表示されないエッジの評価を推測できます:

opg4j> var labels = model.infer(fullGraph, testEdges)
opg4j> labels.head().print()
PgxFrame labels = model.infer(fullGraph, testEdges);
labels.head().print();
labels = model.infer(full_graph, test_edges)
labels.print()

これにより、エッジの評価予測が次のように返されます:

+-----------------------------+
| edgeId | value              |
+-----------------------------+
| 68472  | 3.844510078430176  |
| 53436  | 3.5453758239746094 |
| 73364  | 3.688265085220337  |
| 12096  | 3.8873679637908936 |
| 78740  | 3.3845553398132324 |
| 27664  | 2.6601722240448    |
| 34844  | 4.108948230743408  |
| 74224  | 3.7714107036590576 |
| 33744  | 3.2331383228302    |
| 32812  | 3.8763082027435303 |
+-----------------------------+

また、モデルのパフォーマンスを評価することもできます:

opg4j> model.evaluate(fullGraph, testEdges).print()
model.evaluate(fullGraph,testEdges).print();
model.evaluate(full_graph,test_edges).print()

これにより、次の出力が返されます:

+--------------------+
| MSE                |
+--------------------+
| 0.9573243436116953 |
+--------------------+