Distributed Training with JAX and Flax NNX: A Practical Guide to Sharding
Training large machine learning models often pushes beyond the limits of a single GPU or TPU. Scaling efficiently requires distributing computation and memory across multiple devices. JAX, with its powerful
jit
compiler and explicit sharding capabilities, offers a fantastic toolkit for this.
Recently, the Flax team introduced NNX, a new API designed for more explicit state management in neural networks. This post walks through a practical example (based on an official Flax NNX example) to demonstrate how JAX's sharding features can be combined with Flax NNX to implement distributed training strategies like Fully Sharded Data Parallelism (FSDP).
Our goal is to break down the code step-by-step, making the concepts of JAX sharding and its integration with the new Flax NNX API more accessible.
Note: The code presented here is derived from the official Flax examples repository. You can find the original source at: https://github.com/google/flax/blob/f7d3873b203ac0f3c6859738b1d48c2385359ca0/examples/nnx_toy_examples/10_fsdp_and_optimizer.py. This blog aims to provide a detailed explanation to aid understanding.
Let's dive in!
Setup: Imports and Simulating Devices
First, we need the necessary imports. We'll use jax
, flax.nnx
, numpy
, and JAX's sharding utilities. A key trick for development is simulating multiple devices using the XLA_FLAGS
environment variable. This lets us test our sharding logic even on a single CPU or GPU machine before deploying to a larger cluster.
import dataclasses
import os
# Forces JAX to behave as if 8 devices (e.g., CPU cores) are available,
# even if running on a machine with fewer physical accelerators.
# Useful for testing sharding logic without multi-GPU/TPU hardware.
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
from matplotlib import pyplot as plt # For plotting results
# Utilities for creating device meshes easily
from jax.experimental import mesh_utils
# Core JAX sharding components: Mesh defines the device grid,
# PartitionSpec defines how tensor axes map to mesh axes,
# NamedSharding links PartitionSpec to a Mesh with named axes.
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax
import jax.numpy as jnp # JAX's accelerated NumPy
import numpy as np # Standard NumPy for data generation
# Import the Flax NNX API components
from flax import nnx
import typing as tp # For type hints
- Imports & Environment Setup: Standard imports for various functionalities.
os.environ['XLA_FLAGS'] = ...
: This tells JAX's underlying XLA compiler to simulate 8 devices, making multi-device development accessible.
Defining the Parallelism Strategy: The Device Mesh
The core concept for explicit sharding in JAX is the Mesh
. It represents a logical grid of your devices (real or simulated). We associate names with the axes of this grid to define our parallelism strategy. Here, we create a 2x4 mesh (8 devices) with axes named 'data'
and 'model'
.
# Create a 2D mesh (grid) of devices with shape (2, 4), meaning 8 devices total.
# Assign logical names 'data' and 'model' to the axes of this grid.
# The first dimension (size 2) is named 'data'.
# The second dimension (size 4) is named 'model'.
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((2, 4)),
('data', 'model'),
)
- Mesh Definition:
mesh_utils.create_device_mesh((2, 4))
arranges the devices.jax.sharding.Mesh(...)
creates the logical grid with named axes. This allows us to say "split data along the 'data' axis" and "split parameters along the 'model' axis."
Sharding Helpers: named_sharding
and MeshRules
To make defining sharding specifications easier and more organized, the example uses a helper function and a dataclass.
# A helper function to quickly create a NamedSharding object
# using the globally defined 'mesh'.
def named_sharding(*names: str | None) -> NamedSharding:
# P(*names) creates a PartitionSpec, e.g., P('data', None)
# NamedSharding binds this PartitionSpec to the 'mesh'.
return NamedSharding(mesh, P(*names))
named_sharding
Helper: Simplifies creatingNamedSharding
objects, which link aPartitionSpec
(how dimensions map to mesh axes, e.g.,P('data', None)
shards dim 0 along 'data', replicates dim 1) to our specificmesh
.
# A dataclass to hold sharding rules for different parts of the model/data.
# Makes it easy to manage and change sharding strategies.
@dataclasses.dataclass(unsafe_hash=True)
class MeshRules:
embed: str | None = None # Sharding rule for embedding-like dimensions
mlp: str | None = None # Sharding rule for MLP layers dimensions
data: str | None = None # Sharding rule for the data batch dimension
# Allows calling the instance like `mesh_rules('embed', 'mlp')`
# to get a tuple of the corresponding sharding rules.
def __call__(self, *keys: str) -> tuple[str, ...]:
return tuple(getattr(self, key) for key in keys)
# Create an instance of MeshRules defining the specific strategy:
# - 'embed' dimensions will be replicated (None).
# - 'mlp' dimensions will be sharded along the 'model' mesh axis.
# - 'data' dimensions will be sharded along the 'data' mesh axis.
mesh_rules = MeshRules(
embed=None,
mlp='model',
data='data',
)
MeshRules
Dataclass: Provides a structured way to define and retrieve the desired sharding axis names ('data'
,'model'
, orNone
for replication) for logical parts of the model (embed
,mlp
) and data.
Building the Sharded Model with Flax NNX
Now, let's define our MLP using the Flax NNX API. A key feature of NNX is its explicit state management. Notice how sharding intentions are specified directly when creating parameters using nnx.Param
.
# Define the MLP using Flax NNX API.
class MLP(nnx.Module):
# Constructor takes input/hidden/output dimensions and an NNX Rngs object.
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
# Define the first weight matrix as an nnx.Param.
self.w1 = nnx.Param(
# Initialize with lecun_normal initializer using a key from rngs.
nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
# CRITICAL: Specify the desired sharding using MeshRules.
# ('embed', 'mlp') -> (None, 'model') -> Replicate dim 0, shard dim 1 along 'model' axis.
sharding=mesh_rules('embed', 'mlp'),
)
# Define the first bias vector as an nnx.Param.
self.b1 = nnx.Param(
jnp.zeros((dmid,)), # Initialize with zeros.
# Sharding: ('mlp',) -> ('model',) -> Shard dim 0 along 'model' axis.
sharding=mesh_rules('mlp'),
)
# Define the second weight matrix as an nnx.Param.
self.w2 = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
# Sharding: ('embed', 'mlp') -> (None, 'model') -> Replicate dim 0, shard dim 1 along 'model' axis.
sharding=mesh_rules('embed', 'mlp'),
)
# Note: No second bias b2 is defined in this simple example.
# The forward pass of the MLP.
def __call__(self, x: jax.Array):
# Standard MLP calculation: (x @ W1 + b1) -> ReLU -> @ W2
# NNX automatically accesses the .value attribute of nnx.Param objects.
return nnx.relu(x @ self.w1 + self.b1) @ self.w2
MLP
NNX Module:nnx.Param
: Defines trainable parameters, wrapping the JAX array.sharding=mesh_rules(...)
: This is the crucial NNX integration point. We attach metadata (a tuple like(None, 'model')
) directly to the parameter, indicating how it should be sharded across themesh
axes based on ourMeshRules
.
Handling Sharded Optimizer State
Parameters aren't the only state we need to manage; optimizer states (like momentum) must also be sharded consistently with their corresponding parameters. NNX's explicit variable system handles this elegantly.
# Define a custom type for SGD momentum state, inheriting from nnx.Variable.
# This allows it to be tracked as part of the NNX state tree.
class SGDState(nnx.Variable):
pass
# Define the SGD optimizer using NNX API.
class SGD(nnx.Object):
# Constructor takes the model parameters (as nnx.State), learning rate, and decay.
def __init__(self, params: nnx.State, lr, decay=0.9):
# Helper function to initialize momentum buffer for a given parameter.
def init_optimizer_state(variable: nnx.Variable):
# Create momentum state with zeros, same shape and metadata (incl. sharding)
# as the parameter it corresponds to.
return SGDState(
jnp.zeros_like(variable.value), **variable.get_metadata()
)
self.lr = lr
# Store a reference to the parameter State tree.
self.params = params
# Create the momentum state tree, mirroring the structure of 'params',
# using the helper function. Momentum will have the same sharding as params.
self.momentum = jax.tree.map(init_optimizer_state, self.params)
self.decay = decay
# Method to update parameters based on gradients.
def update(self, grads: nnx.State):
# Define the update logic for a single parameter/momentum/gradient triple.
def update_fn(
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
):
# Standard SGD with momentum update rule.
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
momentum.value = self.decay * momentum.value + (1 - self.decay) * grad.value
# θ_{t+1} = θ_t - α * v_t
params.value -= self.lr * momentum.value # NOTE: Direct mutation of param value!
# Apply the update function across the parameter, momentum, and gradient trees.
# This performs the update in-place on the parameter values referenced by self.params.
jax.tree.map(update_fn, self.params, self.momentum, grads)
SGD
NNX Optimizer:SGDState(nnx.Variable)
: Custom type for momentum, making it part of the NNX state system.init_optimizer_state
: Creates momentum buffers matching the shape and, crucially, inheriting the metadata (including thesharding
tuple) from the corresponding parameter via**variable.get_metadata()
.update
: Applies gradients in-place to the sharded parameters (referenced viaself.params
) and momentum buffers.
Applying and Enforcing Sharding: The create_model
Function
We've defined how things should be sharded via metadata. Now, we need to tell JAX to actually enforce this sharding layout during computation. This happens in the create_model
function, using jax.lax.with_sharding_constraint
and nnx.update
.
# JIT-compile the model and optimizer creation function.
@nnx.jit
def create_model():
# Instantiate the MLP model. rngs=nnx.Rngs(0) provides PRNG keys.
model = MLP(1, 32, 1, rngs=nnx.Rngs(0))
# Create the optimizer. nnx.variables(model, nnx.Param) extracts
# only the nnx.Param state variables from the model object.
optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)
# === Explicit Sharding Application ===
# 1. Extract ALL state (model params + optimizer momentum) into a flat State pytree.
state = nnx.state(optimizer)
# 2. Define the target sharding for the state pytree.
# This function maps state paths to NamedSharding objects based on stored metadata.
def get_named_shardings(path: tuple, value: nnx.VariableState):
# Assumes params and momentum use the sharding defined in their metadata.
if path[0] in ('params', 'momentum'):
# value.sharding contains the tuple like ('model',) or (None, 'model')
# stored during Param/SGDState creation.
return value.replace(NamedSharding(mesh, P(*value.sharding)))
else:
# Handle other state if necessary (e.g., learning rate if it were a Variable)
raise ValueError(f'Unknown path: {path}')
# Create the pytree of NamedSharding objects.
named_shardings = state.map(get_named_shardings)
# 3. Apply sharding constraint. This tells JAX how the 'state' pytree
# SHOULD be sharded when computations involving it are run under jit/pjit.
# It doesn't immediately move data but sets up the constraint for the compiler.
sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
# 4. Update the original objects (model params, optimizer momentum)
# with the constrained state values. This step makes the sharding
# "stick" to the objects themselves for subsequent use outside this function.
nnx.update(optimizer, sharded_state)
# Return the model and optimizer objects, now containing sharded state variables.
return model, optimizer
# Call the function to create the sharded model and optimizer.
model, optimizer = create_model()
# Visualize the sharding of the first weight's parameter tensor.
jax.debug.visualize_array_sharding(model.w1.value)
# Visualize the sharding of the first weight's momentum tensor.
jax.debug.visualize_array_sharding(optimizer.momentum.w1.value)
create_model
Function:nnx.state(optimizer)
: Collects allnnx.Variable
instances (parameters and momentum) into a singlennx.State
pytree.state.map(get_named_shardings)
: Creates a parallel pytree ofNamedSharding
specifications, reading thesharding
metadata attached to each variable.jax.lax.with_sharding_constraint
: The core JAX primitive. It tells the JIT compiler that thestate
pytree must conform to thenamed_shardings
layout.nnx.update(optimizer, sharded_state)
: This NNX function pushes the constrained state back into the originalmodel
andoptimizer
objects, making the sharding effective for future use.
- Visualization:
jax.debug.visualize_array_sharding
confirms the parameter and momentum tensors are distributed across the device mesh as intended.
The Distributed Training Step
With the model and optimizer state correctly sharded, defining the JIT-compiled training step is straightforward. JAX automatically handles the necessary communication (like gradient aggregation) based on the sharding constraints we've established.
# JIT-compile the training step function.
@nnx.jit
def train_step(model: MLP, optimizer: SGD, x, y):
# Define the loss function (Mean Squared Error).
# Takes the model object as input, consistent with nnx.value_and_grad.
def loss_fn(model):
y_pred = model(x) # Forward pass
loss = jnp.mean((y - y_pred) ** 2)
return loss
# Calculate loss and gradients w.r.t the model's state (its nnx.Param variables).
# 'grad' will be an nnx.State object mirroring model's Param structure.
loss, grad = nnx.value_and_grad(loss_fn)(model)
# Call the optimizer's update method to apply gradients.
# This updates the model parameters in-place.
optimizer.update(grad)
# Return the calculated loss.
return loss
train_step
Function:@nnx.jit
: Compiles the function. JAX infers the distributed execution plan.nnx.value_and_grad
: Computes loss and gradients with respect to the sharded variables withinmodel
.optimizer.update(grad)
: Applies the (implicitly sharded) gradients to the sharded state.
Data Loading and the Training Loop
Finally, the training loop generates data and feeds it to the train_step
. The crucial part here is sharding the input data batch along the 'data'
axis using jax.device_put
before each step.
# Generate synthetic dataset: y = 0.8*x^2 + 0.1 + noise
X = np.linspace(-2, 2, 100)[:, None] # Input features
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) # Target values
# A generator function to yield batches of data for training.
def dataset(batch_size, num_steps):
for _ in range(num_steps):
# Randomly sample indices for the batch.
idx = np.random.choice(len(X), size=batch_size)
# Yield the corresponding input and target pairs.
yield X[idx], Y[idx]
# --- Training Loop ---
losses = [] # To store loss values for plotting
# Iterate through the dataset generator for 10,000 steps.
for step, (x_batch, y_batch) in enumerate(
dataset(batch_size=32, num_steps=10_000)
):
# CRITICAL: Place the NumPy data onto JAX devices AND apply sharding.
# named_sharding('data') -> Shard along the 'data' mesh axis (first dim, size 2).
# Each device along the 'data' axis gets a slice of the batch.
x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data'))
# Execute the JIT-compiled training step with the sharded model, optimizer, and data.
loss = train_step(model, optimizer, x_batch, y_batch)
# Record the loss (move scalar loss back to host CPU).
losses.append(float(loss))
# Log progress periodically.
if step % 1000 == 0:
print(f'Step {step}: Loss = {loss}')
# --- Plotting Results ---
plt.figure()
plt.title("Training Loss")
plt.plot(losses[20:]) # Plot loss, skipping initial noisy steps
plt.xlabel("Step")
plt.ylabel("MSE Loss")
# Get model predictions on the full dataset (X is on host CPU).
# Model applies function executes potentially on device, result brought back implicitly.
y_pred = model(X)
plt.figure()
plt.title("Model Fit")
plt.scatter(X, Y, color='blue', label='Data') # Original data
plt.plot(X, y_pred, color='black', label='Prediction') # Model's predictions
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.show() # Display the plots
- Data Generation & Dataset: Standard data setup.
- Training Loop:
jax.device_put(..., named_sharding('data'))
: This implements data parallelism. It sends the NumPy batch to devices and splits it along the'data'
mesh axis.loss = train_step(...)
: Executes the distributed training step.
- Plotting: Visualizes the training progress and the final model fit.
Summary
This example demonstrates a powerful pattern for scalable ML training:
- Define a
Mesh
: Map logical names ('data'
,'model'
) to device axes. - Use Flax NNX: Define parameters (
nnx.Param
) and other state (nnx.Variable
) explicitly. - Attach Sharding Metadata: Specify the desired sharding tuple directly when creating NNX variables.
- Enforce Sharding: Use
nnx.state
,jax.lax.with_sharding_constraint
, andnnx.update
within a JIT context to apply the constraints. - Shard Data: Use
jax.device_put
withNamedSharding
to distribute input batches. - JIT the Training Step: Let JAX compile the distributed execution plan.
This combination of JAX's explicit sharding control and Flax NNX's explicit state management provides a clear and flexible way to implement complex parallelism strategies like FSDP, enabling the training of larger and more capable models.
We hope this detailed walkthrough of the official Flax example helps clarify how these powerful tools work together!