Modellbeschreibung
Das Modell slate.30m.english.rtrvr ist ein Standard-Satzumwandlungsmodell, das auf Bi-Codierern basiert. Das Modell erzeugt eine Einbettung für eine bestimmte Eingabe, z. B. Abfrage, Passage, Dokument usw. Auf hoher Ebene wird unser Modell trainiert, um die Kosinus-Ähnlichkeit zwischen zwei Eingabetexten zu maximieren, z. B. Text A (Abfragetext) und Text B (Passagen), die zu den Satzeinbettungen q und p führen. Diese Satzeinbettungen können dann mithilfe der Kosinusähnlichkeit verglichen werden.
Abb. 1. Bi-Encoder-Einbettungsmodell für Abruf
Basissprachmodell
Das zugrunde liegende Sprachmodell für unsere Einbettungen ist slate.30m.english. Es hat die gleiche Architektur wie ein small-RoBERTa-Basistransformatormodell (6 Schichten) und hat ~30 Millionen Parameter und eine Einbettungsdimension von 384. “slate.30m.english” wurde aus “slate.125m.english” (früher WatBERT) destilliert. Unser finales Modell heißt “slate.30m.english.rtrvr” -Beachten Sie das Suffix am Ende, das angibt, dass die zugrunde liegende Modellarchitektur für abrufbasierte Tasks optimiert wird.
Trainingsalgorithmus
Die meisten Einbettungsmodelle, die entweder dem neuesten Stand der Technik entsprechen oder an der Spitze der MTEB-Rangliste stehen, werden in der Regel in drei Stufen trainiert:
- Taskspezifisches (abruffasiertes) Vortraining
- Taskspezifische Optimierung für Mining-Paare
- Feinabstimmung für überwachte Paare
Wir folgen einem ähnlichen Ansatz und kombinieren die letzten beiden Stufen zu einem einzigen Feinsteuerungsschritt.
slate.30m.english.rtrvr wird durch Destillation aus dem Modell “slate.125m.english.rtrvr” im Feinsteuerungsschritt erzeugt. Knowledge Destillation überträgt das Wissen aus einem leistungsstarken Lehrermodell in ein kleineres Schülermodell, indem die Ausgabewahrscheinlichkeitsverteilung des Schülers so weit wie möglich an die des Lehrers angepasst wird, wodurch die Leistung des Schülers im Vergleich zur eigenständigen Finetuning verbessert wird.
Aufgabenspezifisches Vortraining
In dieser Phase wird das RetroMAE -Framework verwendet, um den zugrunde liegenden LM stärker abruforientiert zu machen. Wir initialisieren unseren Basis-LM mit slate.30m.english und fahren mit RetroMAE vor dem Training fort, wobei die Daten in Tabelle 1 verwendet werden. Unsere Hyperparameter sind: Lernrate: 2e-5, Anzahl der Schritte: 435000, GPUs: 8 A100 (80GB) GPUs. Hinweis: Dies ist unser Basis-LM für die folgenden 2 Phasen.
Destillation unter Verwendung von nicht überwachten und überwachten Paaren
Wir verwenden ein Bi-Encoder-Framework zum Trainieren eines Einbettungsmodells, wie in Abbildung 1. Die Initialisierung erfolgt mit dem RetroMAE -Modell, das vorab trainiert wurde, und die weitere Verwendung von Knowledge Destillation mit <query, passage>
-Textpaaren unter Verwendung eines kontrastierenden Verlustziels mit Negativen in der Stapelverarbeitung. Knowledge Destillation trainiert die Verteilung der Ausgabewahrscheinlichkeit des Schülers so weit wie möglich an die des Lehrers. Im Kontext von Abruffunktionsmodellen ist die Ausgabeverteilung die Ähnlichkeitsbewertung zwischen Textpaaren. Insbesondere wird für jedes Paar von Sätzen <query, passage>
die Verteilung der Scores des Lehrers zwischen den Einbettungen von Abfrage und Passage, d. h. die Kosinusähnlichkeit zwischen den Einbettungen, in den Schüler destilliert.
Der für die Destillation verwendete Lehrer ist das Modell "slate.125m.english.rtrvr” , das mit denselben unten genannten Daten trainiert wurde. Der Lehrer wird mithilfe der Modellfusion zwischen zwei Modellen gebildet, von denen jedes RetroMAE -Vortraining und Finetuning verwendet hat, sich jedoch in den Finetuning-Daten unterscheidet. Weitere Details finden Sie auf der Modellkarte für slate.125m.english.rtrvr. Der Fluss des Wissenstransfers ist in Abbildung 2 bildhaft dargestellt.
Abbildung 2: Wissensdestillation
Wir filtern große Paare aus verschiedenen Bereichen, wie im Abschnitt "Trainingsdaten" angegeben. Darüber hinaus haben wir auch hochwertige Paare für die Retrieval-Aufgabe auf den folgenden Datensätzen: SQuAD, Natural Questions, Specter, Stack Exchange (Title, Body) Paare, S2ORC, SearchQA, HotpotQA und Fever. Destillationshyperparameter sind: Learning rate:7e-4, Anzahl Schritte: 400000, effektive Stapelgröße: 2048, GPUs: 2 A100_80GB GPUs.
Trainingsdaten
Dataset | Passagen |
---|---|
Wikipedia, | 36396918 |
Bücherkorpus | 3401308 |
Stackaustausch | 15999837 |
Dataset | Paare |
---|---|
SPECTER-Zitattriplets | 684100 |
Stack Exchange-Doppelte Fragen (Titel) | 304525 |
AllNLI (SNLI und MultiNLI) | 277230 |
Stack Exchange-Doppelte Fragen (Hauptteile) | 250519 |
Stack Exchange Doppelte Fragen (Titel + Hauptteile) | 250460 |
Natürliche Fragen (NQ) | 100231 |
SQuAD2.0 | 87599 |
PAQ-Paare (Frage, Antwort) | 64371441 |
Stack Exchange-Paare (Titel, Antwort) | 4067139 |
Stack Exchange-Paare (Titel, Hauptteil) | 23978013 |
Stack Exchange-Paare (Titel + Hauptteil, Antwort) | 187195 |
S2ORC Zitatpaare (Titel) | 52603982 |
S2ORC (Titel, Kurzdarstellung) | 41769185 |
S2ORC_citations_abstracts | 52603982 |
WikiAnswers Doppelte Fragenpaare | 77427422 |
SearchQA | 582261 |
HotpotQA | 85000 |
Fieber | 109810 |
Arxiv | 2358545 |
Wikipedia, | 20745403 |
PubMed | 20000000 |
Verwendung
# Make sure the sentence transformers is installed.
pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('path_to_slate_model')
input_queries = [
'Who made the son My achy breaky heart ?',
'summit define']
input_passages = [
"Achy Breaky Heart is a country song written by Don Von Tress. Originally titled Don't Tell My Heart and performed by The Marcy Brothers in 1991",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."]
query_embeddings = model.encode(input_queries)
passage_embeddings = model.encode(input_passages)
print(util.cos_sim(query_embeddings, passage_embeddings))
Die maximale Sequenzlänge dieses Modells beträgt 512 Token.
Evaluierung
Grundlinien
Für einen fairen Vergleich vergleichen wir mit den folgenden Baselines:
- BM25 (ein traditionelles Modell, das auf tf-idf basiert).
- ELSER (ein von Elastic bereitgestellter kommerzieller Suchalgorithmus).
- all-MiniLM-l6-v2: : Ein gängiges Open-Source-Modell für Satztransformatoren. Dieses Modell verwendet dieselbe Architektur wie slate.125m.english.rtvrmit einer kleineren Einbettungsdimension (384) und wurde ohne kommerzielle Lizenzen auf mehr Daten trainiert. Weitere Einzelheiten finden Sie auf der Modellkarte Hugging Facehttps://huggingface.co/sentence-transformers/all-MiniLM-L6-v2.
- E5-base: Ein neues Open-Source-Transformatormodell mit sehr guter Leistung für die BEIR-Benchmark. Dies ist ein Basismodell, das dieselbe Architektur wie slate.125m.english.rtvrhat. [ Referenz: Wang et.al. 2022: Texteinbettung durch schwach beaufsichtigte kontrastive Vorschulung]. Huggingface-Modellkarte (https://huggingface.co/intfloat/e5-base)
- E5-small: Ein kleineres Modell in der Open-Source- E5 -Familie. Die Einbettungsdimension dieses Modells entspricht der von all-minilm-l6-v2 (384), hat jedoch 12 Schichten und ist daher größer und etwas langsamer. [ Referenz: Wang et.al. 2022: Texteinbettung durch schwach beaufsichtigte kontrastive Vorschulung]. Huggingface-Modellkarte (https://huggingface.co/intfloat/e5-small)
- BGE-Basis: ein neues Open-Source-Transformatormodell mit der besten Leistung für die BEIR-Benchmark für die Einbettungsgröße 768 (ab 01.20.2024). Huggingface-Modellkarte (https://huggingface.co/BAAI/bge-base-en-v1.5).
Unsere Evaluation Benchmark: BEIR (MTEB's retrieval tab)
Die BEIR-Benchmark enthält 15 Open-Source-Abruftasks, die unter einer Nullpunkteinstellung ausgewertet werden. BEIR konzentrierte sich auf Diversität, darunter neun verschiedene Abruftasks: Faktenprüfung, Zitatvorhersage, Abfrage doppelter Fragen, Argumentabruf, Nachrichtenabruf, Beantwortung von Fragen, Tweetabruf, biomedizinische IR und Entitätsabruf. Darüber hinaus enthält es Datasets aus verschiedenen Textdomänen, Datasets, die breite Themen (wie Wikipedia) und spezialisierte Themen (wie COVID-19 -Veröffentlichungen) abdecken, verschiedene Texttypen (Nachrichtenartikel vs. Tweets), Datasets verschiedener Größen (3.6k - 15M -Dokumente) und Datasets mit unterschiedlichen Abfragelängen (durchschnittliche Abfragelänge zwischen 3 und 192 Wörtern) und Dokumentlängen (durchschnittliche Dokumentlänge zwischen 11 und 635 Wörtern). BEIR verwendet die Metrik "Normalisierter kumulativer Rabattgewinn" (insbesondere nDCG@10) zur Auswertung.
Lange NQ
Long NQ ist ein IBM Dataset, das zur Auswertung der vollständigen RAG-Pipeline auf der Basis einer Untergruppe des Datasets NaturalQuestions entwickelt wurde. Das dev-Set hat 300 beantwortete Fragen mit einem Korpus von 178.891 Passagen aus 2.345 Wikipedia-Dokumenten. Long NQ bietet auch goldene Wikipedia-Passagen, die für jede Frage relevant sind. Während des Abrufens besteht die Aufgabe darin, die relevante Goldpassage aus dem Korpus für jede Frage abzurufen.
Ergebnisse
Modell | BEIR-15 (NDCG@10) |
---|---|
BM25 | 42.02 |
ELSER | 49.01 |
all-miniLM-L6-v2 | 41.95 |
ES-klein | 46.01 |
ES-Basis | 48.75 |
BGE-Basis | 53.25 |
slate.30m.english.rtrvr | 49.37 |
slate.30m.english.rtrvr | 46.91 |
Abb. 3 Leistungsvergleich für BEIR-Benchmark (Registerkarte "MTEB-Abruf")
Modell | LONGNQ (NDCG@10) |
---|---|
all-miniLM-L6-v2 | 58.10 |
ES-klein | 66.87 |
ES-Basis | 63.95 |
BGE-Basis | 61.29 |
slate.30m.english.rtrvr | 65.01 |
slate.30m.english.rtrvr | 59.94 |
Abbildung 4. Leistungsvergleich für das Dataset 'Long NQ'
Laufzeitleistung
Die Leistungslaufzeit wird für eine neu rangierte Task mit 466 Abfragen gemessen. Für jede Abfrage werden die top-100 Passagen, die von BM25 abgerufen wurden, neu eingestuft und die durchschnittliche Zeit über alle Abfragen berichtet. Die erneute Einstufung wurde auf einer A100_40GB -GPU ausgeführt.
Modell | Zeit/Abfrage |
---|---|
all-miniLM-L6-v2 | 0.18 Sek |
E5-small | 0.33 Sek. |
E5-base | 0.75 Sek |
BGE-Basis | 0.75 Sek |
slate.125m.english.rtrvr | 0.71 Sek. |
slate.30m.english.rtrvr | 0.20 Sek |