モデルの説明
slate.125m.english.rtrvrモデルは、バイエンコーダをベースにした標準的な文変換モデルである。 このモデルは、指定された入力 (照会、パッセージ、文書など) の組み込みを生成します。 大まかには、テキスト A (照会テキスト) とテキスト B (パッセージ・テキスト) など、テキストの 2 つの入力部分の間のコサインの類似性を最大化するようにモデルがトレーニングされています。これにより、センテンス埋め込み q と p が生成されます。 これらの文の埋め込みは、コサインの類似性を使用して比較することができます。
図1: 取得のための bi-encoder 埋め込みモデル
基本言語モデル
組み込みのための基礎となる言語モデル (LM) は、 slate.125m.english (旧称 WatBERT) です。 これは、 RoBERTa 基本変圧器モデルと同じアーキテクチャーを持ち、約 1 億 2500 万個のパラメーターと 768 の埋め込みディメンションを持ちます。 最後のモデルは “slate.125m.english.rtrvr” と呼ばれます。末尾の接尾部は、取得ベースのタスク用に基礎となるモデル・アーキテクチャーを微調整することを示しています。
トレーニング・アルゴリズム
最先端の、あるいはMTEBリーダーボードの上位に位置するエンベッディング・モデルのほとんどは、通常3段階でトレーニングされる:
- タスク固有 (取得ベース) のプリトレーニング
- マイニングされたペアに対するタスク固有の微調整
- 監視対象ペアの微調整。
異なるトレーニング済みモデルの重みを平均化することで、同じアプローチに従い、最終的にモデル融合を実行します。
slate.125m.english.rtrvr は、「モデル融合」を実行することによって生成されます。これは、以下のモデルの重みを平均化します。両方とも段階的にトレーニングされますが、以下のバリエーションがあります。
- 大規模な非監視データを使用してモデル 1 を微調整
- モデル 2 は、監視対象データのより小さいサブセットを使用して微調整されます。
タスク固有の事前トレーニング
このステージでは、 RetroMAE フレームワークを使用して、基礎となる LM をより取得指向にします。 slate.125m.english を使用してベース LM を初期化し、表 1 のデータを使用して RetroMAE の事前トレーニングを続行します。 IBM のハイパーパラメーターは、学習速度: 2e-5、ステップ数: 190000、GPU: 24 A100 40GBです。 注: これは、以下の 2 つのステージの基本 LM です。
Model1: 大規模で監視されていないデータを使用した調整
このモデルは、事前にトレーニングされた RetroMAE モデルで初期化され、2 段階でトレーニングされます。
ステージ 1: 教師なしの微調整
図 1 のように、組み込みモデルをトレーニングするためにバイエンコーダー・フレームワークを使用します。 RetroMAE の事前トレーニングされた LM は、対比的損失目標を使用して、 <query, passage>
テキスト・ペアで微調整されます。 表 2 に示すように、さまざまな領域から大規模なペアをマイニングします。 このモデルは、前提と対応する仮説を突き合わせることで構成される NLI (自然言語推論) などの分類タスクを含む、さまざまなペアを使用してトレーニングされます。 ハイパーパラメーターは次のとおりです。学習速度: 2e-5; ステップ数: 140000; GPU: 8 A100_80GB、有効なバッチ・サイズ: 4096 ペア
ステージ 2: 教師あり学習の微調整
最後に、以下のデータセットにおいて、検索タスクのための高品質な教師あり訓練ペアを用いてモデルを微調整する:SQuAD, Natural Questions、Specter、Stack Exchange (Title, Body)のペア、S2ORC、SearchQA, HotpotQA、Fever。 トレーニング・ハイパーパラメーターの学習速度は、 2e-5; ステップ数: 10000; GPU: 8 A100_80GB、有効なバッチ・サイズ: 4096 ペアです。
モデル 2: よりタスク中心のサブセットを使用した調整
このステージでは、 RetroMAE の事前トレーニング・モデルは、ハード・ネガティブ・マイニングによる監督により、 Table2 の小さいサブセットを使用して、監視された微調整を受けます。 中間モデル・チェックポイントは、データ・セット固有のハード・ネガティブをマイニングするために繰り返し使用され、その後、監視された詳細化に使用されます。 このプロセスは、モデルが自身の誤りから学習できるようにすることで、モデルをより堅固にすることを目的としており、はるかに小さいデータで安定化するのに役立ちます。
Table2に挙げたデータセットのサブセット(保留されたデータセットで検証実験を行うことで発見された)を使ってモデルを微調整する:AllNLI, Squad、Stackexchange、NQ、HotpotQA, Fever、および5M Specter、S2orc 、WikiAnswers の各サブセットです。
トレーニング・ハイパーパラメーターの学習速度: 2e-5; 最大照会長: 512、最大パッセージ長: 512、エポック: 2、実効バッチ・サイズ: 384 トリプル、GPU: 24 A100_80GB。
最終モデル: slate.125m.english.rtrvr: モデル融合
大規模な教師なしデータで学習したモデルの重みと、より小さな教師ありデータのサブセットで学習したモデルの重みを平均化することで、モデル・フュージョンを実行する。
開発セット (https://huggingface.co/datasets/colbertv2/lotte) を使用し、グリッド検索を実行して、これらのモデルの最適な重みの組み合わせを取得します。 最適なパラメーターに基づいてモデルの重みを平均します。 Model1 の場合は 0.7 、 Model2の場合は 0.3 です。
データの学習
データ・セット | パッセージ |
---|---|
ウィキペディア | 36396918 |
ブック・コーパス | 3401308 |
スタック交換 | 15999837 |
データ・セット | ペア |
---|---|
SPECTER サイテーション・トリプレット | 684100 |
スタック交換の重複する質問 (タイトル) | 304525 |
AllNLI (SNLI および MultiNLI) | 277230 |
スタック交換の重複する質問 (本文) | 250519 |
スタック交換の重複する質問 (タイトル + 本文) | 250460 |
自然問題 (NQ) | 100231 |
SQuAD2.0 | 87599 |
PAQ (質問、回答) ペア | 64371441 |
スタック交換 (タイトル、回答) のペア | 4067139 |
スタック交換 (タイトル、本文) のペア | 23978013 |
スタック交換 (タイトル + 本文、回答) のペア | 187195 |
S2ORC 引用のペア (タイトル) | 52603982 |
S2ORC (タイトル、要約) | 41769185 |
S2ORC_citations_abstracts | 52603982 |
WikiAnswers 重複する質問のペア | 77427422 |
SearchQA | 582261 |
HotpotQA | 85000 |
熱 | 109810 |
アル14号 | 2358545 |
ウィキペディア | 20745403 |
PubMed | 20000000 |
使用法
# Make sure the sentence transformers is installed.
pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('path_to_slate_model')
input_queries = [
'Who made the son My achy breaky heart ?',
'summit define']
input_passages = [
"Achy Breaky Heart is a country song written by Don Von Tress. Originally titled Don't Tell My Heart and performed by The Marcy Brothers in 1991",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."]
query_embeddings = model.encode(input_queries)
passage_embeddings = model.encode(input_passages)
print(util.cos_sim(query_embeddings, passage_embeddings))
このモデルの最大シーケンス長は 512 トークンです。
評価
ベースライン
公平な比較のために、以下のベースラインと比較します。
- BM25 (tf-idf に基づく従来型のモデル)。
- ELSER (Elastic が提供する商用検索アルゴリズム)。
- all-MiniLM-l6-v2: 一般的なオープン・ソースのセンテンス・トランスフォーマー・モデル。 このモデルは、 slate.125m.english.rtvrと同じアーキテクチャーを共有しており、埋め込みディメンションが小さく、商用ライセンスなしでより多くのデータについてトレーニングされています。 詳しくは、Huggingface モデル・カード (https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) を参照してください。
- E5-base: BEIR ベンチマークで非常に優れたパフォーマンスを発揮する、最近のオープン・ソースのトランスフォーマー・モデル。 これは、 slate.125m.english.rtvrと同じアーキテクチャーを持つ基本サイズのモデルです。 [参照: Wang et.al., 2022: Text Embedding by Weakly-Supervised Contrastive Pre-training]。 ハギングフェイス・モデル・カード (https://huggingface.co/intfloat/e5-base)。
- E5-small: オープン・ソースの E5 ファミリー内の小規模なモデル。 このモデルの埋め込みディメンションは、 all-minilm-l6-v2 (384) の埋め込みディメンションと一致しますが、12 個のレイヤーがあるため、より大きく、少し遅くなります。 [参照: Wang et.al., 2022: Text Embedding by Weakly-Supervised Contrastive Pre-training]。 ハギングフェイス・モデル・カード (https://huggingface.co/intfloat/e5-small)。
- BGE-base: 768 埋め込みサイズ ( 01.20.2024現在) の BEIR ベンチマークで最高のパフォーマンスを発揮する、最近のオープン・ソース変換プログラム・モデル。 ハギングフェイス・モデル・カード (https://huggingface.co/BAAI/bge-base-en-v1.5)。
評価ベンチマーク: BEIR (MTEB の取得タブ)
BIR ベンチマークには、ファクト・チェック、サイテーション予測、重複質問の取得、引数の取得、ニュースの取得、質問の回答、ツイートの取得、bio-メディカル IR、およびエンティティーの取得という 9 つの異なる取得タスクを含む、さまざまなドメインに焦点を当てた 15 のオープン・ソース取得タスクが含まれています。 さらに、多様なテキスト・ドメインからのデータ・セット、広範なトピック (Wikipedia など) と特殊なトピック ( COVID-19 資料など) をカバーするデータ・セット、さまざまなテキスト・タイプ (ニュース記事とツイート)、さまざまなサイズのデータ・セット (3.6k - 15M 文書間の平均長さ)、およびデータ・セットも含まれます。照会の長さは 11 語間 (平均長さ)。 以下の表に、すべてのモデルのパフォーマンスを示します。 BEIR は、正規化された累積割引ゲイン (特に nDCG@10) メトリックを評価に使用します。 これは、 HuggingFace MTEB のリーダーシップ・ボードで使用される評価と同じですが、主に取得タスクに重点が置かれています。
長い NQ
Long NQ は、 NaturalQuestions データ・セットのサブセットに基づいて RAG パイプライン全体を評価するために設計された IBM データ・セットです。 この開発セットには、2,345 件のウィキペディアの文書から 178,891 節のコーパスを持つ 300 の回答可能な質問が含まれています。 Long NQ では、各質問に関連するゴールド・ウィキペディア・パッセージも提供されています。 取得時には、すべての質問に対してコーパスから関連するゴールド・パッセージを取得する必要があります。
結果
モデル | BEIR-15 (NDCG@10) |
---|---|
BM25 | 42.02 |
エルサー | 49.01 |
all-miniLM-L6-v2 | 41.95 |
ES-小文字 | 46.01 |
ES ベース | 48.75 |
BGE ベース | 53.25 |
slate.125m.english.rtrvr | 49.37 |
図2: BIR ベンチマークでのパフォーマンス比較 (「MTEB 取得 (MTEB retrieval)」タブ)
モデル | LONGNQ (NDCG@10) |
---|---|
all-miniLM-L6-v2 | 58.10 |
ES-小文字 | 66.87 |
ES ベース | 63.95 |
BGE ベース | 61.29 |
slate.125m.english.rtrvr | 65.01 |
図3: 長い NQ データ・セットでのパフォーマンスの比較
ランタイム・パフォーマンス
パフォーマンス・ランタイムは、466 個の照会を持つ再ランキング・タスクで測定されます。 照会ごとに、 BM25 によって取得された top-100 個のパッセージを再ランク付けし、すべての照会の平均時間を報告します。 再ランキングは A100_40GB GPU で実行されました。
モデル | 時刻/照会 |
---|---|
all-miniLM-L6-v2 | 0.18 秒 |
E5-small | 0.33 秒 |
E5-base | 0.75 秒 |
BGE ベース | 0.75 秒 |
slate.125m.english.rtrvr | 0.71 秒 |