Carte modèle IBM slate-30m-english-rtrvr
Description de modèle
Le modèle slate.30m.english.rtrvr est un modèle standard de transformateurs de phrases basé sur des bi-encodeurs. Le modèle produit une imbrication pour une entrée donnée, par exemple une requête, un passage, un document, etc. A un niveau élevé, notre modèle est entraîné pour maximiser la similarité en cosinus entre deux éléments de texte d'entrée, par exemple le texte A (texte de requête) et le texte B (texte de passage), ce qui entraîne l'incorporation des phrases q et p. Ces plongements de phrases peuvent ensuite être comparés à l'aide de la similarité cosinus.
Figure 1 : Modèle d'imbrication à deux codeurs pour la récupération
Modèle de langue de base
Le modèle de langage sous-jacent (LM) pour nos imbrication est slate.30m.english. Il a la même architecture qu'un modèle de transformateur de base small-RoBERTa (6 couches) et a ~30 millions de paramètres et une dimension d'intégration de 384. Plus précisément, “slate.30m.english” a été distillé à partir de “slate.125m.english” (anciennement WatBERT). Notre modèle final est appelé “slate.30m.english.rtrvr” -notez le suffixe à la fin indiquant que nous affinons l'architecture de modèle sous-jacente pour les tâches basées sur l'extraction.
Algorithme d'apprentissage
La plupart des modèles d'intégration qui sont soit à la pointe de la technologie, soit en tête du classement MTEB, sont généralement formés en trois étapes :
- Pré-entraînement spécifique à la tâche (basé sur l'extraction)
- Optimisation spécifique à la tâche sur les paires explorées
- Optimisation des paires surveillées
Nous adoptons une approche similaire, combinant les deux dernières étapes en une seule étape de réglage fin.
slate.30m.english.rtrvr est produit en distillant à partir du modèle “slate.125m.english.rtrvr” dans l'étape de réglage fin. La distillation des connaissances permet de transférer les connaissances d'un modèle d'enseignant très performant vers un modèle d'étudiant plus petit en formant la distribution de la probabilité de sortie de l'étudiant pour qu'elle corresponde le plus possible à celle de l'enseignant, ce qui améliore la performance de l'étudiant par rapport à la mise au point autonome.
Pré-formation spécifique à la tâche
Cette étape utilise l'infrastructure RetroMAE pour rendre notre LM sous-jacent plus orienté vers l'extraction. Nous initialisons notre module LM de base avec slate.30m.english et continuons avec le pré-entraînement RetroMAE , en utilisant les données du tableau 1. Nos hyperparamètres sont: taux d'apprentissage: 2e-5, nombre d'étapes: 435000, GPU: 8 A100 (80GB) GPU. Remarque: il s'agit de notre base LM pour les 2 étapes suivantes.
Distillation à l'aide de paires non surveillées et surveillées
Nous utilisons une structure à deux codeurs pour l'entraînement d'un modèle d'intégration, comme dans la figure 1. Nous nous initialisons avec le modèle pré-entraîné RetroMAE et employons davantage Knowledge Distillation avec des paires de texte <query, passage>
à l'aide d'un objectif de perte contrastif avec des négatifs en lot. Connaissance Distillation entraîne la distribution de la probabilité de sortie de l'élève pour qu'elle corresponde le plus étroitement possible à celle de l'enseignant. Dans le contexte des modèles d'extracteur, la distribution de sortie correspond aux scores de similarité entre les paires de texte. Plus précisément, pour chaque paire de phrases <query, passage>
, la distribution des scores de l'enseignant entre les incorporations de requête et de passage, c'est-à-dire la similarité cosinus entre les incorporations, est distillée dans l'élève.
L'enseignant utilisé pour la distillation est le modèle "slate.125m.english.rtrvr” entraîné sur les mêmes données mentionnées ci-dessous. Le professeur est formé à l'aide de la fusion de modèles entre deux modèles, chacun utilisant le préapprentissage et le finetuning RetroMAE , mais différant dans les données de finetuning. Pour plus de détails, reportez-vous à la carte modèle pour slate.125m.english.rtrvr. Le flux du transfert de connaissances est illustré de façon pictoriale à la figure 2.
Figure 2. Distillation des connaissances
Nous exploitons des paires à grande échelle provenant de divers domaines, comme indiqué dans la section Données d'apprentissage. En outre, nous incluons également des paires de haute qualité pour la tâche d'extraction sur les ensembles de données suivants : SQuAD, paires Natural Questions, Specter, Stack Exchange (Title, Body), S2ORC, SearchQA, HotpotQA et Fever. Les hyperparamètres de distillation sont: apprentissage rate:7e-4, nombre d'étapes: 400000, taille de lot effective: 2048, GPU: 2 GPU A100_80GB .
Données d'entraînement
Ensemble de données | Passages |
---|---|
Wikipédia | 36396918 |
Corpus de livres | 3401308 |
Echange de pile | 15999837 |
Ensemble de données | Paires |
---|---|
Triplets de citation SPECTER | 684100 |
Empiler les questions en double (titres) | 304525 |
AllNLI (SNLI et MultiNLI) | 277230 |
Empiler les questions en double (corps) | 250519 |
Empiler les questions en double (titres + corps) | 250460 |
Questions naturelles (NQ) | 100231 |
SQuAD2.0 | 87599 |
Paires de PAQ (question, réponse) | 64371441 |
Paires d'échange de pile (titre, réponse) | 4067139 |
Paires d'échange de pile (titre, corps) | 23978013 |
Paires d'échange de pile (titre + corps, réponse) | 187195 |
S2ORC Paires de citation (titres) | 52603982 |
S2ORC (Titre, Résumé) | 41769185 |
S2ORC_citations_abstracts | 52603982 |
WikiAnswers Dupliquer les paires de questions | 77427422 |
SearchQA | 582261 |
HotpotQA | 85000 |
Fièvre | 109810 |
Arxiv | 2358545 |
Wikipédia | 20745403 |
PubMed | 20000000 |
Utilisation
# Make sure the sentence transformers is installed.
pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('path_to_slate_model')
input_queries = [
'Who made the son My achy breaky heart ?',
'summit define']
input_passages = [
"Achy Breaky Heart is a country song written by Don Von Tress. Originally titled Don't Tell My Heart and performed by The Marcy Brothers in 1991",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."]
query_embeddings = model.encode(input_queries)
passage_embeddings = model.encode(input_passages)
print(util.cos_sim(query_embeddings, passage_embeddings))
La longueur de séquence maximale de ce modèle est de 512 jetons.
Évaluation
Lignes de base
Pour une comparaison équitable, nous comparons avec les versions de référence suivantes:
- BM25 (modèle traditionnel basé sur tf-idf).
- ELSER (un algorithme de recherche commerciale fourni par Elastic).
- all-MiniLM-l6-v2: un modèle populaire de transformateurs de phrases open source. Ce modèle partage la même architecture que slate.125m.english.rtvr, avec une plus petite dimension d'imbrication (384) et a été entraîné sur plus de données sans licences commerciales. Pour plus de détails, voir la carte modèle " Hugging Facehttps://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
- E5-base: un modèle récent de transformateur open source avec de très bonnes performances sur le test de performances BEIR. Il s'agit d'un modèle de taille de base, qui possède la même architecture que slate.125m.english.rtvr. [ Référence: Wang et.al., 2022: Texte Embeddings by ???-??? Contrastive Pre-training ]. Carte modèle Huggingface (https://huggingface.co/intfloat/e5-base).
- E5-small: modèle plus petit au sein de la famille open source E5 . La dimension d'imbrication de ce modèle correspond à celle de all-minilm-l6-v2 (384), mais elle comporte 12 couches et est donc plus grande et légèrement plus lente. [ Référence: Wang et.al., 2022: Texte Embeddings by ???-??? Contrastive Pre-training ]. Carte modèle Huggingface (https://huggingface.co/intfloat/e5-small).
- BGE-base: un modèle récent de transformateur open source avec les meilleures performances sur le test de performances BEIR pour la taille d'imbrication 768 (à partir du 01.20.2024). Carte modèle Huggingface (https://huggingface.co/BAAI/bge-base-en-v1.5).
Notre point de repère d'évaluation: BEIR (onglet d'extraction de MTEB)
Le test de référence BEIR contient 15 tâches d'extraction open source évaluées dans un cadre zéro. BEIR s'est concentré sur la diversité, y compris neuf tâches d'extraction différentes: la vérification des faits, la prédiction des citations, la récupération des questions en double, la récupération des arguments, la récupération des nouvelles, la réponse aux questions, la récupération des tweets, l'IR bio-médicale et la récupération des entités. En outre, il comprend des jeux de données provenant de divers domaines de texte, des jeux de données qui couvrent de vastes sujets (comme Wikipedia) et des sujets spécialisés (comme les publications COVID-19 ), différents types de texte (articles de presse et tweets), des jeux de données de différentes tailles (3.6k - 15M documents) et des jeux de données avec des longueurs de requête différentes (longueur moyenne de requête comprise entre 3 et 192 mots) et des longueurs de document (longueur moyenne de document comprise entre 11 et 635 mots). BEIR utilise l'indicateur Normalized Cumulative Discount Gain (en particulier, nDCG@10) pour l'évaluation.
NQ longue
Long NQ est un jeu de données IBM conçu pour évaluer le pipeline RAG complet, basé sur un sous-ensemble de l'ensemble de données NaturalQuestions . L'ensemble de développement a 300 questions à répondre avec un corpus de 178 891 passages de 2 345 documents Wikipédia. Long NQ fournit également des passages de Wikipédia en or qui sont pertinents pour chaque question. Lors de la récupération, la tâche consiste à obtenir le passage d'or pertinent du corpus pour chaque question.
Résultats
Modèle | BEIR-15 (NDCG@10) |
---|---|
BM25 | 42.02 |
ELSER | 49.01 |
all-miniLM-L6-v2 | 41.95 |
ES-petit | 46.01 |
ES-base | 48.75 |
Base BGE | 53.25 |
slate.30m.english.rtrvr | 49.37 |
slate.30m.english.rtrvr | 46.91 |
Figure 3 Comparaison des performances sur l'indice de référence BEIR (onglet Extraction MTEB)
Modèle | LONGNQ (NDCG@10) |
---|---|
all-miniLM-L6-v2 | 58.10 |
ES-petit | 66.87 |
ES-base | 63.95 |
Base BGE | 61.29 |
slate.30m.english.rtrvr | 65.01 |
slate.30m.english.rtrvr | 59.94 |
Figure 4 Comparaison des performances sur le jeu de données NQ long
Performances d'exécution
L'exécution des performances est mesurée sur une tâche de reclassement avec 466 requêtes. Pour chaque requête, nous reclassons les top-100 passages obtenus par BM25 et nous rapportons le temps moyen sur toutes les requêtes. Le nouveau classement a été effectué sur un processeur graphique A100_40GB .
Modèle | Heure / requête |
---|---|
all-miniLM-L6-v2 | 0.18 sec |
E5-small | 0.33 sec |
E5-base | 0.75 sec |
Base BGE | 0.75 sec |
slate.125m.english.rtrvr | 0.71 sec |
slate.30m.english.rtrvr | 0.20 sec |