0 / 0
Volver a la versión inglesa de la documentación

Tarjeta de modelo IBM slate-125m-english-rtrvr

Última actualización: 28 nov 2024
Tarjeta de modelo IBM slate-125m-english-rtrvr

Modelo, descripción

El modelo slate.125m.english.rtrvr es un modelo estándar de transformadores de frases basado en bicodificadores. El modelo produce una incrustación para una entrada dada, por ejemplo, consulta, pasaje, documento, etc. En un nivel alto, nuestro modelo está entrenado para maximizar la similitud del coseno entre dos piezas de texto de entrada, por ejemplo, el texto A (texto de consulta) y el texto B (texto de pasaje), lo que da como resultado la inserción de la frase q y p. Estas incrustaciones de frases se pueden comparar utilizando la similitud de coseno.

Diagrama que compara el texto de consulta codificado de Slate con el texto de pasaje codificado

Figura 1. Modelo de incorporaciones bicodificador para recuperación

Modelo de lenguaje base

El modelo de lenguaje subyacente (LM) para nuestras incorporaciones es slate.125m.english (anteriormente, conocido como WatBERT). Tiene la misma arquitectura que un modelo de transformador base RoBERTa y tiene ~ 125 millones de parámetros y una dimensión de inclusión de 768. Nuestro modelo final se denomina “slate.125m.english.rtrvr” -observe el sufijo al final que indica que ajustamos la arquitectura del modelo subyacente para las tareas basadas en la recuperación.

Algoritmo de entrenamiento

La mayoría de los modelos de incrustación más avanzados o que ocupan los primeros puestos de la clasificación MTEB suelen entrenarse en tres fases:

  1. Formación previa específica de tarea (basada en recuperación)
  2. Ajuste específico de tarea en pares minados
  3. Ajuste preciso en pares supervisados.

Seguimos el mismo enfoque y finalmente realizamos una fusión de modelos promediando los pesos de diferentes modelos entrenados.

slate.125m.english.rtrvr se produce realizando "fusión de modelos"-promediando las ponderaciones de los modelos siguientes, ambos entrenados en etapas pero con las variaciones siguientes:

  • Modelo 1 ajustado con datos no supervisados a gran escala
  • Modelo 2 ajustado con un subconjunto más pequeño de datos supervisados

Formación previa específica de la tarea

Esta etapa utiliza la infraestructura RetroMAE , para que el LM subyacente esté más orientado a la recuperación. Inicializamos nuestro LM base con slate.125m.english y continuamos con el entrenamiento previo RetroMAE , utilizando los datos de la Tabla 1. Nuestros hiperparámetros son: tasa de aprendizaje: 2e-5, número de pasos: 190000, GPU: 24 A100 40GB. Nota: esta es nuestra LM base para las siguientes 2 etapas.

Model1: Ajuste fino con datos no supervisados a gran escala

Este modelo se inicializa con el modelo preentrenado RetroMAE y se entrena en 2 etapas.

Etapa 1: Ajuste preciso no supervisado

Utilizamos un marco bi-codificador para entrenar un modelo de incrustación, como en la Figura 1. El LM entrenado previamente RetroMAE se ajusta con pares de texto <query, passage> utilizando un objetivo de pérdida contrastiva. Se extraen pares a gran escala de varios dominios, como se indica en la Tabla 2. El modelo se entrena con diversos pares, incluyendo tareas de clasificación como NLI (Natural Language Inference) que consiste en emparejar una premisa con la hipótesis correspondiente. Nuestros hiperparámetros son: tasa de aprendizaje: 2e-5; número de pasos: 140000; GPU: 8 A100_80GB, tamaño de lote efectivo: 4096 pares

Fase 2: Ajuste detallado supervisado

Por último, el modelo se pone a punto con pares de entrenamiento supervisado de alta calidad para la tarea de recuperación en los siguientes conjuntos de datos: SQuAD, Natural Questions, Specter, Stack Exchange (Title, Body) pairs, S2ORC, SearchQA, HotpotQA y Fever. Los hiperparámetros de entrenamiento son la tasa de aprendizaje: 2e-5; número de pasos: 10000; GPU: 8 A100_80GB, tamaño de lote efectivo: 4096 pares.

Modelo 2: Ajuste fino con un subconjunto más centrado en las tareas

