Historia de optimización Inferencia de Bloom

Optimización de la Inferencia de Bloom

Este artículo te brinda información sobre cómo creamos un servidor de inferencia eficiente que alimenta a bloom, un servidor de inferencia que alimenta https://huggingface.co/bigscience/bloom.

Logramos reducir la latencia en un factor de 5 durante varias semanas (y aumentar el rendimiento en un factor de 50). Queríamos compartir todas las dificultades y logros épicos que atravesamos para lograr estas mejoras de velocidad.

Participaron muchas personas diferentes en diversas etapas, por lo que no se cubrirá todo aquí. Además, ten en cuenta que parte del contenido puede estar desactualizado o ser incorrecto, ya que todavía estamos aprendiendo cómo optimizar modelos extremadamente grandes y cómo aprovechar las nuevas características y contenidos de hardware que se lanzan regularmente.

Si tu optimización favorita no se discute o se representa incorrectamente, lo lamentamos. Por favor, compártela con nosotros, estaremos encantados de probar cosas nuevas y corregir nuestros errores.

Esto va sin decir, pero sin el acceso al modelo grande en primer lugar, no habría razones reales para optimizar la inferencia. Este fue un esfuerzo increíble liderado por muchas personas diferentes.

Para maximizar la GPU durante el entrenamiento, se exploraron varias soluciones y, al final, se eligió Megatron-Deepspeed para entrenar el modelo final. Esto significaba que el código tal como estaba no era necesariamente compatible con la biblioteca transformers.

Debido al código de entrenamiento original, nos propusimos hacer algo que hacemos regularmente: adaptar un modelo existente a transformers. El objetivo era extraer las partes relevantes del código de entrenamiento y implementarlas dentro de transformers. Younes se encargó de este esfuerzo, que no fue para nada pequeño, ya que tomó casi un mes y 200 confirmaciones para lograrlo.

Hay varias cosas a tener en cuenta que se mencionarán más adelante:

Necesitábamos tener modelos más pequeños, como bigscience/bigscience-small-testing y bigscience/bloom-560m. Esto es extremadamente importante porque al ser más pequeños, todo es más rápido al trabajar con ellos.

En primer lugar, debes abandonar toda esperanza de obtener exactamente los mismos logits al final, hasta los bytes. Las versiones de PyTorch pueden cambiar los núcleos e introducir diferencias sutiles, y diferentes hardware pueden producir resultados diferentes debido a diferentes arquitecturas (y probablemente no quieras desarrollar todo el tiempo en una GPU A100 por motivos de coste).

Tener un conjunto de pruebas estrictas y sólidas es realmente importante para todos los modelos

La mejor prueba que encontramos fue tener un conjunto fijo de indicaciones. Conoces la indicación, conoces la finalización que debe ser determinista, por lo que se utiliza el método greedy. Si dos generaciones son idénticas, básicamente puedes ignorar las pequeñas diferencias en los logits. Si ves una desviación, debes investigar. Puede ser que tu código no esté haciendo lo que debería o que estés fuera del dominio de ese modelo y, por lo tanto, el modelo sea más sensible al ruido. Si tienes varias indicaciones y suficientemente largas, es menos probable que se active eso accidentalmente para todas las indicaciones. Cuantas más indicaciones, mejor; cuanto más largas, mejor.

El primer modelo (small-testing) está en bfloat16 como el gran bloom, por lo que todo debería ser muy similar, pero no se entrenó mucho o simplemente no tiene un buen rendimiento, por lo que fluctúa mucho en las salidas. Eso significa que tuvimos problemas con esas pruebas de generación. El segundo modelo es más estable, pero se entrenó y guardó en float16 en lugar de bfloat16. Eso deja más margen de error entre los dos.

Para ser completamente justos, la conversión de bfloat16 a float16 parecía estar bien en el modo de inferencia (bfloat16 existe principalmente para manejar grandes gradientes, que no existen en la inferencia).

En ese paso, se descubrió e implementó un importante compromiso. Debido a que bloom se entrenó en un entorno distribuido, parte del código estaba realizando paralelismo de tensores en una capa lineal, lo que significa que ejecutar la misma operación como una única operación en una sola GPU daba resultados diferentes. Esto tomó un tiempo para identificarlo y, o bien optábamos por una compatibilidad del 100% y el modelo era mucho más lento, o aceptábamos una pequeña diferencia en la generación pero era mucho más rápido de ejecutar y el código era más simple. Optamos por una bandera configurable.

