Skip to content

Commit

Permalink
[nnx] add nnx.variables + FSDP toy example
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 30, 2024
1 parent c057337 commit 142402c
Show file tree
Hide file tree
Showing 17 changed files with 287 additions and 49 deletions.
158 changes: 158 additions & 0 deletions examples/nnx_toy_examples/10_fsdp_and_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# %%
import dataclasses
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

from matplotlib import pyplot as plt
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import typing as tp

mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((2, 4)),
('data', 'model'),
)


def named_sharding(*names: str | None) -> NamedSharding:
return NamedSharding(mesh, P(*names))


@dataclasses.dataclass(unsafe_hash=True)
class MeshRules:
embed: str | None = None
mlp: str | None = None
data: str | None = None

def __call__(self, *keys: str) -> tuple[str, ...]:
return tuple(getattr(self, key) for key in keys)


mesh_rules = MeshRules(
embed=None,
mlp='model',
data='data',
)


class MLP(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.w1 = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
sharding=mesh_rules('embed', 'mlp'),
)
self.b1 = nnx.Param(
jnp.zeros((dmid,)),
sharding=mesh_rules('mlp'),
)
self.w2 = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
sharding=mesh_rules('embed', 'mlp'),
)

def __call__(self, x: jax.Array):
return nnx.relu(x @ self.w1 + self.b1) @ self.w2


class SGDState(nnx.Variable):
pass


class SGD(nnx.Object):
def __init__(self, params: nnx.State, lr, decay=0.9):
def init_optimizer_state(variable: nnx.Variable):
return SGDState(
jnp.zeros_like(variable.value), **variable.get_metadata()
)

self.lr = lr
self.params = params
self.momentum = jax.tree.map(init_optimizer_state, self.params)
self.decay = decay

def update(self, grads: nnx.State):
def update_fn(
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
):
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
momentum.value = self.decay * momentum + (1 - self.decay) * grad.value
# θ_{t+1} = θ_t - α * v_t
params.value -= self.lr * momentum

jax.tree.map(update_fn, self.params, self.momentum, grads)


@nnx.jit
def create_model():
model = MLP(1, 32, 1, rngs=nnx.Rngs(0))
optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)
state = nnx.state(optimizer)
sharded_state = jax.lax.with_sharding_constraint(
state, nnx.get_named_sharding(state, mesh)
)

def get_named_shardings(path: tuple, value: nnx.VariableState):
if path[0] == 'params':
return value.replace(NamedSharding(mesh, P(*value.sharding)))
elif path[0] == 'momentum':
# currently the same as above but in general it could be different
return value.replace(NamedSharding(mesh, P(*value.sharding)))
else:
raise ValueError(f'Unknown path: {path}')

named_shardings = state.map(get_named_shardings)
sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
nnx.update(optimizer, sharded_state)
return model, optimizer


model, optimizer = create_model()

jax.debug.visualize_array_sharding(model.w1.value)
jax.debug.visualize_array_sharding(optimizer.momentum.w1.value)


@nnx.jit
def train_step(model: MLP, optimizer: SGD, x, y):
def loss_fn(model):
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return loss

loss, grad = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grad)
return loss


X = np.linspace(-2, 2, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)


def dataset(batch_size, num_steps):
for _ in range(num_steps):
idx = np.random.choice(len(X), size=batch_size)
yield X[idx], Y[idx]


losses = []
for step, (x_batch, y_batch) in enumerate(
dataset(batch_size=32, num_steps=10_000)
):
x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data'))
loss = train_step(model, optimizer, x_batch, y_batch)
losses.append(float(loss))
if step % 1000 == 0:
print(f'Step {step}: Loss = {loss}')

plt.figure()
plt.plot(losses[20:])

y_pred = model(X)
plt.figure()
plt.scatter(X, Y, color='blue')
plt.plot(X, y_pred, color='black')
plt.show()
19 changes: 10 additions & 9 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from .graph import split_context as split_context
from .graph import MergeContext as MergeContext
from .graph import merge_context as merge_context
from .graph import variables as variables
from .nn import initializers as initializers
from .nn.activations import celu as celu
from .nn.activations import elu as elu
Expand Down Expand Up @@ -116,7 +117,7 @@
from .spmd import with_sharding_constraint as with_sharding_constraint
from .statelib import State as State
from .training import metrics as metrics
from .variables import (
from .variablelib import (
Param as Param,
)
# this needs to be imported before optimizer to prevent circular import
Expand All @@ -143,14 +144,14 @@
from .transforms.transforms import eval_shape as eval_shape
from .transforms.transforms import cond as cond
from .transforms.iteration import StateAxes as StateAxes
from .variables import A as A
from .variables import BatchStat as BatchStat
from .variables import Cache as Cache
from .variables import Intermediate as Intermediate
from .variables import Variable as Variable
from .variables import VariableState as VariableState
from .variables import VariableMetadata as VariableMetadata
from .variables import with_metadata as with_metadata
from .variablelib import A as A
from .variablelib import BatchStat as BatchStat
from .variablelib import Cache as Cache
from .variablelib import Intermediate as Intermediate
from .variablelib import Variable as Variable
from .variablelib import VariableState as VariableState
from .variablelib import VariableMetadata as VariableMetadata
from .variablelib import with_metadata as with_metadata
from .visualization import display as display
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flax.core import meta
from flax.nnx import spmd
from flax.nnx import traversals
from flax.nnx import variables as variableslib
from flax.nnx import variablelib as variableslib
from flax.nnx.module import GraphDef
import typing as tp

Expand Down
10 changes: 10 additions & 0 deletions flax/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def to_predicate(filter: Filter) -> Predicate:
else:
raise TypeError(f'Invalid collection filter: {filter:!r}. ')

def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
if not all(f in (..., True) for f in remaining_filters):
raise ValueError(
'`...` or `True` can only be used as the last filters, '
f'got {filter_} it at index {i}.'
)
return tuple(map(to_predicate, filters))

@dataclasses.dataclass(frozen=True)
class WithTag:
Expand Down
43 changes: 37 additions & 6 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
DelayedAccessor,
)
from flax.nnx.statelib import FlatState, State
from flax.nnx.variables import Variable, VariableState
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts

