Modellbeschreibung
Das Modell slate.125m.english.rtrvr ist ein Standard-Satzumwandlungsmodell, das auf Bi-Codierern basiert. Das Modell erzeugt eine Einbettung für eine bestimmte Eingabe, z. B. Abfrage, Passage, Dokument usw. Auf hoher Ebene wird unser Modell trainiert, um die Kosinus-Ähnlichkeit zwischen zwei Eingabetexten zu maximieren, z. B. Text A (Abfragetext) und Text B (Passagen), die zu den Satzeinbettungen q und p führen. Diese Satzeinbettungen können dann mithilfe der Kosinusähnlichkeit verglichen werden.
Abb. 1. Bi-Encoder-Einbettungsmodell für Abruf
Basissprachmodell
Das zugrunde liegende Sprachmodell (LM) für unsere Einbettung ist slate.125m.english (früher bekannt als WatBERT). Es hat dieselbe Architektur wie ein RoBERTa -Basistransformatormodell und hat ~ 125 Millionen Parameter und eine Einbettungsdimension von 768. Unser finales Modell heißt “slate.125m.english.rtrvr” -Beachten Sie das Suffix am Ende, das angibt, dass wir die zugrunde liegende Modellarchitektur für abrufbasierte Tasks optimieren.
Trainingsalgorithmus
Die meisten Einbettungsmodelle, die entweder dem neuesten Stand der Technik entsprechen oder an der Spitze der MTEB-Rangliste stehen, werden in der Regel in drei Stufen trainiert:
- Taskspezifisches (abruffasiertes) Vortraining
- Taskspezifische Optimierung für Mining-Paare
- Feinabstimmung für überwachte Paare.
Wir folgen demselben Ansatz und führen schließlich eine Modellfusion durch, indem wir die Gewichtungen verschiedener trainierter Modelle gemittelt haben.
slate.125m.english.rtrvr wird durch Ausführen von "Modellfusion" erzeugt. Dabei werden die Gewichtungen der folgenden Modelle gemittelt, die beide in Phasen trainiert wurden, aber die folgenden Variationen aufweisen:
- Modell 1 mit großen, nicht überwachten Daten optimiert
- Modell 2 mit einer kleineren Untergruppe überwachter Daten optimiert
Aufgabenspezifisches Vortraining
In dieser Phase wird das RetroMAE -Framework verwendet, um den zugrunde liegenden LM stärker abruforientiert zu machen. Wir initialisieren unseren Basis-LM mit slate.125m.english und fahren mit RetroMAE vor dem Training fort, wobei die Daten in Tabelle 1 verwendet werden. Unsere Hyperparameter sind: Lernrate: 2e-5, Anzahl Schritte: 190000, GPUs: 24 A100 40GB. Hinweis: Dies ist unser Basis-LM für die folgenden 2 Phasen.
Model1: Feinabstimmung mit großen nicht überwachten Daten
Dieses Modell wird mit dem vorab trainierten Modell RetroMAE initialisiert und in zwei Phasen trainiert.
Stufe 1: Nicht überwachte Optimierung
Wir verwenden ein Bi-Encoder-Framework zum Trainieren eines Einbettungsmodells, wie in Abbildung 1. Der vortrainierte LM RetroMAE wird mit <query, passage>
-Textpaaren unter Verwendung eines kontrastierenden Verlustziels optimiert. Wir filtern großflächige Paare aus verschiedenen Domänen, wie in Tabelle 2 angegeben. Das Modell wird mit verschiedenen Paaren trainiert, einschließlich Klassifikationsaufgaben wie NLI (Natural Language Inference), die aus dem Abgleich einer Prämisse mit der entsprechenden Hypothese besteht. Unsere Hyperparameter sind: Lernrate: 2e-5; Anzahl Schritte: 140000; GPUs: 8 A100_80GB, effektive Stapelgröße: 4096 Paare
Phase 2: Überwachte Optimierung
Schließlich wird das Modell mit hochwertigen überwachten Trainingspaaren für die Retrievalaufgabe auf den folgenden Datensätzen feinabgestimmt: SQuAD, Natural Questions, Specter, Stack Exchange (Title, Body) Paare, S2ORC, SearchQA, HotpotQA und Fever. Traininghyper-Parameter sind Lernrate: 2e-5; Anzahl der Schritte: 10000; GPUs: 8 A100_80GB, effektive Stapelgröße: 4096 Paare.
Modell 2: Feinabstimmung mit einem stärker aufgabenfokussierten Subset
In dieser Phase durchläuft das vortrainierte Modell RetroMAE eine überwachte Finetuning mit einer kleineren Untergruppe von Table2 , wobei die Überwachung vom harten negativen Mining stammt. Die Zwischenmodellprüfpunkte werden iterativ verwendet, um datasetspezifische harte Negative zu filtern, die dann für das überwachte Finetuning verwendet werden. Dieser Prozess zielt darauf ab, das Modell robuster zu machen, indem es aus eigenen Fehlern lernen lässt und bei der Stabilisierung mit viel kleineren Daten hilft.
Zur Feinabstimmung des Modells verwenden wir eine Teilmenge der in Table2 genannten Datensätze (die wir durch Validierungsexperimente mit einem zurückgehaltenen Datensatz gefunden haben), nämlich die folgenden: AllNLI, Squad, Stackexchange, NQ, HotpotQA, Fever und 5M Teilmenge von Specter, S2orc, WikiAnswers.
Traininghyper-Parameter sind Lernrate: 2e-5; maximale Abfragelänge: 512; maximale Durchgriffslänge: 512; Epochen: 2; effektive Stapelgröße: 384 Tripel; GPU: 24 A100_80GB.
Unser finales Modell: slate.125m.english.rtrvr:
Wir führen eine Modellfusion durch, indem wir die (oben genannten) Modellgewichte, die mit großen unbeaufsichtigten Daten trainiert wurden, und das Modell, das mit der kleineren Teilmenge der überwachten Daten trainiert wurde, mitteln.
Wir verwenden ein Einheitenset (https://huggingface.co/datasets/colbertv2/lotte) und führen eine Rastersuche durch, um die optimale Gewichtungskombination für diese Modelle zu erhalten. Die Modellgewichtungen werden basierend auf dem optimalen Parameter berechnet: 0.7 für Model1 und 0.3 für Model2.
Trainingsdaten
Dataset | Passagen |
---|---|
Wikipedia, | 36396918 |
Bücherkorpus | 3401308 |
Stackaustausch | 15999837 |
Dataset | Paare |
---|---|
SPECTER-Zitattriplets | 684100 |
Stack Exchange-Doppelte Fragen (Titel) | 304525 |
AllNLI (SNLI und MultiNLI) | 277230 |
Stack Exchange-Doppelte Fragen (Hauptteile) | 250519 |
Stack Exchange Doppelte Fragen (Titel + Hauptteile) | 250460 |
Natürliche Fragen (NQ) | 100231 |
SQuAD2.0 | 87599 |
PAQ-Paare (Frage, Antwort) | 64371441 |
Stack Exchange-Paare (Titel, Antwort) | 4067139 |
Stack Exchange-Paare (Titel, Hauptteil) | 23978013 |
Stack Exchange-Paare (Titel + Hauptteil, Antwort) | 187195 |
S2ORC Zitatpaare (Titel) | 52603982 |
S2ORC (Titel, Kurzdarstellung) | 41769185 |
S2ORC_citations_abstracts | 52603982 |
WikiAnswers Doppelte Fragenpaare | 77427422 |
SearchQA | 582261 |
HotpotQA | 85000 |
Fieber | 109810 |
Arxiv | 2358545 |
Wikipedia, | 20745403 |
PubMed | 20000000 |
Verwendung
# Make sure the sentence transformers is installed.
pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('path_to_slate_model')
input_queries = [
'Who made the son My achy breaky heart ?',
'summit define']
input_passages = [
"Achy Breaky Heart is a country song written by Don Von Tress. Originally titled Don't Tell My Heart and performed by The Marcy Brothers in 1991",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."]
query_embeddings = model.encode(input_queries)
passage_embeddings = model.encode(input_passages)
print(util.cos_sim(query_embeddings, passage_embeddings))
Die maximale Sequenzlänge dieses Modells beträgt 512 Token.
Evaluierung
Grundlinien
Für einen fairen Vergleich vergleichen wir mit den folgenden Baselines:
- BM25 (ein traditionelles Modell, das auf tf-idf basiert).
- ELSER (ein von Elastic bereitgestellter kommerzieller Suchalgorithmus).
- all-MiniLM-l6-v2: : Ein gängiges Open-Source-Modell für Satztransformatoren. Dieses Modell hat dieselbe Architektur wie slate.125m.english.rtvrmit einer kleineren Einbettungsdimension und wurde ohne kommerzielle Lizenzen auf mehr Daten trainiert. Huggingface-Modellkarte (https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) für weitere Details
- E5-base: Ein neues Open-Source-Transformatormodell mit sehr guter Leistung für die BEIR-Benchmark. Dies ist ein Basismodell, das dieselbe Architektur wie slate.125m.english.rtvrhat. [ Referenz: Wang et.al. 2022: Texteinbettung durch schwach beaufsichtigte kontrastive Vorschulung]. Huggingface-Modellkarte (https://huggingface.co/intfloat/e5-base)
- E5-small: Ein kleineres Modell in der Open-Source- E5 -Familie. Die Einbettungsdimension dieses Modells entspricht der von all-minilm-l6-v2 (384), hat jedoch 12 Schichten und ist daher größer und etwas langsamer. [ Referenz: Wang et.al. 2022: Texteinbettung durch schwach beaufsichtigte kontrastive Vorschulung]. Huggingface-Modellkarte (https://huggingface.co/intfloat/e5-small)
- BGE-Basis: ein neues Open-Source-Transformatormodell mit der besten Leistung für die BEIR-Benchmark für die Einbettungsgröße 768 (ab 01.20.2024). Huggingface-Modellkarte (https://huggingface.co/BAAI/bge-base-en-v1.5).
Unsere Evaluation Benchmark: BEIR (MTEB's retrieval tab)
Die BEIR-Benchmark enthält 15 Open-Source-Abruftasks, die sich auf verschiedene Domänen konzentrieren, darunter neun verschiedene Abruftasks: Faktprüfung, Zitatvorhersage, Abfrage doppelter Fragen, Argumentabruf, Nachrichtenabruf, Beantwortung von Fragen, Tweetabruf, biomedizinische IR und Entitätsabruf. Darüber hinaus enthält es Datasets aus verschiedenen Textdomänen, Datasets, die breite Themen (wie Wikipedia) und spezialisierte Themen (wie COVID-19 -Veröffentlichungen) abdecken, verschiedene Texttypen (Nachrichtenartikel vs. Tweets), Datasets verschiedener Größen (3.6k - 15M -Dokumente) und Datasets mit unterschiedlichen Abfragelängen (durchschnittliche Abfragelänge zwischen 3 und 192 Wörtern) und Dokumentlängen (durchschnittliche Dokumentlänge zwischen 11 und 635 Wörtern). Die Leistung aller Modelle ist in der folgenden Tabelle angegeben. BEIR verwendet die Metrik "Normalisierter kumulativer Rabattgewinn" (insbesondere nDCG@10) zur Auswertung. Dies ist dieselbe Bewertung, die in der MTEB-Bestenliste HuggingFace verwendet wird, aber hauptsächlich auf Abruftasks konzentriert ist.
Lange NQ
Long NQ ist ein IBM Dataset, das zur Auswertung der vollständigen RAG-Pipeline auf der Basis einer Untergruppe des Datasets NaturalQuestions entwickelt wurde. Das dev-Set hat 300 beantwortete Fragen mit einem Korpus von 178.891 Passagen aus 2.345 Wikipedia-Dokumenten. Long NQ bietet auch goldene Wikipedia-Passagen, die für jede Frage relevant sind. Während des Abrufens besteht die Aufgabe darin, die relevante Goldpassage aus dem Korpus für jede Frage abzurufen.
Ergebnisse
Modell | BEIR-15 (NDCG@10) |
---|---|
BM25 | 42.02 |
ELSER | 49.01 |
all-miniLM-L6-v2 | 41.95 |
ES-klein | 46.01 |
ES-Basis | 48.75 |
BGE-Basis | 53.25 |
slate.125m.english.rtrvr | 49.37 |
Abbildung 2: Leistungsvergleich für BEIR-Benchmark (Registerkarte "MTEB-Abruf")
Modell | LONGNQ (NDCG@10) |
---|---|
all-miniLM-L6-v2 | 58.10 |
ES-klein | 66.87 |
ES-Basis | 63.95 |
BGE-Basis | 61.29 |
slate.125m.english.rtrvr | 65.01 |
Abb. 3 Leistungsvergleich für das Dataset 'Long NQ'
Laufzeitleistung
Die Leistungslaufzeit wird für eine neu rangierte Task mit 466 Abfragen gemessen. Für jede Abfrage werden die top-100 Passagen, die von BM25 abgerufen wurden, neu eingestuft und die durchschnittliche Zeit über alle Abfragen berichtet. Die erneute Einstufung wurde auf einer A100_40GB -GPU ausgeführt.
Modell | Zeit/Abfrage |
---|---|
all-miniLM-L6-v2 | 0.18 Sek |
E5-small | 0.33 Sek. |
E5-base | 0.75 Sek |
BGE-Basis | 0.75 Sek |
slate.125m.english.rtrvr | 0.71 Sek. |