Nota: en este contexto, el paralelismo de canalización (PP) significa que cada GPU tendrá 
algunas capas, por lo que cada GPU trabajará en un fragmento de datos determinado antes de 
pasarlo a la siguiente GPU.

Ahora tenemos una versión limpia de transformers en la que podemos empezar a trabajar.

Bloom es un modelo de 352 GB (176B parámetros en bf16), por lo que necesitamos al menos esa cantidad de RAM de GPU para que quepa. Exploramos brevemente la posibilidad de trasladarlo a la CPU en máquinas más pequeñas, pero la velocidad de inferencia era mucho más lenta, por lo que lo descartamos.

Luego quisimos básicamente utilizar el pipeline. Así que esto es dogfooding y esto es lo que la API utiliza debajo del capó todo el tiempo.

Sin embargo, los pipelines no son conscientes de la distribución (no es su objetivo). Después de discutir brevemente las opciones, terminamos usando accelerate, recién creado, device_map=”auto” para gestionar la fragmentación del modelo. Tuvimos que solucionar algunos errores y arreglar un poco el código de transformers para ayudar a accelerate a hacer el trabajo correcto.

Funciona dividiendo las diversas capas de los transformers y dando parte del modelo a cada GPU. Así que la GPU0 comienza a trabajar, luego se la pasa a la GPU1 y así sucesivamente.

Al final, con un pequeño servidor HTTP encima, pudimos empezar a servir bloom (el modelo grande) !!

Pero aún ni siquiera hemos comenzado a discutir las optimizaciones!

En realidad, tenemos bastantes, todo este proceso es un castillo de naipes. Durante las optimizaciones vamos a hacer modificaciones en el código subyacente, asegurarse de que no estás matando el modelo de una forma u otra es realmente importante y más fácil de lo que piensas.

Así que ahora estamos en el primer paso de las optimizaciones y necesitamos comenzar a medir y seguir midiendo el rendimiento. Entonces necesitamos considerar en qué nos importa. Para un servidor de inferencia abierto que admite muchas opciones, esperamos que los usuarios envíen muchas consultas con diferentes parámetros y lo que nos importa es:

La cantidad de usuarios que podemos servir al mismo tiempo (rendimiento) ¿Cuánto tiempo tarda en atenderse un usuario promedio (latencia)?

Hicimos un script de prueba en locust que es exactamente esto:

from locust import HttpUser, between, task
from random import randrange, random


class QuickstartUser(HttpUser):
    wait_time = between(1, 5)

    @task
    def bloom_small(self):
        sentence = "Traducir al chino. EN: Me gusta la sopa. CN: "
        self.client.post(
            "/generate",
            json={
                "inputs": sentence[: randrange(1, len(sentence))],
                "parameters": {"max_new_tokens": 20, "seed": random()},
            },
        )

    @task
    def bloom_small(self):
        sentence = "Traducir al chino. EN: Me gusta la sopa. CN: "
        self.client.post(
            "/generate",
            json={
                "inputs": sentence[: randrange(1, len(sentence))],
                "parameters": {
                    "max_new_tokens": 20,
                    "do_sample": True,
                    "top_p": 0.9,
                    "seed": random(),
                },
            },
        )

**Nota: Esta no es la mejor ni la única prueba de carga que usamos, pero siempre fue la primera en ejecutarse para poder comparar de manera justa entre enfoques. Ser el mejor en este punto de referencia NO significa que sea la mejor solución. Además del rendimiento del mundo real, también se tuvieron que usar escenarios más complejos. **

Queríamos observar la rampa de subida para varias implementaciones y también asegurarnos de que bajo carga el servidor cerrara adecuadamente. El cierre del circuito significa que el servidor puede responder (rápidamente) que no responderá a su consulta porque demasiadas personas están tratando de usarlo al mismo tiempo. Es extremadamente importante evitar el abrazo de la muerte.

En esta prueba de referencia, el rendimiento inicial fue (en 16xA100 40Go en GCP, que es la máquina utilizada en todo el proceso):

