Utilizando JAX para acelerar nuestra investigación

'Using JAX to accelerate our research'

Los ingenieros de DeepMind aceleran nuestra investigación construyendo herramientas, escalando algoritmos y creando mundos virtuales y físicos desafiantes para el entrenamiento y prueba de sistemas de inteligencia artificial (IA). Como parte de este trabajo, evaluamos constantemente nuevas bibliotecas y frameworks de aprendizaje automático.

Recientemente, hemos descubierto que un número creciente de proyectos se benefician de JAX, un framework de aprendizaje automático desarrollado por los equipos de investigación de Google. JAX se ajusta bien a nuestra filosofía de ingeniería y ha sido ampliamente adoptado por nuestra comunidad de investigación en el último año. Aquí compartimos nuestra experiencia trabajando con JAX, explicamos por qué lo encontramos útil para nuestra investigación de IA y ofrecemos una visión general del ecosistema que estamos construyendo para apoyar a los investigadores en todas partes.

¿Por qué JAX?

JAX es una biblioteca de Python diseñada para el cómputo numérico de alto rendimiento, especialmente para la investigación en aprendizaje automático. Su API para funciones numéricas se basa en NumPy, una colección de funciones utilizadas en computación científica. Tanto Python como NumPy son ampliamente utilizados y conocidos, lo que hace que JAX sea simple, flexible y fácil de adoptar.

Además de su API de NumPy, JAX incluye un sistema extensible de transformaciones de funciones componibles que ayudan a la investigación en aprendizaje automático, incluyendo:

  • Diferenciación: La optimización basada en gradientes es fundamental para el aprendizaje automático. JAX admite nativamente la diferenciación automática de modo directo e inverso de funciones numéricas arbitrarias, a través de transformaciones de funciones como grad, hessian, jacfwd y jacrev.
  • Vectorización: En la investigación en aprendizaje automático, a menudo aplicamos una sola función a muchos datos, por ejemplo, calcular la pérdida en un lote o evaluar gradientes por ejemplo para el aprendizaje diferencial privado. JAX proporciona vectorización automática a través de la transformación vmap, que simplifica este tipo de programación. Por ejemplo, los investigadores no necesitan preocuparse por el agrupamiento al implementar nuevos algoritmos. JAX también admite el paralelismo de datos a gran escala a través de la transformación relacionada pmap, distribuyendo de manera elegante datos que son demasiado grandes para la memoria de un solo acelerador.
  • Compilación JIT: XLA se utiliza para compilar en tiempo real (JIT) y ejecutar programas JAX en aceleradores GPU y Cloud TPU. La compilación JIT, junto con la API consistente con NumPy de JAX, permite a los investigadores sin experiencia previa en computación de alto rendimiento escalar fácilmente a uno o varios aceleradores.

Hemos encontrado que JAX ha permitido una experimentación rápida con algoritmos y arquitecturas novedosos y ahora forma la base de muchas de nuestras publicaciones recientes. Para obtener más información, considere unirse a nuestra Mesa Redonda de JAX, el miércoles 9 de diciembre a las 7:00 pm GMT, en la conferencia virtual NeurIPS.

JAX en DeepMind

Apoyar la investigación en IA de vanguardia significa equilibrar la creación rápida de prototipos y la iteración rápida con la capacidad de implementar experimentos a una escala tradicionalmente asociada con los sistemas de producción. Lo que hace que este tipo de proyectos sea particularmente desafiante es que el panorama de la investigación evoluciona rápidamente y es difícil de predecir. En cualquier momento, un nuevo avance en la investigación puede, y regularmente lo hace, cambiar la trayectoria y los requisitos de equipos enteros. Dentro de este panorama en constante cambio, una responsabilidad fundamental de nuestro equipo de ingeniería es asegurarse de que las lecciones aprendidas y el código escrito para un proyecto de investigación se reutilicen de manera efectiva en el siguiente.

Un enfoque que ha demostrado ser exitoso es la modularización: extraemos los bloques de construcción más importantes y críticos desarrollados en cada proyecto de investigación en componentes bien probados y eficientes. Esto permite a los investigadores centrarse en su investigación al mismo tiempo que se benefician de la reutilización de código, correcciones de errores y mejoras de rendimiento en los ingredientes algorítmicos implementados por nuestras bibliotecas principales. También hemos encontrado que es importante asegurarse de que cada biblioteca tenga un alcance claramente definido y asegurarse de que sean interoperables pero independientes. La participación incremental, la capacidad de elegir características sin estar limitado a otras, es fundamental para proporcionar la máxima flexibilidad a los investigadores y siempre apoyarlos en la elección de la herramienta adecuada para el trabajo.

Otras consideraciones que se han tenido en cuenta en el desarrollo de nuestro Ecosistema JAX incluyen asegurarse de que permanezca consistente (en la medida de lo posible) con el diseño de nuestras bibliotecas existentes de TensorFlow (por ejemplo, Sonnet y TRFL). También hemos tratado de construir componentes que (cuando corresponda) se ajusten lo más posible a las matemáticas subyacentes, para que sean autoexplicativos y minimizar los saltos mentales “del papel al código”. Finalmente, hemos elegido liberar nuestras bibliotecas como código abierto para facilitar el intercambio de resultados de investigación y fomentar que la comunidad en general explore el Ecosistema JAX.

Nuestro ecosistema hoy

Haiku ‍