En esta etapa, el modelo entrenado previamente RetroMAE se somete a una finetación supervisada con un subconjunto más pequeño de Table2 con supervisión procedente de la minería negativa dura. Los puntos de comprobación de modelo intermedios se utilizan de forma iterativa para minar los negativos fijos específicos del conjunto de datos, que se utilizan después para la obtención de información supervisada. Este proceso tiene como objetivo hacer que el modelo sea más robusto dejándolo aprender de sus propios errores y ayuda en la estabilización con datos mucho más pequeños.

Afinamos el modelo utilizando un subconjunto de conjuntos de datos (encontrados mediante la realización de experimentos de validación en un conjunto de datos retenido) mencionados en Table2 que son los siguientes: AllNLI, Squad, Stackexchange, NQ, HotpotQA, Fever y 5M subconjunto de cada uno de Specter, S2orc, WikiAnswers.

Los hiperparámetros de entrenamiento son la tasa de aprendizaje: 2e-5; longitud máxima de consulta: 512; longitud máxima de paso: 512; epochs: 2; tamaño de lote efectivo: 384 triples; GPU: 24 A100_80GB.

Nuestro modelo final: slate.125m.english.rtrvr: fusión de modelos

Realizamos la fusión de modelos promediando las ponderaciones del modelo (de lo anterior) entrenado con datos no supervisados a gran escala y el modelo entrenado con el subconjunto más pequeño de datos supervisados.

