The Flax Neural Network Library

I came across two interesting, related, Python libraries recently: JAX and Flax. JAX (“just after execution”) is sort of an enhanced NumPy (numerical Python) library. JAX adds support for numeric arrays on GPU and TPU hardware, and automatic gradient calculation. Flax is a neural network code library, somewhat similar to PyTorch or TensorFlow, that is built upon JAX.

I asked the Flax GitHub Discussion board if Flax is an acronym or not. According to two of the main contributors, the name stands for both “functional layers for JAX” (in early versions) and “flexible JAX” (from the design principles).

Here’s a code snippet from the Flax documentation that creates a 10-[12-8]-4 neural network.

from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

In even a tiny example like this, there is a lot going on. In the import statements, notice that Flax has dependencies on jax and jax.numpy — that’s a big topic by itself.



Two pages from the Flax documentation Web site. It looks like getting up to speed with Flax would take many weeks of dedicated effort.


The code snippet uses the relatively new typing library so that a Sequence can be typed to hold int values instead of List of arbitrary types.

The @nn.compact is a syntax mechanism for simple neural networks to skip a setup() method.

The JAX library uses a somewhat unusual API for random number generation.

I’m intrigued by Flax. I am very familiar with the PyTorch and Keras libraries, and part of me is thinking, “PyTorch works perfectly well for the work I do, and PyTorch has a large, fairly stable ecosystem. So why should I spend valuable time learning a new neural library?”

But another part of me is thinking, “Flax and JAX look very well thought-out and probably have learned lots from the lessons of PyTorch. Maybe Flax represents a big step forward.”

Well, this is the good and the bad of machine learning — there’s always something new and interesting.



Left: Flax is a plant. Linen is made from flax. Ancient Egyptians made linen clothes. Center: A painting by contemporary artist Louise Flax. Right: “Field of Flax” (1892) by Edgar Degas.


This entry was posted in Machine Learning. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s