Tres marcos ampliamente utilizados están liderando el camino en la investigación y producción de aprendizaje profundo en la actualidad. Uno es famoso por su facilidad de uso, otro por sus características y madurez, y otro por su inmensa escalabilidad.
El aprendizaje profundo (deep learning en su expresión inglesa) está cambiando nuestras vidas a diario. Ya sea que Siri o Alexa sigan nuestros comandos de voz, las aplicaciones de traducción en tiempo real en los teléfonos o la tecnología de visión artificial que habilitan vehículos inteligentes o robots de almacén, entre otros casos de uso. Y casi todos estos avances están escritos en uno de los tras marcos o códigos más famosos: TensorFlow, PyTorch y JAX. Pero, ¿cómo es la comparación entre ellos?
¿Se debería usar TensorFlow?
“Nunca nadie fue despedido por comprar IBM” fue el grito de guerra del sector de la informática en las décadas de 1970 y 1980, y lo mismo podría decirse sobre el uso de TensorFlow en la década de 2010 para el deep learning. Como es sabido, el Gigante Azul se quedó en el camino cuando entramos en los 90. ¿TensorFlow sigue siendo competitivo en este nuevo decenio, siete años después de su lanzamiento inicial en 2015?
Lo cierto es que este marco no se ha detenido en todo este tiempo.TensorFlow 1.x tenía que ver con la creación de gráficos estáticos de una manera muy diferente a Python, pero con la línea TensorFlow 2.x también se pueden crear modelos para la evaluación inmediata de las operaciones con un nivel similar a PyTorch. En el nivel alto, TensorFlow brinda optimización de velocidad, conocido como XLA, que aumenta rendimiento de las GPU y es el método principal para aprovechar la potencia de las unidades de procesamiento de tensor (TPU) de Google para entrenar modelos a gran escala.
Por otra parte, este marco ha estado haciendo bien cosas durante estos años. Por ejemplo, está Swing para servir modelos repetidos en plataformas maduras; y TensorFlow.js y Lite para reorientar las implementaciones de modelos para web, computación de bajo consumo como smartphones o para el Internet de las Cosas (IoT, de sus siglas inglesas). Además, teniendo en cuenta que Google todavía ejecuta el 100% de sus implementaciones de producción con TensorFlow, es fácil manejar la escalabilidad.
En cualquier caso, también ha existido una cierta falta de energía en torno al proyecto que no se ha de ignorar. La actualización a TensorFlow 2.x fue brutal. Algunas empresas analizaron el esfuerzo necesario para actualizar su código y decidieron transportarlo a PyTorch. Por otra parte, también perdió fuerza en la comunidad de investigación, que comenzó a preferir la flexibilidad de este último rival.
El caso Keras tampoco ha ayudado. Se convirtió en una parte integrada de los lanzamientos del marco hace dos años, pero recientemente se retiró a una biblioteca separada con su propio calendario de lanzamiento. Dividir Keras no es algo que afecte a la vida cotidiana de un desarrollador, pero una reversión de tan alto perfil en una revisión menor del marco no inspira confianza.
Dicho todo esto, TensorFlow es confiable y alberga un extenso ecosistema para el aprendizaje profundo. Pude crear aplicaciones y modelos que funcionen en todas las escalas.
¿Se debería usar PyTorch?
Ya no es el advenedizo que le pisa los talones a TensorFlow, sino que es una fuerza importante en el mundo del deep learning hoy en día, quizás principalmente para la investigación, pero también, y cada vez más, en aplicaciones de producción. Y dado que el modo entusiasta se ha convertido en el método predeterminado de desarrollo en TensorFlow y PyTorch, el enfoque que ofrece la diferenciación automática de PyTorch parece haber ganado la guerra contra los gráficos estáticos.
A diferencia de TensorFlow, PyTorch no ha experimentado rupturas importantes en el código central desde la desaprobación de la API variable en la versión 0.4. Pero eso no quiere decir que no haya habido algunos errores. Por ejemplo, si se ha estado usando PyTorch para entrenar en varias GPU, es probable que haya habido problemas con las diferencias entre DataParallel y el DistributedDataParallel más nuevo. Casi siempre se debe usar DistributedDataParallel, pero DataParallel en realidad no está obsoleto.
Aunque se ha quedado rezagado con respecto a TensorFlow y JAX en la compatibilidad con XLA/TPU, la situación ha mejorado mucho a partir de 2022. PyTorch ahora admite el acceso a máquinas virtuales de TPU, así como el estilo anterior de compatibilidad con nodos de TPU, junto con una implementación sencilla de la línea de comandos. Y para no lidiar con parte del código repetitivo que PyTorch a menudo hace escribir, se puede recurrir a adiciones de mayor nivel como PyTorch Lightning. En el lado negativo, aunque el trabajo continúa en PyTorch Mobile, todavía está mucho menos maduro que TensorFlow Lite.
En términos de producción, PyTorch ahora tiene integraciones con plataformas independientes del marco como Kubeflow, mientras que el proyecto TorchServe puede manejar detalles de implementación como escalado, métricas e inferencia por lotes.
¿PyTorch escala? Meta lo ha estado ejecutando en producción durante años. Aún así, se puede argumentar que PyTorch podría no ser tan amigable como JAX para las ejecuciones de entrenamiento muy grandes que requieren bancos de GPU o TPU.
¿Se debería usar JAX?
JAX es un marco de aprendizaje profundo creado, mantenido y utilizado por Google, pero no es oficialmente un producto de Google. Sin embargo, gran parte de la investigación de Google se ha trasladado a JAX.
¿Qué es JAX, exactamente? Una manera fácil de pensar en el marco es imaginar una versión de NumPy acelerada por GPU/TPU que pueda, con un movimiento, vectorizar una función de Python y manejar todos los cálculos derivados en dichas funciones. Finalmente, tiene un componente JIT (Just-In-Time) que toma su código y lo optimiza para el compilador XLA, lo que resulta en mejoras de rendimiento significativas sobre TensorFlow y PyTorch.
Dado que JAX funciona en el nivel NumPy, su código está escrito en un nivel mucho más bajo que TensorFlow/Keras e incluso PyTorch. Afortunadamente, hay un ecosistema pequeño pero creciente de proyectos circundantes que agregan bits adicionales. Hay Flax de Google y Haiku de DeepMind (también Google). Hay Optax para todas sus necesidades de optimización y PIX para el procesamiento de imágenes y mucho más. Una vez que está trabajando con algo como Flax, construir redes neuronales se vuelve relativamente fácil de manejar. Solo tenga en cuenta que todavía hay algunas asperezas. Los veteranos hablan mucho sobre cómo JAX maneja los números aleatorios de manera diferente a muchos otros marcos, por ejemplo.
-Ian Pointer, cio.com