Peticiones/s: 0.3 (rendimiento) Latencia: 350ms/token (latencia)

Esos números no son tan buenos. Antes de ponernos manos a la obra, estimemos lo mejor que podemos imaginar lograr. La fórmula para la cantidad de operaciones es 24Bsh^2 + 4𝐵s^2h24Bsh^2 + 4𝐵s^2h donde B es el tamaño del lote, s la longitud de la secuencia y h la dimensión oculta.

Hagamos los cálculos y estamos obteniendo 17 TFlop para un solo pase hacia adelante. Mirando las especificaciones de A100, afirma 312 TFLOPS para una sola tarjeta. Eso significa que una sola GPU podría ejecutarse potencialmente a 17 / 312 = 54ms/token. Estamos usando 16 de esas, por lo que en la máquina general sería 3ms/token. Toma todos estos números con una buena dosis de escepticismo, nunca es posible alcanzar esos números y el rendimiento en la vida real rara vez coincide con las especificaciones. Además, si el cálculo no es su factor limitante, entonces esto no es lo más bajo que puede obtener. Es solo una buena práctica saber qué tan lejos está de su objetivo. En este caso, estamos a 2 órdenes de magnitud, por lo que estamos bastante lejos. Además, esta estimación pone todas las operaciones al servicio de la latencia, lo que significa que solo una solicitud puede ir a la vez (está bien porque estás maximizando tu máquina, por lo que no hay mucho más que hacer, pero podemos tener una latencia más alta y recuperar el rendimiento a través del agrupamiento con mayor facilidad).

Nota: El paralelismo de tensores (TP) significa en este contexto que cada GPU poseerá una parte de los pesos, por lo que TODAS las GPUs están activas todo el tiempo y hacen menos trabajo. Por lo general, esto conlleva una sobrecarga muy ligera en la que se duplica parte del trabajo y, lo que es más importante, las GPUs deben comunicarse regularmente entre sí para continuar la computación.

Ahora que tenemos una buena comprensión de nuestra situación, es hora de ponerse a trabajar.

Probamos muchas cosas diferentes basadas en las personas y nuestros diversos conocimientos.

TODOS los esfuerzos merecen su propia publicación de blog, así que solo los enumeraré, explicaré los pocos aprendizajes finales y me adentraré en los detalles solo de lo que se incluyó en el servidor actual. Pasar del Paralelismo de Canalización (PP) al Paralelismo de Tensores (TP) es un gran cambio interesante para la latencia. Cada GPU poseerá una parte de los parámetros y todos trabajarán al mismo tiempo. Por lo tanto, la latencia debería disminuir drásticamente, pero el precio a pagar es la sobrecarga de comunicación, ya que necesitan comunicarse regularmente entre sí acerca de sus resultados.

Cabe destacar que este es un rango muy amplio de enfoques y la intención fue aprender más sobre cada herramienta y cómo podría encajar en futuros proyectos.

Portar el código JAX/Flax para ejecutar en TPUs:

  • Se esperaba que fuera más fácil elegir el tipo de paralelismo, por lo que TP debería ser más fácil de probar. Es una de las ventajas del diseño de Jax.
  • Más limitado en hardware, es probable que el rendimiento en TPU sea superior al de la GPU y hay menos opciones de proveedores para TPU.
  • Desventajas, se necesita otro puerto. Pero de todos modos sería bienvenido en nuestras bibliotecas.

