ONNXパイプライン・モデル: 再ランク付けパイプライン

ONNXパイプライン・モデルは、テキストの特定のペアの類似度スコアを計算する再ランク付けパイプラインを提供します。

再ランク付けパイプライン

再ランク付けモデル(クロス・エンコーダまたは再ランカとも呼ばれる)は、特定のテキスト・ペアの類似度スコアを計算します。テキスト埋込みモデルのように入力テキストを固定長ベクトルにエンコードするかわりに、再ランカーは2つの入力テキストを同時にエンコードし、ペアの類似度スコアを生成します。デフォルトでは、sigmoidアクティベーション関数を適用することによって0から1までの数値に変換できるロジットが出力されます。埋込みモデルを使用して2つのテキスト間の類似度スコアを計算するには、まずテキストの埋込みを計算してからコサイン類似度を適用しますが、通常、再ランカー・モデルではより優れたパフォーマンスが得られ、埋込みモデルから上位Kの結果を再ランク付けするために使用できます。

  1. 入力: 再ランク付けパイプラインへのfirst_inputおよびsecond_inputという名前の2つの入力があります。それぞれに1つ以上のテキスト文字列の配列が含まれます。2つの入力には、同じ数のテキスト文字列がある必要があります。配列は、first_inputおよびsecond_inputのテキストのペアごとに生成されます。たとえば、次の入力ペアの2つの配列が生成されます。1つはペア'hi''hi'のもので、もう1つはペア'halloween''hello'のものです。

    {'first_input':['hi','halloween'],'second_input':['hi','hello']}

  2. 前処理:

    再ランク付けパイプラインの前処理には、トークン化が含まれます。トークナイザ・クラス: transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizerは、モデルの再ランク付けのためにOML4Py 2.1でサポートされています。

  3. 出力: 再ランク付けパイプラインの出力は、類似度スコアの配列であり、入力テキスト・ペアごとに1つずつ生成されます。前述の入力セクションの例の場合、出力の類似度スコアは、ペア'hi' および 'hi'が9.290878、ペア'halloween'および'hello'が4.3913193です。

再ランク付けパイプラインの例

  1. 再ランク付けパイプラインを生成します:
    mn = 'BAAI/bge-reranker-base'
    em = ONNXPipeline(mn, function=MiningFunction.REGRESSION)
    mf = 'bge-reranker-base'
    em.export2file(mf)
  2. 再ランカ・モデルをデータベースにインポートします:
    BEGIN 
    DBMS_VECTOR.LOAD_ONNX_MODEL('DM_DUMP','bge-reranker-base.onnx','doc_model', JSON('{"function" : "regression", 
                                "regressionOutput" : "output", "input":{"first_input": ["DATA1"],"second_input": ["DATA2"]}}'));
    END;
  3. データベースで類似度スコアを取得します:
    SELECT prediction(doc_model USING 'what is panda?' as DATA1, 'hi' as DATA2) from dual;

    前述の例では、ステップ2でインポートした再ランカーONNXパイプラインを使用して、2つのテキスト文字列間の類似度スコアが生成さます。スコアの範囲は、負の無限大から正の無限大までです。正の無限大に近いスコアは2つの文字列間の類似度が高いことを示し、負の無限大に近いスコアは文字列間の類似度が低いことを示します。sigmoid関数は、類似度スコアを[0,1]の範囲の浮動小数点値にマップできます。前述の例に対応する出力スコアの例を次に示します:

    PREDICTION(DOC_MODELUSING'WHATISPANDA?'ASDATA1,'HI'ASDATA2)
    -----------------------------------------------------------
                        -7.73E+000