Skip to content

Commit

Permalink
Merge pull request #4250 from google:enable-notebook-doctest
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682363934
  • Loading branch information
Flax Authors committed Oct 4, 2024
2 parents 2d64500 + 146d2ff commit 5d31452
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 12 deletions.
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@

# -- Options for myst ----------------------------------------------
# uncomment line below to avoid running notebooks during development
nb_execution_mode = 'off'
# nb_execution_mode = 'off'
# Notebook cell execution timeout; defaults to 30.
nb_execution_timeout = 100
# List of patterns, relative to source directory, that match notebook
Expand All @@ -147,6 +147,8 @@
'quick_start.ipynb', # <-- times out
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
'flax/nnx', # exclude nnx
'guides/quantization/fp8_basics.ipynb',
'guides/training_techniques/use_checkpointing.ipynb', # TODO(IvyZX): needs to be updated
]
# raise exceptions on execution so CI can catch errors
nb_execution_allow_errors = False
Expand Down
6 changes: 5 additions & 1 deletion docs_nnx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@

# -- Options for myst ----------------------------------------------
# uncomment line below to avoid running notebooks during development
nb_execution_mode = 'off'
# nb_execution_mode = 'off'
# Notebook cell execution timeout; defaults to 30.
nb_execution_timeout = 100
# List of patterns, relative to source directory, that match notebook
Expand All @@ -147,6 +147,10 @@
'quick_start.ipynb', # <-- times out
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
'flax/nnx', # exclude nnx
'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update
'guides/why.ipynb', # TODO(cgarciae): broken, remove in favor on the new guide
'guides/flax_gspmd.ipynb', # TODO(IvyZX): broken, needs to be updated
'guides/surgery.ipynb', # TODO(IvyZX): broken, needs to be updated
]
# raise exceptions on execution so CI can catch errors
nb_execution_allow_errors = False
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
" pass\n",
"\n",
"model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n",
"model.set_attributes(deterministic=False, use_running_average=False) # set flags\n",
"model.set_attributes(use_running_average=False) # set flags\n",
"y = model(jnp.ones((2, 4))) # call methods directly\n",
"\n",
"print(f'{model = }'[:500] + '\\n...')"
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Count(nnx.Variable): # custom Variable types define the "collections"
pass
model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method
model.set_attributes(deterministic=False, use_running_average=False) # set flags
model.set_attributes(use_running_average=False) # set flags
y = model(jnp.ones((2, 4))) # call methods directly
print(f'{model = }'[:500] + '\n...')
Expand Down
9 changes: 5 additions & 4 deletions docs_nnx/guides/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@
"\n",
"Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n",
"and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n",
"use the following filters:"
"use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`\n",
"to specify how `model`'s various substates should be vectorized:"
]
},
{
Expand All @@ -182,9 +183,9 @@
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})\n",
"\n",
"@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})\n",
"@nnx.vmap(in_axes=(state_axes, 0))\n",
"def forward(model, x):\n",
" ..."
]
Expand Down Expand Up @@ -275,7 +276,7 @@
"KeyPath = tuple[nnx.graph.Key, ...]\n",
"\n",
"def split(node, *filters):\n",
" graphdef, state, _ = nnx.graph.flatten(node)\n",
" graphdef, state = nnx.graph.flatten(node)\n",
" predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
" flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
"\n",
Expand Down
9 changes: 5 additions & 4 deletions docs_nnx/guides/filters_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ Here is a list of all the callable Filters included in Flax NNX and their DSL li

Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters
and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can
use the following filters:
use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`
to specify how `model`'s various substates should be vectorized:

```{code-cell} ipython3
from functools import partial
state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})
@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})
@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
...
```
Expand Down Expand Up @@ -140,7 +141,7 @@ from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]
def split(node, *filters):
graphdef, state, _ = nnx.graph.flatten(node)
graphdef, state = nnx.graph.flatten(node)
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
Expand Down

0 comments on commit 5d31452

Please sign in to comment.