Resultados:

  • El portar no fue una tarea fácil, ya que algunas condiciones y núcleos eran difíciles de reproducir lo suficientemente bien. Aún así, se pudo manejar.
  • El paralelismo fue bastante fácil de obtener una vez que se realizó el portado. Felicitaciones a Jax, la afirmación es cierta.
  • Ray/la comunicación con los trabajadores de TPU resultó ser un verdadero dolor de cabeza para nosotros. No sabemos si es la herramienta, la red o simplemente nuestra falta de conocimiento, pero ralentizó los experimentos y el trabajo mucho más de lo que anticipamos. Lanzábamos un experimento que tardaba 5 minutos en ejecutarse, esperábamos 5 minutos y no había sucedido nada, 10 minutos después todavía nada, resultó que algún trabajador estaba caído/no respondía, teníamos que ingresar manualmente, averiguar qué sucedió, solucionarlo, reiniciar algo y volver a lanzarlo y habíamos perdido media hora. Repite eso suficientes veces y los días perdidos se acumulan rápidamente. Queremos enfatizar que esto no es necesariamente una crítica a las herramientas que usamos, sino la experiencia subjetiva que tuvimos.
  • No hay control sobre la compilación. Una vez que hicimos que las cosas funcionaran, probamos varias configuraciones para descubrir cuál se adaptaba mejor a la inferencia que teníamos en mente y resultó ser muy difícil predecir la latencia/rendimiento según la configuración. Por ejemplo, teníamos una tasa de solicitudes por segundo (rps) de 0.3 con un tamaño de lote de 1 (por lo que cada solicitud/usuario está por separado) con una latencia de 15 ms/token (no lo compares demasiado con otros números en este artículo, es en una máquina diferente con un perfil muy diferente), lo cual es genial, pero el rendimiento general no es mucho mejor que lo que teníamos con el código antiguo. Así que decidimos agregar el procesamiento por lotes y con un tamaño de lote de 2, la latencia aumentó 5 veces, con solo el doble de rendimiento… Después de investigar más a fondo, resultó que hasta el tamaño de lote de 16, cada tamaño de lote tenía el mismo perfil de latencia. Por lo tanto, podríamos haber tenido 16 veces más rendimiento a un costo de latencia 5 veces mayor. No está mal, pero al observar los números realmente hubiéramos preferido un control más detallado. Los números a los que apuntábamos se basan en la regla de los 100 ms, 1 s, 10 s, 1 min.

Usando ONNX/TRT u otros enfoques compilados

  • Se supone que manejan la mayor parte del trabajo de optimización.
  • Desventaja, por lo general el paralelismo debe ser manejado manualmente.

Resultados:

  • Resultó que, para poder rastrear/jit/exportar cosas, necesitábamos reestructurar parte de PyTorch para que se fusionara fácilmente con el enfoque puro de PyTorch. En general, descubrimos que podríamos tener la mayoría de las optimizaciones que deseábamos al permanecer dentro del mundo de PyTorch, lo que nos permitiría mantener la flexibilidad sin tener que hacer demasiado esfuerzo de codificación. Otra cosa a tener en cuenta, dado que estamos ejecutando en GPU y la generación de texto tiene muchas pasadas hacia adelante, necesitamos que los tensores permanezcan en la GPU y a veces es difícil enviar tus tensores a alguna librería, recibir el resultado, realizar el cálculo de los logits (como argmax o muestreo) y volver a alimentarlo. Poner el bucle dentro de la librería externa significa perder flexibilidad, al igual que Jax, por lo que no se contempló en nuestro caso de uso.

DeepSpeed

  • Esta es la tecnología que impulsó el entrenamiento, parecía justo usarla para la inferencia
  • Contras, nunca se usó/preparó para la inferencia antes.

Resultados:

  • Tuvimos resultados realmente impresionantes rápidamente, que son aproximadamente los mismos que la última iteración que estamos ejecutando actualmente.
  • Tuvimos que inventar una forma de poner un servidor web (lidiando con la concurrencia) sobre DeepSpeed, que también tiene varios procesos (uno para cada GPU). Dado que existe una excelente biblioteca Mii. No se ajusta a los objetivos extremadamente flexibles que teníamos en mente, pero probablemente habríamos comenzado a trabajar sobre ella ahora. (La solución actual se discute más adelante).
  • La mayor advertencia que encontramos con DeepSpeed fue la falta de estabilidad. Tuvimos problemas al ejecutarlo en CUDA 11.4, donde el código estaba construido para 11.6. Y el problema de larga data que nunca pudimos solucionar realmente es que habría bloqueos de kernel regulares (acceso ilegal de Cuda, discrepancia de dimensiones, etc.). Arreglamos muchos de estos, pero nunca pudimos lograr la estabilidad bajo el estrés de nuestro servidor web. A pesar de eso, quiero agradecer a las personas de Microsoft que nos ayudaron, tuvimos una conversación realmente buena que mejoró nuestra comprensión de lo que estaba sucediendo y nos dio ideas reales para hacer algunos trabajos de seguimiento.
  • Uno de los puntos dolorosos que siento es que nuestro equipo está en su mayoría en Europa, mientras que Microsoft está en California, por lo que la colaboración fue complicada en términos de tiempo y probablemente perdimos una gran cantidad de tiempo debido a esto. Esto no tiene nada que ver con la parte técnica, pero es bueno reconocer que la parte organizativa de trabajar juntos también es muy importante.
  • Otra cosa a tener en cuenta es que DeepSpeed se basa en transformers para inyectar su optimización, y dado que estábamos actualizando nuestro código de manera bastante constante, fue difícil para el equipo de DeepSpeed mantener las cosas funcionando en nuestra rama principal. Lamentamos haberlo dificultado, supongo que por eso se llama vanguardia.

