diff --git a/docs/conf.py b/docs/conf.py index bc0d98416b..e65c142bd4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -137,7 +137,7 @@ # -- Options for myst ---------------------------------------------- # uncomment line below to avoid running notebooks during development -nb_execution_mode = 'off' +# nb_execution_mode = 'off' # Notebook cell execution timeout; defaults to 30. nb_execution_timeout = 100 # List of patterns, relative to source directory, that match notebook @@ -147,6 +147,8 @@ 'quick_start.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 'flax/nnx', # exclude nnx + 'guides/quantization/fp8_basics.ipynb', + 'guides/training_techniques/use_checkpointing.ipynb', # TODO(IvyZX): needs to be updated ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 641080c28e..344010ac8b 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -137,7 +137,7 @@ # -- Options for myst ---------------------------------------------- # uncomment line below to avoid running notebooks during development -nb_execution_mode = 'off' +# nb_execution_mode = 'off' # Notebook cell execution timeout; defaults to 30. nb_execution_timeout = 100 # List of patterns, relative to source directory, that match notebook @@ -147,6 +147,10 @@ 'quick_start.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 'flax/nnx', # exclude nnx + 'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update + 'guides/why.ipynb', # TODO(cgarciae): broken, remove in favor on the new guide + 'guides/flax_gspmd.ipynb', # TODO(IvyZX): broken, needs to be updated + 'guides/surgery.ipynb', # TODO(IvyZX): broken, needs to be updated ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False diff --git a/docs_nnx/guides/demo.ipynb b/docs_nnx/guides/demo.ipynb index a2521ef10f..acf77951f9 100644 --- a/docs_nnx/guides/demo.ipynb +++ b/docs_nnx/guides/demo.ipynb @@ -88,7 +88,7 @@ " pass\n", "\n", "model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n", - "model.set_attributes(deterministic=False, use_running_average=False) # set flags\n", + "model.set_attributes(use_running_average=False) # set flags\n", "y = model(jnp.ones((2, 4))) # call methods directly\n", "\n", "print(f'{model = }'[:500] + '\\n...')" diff --git a/docs_nnx/guides/demo.md b/docs_nnx/guides/demo.md index f507f9c482..1f423a77eb 100644 --- a/docs_nnx/guides/demo.md +++ b/docs_nnx/guides/demo.md @@ -48,7 +48,7 @@ class Count(nnx.Variable): # custom Variable types define the "collections" pass model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method -model.set_attributes(deterministic=False, use_running_average=False) # set flags +model.set_attributes(use_running_average=False) # set flags y = model(jnp.ones((2, 4))) # call methods directly print(f'{model = }'[:500] + '\n...') diff --git a/docs_nnx/guides/filters_guide.ipynb b/docs_nnx/guides/filters_guide.ipynb index 5f63191bbf..ed37ad8731 100644 --- a/docs_nnx/guides/filters_guide.ipynb +++ b/docs_nnx/guides/filters_guide.ipynb @@ -172,7 +172,8 @@ "\n", "Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n", "and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n", - "use the following filters:" + "use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`\n", + "to specify how `model`'s various substates should be vectorized:" ] }, { @@ -182,9 +183,9 @@ "metadata": {}, "outputs": [], "source": [ - "from functools import partial\n", + "state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})\n", "\n", - "@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})\n", + "@nnx.vmap(in_axes=(state_axes, 0))\n", "def forward(model, x):\n", " ..." ] @@ -275,7 +276,7 @@ "KeyPath = tuple[nnx.graph.Key, ...]\n", "\n", "def split(node, *filters):\n", - " graphdef, state, _ = nnx.graph.flatten(node)\n", + " graphdef, state = nnx.graph.flatten(node)\n", " predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n", " flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n", "\n", diff --git a/docs_nnx/guides/filters_guide.md b/docs_nnx/guides/filters_guide.md index c403451649..97ff439ce2 100644 --- a/docs_nnx/guides/filters_guide.md +++ b/docs_nnx/guides/filters_guide.md @@ -98,12 +98,13 @@ Here is a list of all the callable Filters included in Flax NNX and their DSL li Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can -use the following filters: +use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes` +to specify how `model`'s various substates should be vectorized: ```{code-cell} ipython3 -from functools import partial +state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None}) -@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None}) +@nnx.vmap(in_axes=(state_axes, 0)) def forward(model, x): ... ``` @@ -140,7 +141,7 @@ from typing import Any KeyPath = tuple[nnx.graph.Key, ...] def split(node, *filters): - graphdef, state, _ = nnx.graph.flatten(node) + graphdef, state = nnx.graph.flatten(node) predicates = [nnx.filterlib.to_predicate(f) for f in filters] flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]