El modelo de programación JAX de transformaciones de funciones componibles puede complicar el manejo de objetos con estado, como las redes neuronales con parámetros entrenables. Haiku es una biblioteca de redes neuronales que permite a los usuarios utilizar modelos de programación orientados a objetos familiares al tiempo que aprovechan el poder y la simplicidad del paradigma funcional puro de JAX.

Haiku es utilizado activamente por cientos de investigadores en DeepMind y Google, y ya ha sido adoptado en varios proyectos externos (por ejemplo, Coax , DeepChem , NumPyro ). Se basa en la API de Sonnet , nuestro modelo de programación basado en módulos para redes neuronales en TensorFlow, y hemos tratado de hacer que la migración de Sonnet a Haiku sea lo más sencilla posible.

Más información en GitHub

Optax

La optimización basada en gradientes es fundamental para el aprendizaje automático. Optax proporciona una biblioteca de transformaciones de gradientes, junto con operadores de composición (por ejemplo, chain) que permiten implementar muchos optimizadores estándar (por ejemplo, RMSProp o Adam) en una sola línea de código.

La naturaleza compositiva de Optax admite naturalmente la recombinación de los mismos ingredientes básicos en optimizadores personalizados. Además, ofrece varias utilidades para la estimación de gradientes estocásticos y la optimización de segundo orden.

Muchos usuarios de Optax han adoptado Haiku, pero de acuerdo con nuestra filosofía de adquisición incremental, se admiten bibliotecas que representen parámetros como estructuras de árbol JAX (por ejemplo, Elegy , Flax y Stax ). Consulte aquí para obtener más información sobre este rico ecosistema de bibliotecas de JAX.

Más información en GitHub

RLax

Muchos de nuestros proyectos más exitosos se encuentran en la intersección del aprendizaje profundo y el aprendizaje por refuerzo (RL), también conocido como aprendizaje por refuerzo profundo. RLax es una biblioteca que proporciona bloques de construcción útiles para la construcción de agentes de RL.

Los componentes en RLax cubren un amplio espectro de algoritmos e ideas: aprendizaje TD, gradientes de políticas, criticas de actores, MAP, optimización de políticas proximales, transformación no lineal de valores, funciones de valor general y varios métodos de exploración.

Aunque se proporcionan algunos ejemplos introductorios de agentes, RLax no está destinado como un marco para construir y desplegar sistemas de agentes de RL completos. Un ejemplo de un marco de agente totalmente funcional que se basa en los componentes de RLax es Acme .

Más información en GitHub

Chex

Las pruebas son fundamentales para la confiabilidad del software y el código de investigación no es una excepción. Llegar a conclusiones científicas a partir de experimentos de investigación requiere tener confianza en la corrección de su código. Chex es una colección de utilidades de prueba utilizadas por los autores de bibliotecas para verificar que los componentes comunes de construcción sean correctos y robustos, y por los usuarios finales para verificar su código experimental.

Chex proporciona una variedad de utilidades, incluyendo pruebas de unidad con soporte para JAX, afirmaciones de propiedades de los tipos de datos de JAX, simulacros y falsos, y entornos de prueba multidispositivo. Chex se utiliza en todo el Ecosistema JAX de DeepMind y en proyectos externos como Coax y MineRL .

Más información en GitHub

Jraph

Las redes neuronales de grafos (GNNs) son un área emocionante de investigación con muchas aplicaciones prometedoras. Vea, por ejemplo, nuestro trabajo reciente sobre predicción de tráfico en Google Maps y nuestro trabajo sobre simulación de física . Jraph (pronunciado “jirafa”) es una biblioteca ligera para trabajar con GNNs en JAX.

Jraph proporciona una estructura de datos estandarizada para grafos, un conjunto de utilidades para trabajar con grafos y un ‘zoo’ de modelos de redes neuronales gráficas fácilmente bifurcables y extensibles. Otras características clave incluyen: agrupación de GraphTuples que aprovecha eficientemente los aceleradores de hardware, compatibilidad con compilación JIT de grafos de forma variable mediante relleno y enmascaramiento, y pérdidas definidas sobre particiones de entrada. Al igual que Optax y nuestras otras bibliotecas, Jraph no impone restricciones en la elección de la biblioteca de redes neuronales por parte del usuario.

Aprenda más sobre cómo utilizar la biblioteca en nuestra amplia colección de ejemplos.

Obtenga más información en GitHub

Nuestro ecosistema JAX está en constante evolución y animamos a la comunidad de investigación de ML a explorar nuestras bibliotecas y el potencial de JAX para acelerar su propia investigación.

We will continue to update Zepes; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

Inteligencia Artificial

Investigadores de Microsoft presentan FP8 Mixed-Precision Training Framework Potenciando la eficiencia del entrenamiento de modelos de lenguaje grandes

Los modelos de lenguaje grandes han demostrado una destreza sin precedentes en la creación y comprensión del lenguaje...

Ciencia de Datos

Celebrando el impacto de IDSS

Una conferencia de dos días en MIT reflexionó sobre el impacto del Instituto de Datos, Sistemas y Sociedad desde su l...

Inteligencia Artificial

El mundo natural potencia el futuro de la visión por computadora

Un sistema de software de código abierto tiene como objetivo mejorar el entrenamiento de sistemas de visión por compu...

Inteligencia Artificial

¡No, no, no lo pongamos ahí! Este método de IA puede realizar edición de diseño continua con modelos de difusión

En este punto, todos están familiarizados con los modelos de texto a imagen. Se hicieron conocidos con el lanzamiento...