Ideas para el servidor web

  • Dado que vamos a ejecutar un servidor gratuito donde los usuarios van a enviar texto largo, texto corto, querer unos pocos tokens o una receta completa, algo tenía que hacerse aquí.

Resultados:

  • Recodificamos todo en Rust con las excelentes librerías tch-rs. Rust no tenía como objetivo obtener ganancias de rendimiento, sino simplemente tener un control mucho más fino sobre el paralelismo (hilos/procesos) y jugar de manera más detallada en la concurrencia del servidor web y de PyTorch. Python es famosamente difícil de manejar en detalles de bajo nivel gracias al GIL.
  • Resultó que la mayor parte del dolor venía del puerto, y después de eso, la experimentación fue muy fácil. Y descubrimos que con suficiente control sobre los bucles podríamos tener un gran rendimiento para todos, incluso en el contexto de una amplia variedad de solicitudes con diferentes propiedades. Código para los curiosos, pero no viene con ningún soporte o documentación agradable.
  • Se convirtió en producción durante algunas semanas porque era más tolerante con el paralelismo, podíamos usar las GPUs de manera más eficiente (usando GPU0 para la solicitud 1 mientras GPU1 está tratando la solicitud 0). Y pasamos de 0.3 RPS a ~2.5 RPS con la misma latencia. El caso óptimo habría sido aumentar el rendimiento en 16 veces, pero los números mostrados aquí son mediciones de cargas de trabajo reales, por lo que no está mal del todo.

PyTorch puro

  • Modificar puramente el código existente para hacerlo más rápido eliminando operaciones como reshape, utilizando kernels mejor optimizados, y así sucesivamente.
  • Con, tenemos que codificar el TP nosotros mismos y tenemos la restricción de que el código aún se ajuste a nuestra biblioteca (principalmente).

Resultados:

  • Siguiente capítulo.

Escribir PyTorch más eficiente

El primer elemento de la lista fue eliminar operaciones innecesarias en las primeras implementaciones. Algunas se pueden ver simplemente mirando el código y descubriendo fallos obvios:

  • Alibi se usaba en Bloom para agregar incrustaciones de posición y se calculaba en demasiados lugares, solo podíamos calcularlo una vez y de manera más eficiente.

El código antiguo: enlace. El nuevo código: enlace.

Esto es una aceleración de 10 veces y la última versión incluye acolchado también. Dado que este paso se calcula solo una vez, la velocidad real no es importante, pero en general, reducir el número de operaciones y creación de tensores es una buena dirección.

Otras partes se ven más claras cuando se comienza el perfilado y utilizamos de manera bastante extensa la extensión de tensorboard

Esto proporciona este tipo de imagen que brinda información:

La atención lleva mucho tiempo, ten cuidado, esta es una vista de la CPU, por lo que las barras largas no significan largas, sino que la CPU está esperando los resultados de la GPU del paso anterior. Vemos muchas operaciones “cat” antes de “baddbmm”.

Al eliminar muchas reshape/transpose, por ejemplo, descubrimos que: – La atención es el camino crítico (se espera, pero siempre es bueno verificarlo). – En la atención, muchos kernels eran copias reales debido a la gran cantidad de reshape – Podríamos eliminar los reshape reconfigurando los pesos mismos y el pasado. Esto es un cambio drástico, ¡pero mejoró bastante el rendimiento!

Soporte para TP