Utilizamos un conjunto de desarrollo (https://huggingface.co/datasets/colbertv2/lotte) y realizamos una búsqueda de cuadrícula para obtener la combinación de ponderación óptima para estos modelos. El promedio de las ponderaciones del modelo se basa en el parámetro óptimo: 0.7 para Model1 y 0.3 para Model2.

Datos de entrenamiento

Tabla 1. Datos previos al entrenamiento
Conjunto de datos Pasajes
Wikipedia 36396918
Corpus de libros 3401308
Intercambio de pila 15999837
Tabla 2. Datos de ajuste preciso
Conjunto de datos Pares
Tripletes de citación SPECTER 684100
Preguntas duplicadas de Stack Exchange (títulos) 304525
AllNLI (SNLI y MultiNLI) 277230
Preguntas duplicadas de Stack Exchange (cuerpos) 250519
Preguntas duplicadas de intercambio de pila (títulos + cuerpos) 250460
Preguntas naturales (NQ) 100231
SQuAD2.0 87599
Pares PAQ (Pregunta, Respuesta) 64371441
Pares de intercambio de pila (Título, Respuesta) 4067139
Pares de intercambio de pila (título, cuerpo) 23978013
Pares de intercambio de pila (Título + Cuerpo, Respuesta) 187195
S2ORC Pares de agitación (Títulos) 52603982
S2ORC (Título, Resumen) 41769185
S2ORC_citations_abstracts 52603982
WikiAnswers Duplicar pares de preguntas 77427422
SearchQA 582261
HotpotQA 85000
Fiebre 109810
Arxiv 2358545
Wikipedia 20745403
PubMed 20000000

Uso

# Make sure the sentence transformers is installed.
pip install -U sentence-transformers 

from sentence_transformers import SentenceTransformer, util 
model = SentenceTransformer('path_to_slate_model') 

input_queries = [ 
  'Who made the son My achy breaky heart ?', 
  'summit define'] 
input_passages = [ 
  "Achy Breaky Heart is a country song written by Don Von Tress. Originally titled Don't Tell My Heart and performed by The Marcy Brothers in 1991", 
  "Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."] 

query_embeddings = model.encode(input_queries) 
passage_embeddings = model.encode(input_passages) 
print(util.cos_sim(query_embeddings, passage_embeddings))

La longitud máxima de secuencia de este modelo es 512 señales.

Evaluación

Líneas base

Para una comparación justa, comparamos con las líneas base siguientes:

  1. BM25 (un modelo tradicional basado en tf-idf).
  2. ELSER (un algoritmo de búsqueda comercial proporcionado por Elastic).
  3. all-MiniLM-l6-v2: un popular modelo de transformadores de frases de código abierto. Este modelo comparte la misma arquitectura que slate.125m.english.rtvr, con una dimensión de integración más pequeña y se ha entrenado en más datos sin licencias compatibles con el comercio. Tarjeta de modelo Huggingface (https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) para obtener más detalles.
  4. E5-base: un modelo de transformador de código abierto reciente con un rendimiento muy bueno en la prueba de referencia BEIR. Se trata de un modelo de tamaño base, que tiene la misma arquitectura que slate.125m.english.rtvr. [Referencia: Wang et.al., 2022: Text Embeddings de Débil-Supervisado Contrastive Pre-training]. Tarjeta modelo Huggingface (https://huggingface.co/intfloat/e5-base).
  5. E5-small: un modelo más pequeño dentro de la familia E5 de código abierto. La dimensión de inclusión de este modelo coincide con la de all-minilm-l6-v2 (384), sin embargo tiene 12 capas y, por lo tanto, es más grande y ligeramente más lenta. [Referencia: Wang et.al., 2022: Text Embeddings de Débil-Supervisado Contrastive Pre-training]. Tarjeta modelo Huggingface (https://huggingface.co/intfloat/e5-small).
  6. BGE-base: un modelo de transformador de código abierto reciente con el mejor rendimiento en la referencia BEIR para el tamaño de incorporación 768 (a partir de 01.20.2024). Tarjeta modelo Huggingface (https://huggingface.co/BAAI/bge-base-en-v1.5).

Nuestro índice de referencia de evaluación: BEIR (pestaña de recuperación de MTEB)

La referencia BEIR contiene 15 tareas de recuperación de código abierto enfocadas en diferentes dominios incluyendo nueve tareas de recuperación diferentes: Comprobación de hechos, predicción de citas, recuperación de preguntas duplicadas, recuperación de argumentos, recuperación de noticias, respuesta a preguntas, recuperación de tuits, IR bio-médica y recuperación de entidades. Además, incluye conjuntos de datos de diversos dominios de texto, conjuntos de datos que cubren temas amplios (como Wikipedia) y temas especializados (como publicaciones COVID-19 ), diferentes tipos de texto (artículos de noticias frente a Tweets), conjuntos de datos de varios tamaños (documentos3.6k - 15M ) y conjuntos de datos con diferentes longitudes de consulta (longitud media de consulta entre 3 y 192 palabras) y longitudes de documento (longitud media de documento entre 11 y 635 palabras). El rendimiento de todos los modelos se indica en la tabla siguiente. BEIR utiliza la métrica de ganancia de descuento acumulativo normalizado (específicamente, nDCG@10) para la evaluación. Esta es la misma evaluación que se utiliza en el marcador MTEB HuggingFace , pero principalmente se centra en las tareas de recuperación.

NQ largo

Long NQ es un conjunto de datos de IBM diseñado para evaluar la interconexión RAG completa, basándose en un subconjunto del conjunto de datos NaturalQuestions . El conjunto de desarrolladores tiene 300 preguntas respondibles con un corpus de 178.891 pasajes de 2.345 documentos de Wikipedia. Long NQ también proporciona pasajes de Wikipedia en oro que son relevantes para cada pregunta. Durante la recuperación, la tarea es obtener el pasaje de oro relevante del corpus para cada pregunta.

Resultados

Comparación de rendimiento en la prueba de referencia BEIR (pestaña Recuperación de MTEB)
Modelo BEIR-15 (NDCG@10)
BM25 42.02
ELSER 49.01
all-miniLM-L6-v2 41.95
ES-pequeño 46.01
ES-base 48.75
BGE-base 53.25
slate.125m.english.rtrvr 49.37

Gráfico que muestra los resultados de Slate y otros modelos

Figura 2. Comparación de rendimiento en la prueba de referencia BEIR (pestaña Recuperación de MTEB)

Comparación de rendimiento en el conjunto de datos Long NQ
Modelo LONGNQ (NDCG@10)
all-miniLM-L6-v2 58.10
ES-pequeño 66.87
ES-base 63.95
BGE-base 61.29
slate.125m.english.rtrvr 65.01

Gráfico que muestra los resultados de Slate y otros modelos

Figura 3. Comparación de rendimiento en el conjunto de datos Long NQ

Rendimiento de tiempo de ejecución

El tiempo de ejecución de rendimiento se mide en una tarea de reclasificación con 466 consultas. Para cada consulta se vuelven a clasificar los pasajes top-100 obtenidos por BM25 y se informa del tiempo promedio de todas las consultas. La reclasificación se ha realizado en una GPU A100_40GB .

Tabla 3. Rendimiento en tiempo de ejecución al volver a clasificar
Modelo Hora/consulta
all-miniLM-L6-v2 0.18 seg
E5-small 0.33 seg
E5-base 0.75 seg
BGE-base 0.75 seg
slate.125m.english.rtrvr 0.71 seg