A = tp.TypeVar('A')
Expand Down Expand Up @@ -1325,15 +1326,47 @@ def update(node, state: State, /, *states: State) -> None:

_graph_update_dynamic(node, state.raw_mapping)

def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]:
for path, value in iter_graph(node):
if isinstance(value, Variable):
yield path, value


@tp.overload
def state(node, /) -> GraphState: ...
def variables(node, /) -> State[Key, Variable]: ...
@tp.overload
def variables(node, first: filterlib.Filter, /) -> State[Key, Variable]: ...
@tp.overload
def variables(
node,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[State[Key, Variable], ...]: ...
def variables(
node,
*filters: filterlib.Filter,
) -> tp.Union[State[Key, Variable], tuple[State[Key, Variable], ...]]:
num_filters = len(filters)
if num_filters == 0:
filters = (..., ...)
else:
filters = (*filters, ...)

variables_iterable = _variables_generator(node)
flat_states = variablelib.split_flat_state(
variables_iterable, (*filters, ...)
)
states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
if num_filters < 2:
return states[0]
return states

@tp.overload
def state(node, /) -> GraphState: ...
@tp.overload
def state(node, first: filterlib.Filter, /) -> GraphState: ...


@tp.overload
def state(
node,
Expand All @@ -1342,8 +1375,6 @@ def state(
/,
*filters: filterlib.Filter,
) -> tuple[GraphState, ...]: ...


def state(
node,
*filters: filterlib.Filter,
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
filterlib,
graph,
)
from flax.nnx import variables as variableslib
from flax.nnx import variablelib as variableslib
from flax.nnx.graph import GraphDef
from flax.nnx.object import Object, ObjectMeta
from flax.nnx.graph import GraphState, StateLeaf
Expand Down
6 changes: 3 additions & 3 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from flax.core.frozen_dict import FrozenDict
from flax import nnx
from flax.nnx import rnglib, variables
from flax.nnx import rnglib, variablelib
from flax.nnx.module import Module, first_from
from flax.nnx.nn import dtypes, initializers
from flax.typing import (
Expand Down Expand Up @@ -193,7 +193,7 @@ def kernel_init_wrap(rng, shape, dtype):
)
flat_shape = jax.tree.map(int, flat_shape)
kernel = self.kernel_init(rng, flat_shape, dtype)
if isinstance(kernel, variables.VariableMetadata):
if isinstance(kernel, variablelib.VariableMetadata):
kernel.raw_value = jnp.reshape(kernel.raw_value, shape)
else:
kernel = jnp.reshape(kernel, shape)
Expand All @@ -215,7 +215,7 @@ def kernel_init_wrap(rng, shape, dtype):
def bias_init_wrap(rng, shape, dtype):
flat_shape = (int(np.prod(shape)),)
bias = self.bias_init(rng, flat_shape, dtype)
if isinstance(bias, variables.VariableMetadata):
if isinstance(bias, variablelib.VariableMetadata):
bias.raw_value = jnp.reshape(bias.raw_value, shape)
else:
bias = jnp.reshape(bias, shape)
Expand Down
8 changes: 4 additions & 4 deletions flax/nnx/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax
import jax.numpy as jnp

from flax.nnx import rnglib, variables
from flax.nnx import rnglib, variablelib
from flax.nnx.module import Module
from flax.nnx.nn import initializers
from flax.nnx.nn.linear import Linear
Expand All @@ -32,7 +32,7 @@
default_kernel_init = initializers.lecun_normal()


class LoRAParam(variables.Param[A]):
class LoRAParam(variablelib.Param[A]):
pass


Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
kernel_init: Initializer = default_kernel_init,
lora_param_type: tp.Type[variables.Variable] = LoRAParam,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
):
self.in_features = in_features
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(
lora_dtype: tp.Optional[Dtype] = None,
lora_param_dtype: Dtype = jnp.float32,
lora_kernel_init: Initializer = default_kernel_init,
lora_param_type: tp.Type[variables.Variable] = LoRAParam,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
tracers,
)
from flax.nnx import graph
from flax.nnx.variables import Variable, VariableState
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key
from flax import errors

Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flax import struct
from flax.nnx import graph
from flax.nnx.statelib import State
from flax.nnx.variables import Variable
from flax.nnx.variablelib import Variable
from flax.nnx import filterlib
from flax.nnx.filterlib import All
from flax.nnx.object import Object
Expand Down
Loading

0 comments on commit 142402c

Please sign in to comment.