I came across two interesting, related, Python libraries recently: JAX and Flax. JAX (“just after execution”) is sort of an enhanced NumPy (numerical Python) library. JAX adds support for numeric arrays on GPU and TPU hardware, and automatic gradient calculation. Flax is a neural network code library, somewhat similar to PyTorch or TensorFlow, that is built upon JAX.
I asked the Flax GitHub Discussion board if Flax is an acronym or not. According to two of the main contributors, the name stands for both “functional layers for JAX” (in early versions) and “flexible JAX” (from the design principles).
Here’s a code snippet from the Flax documentation that creates a 10-[12-8]-4 neural network.
from typing import Sequence import numpy as np import jax import jax.numpy as jnp import flax.linen as nn class MLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, x): for feat in self.features[:-1]: x = nn.relu(nn.Dense(feat)(x)) x = nn.Dense(self.features[-1])(x) return x model = MLP([12, 8, 4]) batch = jnp.ones((32, 10)) variables = model.init(jax.random.PRNGKey(0), batch) output = model.apply(variables, batch)
In even a tiny example like this, there is a lot going on. In the import statements, notice that Flax has dependencies on jax and jax.numpy — that’s a big topic by itself.
Two pages from the Flax documentation Web site. It looks like getting up to speed with Flax would take many weeks of dedicated effort.
The code snippet uses the relatively new typing library so that a Sequence can be typed to hold int values instead of List of arbitrary types.
The @nn.compact is a syntax mechanism for simple neural networks to skip a setup() method.
The JAX library uses a somewhat unusual API for random number generation.
I’m intrigued by Flax. I am very familiar with the PyTorch and Keras libraries, and part of me is thinking, “PyTorch works perfectly well for the work I do, and PyTorch has a large, fairly stable ecosystem. So why should I spend valuable time learning a new neural library?”
But another part of me is thinking, “Flax and JAX look very well thought-out and probably have learned lots from the lessons of PyTorch. Maybe Flax represents a big step forward.”
Well, this is the good and the bad of machine learning — there’s always something new and interesting.
Left: Flax is a plant. Linen is made from flax. Ancient Egyptians made linen clothes. Center: A painting by contemporary artist Louise Flax. Right: “Field of Flax” (1892) by Edgar Degas.
You must be logged in to post a comment.