Ajuste fino de Llama 2 con DPO

'Fine-tuning Llama 2 with DPO'.

Introducción

El Aprendizaje por Reforzamiento a partir de Retroalimentación Humana (RLHF) se ha convertido en el último paso de entrenamiento por defecto de LLMs como GPT-4 o Claude para asegurarse de que las salidas del modelo de lenguaje estén alineadas con las expectativas humanas, como la capacidad de mantener una conversación o características de seguridad. Sin embargo, esto introduce parte de la complejidad del RL en el PLN: necesitamos construir una buena función de recompensa, entrenar el modelo para estimar el valor de un estado y, al mismo tiempo, tener cuidado de no alejarnos demasiado del modelo original y producir texto incoherente en lugar de texto sensato. Este proceso es bastante complejo y requiere varios componentes móviles donde no siempre es fácil hacer las cosas correctamente.

El reciente artículo “Optimización Directa de Preferencias” de Rafailov, Sharma, Mitchell et al. propone convertir el objetivo basado en RL utilizado por los métodos existentes en un objetivo que se puede optimizar directamente mediante una simple pérdida de entropía cruzada binaria, lo que simplifica en gran medida este proceso de mejora de los LLMs.

Esta entrada de blog presenta el método de Optimización Directa de Preferencias (DPO) que está ahora disponible en la biblioteca TRL y muestra cómo se puede ajustar finamente el modelo Llama v2 de 7 mil millones de parámetros en el conjunto de datos de preferencias de stack-exchange, que contiene respuestas clasificadas a preguntas en los diversos portales de stack-exchange.

DPO vs PPO

En el modelo tradicional de optimización de preferencias derivadas de los humanos a través del RL, el método utilizado ha sido utilizar un modelo de recompensa auxiliar y ajustar el modelo de interés para que maximice esta recompensa dada a través de la maquinaria del RL. Intuitivamente, usamos el modelo de recompensa para proporcionar retroalimentación al modelo que estamos optimizando para que genere muestras de alta recompensa con más frecuencia y muestras de baja recompensa con menos frecuencia. Al mismo tiempo, usamos un modelo de referencia congelado para asegurarnos de que lo que se genera no se desvíe demasiado y siga manteniendo la diversidad en la generación. Esto se hace típicamente añadiendo una penalización KL al objetivo de maximización de recompensa completa a través de un modelo de referencia, lo que evita que el modelo aprenda a hacer trampa o explotar el modelo de recompensa.

La formulación DPO omite el paso de modelado de recompensa y optimiza directamente el modelo de lenguaje en datos de preferencia mediante una idea clave: una correspondencia analítica desde la función de recompensa hasta la política RL óptima que permite a los autores transformar la pérdida RL sobre los modelos de recompensa y referencia en una pérdida solo sobre el modelo de referencia. Esta correspondencia mide intuitivamente qué tan bien una función de recompensa dada se alinea con los datos de preferencia dados. ¡DPO comienza así con la solución óptima para la pérdida RLHF y, mediante un cambio de variables, deriva una pérdida solo sobre el modelo de referencia!

Por lo tanto, este objetivo de verosimilitud directa se puede optimizar sin necesidad de un modelo de recompensa o de realizar la optimización basada en RL potencialmente complicada.

Cómo entrenar con TRL

Como se mencionó, típicamente el proceso RLHF consiste en estas partes distintas:

  1. un paso de ajuste de supervisión (SFT)
  2. el proceso de anotar datos con etiquetas de preferencia
  3. entrenar un modelo de recompensa en los datos de preferencia
  4. y el paso de optimización RL

La biblioteca TRL incluye ayudantes para todas estas partes, sin embargo, el entrenamiento DPO elimina la tarea de modelado de recompensa y RL (pasos 3 y 4) y optimiza directamente el objeto DPO en datos anotados con preferencias.

En este sentido, todavía necesitaríamos realizar el paso 1, pero en lugar de los pasos 3 y 4, necesitamos proporcionar el DPOTrainer en TRL con datos de preferencia del paso 2 que tienen un formato muy específico, es decir, un diccionario con las siguientes tres claves:

  • prompt esto consiste en la indicación de contexto que se le da a un modelo al momento de la inferencia para la generación de texto
  • chosen contiene la respuesta generada preferida para la indicación correspondiente
  • rejected contiene la respuesta que no es preferida o no debe ser la respuesta muestreada con respecto a la indicación dada

Como ejemplo, para el conjunto de datos de pares de preferencias de stack-exchange, podemos asignar las entradas del conjunto de datos para que devuelvan el diccionario deseado a través del siguiente ayudante y eliminar todas las columnas originales:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Pregunta: " + question + "\n\nRespuesta: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"],   # mejor clasificada que k
        "rejected": samples["response_k"], # peor clasificada que j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

Una vez que tenemos el conjunto de datos ordenado, la pérdida de DPO es esencialmente una pérdida supervisada que obtiene una recompensa implícita a través de un modelo de referencia y, por lo tanto, a un nivel alto, el DPOTrainer requiere el modelo base que deseamos optimizar, así como un modelo de referencia:

dpo_trainer = DPOTrainer(
    modelo,                 # modelo base del pipeline SFT
    modelo_ref,             # típicamente una copia del modelo base entrenado SFT
    beta=0.1,              # hiperparámetro de temperatura de DPO
    train_dataset=conjunto_de_datos, # conjunto de datos preparado anteriormente
    tokenizer=tokenizer,   # tokenizador
    args=args_entrenamiento,    # argumentos de entrenamiento como tamaño de lote, lr, etc.
)

