Google Brain's Jax and Flax

Google's AI division, Google Brain, has two main products for deep learning: TensorFlow and Jax. While TensorFlow is best known, Jax can be thought of as a higher-level language for specifying deep learning algorithms while automatically eliding code that doesn't need to run as part of the model.

Jax evolved from Autograd, and is a combination of Autograd and XLA. Autograd "can automatically differentiate native Python and Numpy code. It can handle a large subset of Python's features, including loops, ifs, recursion and closures, and it can even take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation), which means it can efficiently take gradients of scalar-valued functions with respect to array-valued arguments, as well as forward-mode differentiation, and the two can be composed arbitrarily. The main intended application of Autograd is gradient-based optimization."

Flax is then built on top of Jax, and allows for easier customization of existing models.

What do you see as the future of domain specific languages for AI?


This entry originally appeared at lambda-the-ultimate.org/google-brain-jax, and may be a summary or abridged version.