Vale, hemos eliminado la mayoría de los frutos de bajo colgante, ahora hemos reducido aproximadamente de 350ms/token de latencia a 300ms/token en PP. Eso es una reducción del 15% en la latencia, pero en realidad proporcionó más que eso, pero al principio no fuimos extremadamente rigurosos en nuestras mediciones, así que vamos a quedarnos con esa cifra.

Luego pasamos a proporcionar una implementación de TP. Resultó ser mucho más rápido de lo que anticipábamos, la implementación tomó medio día de un solo desarrollador (experimentado). El resultado está aquí . También pudimos reutilizar código de otros proyectos, lo cual ayudó.

La latencia pasó directamente de 300ms/token a 91ms/token, lo cual es una mejora enorme en la experiencia del usuario. Una solicitud simple de 20 tokens pasó de 6 segundos a 2 segundos, lo cual pasó de una experiencia “lenta” a ligeramente retrasada.

Además, el rendimiento aumentó mucho a 10RPS. El rendimiento proviene del hecho de que ejecutar una consulta con batch_size=1 lleva el mismo tiempo que batch_size=32 y el rendimiento se vuelve prácticamente gratuito en costo de latencia en este punto.

Frutos de bajo colgante

Ahora que teníamos una implementación de TP, pudimos comenzar a perfilar y optimizar nuevamente. Es un cambio lo suficientemente significativo como para tener que empezar de cero de nuevo.

Lo primero que destacó es que la sincronización (ncclAllReduce) comienza a convertirse en una parte preponderante de la carga, lo cual es esperado, esta es la parte de sincronización y sí lleva algo de tiempo. Nunca intentamos buscar y optimizar esto, ya que ya está utilizando nccl, pero aún podría haber margen de mejora allí. Supusimos que sería difícil hacer mucho mejor.

Lo segundo es que el operador Gelu lanzaba muchos kernels elementwise y en general estaba ocupando una mayor parte de la computación de lo esperado.

Hicimos el cambio de:

def bloom_gelu_forward(x):
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

a

@torch.jit.script
def bloom_gelu_forward(x):
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

¡Esto transforma las operaciones de múltiples pequeños kernels element-wise (y por lo tanto copias de tensores) a una única operación kernel!

Esto proporcionó una mejora de latencia del 10% de 91ms/token a 81ms/token, ¡justo ahí!

Ten cuidado, esto no es una caja negra mágica que puedas usar en todas partes, la fusión de kernels no necesariamente ocurrirá o las operaciones previamente utilizadas ya son extremadamente eficientes.

Lugares donde encontramos que funcionaba bien:

  • Tienes muchas operaciones pequeñas/elementwise
  • Tienes un punto crítico con pocos reshape difíciles de eliminar, copias en general
  • Cuando ocurre la fusión.

Fracaso épico

También tuvimos algunos puntos, durante nuestros períodos de prueba, donde terminamos viendo una latencia consistentemente un 25% menor para el servidor Rust en comparación con el de Python. Esto fue bastante extraño, pero porque se midió de manera constante y porque eliminar kernels proporcionó una aceleración, tuvimos la impresión de que tal vez eliminar la sobrecarga de Python podría proporcionar un impulso agradable.

Comenzamos un trabajo de 3 días para reimplementar las partes necesarias de torch.distributed para poder funcionar en el mundo Rust con nccl-rs. Teníamos la versión funcionando pero algo no estaba bien en las generaciones en comparación con su contraparte en Python. Durante la investigación de los problemas, nos dimos cuenta… de que habíamos olvidado eliminar el perfilador en las mediciones de Pytorch

Eso fue un fracaso épico porque al eliminarlo recuperamos el 25% y ambos códigos se ejecutaron igual de rápido. Esto es lo que esperábamos inicialmente, que Python no debería afectar el rendimiento, ya que en su mayoría ejecuta código cpp de torch. Al final, 3 días no es el fin del mundo, y podría ser útil en el futuro, pero sigue siendo bastante malo. Esto es bastante común cuando se hacen optimizaciones y se obtienen mediciones incorrectas o engañosas que resultan decepcionantes o incluso perjudiciales para el producto en general. Por eso, dar pequeños pasos y tener expectativas sobre los resultados lo antes posible ayuda a controlar ese riesgo.

