diff --git a/docs/api_reference/flax.nnx/filterlib.rst b/docs/api_reference/flax.nnx/filterlib.rst new file mode 100644 index 0000000000..09dffe4865 --- /dev/null +++ b/docs/api_reference/flax.nnx/filterlib.rst @@ -0,0 +1,16 @@ +filterlib +------------------------ + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + + +.. autofunction:: flax.nnx.filterlib.to_predicate +.. autoclass:: WithTag +.. autoclass:: PathContains +.. autoclass:: OfType +.. autoclass:: Any +.. autoclass:: All +.. autoclass:: Not +.. autoclass:: Everything +.. autoclass:: Nothing \ No newline at end of file diff --git a/docs/api_reference/flax.nnx/index.rst b/docs/api_reference/flax.nnx/index.rst index 957c99567d..98b05093f7 100644 --- a/docs/api_reference/flax.nnx/index.rst +++ b/docs/api_reference/flax.nnx/index.rst @@ -17,4 +17,5 @@ Experimental API. See the `NNX page bool\n", + "\n", + "```\n", + "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n", + "\n", + "Types are obviously not functions of this form, so the reason why they are treated as Filters \n", + "is because, as we will see next, types and some other literals are converted to predicates. For example, \n", + "`Param` is roughly converted to a predicate like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "30f4c868", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "is_param((), nnx.Param(0)) = True\n", + "is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True\n" + ] + } + ], + "source": [ + "def is_param(path, value) -> bool:\n", + " return isinstance(value, nnx.Param) or (\n", + " hasattr(value, 'type') and issubclass(value.type, nnx.Param)\n", + " )\n", + "\n", + "print(f'{is_param((), nnx.Param(0)) = }')\n", + "print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')" + ] + }, + { + "cell_type": "markdown", + "id": "a8a2641e", + "metadata": {}, + "source": [ + "Such function matches any value that is an instance of `Param` or any value that has a \n", + "`type` attribute that is a subclass of `Param`. Internally NNX uses `OfType` which defines \n", + "a callable of this form for a given type:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b3095221", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "is_param((), nnx.Param(0)) = True\n", + "is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True\n" + ] + } + ], + "source": [ + "is_param = nnx.OfType(nnx.Param)\n", + "\n", + "print(f'{is_param((), nnx.Param(0)) = }')\n", + "print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')" + ] + }, + { + "cell_type": "markdown", + "id": "87c06e39", + "metadata": {}, + "source": [ + "## The Filter DSL\n", + "\n", + "To avoid users having to create these functions, NNX exposes a small DSL, formalized \n", + "as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, \n", + "tuples/lists, etc, and converts them to the appropriate predicate internally.\n", + "\n", + "Here is a list of all the callable Filters included in NNX and their DSL literals \n", + "(when available):\n", + "\n", + "\n", + "| Literal | Callable | Description |\n", + "|--------|----------------------|-------------|\n", + "| `...` | `Everything()` | Matches all values |\n", + "| `None` | `Nothing()` | Matches no values |\n", + "| `True` | `Everything()` | Matches all values |\n", + "| `False` | `Nothing()` | Matches no values |\n", + "| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` |\n", + "| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` |\n", + "| `'{filter}'` | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. |\n", + "| `(*filters)` tuple or `[*filters]` list | `Any(*filters)` | Matches values that match any of the inner `filters` |\n", + "| | `All(*filters)` | Matches values that match all of the inner `filters` |\n", + "| | `Not(filter)` | Matches values that do not match the inner `filter` |\n", + "\n", + "Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n", + "and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n", + "use the following filters:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d38b7694", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})\n", + "def forward(model, x):\n", + " ..." + ] + }, + { + "cell_type": "markdown", + "id": "bd60f0e1", + "metadata": {}, + "source": [ + "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n", + "expands to `Everything()`.\n", + "\n", + "If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7e065fa9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "is_param = OfType()\n", + "everything = Everything()\n", + "nothing = Nothing()\n", + "params_or_dropout = Any(OfType(), WithTag('dropout'))\n" + ] + } + ], + "source": [ + "is_param = nnx.filterlib.to_predicate(nnx.Param)\n", + "everything = nnx.filterlib.to_predicate(...)\n", + "nothing = nnx.filterlib.to_predicate(False)\n", + "params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))\n", + "\n", + "print(f'{is_param = }')\n", + "print(f'{everything = }')\n", + "print(f'{nothing = }')\n", + "print(f'{params_or_dropout = }')" + ] + }, + { + "cell_type": "markdown", + "id": "db9b4cf3", + "metadata": {}, + "source": [ + "## Grouping States\n", + "\n", + "With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:\n", + "\n", + "* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.\n", + "* Convert all the filters to predicates.\n", + "* Use `State.flat_state` to get the flat representation of the state.\n", + "* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n", + "* Use `State.from_flat_state` to convert the flat states to nested `State`s." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "068208fc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "params = State({\n", + " 'a': VariableState(\n", + " type=Param,\n", + " value=0\n", + " )\n", + "})\n", + "batch_stats = State({\n", + " 'b': VariableState(\n", + " type=BatchStat,\n", + " value=True\n", + " )\n", + "})\n" + ] + } + ], + "source": [ + "from typing import Any\n", + "KeyPath = tuple[nnx.graph.Key, ...]\n", + "\n", + "def split(node, *filters):\n", + " graphdef, state, _ = nnx.graph.flatten(node)\n", + " predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n", + " flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n", + "\n", + " for path, value in state.flat_state().items():\n", + " for i, predicate in enumerate(predicates):\n", + " if predicate(path, value):\n", + " flat_states[i][path] = value\n", + " break\n", + " else:\n", + " raise ValueError(f'No filter matched {path = } {value = }')\n", + " \n", + " states: tuple[nnx.GraphState, ...] = tuple(\n", + " nnx.State.from_flat_path(flat_state) for flat_state in flat_states\n", + " )\n", + " return graphdef, *states\n", + "\n", + "# lets test it...\n", + "foo = Foo()\n", + "\n", + "graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)\n", + "\n", + "print(f'{params = }')\n", + "print(f'{batch_stats = }')" + ] + }, + { + "cell_type": "markdown", + "id": "7b3aeac8", + "metadata": {}, + "source": [ + "One very important thing to note is that **filtering is order-dependent**. The first filter that \n", + "matches a value will keep it, therefore you should place more specific filters before more general \n", + "filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` \n", + "object that contains both types of parameters, if we try to split the `Param`s before the \n", + "`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group \n", + "will be empty because all `SpecialParam`s are also `Param`s:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "014da4d4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "params = State({\n", + " 'a': VariableState(\n", + " type=Param,\n", + " value=0\n", + " ),\n", + " 'b': VariableState(\n", + " type=SpecialParam,\n", + " value=0\n", + " )\n", + "})\n", + "special_params = State({})\n" + ] + } + ], + "source": [ + "class SpecialParam(nnx.Param):\n", + " pass\n", + "\n", + "class Bar(nnx.Module):\n", + " def __init__(self):\n", + " self.a = nnx.Param(0)\n", + " self.b = SpecialParam(0)\n", + "\n", + "bar = Bar()\n", + "\n", + "graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!\n", + "print(f'{params = }')\n", + "print(f'{special_params = }')" + ] + }, + { + "cell_type": "markdown", + "id": "a9f0b7b8", + "metadata": {}, + "source": [ + "Reversing the order will make sure that the `SpecialParam` are captured first" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a2ebf5b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "params = State({\n", + " 'a': VariableState(\n", + " type=Param,\n", + " value=0\n", + " )\n", + "})\n", + "special_params = State({\n", + " 'b': VariableState(\n", + " type=SpecialParam,\n", + " value=0\n", + " )\n", + "})\n" + ] + } + ], + "source": [ + "graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!\n", + "print(f'{params = }')\n", + "print(f'{special_params = }')" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/nnx/filters_guide.md b/docs/nnx/filters_guide.md new file mode 100644 index 0000000000..88145fb7a8 --- /dev/null +++ b/docs/nnx/filters_guide.md @@ -0,0 +1,198 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Using Filters + +Filters are used extensively in NNX as a way to create `State` groups in APIs +such as `nnx.split`, `nnx.state`, and many of the NNX transforms. For example: + +```{code-cell} ipython3 +from flax import nnx + +class Foo(nnx.Module): + def __init__(self): + self.a = nnx.Param(0) + self.b = nnx.BatchStat(True) + +foo = Foo() + +graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat) + +print(f'{params = }') +print(f'{batch_stats = }') +``` + +Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions: + +* What is a Filter? +* Why are types, such as `Param` or `BatchStat`, Filters? +* How is `State` grouped / filtered? + ++++ + +## The Filter Protocol + +In general Filter are predicate functions of the form: + +```python + +(path: tuple[Key, ...], value: Any) -> bool + +``` +where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise. + +Types are obviously not functions of this form, so the reason why they are treated as Filters +is because, as we will see next, types and some other literals are converted to predicates. For example, +`Param` is roughly converted to a predicate like this: + +```{code-cell} ipython3 +def is_param(path, value) -> bool: + return isinstance(value, nnx.Param) or ( + hasattr(value, 'type') and issubclass(value.type, nnx.Param) + ) + +print(f'{is_param((), nnx.Param(0)) = }') +print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') +``` + +Such function matches any value that is an instance of `Param` or any value that has a +`type` attribute that is a subclass of `Param`. Internally NNX uses `OfType` which defines +a callable of this form for a given type: + +```{code-cell} ipython3 +is_param = nnx.OfType(nnx.Param) + +print(f'{is_param((), nnx.Param(0)) = }') +print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') +``` + +## The Filter DSL + +To avoid users having to create these functions, NNX exposes a small DSL, formalized +as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, +tuples/lists, etc, and converts them to the appropriate predicate internally. + +Here is a list of all the callable Filters included in NNX and their DSL literals +(when available): + + +| Literal | Callable | Description | +|--------|----------------------|-------------| +| `...` | `Everything()` | Matches all values | +| `None` | `Nothing()` | Matches no values | +| `True` | `Everything()` | Matches all values | +| `False` | `Nothing()` | Matches no values | +| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` | +| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` | +| `'{filter}'` | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. | +| `(*filters)` tuple or `[*filters]` list | `Any(*filters)` | Matches values that match any of the inner `filters` | +| | `All(*filters)` | Matches values that match all of the inner `filters` | +| | `Not(filter)` | Matches values that do not match the inner `filter` | + +Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters +and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can +use the following filters: + +```{code-cell} ipython3 +from functools import partial + +@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None}) +def forward(model, x): + ... +``` + +Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` +expands to `Everything()`. + +If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`: + +```{code-cell} ipython3 +is_param = nnx.filterlib.to_predicate(nnx.Param) +everything = nnx.filterlib.to_predicate(...) +nothing = nnx.filterlib.to_predicate(False) +params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout')) + +print(f'{is_param = }') +print(f'{everything = }') +print(f'{nothing = }') +print(f'{params_or_dropout = }') +``` + +## Grouping States + +With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas: + +* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node. +* Convert all the filters to predicates. +* Use `State.flat_state` to get the flat representation of the state. +* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates. +* Use `State.from_flat_state` to convert the flat states to nested `State`s. + +```{code-cell} ipython3 +from typing import Any +KeyPath = tuple[nnx.graph.Key, ...] + +def split(node, *filters): + graphdef, state, _ = nnx.graph.flatten(node) + predicates = [nnx.filterlib.to_predicate(f) for f in filters] + flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates] + + for path, value in state.flat_state().items(): + for i, predicate in enumerate(predicates): + if predicate(path, value): + flat_states[i][path] = value + break + else: + raise ValueError(f'No filter matched {path = } {value = }') + + states: tuple[nnx.GraphState, ...] = tuple( + nnx.State.from_flat_path(flat_state) for flat_state in flat_states + ) + return graphdef, *states + +# lets test it... +foo = Foo() + +graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat) + +print(f'{params = }') +print(f'{batch_stats = }') +``` + +One very important thing to note is that **filtering is order-dependent**. The first filter that +matches a value will keep it, therefore you should place more specific filters before more general +filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` +object that contains both types of parameters, if we try to split the `Param`s before the +`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group +will be empty because all `SpecialParam`s are also `Param`s: + +```{code-cell} ipython3 +class SpecialParam(nnx.Param): + pass + +class Bar(nnx.Module): + def __init__(self): + self.a = nnx.Param(0) + self.b = SpecialParam(0) + +bar = Bar() + +graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong! +print(f'{params = }') +print(f'{special_params = }') +``` + +Reversing the order will make sure that the `SpecialParam` are captured first + +```{code-cell} ipython3 +graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct! +print(f'{params = }') +print(f'{special_params = }') +``` diff --git a/docs/nnx/index.rst b/docs/nnx/index.rst index 421e811d82..d85a5a15bf 100644 --- a/docs/nnx/index.rst +++ b/docs/nnx/index.rst @@ -175,4 +175,5 @@ Learn more mnist_tutorial transforms haiku_linen_vs_nnx - surgery \ No newline at end of file + filters_guide + surgery diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 555f5dc66e..0ed7392b5a 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -24,9 +24,15 @@ from .nnx import helpers as helpers from .nnx import compat as compat from .nnx import traversals as traversals -from .nnx.filterlib import All as All +from .nnx import filterlib as filterlib +from .nnx.filterlib import WithTag as WithTag from .nnx.filterlib import PathContains as PathContains +from .nnx.filterlib import OfType as OfType +from .nnx.filterlib import Any as Any +from .nnx.filterlib import All as All from .nnx.filterlib import Not as Not +from .nnx.filterlib import Everything as Everything +from .nnx.filterlib import Nothing as Nothing from .nnx.graph import GraphDef as GraphDef from .nnx.graph import GraphState as GraphState from .nnx.object import Object as Object diff --git a/flax/nnx/nnx/filterlib.py b/flax/nnx/nnx/filterlib.py index 97b4765afb..9113f12a7f 100644 --- a/flax/nnx/nnx/filterlib.py +++ b/flax/nnx/nnx/filterlib.py @@ -38,6 +38,10 @@ class _HasType(tp.Protocol): def to_predicate(filter: Filter) -> Predicate: + """Converts a Filter to a predicate function. + See `Using Filters `__. + """ + if isinstance(filter, str): return WithTag(filter) elif isinstance(filter, type): @@ -59,22 +63,29 @@ def to_predicate(filter: Filter) -> Predicate: raise TypeError(f'Invalid collection filter: {filter:!r}. ') -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class WithTag: tag: str def __call__(self, path: PathParts, x: tp.Any): return isinstance(x, _HasTag) and x.tag == self.tag -@dataclasses.dataclass + def __repr__(self): + return f'WithTag({self.tag!r})' + + +@dataclasses.dataclass(frozen=True) class PathContains: key: Key def __call__(self, path: PathParts, x: tp.Any): return self.key in path + def __repr__(self): + return f'PathContains({self.key!r})' -@dataclasses.dataclass + +@dataclasses.dataclass(frozen=True) class OfType: type: type @@ -83,6 +94,9 @@ def __call__(self, path: PathParts, x: tp.Any): isinstance(x, _HasType) and issubclass(x.type, self.type) ) + def __repr__(self): + return f'OfType({self.type!r})' + class Any: def __init__(self, *filters: Filter): @@ -93,6 +107,15 @@ def __init__(self, *filters: Filter): def __call__(self, path: PathParts, x: tp.Any): return any(predicate(path, x) for predicate in self.predicates) + def __repr__(self): + return f'Any({", ".join(map(repr, self.predicates))})' + + def __eq__(self, other): + return isinstance(other, Any) and self.predicates == other.predicates + + def __hash__(self): + return hash(self.predicates) + class All: def __init__(self, *filters: Filter): @@ -103,6 +126,15 @@ def __init__(self, *filters: Filter): def __call__(self, path: PathParts, x: tp.Any): return all(predicate(path, x) for predicate in self.predicates) + def __repr__(self): + return f'All({", ".join(map(repr, self.predicates))})' + + def __eq__(self, other): + return isinstance(other, All) and self.predicates == other.predicates + + def __hash__(self): + return hash(self.predicates) + class Not: def __init__(self, collection_filter: Filter, /): @@ -111,12 +143,39 @@ def __init__(self, collection_filter: Filter, /): def __call__(self, path: PathParts, x: tp.Any): return not self.predicate(path, x) + def __repr__(self): + return f'Not({self.predicate!r})' + + def __eq__(self, other): + return isinstance(other, Not) and self.predicate == other.predicate + + def __hash__(self): + return hash(self.predicate) + class Everything: def __call__(self, path: PathParts, x: tp.Any): return True + def __repr__(self): + return 'Everything()' + + def __eq__(self, other): + return isinstance(other, Everything) + + def __hash__(self): + return hash(Everything) + class Nothing: def __call__(self, path: PathParts, x: tp.Any): return False + + def __repr__(self): + return 'Nothing()' + + def __eq__(self, other): + return isinstance(other, Nothing) + + def __hash__(self): + return hash(Nothing)