donde el hiperparámetro beta es el parámetro de temperatura para la pérdida de DPO, típicamente en el rango de 0.1 a 0.5. Esto controla cuánta atención prestamos al modelo de referencia en el sentido de que a medida que beta se vuelve más pequeño, más ignoramos el modelo de referencia. Una vez que tenemos nuestro entrenador inicializado, podemos entrenarlo en el conjunto de datos con los training_args dados simplemente llamando a:

dpo_trainer.train()

Experimentar con Llama v2

El beneficio de implementar el entrenador de DPO en TRL es que uno puede aprovechar todas las características adicionales de entrenar grandes LLMs que vienen con TRL y sus bibliotecas dependientes como Peft y Accelerate. Con estas bibliotecas, incluso podemos entrenar un modelo Llama v2 utilizando la técnica QLoRA proporcionada por la biblioteca bitsandbytes.

Ajuste fino supervisado

El proceso presentado anteriormente implica el paso de ajuste fino supervisado utilizando QLoRA en el modelo Llama v2 de 7B en la división SFT de los datos a través de SFTTrainer de TRL:

# cargar el modelo base en cuantización de 4 bits
config_bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

modelo_base = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,        # "meta-llama/Llama-2-7b-hf"
    quantization_config=config_bnb,
    device_map={"": 0},
    trust_remote_code=True,
    use_auth_token=True,
)
modelo_base.config.use_cache = False

# agregar capas LoRA encima del modelo base cuantizado
config_peft = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
...
entrenador = SFTTrainer(
    model=modelo_base,
    train_dataset=conjunto_de_datos_entrenamiento,
    eval_dataset=conjunto_de_datos_evaluacion,
    peft_config=config_peft,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=args_entrenamiento,         # argumentos de entrenamiento de HF Trainer
)
entrenador.train()

Entrenamiento de DPO

Una vez que el SFT ha terminado, podemos guardar el modelo resultante y pasar al entrenamiento de DPO. Como se hace típicamente, utilizaremos el modelo guardado del paso SFT anterior tanto como el modelo base como el modelo de referencia de DPO. Luego podemos usar estos para entrenar el modelo con el objetivo de DPO en los datos de preferencia de stack-exchange mostrados anteriormente. Dado que los modelos fueron entrenados mediante adaptadores LoRa, cargamos los modelos a través de las funciones auxiliares AutoPeftModelForCausalLM de Peft:

modelo = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # ubicación del modelo SFT guardado
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
modelo_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path,  # mismo modelo que el principal
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    modelo,
    modelo_ref,
    args=args_entrenamiento,
    beta=script_args.beta,
    train_dataset=conjunto_de_datos_entrenamiento,
    eval_dataset=conjunto_de_datos_evaluacion,
    tokenizer=tokenizer,
    peft_config=config_peft,
)
dpo_trainer.train()
dpo_trainer.save_model()

Como se puede observar, cargamos el modelo en la configuración de 4 bits y luego lo entrenamos utilizando el método QLora a través de los argumentos peft_config. El entrenador también evaluará el progreso durante el entrenamiento con respecto al conjunto de datos de evaluación y reportará una serie de métricas clave como la recompensa implícita, que se puede registrar y mostrar a través de WandB, por ejemplo. Luego podemos enviar el modelo entrenado final al HuggingFace Hub.

Conclusión

El código fuente completo de los scripts de entrenamiento para SFT y DPO están disponibles en el siguiente directorio examples/stack_llama_2 y el modelo entrenado con los adaptadores fusionados se puede encontrar en HF Hub aquí.

Los registros de WandB para la ejecución de entrenamiento de DPO se pueden encontrar aquí, donde durante el entrenamiento y la evaluación, el DPOTrainer registra las siguientes métricas de recompensa:

  • rewards/chosen: la diferencia promedio entre las probabilidades logarítmicas del modelo de política y el modelo de referencia para las respuestas seleccionadas, escaladas por beta
  • rewards/rejected: la diferencia promedio entre las probabilidades logarítmicas del modelo de política y el modelo de referencia para las respuestas rechazadas, escaladas por beta
  • rewards/accuracies: media de la frecuencia con la que las recompensas seleccionadas son mayores que las recompensas rechazadas correspondientes
  • rewards/margins: la diferencia promedio entre las recompensas seleccionadas y las recompensas rechazadas correspondientes

De manera intuitiva, durante el entrenamiento queremos que los márgenes aumenten y las precisiones se acerquen a 1.0, o en otras palabras, que la recompensa seleccionada sea mayor que la recompensa rechazada (o el margen sea mayor que cero). Estas métricas luego se pueden calcular sobre algún conjunto de datos de evaluación.

Esperamos que con la publicación del código se reduzca la barrera de entrada para que los lectores prueben este método de alineación de modelos de lenguaje grandes en sus propios conjuntos de datos y ¡no podemos esperar a ver qué construyen! Y si quieres probar el modelo tú mismo, puedes hacerlo aquí: trl-lib/stack-llama.

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Inteligencia Artificial

Conoce a MPT-7B un nuevo modelo de lenguaje de código abierto entrenado en 1T tokens de texto y código seleccionados por MosaicML.

MosaicML ha lanzado recientemente una herramienta revolucionaria, MPT-7B, para transformar la forma en que las empres...

Inteligencia Artificial

OpenAI se hace cargo de la Iluminación Global; Celebra su primera adquisición empresarial

En un movimiento que provoca repercusiones en el mundo tecnológico, OpenAI, la fuerza pionera en inteligencia artific...

Inteligencia Artificial

Nuevo estudio revela vulnerabilidades ocultas en la inteligencia artificial

En el panorama de rápido desarrollo de la IA, la promesa de cambios transformadores abarca un sinfín de campos, desde...