Otro lugar donde tuvimos que tener mucho cuidado fue en el pase inicial hacia adelante (sin pasado) y los pases hacia adelante posteriores (con pasado). Si optimizas el primero, es muy probable que ralentes los posteriores, que son mucho más importantes y representan la mayor parte del tiempo de ejecución. Otra causa común es medir tiempos que son tiempos de CPU y no tiempos CUDA reales, por lo que debes utilizar torch.cuda.synchronize() al realizar ejecuciones para asegurarte de que los kernels se completen.

Kernel personalizado

Hasta ahora, ¡habíamos logrado un rendimiento cercano al de DeepSpeed sin ningún código personalizado fuera de PyTorch! Muy bien. ¡Tampoco tuvimos que hacer ningún compromiso en cuanto a la flexibilidad del tamaño de lote en tiempo de ejecución!

Pero dado la experiencia de DeepSpeed, queríamos intentar escribir un kernel personalizado para fusionar algunas operaciones en la ruta crítica donde torch.jit.script no pudo hacerlo por nosotros. Básicamente, las siguientes dos líneas:

attn_weights = attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)

El primer masked fill está creando un nuevo tensor, que solo se utiliza para indicarle al operador softmax que ignore esos valores. Además, el softmax debe calcularse en float32 (para estabilidad), pero dentro de un kernel personalizado podríamos limitar la cantidad de conversión necesaria, limitándola a las sumas y acumulaciones reales necesarias.

El código se puede encontrar aquí. Ten en cuenta que teníamos una única arquitectura de GPU a la que apuntar, por lo que nos pudimos centrar en esto y no somos expertos (aún) en escribir kernels, por lo que podría haber mejores formas de hacerlo.

Este kernel personalizado proporcionó un aumento adicional de latencia del 10%, disminuyendo de una latencia de 81ms/token a 71ms/token. Y todo esto manteniendo nuestra flexibilidad.

Después de eso, investigamos y exploramos otras cosas, como fusionar más operadores eliminando otras remodelaciones o colocándolas en otros lugares. Pero ningún intento tuvo un impacto lo suficientemente significativo como para llegar a las versiones finales.

Parte del servidor web

Al igual que su contraparte en Rust, tuvimos que implementar el agrupamiento de solicitudes con diferentes parámetros. Dado que estábamos en el mundo de PyTorch, tenemos bastante control sobre lo que sucede. Sin embargo, al estar en Python, tenemos el factor limitante de que torch.distributed debe ejecutarse en varios procesos en lugar de hilos, lo que significa que es un poco más difícil comunicarse entre procesos. Al final, optamos por comunicarnos mediante cadenas sin procesar a través de Redis pub/sub para distribuir las solicitudes a todos los procesos a la vez. Dado que estamos en diferentes procesos, es más fácil hacerlo de esta manera que comunicarnos mediante tensores (que son mucho más grandes), por ejemplo.

Luego tuvimos que dejar de usar generate ya que esto aplica los parámetros a todos los miembros del lote, y en realidad queríamos aplicar un conjunto diferente de parámetros. Afortunadamente, podemos reutilizar elementos de nivel inferior como el LogitsProcessor para ahorrarnos mucho trabajo.

Así que reconstruimos una función generate que recibe una lista de parámetros y los aplica a cada miembro del lote.

Otro aspecto realmente importante de la experiencia de usuario final es la latencia. Dado que tenemos diferentes conjuntos de parámetros para diferentes solicitudes, podríamos tener

1 solicitud para 20 tokens y otra para 250 tokens. Dado que la latencia es de 75ms/token, una solicitud tarda 1.5s y la otra 18s. Si agrupáramos todo el camino, estaríamos haciendo esperar al usuario que solicitó durante 18s y haciéndole parecer que estamos funcionando a 900ms/token, ¡lo cual es bastante lento!

Dado que estamos en un mundo de PyTorch con una flexibilidad extrema, lo que podemos hacer en su lugar es extraer del lote la primera solicitud tan pronto como generemos los primeros 20 tokens, ¡y regresar a ese usuario en el tiempo solicitado de 1.5s! También logramos ahorrar 230 tokens en cálculos.

Por lo tanto, la flexibilidad es importante para obtener la mejor latencia posible.

