Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNX checkpointing guide #4249

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
hooks:
- id: check-toml
- id: trailing-whitespace
exclude: ^docs*/.*\.md$
exclude: ^docs.*\.md$
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
Expand Down
382 changes: 382 additions & 0 deletions docs_nnx/guides/checkpointing.ipynb

Large diffs are not rendered by default.

201 changes: 201 additions & 0 deletions docs_nnx/guides/checkpointing.md

Large diffs are not rendered by default.

97 changes: 52 additions & 45 deletions docs_nnx/guides/surgery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"This will throw error: <class 'KeyError'>: 'layer1'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ivyzheng/envs/py310/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1401: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
" warnings.warn(\n"
"This will throw error: <class 'ValueError'>: Dict key mismatch; expected keys: ['linear1', 'linear2']; dict: {'layer1': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}, 'layer2': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}}.\n"
]
}
],
Expand Down Expand Up @@ -267,45 +259,46 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'linear1': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},\n",
" 'kernel': {'raw_value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n",
"{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n",
" 'kernel': {'value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n",
" [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],\n",
" [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],\n",
" [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)}},\n",
" 'linear2': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},\n",
" 'kernel': {'raw_value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n",
" 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n",
" 'kernel': {'value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n",
" [ 0.41914317, 0.84359694, -0.47937787, -0.49135214],\n",
" [-0.46072108, 0.4630125 , 0.39276958, -0.9441406 ],\n",
" [-0.6690758 , -0.18474789, -0.57622856, 0.4821079 ]], dtype=float32)}}}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
" warnings.warn(\n"
]
}
],
"source": [
"def module_from_variables_dict(module_factory, variables, map_key_fn):\n",
" if map_key_fn is None:\n",
" map_key_fn = lambda path: path\n",
" mdl = nnx.eval_shape(module_factory)\n",
" graph_def, state = nnx.split(mdl)\n",
" state = state.flat_state()\n",
" for path, val in flax.traverse_util.flatten_dict(variables).items():\n",
" mapped_path = map_key_fn(path)\n",
" if mapped_path not in state:\n",
" raise ValueError(f\"{mapped_path} doesn't exist in {state.keys()}\")\n",
" state[mapped_path].value = val\n",
" state = nnx.State.from_flat_path(state)\n",
" return nnx.merge(graph_def, state)\n",
"\n",
"# Make your local change on the checkpoint.\n",
"raw = checkpointer.restore('/tmp/nnx-surgery-state')\n",
"pprint(raw)\n",
"raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']\n",
"del raw['linear1'], raw['linear2']\n",
"\n",
"restored_model = module_from_variables_dict(\n",
" lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))),\n",
" raw,\n",
" lambda path: path[:-1] if path[-1] == 'raw_value' else path\n",
")\n",
"def process_raw_dict(raw_state_dict):\n",
" flattened = nnx.traversals.flatten_mapping(raw_state_dict)\n",
" # Cut off the '.value' postfix on every leaf path.\n",
" flattened = {(path[:-1] if path[-1] == 'value' else path): value\n",
" for path, value in flattened.items()}\n",
" return nnx.traversals.unflatten_mapping(flattened)\n",
"\n",
"# Make your local change on the checkpoint dictionary.\n",
"raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')\n",
"pprint(raw_dict)\n",
"raw_dict['layer1'] = raw_dict.pop('linear1')\n",
"raw_dict['layer2'] = raw_dict.pop('linear2')\n",
"\n",
"# Fit it into the model state.\n",
"abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graph_def, state = nnx.split(abs_model)\n",
"state.replace_by_pure_dict(process_raw_dict(raw_dict))\n",
"restored_model = nnx.merge(graph_def, state)\n",
"\n",
"np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))"
]
Expand Down Expand Up @@ -339,9 +332,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Number of jax arrays in memory at start: 34\n",
"Number of jax arrays in memory midway: 38 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n",
"Number of jax arrays in memory at end: 36 (2 discarded - only lora_a & lora_b are used in model)\n"
"Number of jax arrays in memory at start: 38\n",
"Number of jax arrays in memory midway: 42 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n",
"Number of jax arrays in memory at end: 40 (2 discarded - only lora_a & lora_b are used in model)\n"
]
}
],
Expand Down Expand Up @@ -379,8 +372,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Number of jax arrays in memory at start: 40\n",
"Number of jax arrays in memory at end: 42 (2 new created - lora_a and lora_b)\n"
"Number of jax arrays in memory at start: 44\n",
"Number of jax arrays in memory at end: 46 (2 new created - lora_a and lora_b)\n"
]
}
],
Expand All @@ -389,7 +382,7 @@
"old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"\n",
"# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!\n",
"@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)\n",
"@nnx.jit(donate_argnums=0)\n",
"def partial_init(old_state, rngs):\n",
" model = TwoLayerMLP(4, rngs=rngs)\n",
" # Create a new state.\n",
Expand All @@ -404,6 +397,20 @@
"print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'\n",
" ' (2 new created - lora_a and lora_b)')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -420,7 +427,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
53 changes: 27 additions & 26 deletions docs_nnx/guides/surgery.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,31 +152,24 @@ except Exception as e:
But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition.

```{code-cell} ipython3
def module_from_variables_dict(module_factory, variables, map_key_fn):
if map_key_fn is None:
map_key_fn = lambda path: path
mdl = nnx.eval_shape(module_factory)
graph_def, state = nnx.split(mdl)
state = state.flat_state()
for path, val in flax.traverse_util.flatten_dict(variables).items():
mapped_path = map_key_fn(path)
if mapped_path not in state:
raise ValueError(f"{mapped_path} doesn't exist in {state.keys()}")
state[mapped_path].value = val
state = nnx.State.from_flat_path(state)
return nnx.merge(graph_def, state)

# Make your local change on the checkpoint.
raw = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw)
raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']
del raw['linear1'], raw['linear2']

restored_model = module_from_variables_dict(
lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))),
raw,
lambda path: path[:-1] if path[-1] == 'raw_value' else path
)
def process_raw_dict(raw_state_dict):
flattened = nnx.traversals.flatten_mapping(raw_state_dict)
# Cut off the '.value' postfix on every leaf path.
flattened = {(path[:-1] if path[-1] == 'value' else path): value
for path, value in flattened.items()}
return nnx.traversals.unflatten_mapping(flattened)

# Make your local change on the checkpoint dictionary.
raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw_dict)
raw_dict['layer1'] = raw_dict.pop('linear1')
raw_dict['layer2'] = raw_dict.pop('linear2')

# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
state.replace_by_pure_dict(process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)

np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
```
Expand Down Expand Up @@ -218,7 +211,7 @@ Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!
@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)
@nnx.jit(donate_argnums=0)
def partial_init(old_state, rngs):
model = TwoLayerMLP(4, rngs=rngs)
# Create a new state.
Expand All @@ -233,3 +226,11 @@ good_model = partial_init(old_state, nnx.Rngs(42))
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
' (2 new created - lora_a and lora_b)')
```

```{code-cell} ipython3

```

```{code-cell} ipython3

```
1 change: 1 addition & 0 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def __call__(self, inputs: Array) -> Array:
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
assert self.use_bias == (bias is not None)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
return y
Expand Down
Loading