17.3.11 Supervised EdgeWiseモデルのエッジ・ラベルの推測
任意のグラフ(トレーニング中に表示されなかったエッジまたはグラフを含む)のエッジ・ラベルを推測できます。
opg4j> var labels = model.infer(fullGraph, testEdges)
opg4j> labels.head().print()
PgxFrame labels = model.infer(fullGraph, testEdges);
labels.head().print();
labels = model.infer(full_graph,test_edges)
labels.print()
lossがSigmoidCrossEntropy
またはDevNetLoss
の場合、追加パラメータとして追加することで、ロジットに適用される決定しきい値を設定することもできます(デフォルトは0)。
opg4j> var labels = model.infer(fullGraph, testEdges, 6f)
opg4j> labels.head().print()
PgxFrame labels = model.infer(fullGraph,testEdges,6f);
labels.head().print();
labels = model.infer(full_graph, full_graph.get_edges(), 6)
labels.print()
出力は、次の出力例のようになります。
+-----------------------------+
| edgeId | value |
+-----------------------------+
| 68472 | 2.2346956729888916 |
| 53436 | 2.1515913009643555 |
| 73364 | 1.9499346017837524 |
| 12096 | 2.1704165935516357 |
| 78740 | 2.1174447536468506 |
| 27664 | 2.1041007041931152 |
| 34844 | 2.148571491241455 |
| 74224 | 2.089123010635376 |
| 33744 | 2.0866644382476807 |
| 32812 | 2.0604987144470215 |
+-----------------------------+
同様に、タスクが分類タスクの場合、予測ロジットを推測することで、各クラスのモデルの信頼度を取得できます:
opg4j> var logits = model.inferLogits(fullGraph, testEdges)
opg4j> logits.head().print()
PgxFrame logits = model.inferLogits(fullGraph,testEdges);
logits.head().print();
logits = model.infer_logits(full_graph, test_edges)
logits.print()
モデルが分類モデルの場合、inferLabels
メソッドも使用可能で、infer
メソッドと同等です。