La optimización es un trabajo interminable, y como cualquier otro proyecto, el 20% del trabajo generalmente produce el 80% de los resultados. En algún momento, comenzamos a tener una pequeña estrategia de prueba para determinar los posibles resultados de alguna idea que tuviéramos, y si las pruebas no producían resultados significativos, entonces descartábamos la idea. Un día para un aumento del 10% es lo suficientemente valioso, 2 semanas para un aumento de 10X es lo suficientemente valioso. 2 semanas para un aumento del 10% no es tan interesante.

¿Has intentado …?

Las cosas que sabemos que existen y no hemos utilizado por diversas razones. Podría ser que parecía que no se adaptaba a nuestro caso de uso, que era mucho trabajo, que los resultados no eran lo suficientemente prometedores, o incluso simplemente teníamos demasiadas opciones para probar y descartamos algunas sin razones particulares y solo por falta de tiempo. Los siguientes no están en ningún orden particular:

  • Gráficos de Cuda
  • nvFuser (Esto es lo que alimenta torch.jit.script así que sí lo utilizamos.)
  • FasterTransformer
  • Triton de Nvidia
  • XLA (¡Jax también utiliza xla!)
  • torch.fx
  • TensorRT

No dudes en ponerte en contacto si tu herramienta favorita falta aquí o si crees que nos hemos perdido algo importante que podría resultar útil.

Atención flash

Hemos examinado brevemente la integración de la atención flash, y aunque funciona extremadamente bien en el primer pase hacia adelante (sin past_key_values), no produjo grandes mejoras al ejecutarse cuando se utiliza past_key_values. Dado que necesitábamos adaptarlo para incluir el tensor alibi en el cálculo, decidimos no hacerlo (al menos por ahora).

OpenAI Triton

Triton es un gran marco para construir núcleos personalizados en Python. Queremos usarlo más, pero hasta ahora no lo hemos hecho. Nos gustaría ver si funciona mejor que nuestro núcleo de Cuda. Escribir directamente en Cuda parecía el camino más corto para nuestro objetivo cuando consideramos nuestras opciones para esa parte.

Rellenado y remodelado

Como se menciona a lo largo de este artículo, cada copia de tensor tiene un costo y otro costo oculto de ejecución de producción es el rellenado. Cuando llegan dos consultas con longitudes muy diferentes, hay que rellenar (usar un token ficticio) para hacer que encajen en un cuadro. Esto puede llevar a muchos cálculos innecesarios. Más información.

Idealmente, podríamos evitar por completo esos cálculos y nunca tener remodelados. Tensorflow tiene el concepto de RaggedTensor y PyTorch tiene tensores anidados. Ambos parecen no ser tan eficientes como los tensores regulares, pero podrían permitirnos hacer menos cálculos, lo cual siempre es una ventaja.

En un mundo ideal, toda la inferencia estaría escrita en CUDA o en una implementación pura de GPU. Teniendo en cuenta las mejoras de rendimiento obtenidas cuando pudimos fusionar operaciones, esto parece deseable. Pero no tenemos idea de hasta qué punto esto sería efectivo. ¡Si las personas más inteligentes en GPU tienen ideas, estamos escuchando!

Todo este trabajo es resultado de la colaboración de muchos miembros del equipo de HF. En ningún orden en particular, @ThomasWang @stas @Nouamane @Suraj @Sanchit @Patrick @Younes @Sylvain @Jeff (Microsoft) @Reza y toda la organización de BigScience.

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Inteligencia Artificial

Este boletín de inteligencia artificial es todo lo que necesitas #59

Esta semana los cambios en los términos de servicio de Zoom (desde marzo) se pusieron en foco después de los temores ...

Inteligencia Artificial

Conoce DreamSync un nuevo marco de inteligencia artificial para mejorar la síntesis de texto a imagen (T2I) con comentarios de modelos de comprensión de imagen

Investigadores de la Universidad del Sur de California, la Universidad de Washington, la Universidad Bar-Ilan y Googl...

Inteligencia Artificial

Cómo la inteligencia artificial ayuda a combatir los incendios forestales en California

California tiene un nuevo arma contra los incendios forestales que han devastado el estado: la inteligencia artificia...

Ciencia de Datos

Aprendizaje Profundo en Sistemas de Recomendación Una introducción.

Los sistemas de recomendación se encuentran entre las aplicaciones de Aprendizaje Automático industrial de más rápido...