From 94b49d0fcd72c1c46b5f3b5c3bc7c3c3f6023206 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 11 Jun 2024 21:36:07 +0800 Subject: [PATCH 1/2] rewrite ``_make_jaxpr()`` function to be compatible with ``jax==0.4.29`` --- brainstate/functional/__init__.py | 5 +- brainstate/functional/_others.py | 47 ++++++++ brainstate/transform/_make_jaxpr.py | 168 +++++++++++++++++++++++++++- 3 files changed, 215 insertions(+), 5 deletions(-) create mode 100644 brainstate/functional/_others.py diff --git a/brainstate/functional/__init__.py b/brainstate/functional/__init__.py index 3dbf984..cb0516e 100644 --- a/brainstate/functional/__init__.py +++ b/brainstate/functional/__init__.py @@ -20,6 +20,7 @@ from ._normalization import __all__ as __others_all__ from ._spikes import * from ._spikes import __all__ as __spikes_all__ +from ._others import * +from ._others import __all__ as __others_all__ -__all__ = __spikes_all__ + __others_all__ + __activations_all__ - +__all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__ diff --git a/brainstate/functional/_others.py b/brainstate/functional/_others.py new file mode 100644 index 0000000..b652467 --- /dev/null +++ b/brainstate/functional/_others.py @@ -0,0 +1,47 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from functools import partial +from typing import Any + +import jax +import jax.numpy as jnp + +PyTree = Any + +__all__ = [ + 'clip_grad_norm', +] + + +def clip_grad_norm( + grad: PyTree, + max_norm: float | jax.Array, + norm_type: int | str | None = None +): + """ + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients. + norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm. + """ + norm_fn = partial(jnp.linalg.norm, ord=norm_type) + norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad)))) + return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad) diff --git a/brainstate/transform/_make_jaxpr.py b/brainstate/transform/_make_jaxpr.py index 08ae714..8fffe5e 100644 --- a/brainstate/transform/_make_jaxpr.py +++ b/brainstate/transform/_make_jaxpr.py @@ -54,20 +54,27 @@ from __future__ import annotations import functools +import inspect import operator from collections.abc import Hashable, Iterable, Sequence +from contextlib import ExitStack from typing import Any, Callable, Tuple, Union, Dict, Optional import jax from jax._src import source_info_util +from jax._src.linear_util import annotate +from jax._src.traceback_util import api_boundary +from jax.extend.linear_util import transformation_with_aux, wrap_init from jax.interpreters import partial_eval as pe -from jax.util import wraps from jax.interpreters.xla import abstractify +from jax.util import wraps from brainstate._state import State, StateTrace from brainstate._utils import set_module_as PyTree = Any +AxisName = Hashable + __all__ = [ "StatefulFunction", @@ -393,7 +400,8 @@ def make_jaxpr(self, *args, **kwargs): if cache_key not in self._state_trace: try: # jaxpr - jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr( + # jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr( + jaxpr, (out_shapes, state_shapes) = _make_jaxpr( functools.partial(self._wrapped_fun_to_eval, cache_key), static_argnums=self.static_argnums, axis_env=self.axis_env, @@ -474,7 +482,8 @@ def make_jaxpr( state_returns: Union[str, Tuple[str, ...]] = ('read', 'write') ) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] | Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]: - """Creates a function that produces its jaxpr given example args. + """ + Creates a function that produces its jaxpr given example args. Args: fun: The function whose ``jaxpr`` is to be computed. Its positional @@ -571,3 +580,156 @@ def make_jaxpr_f(*args, **kwargs): if hasattr(fun, "__name__"): make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})" return make_jaxpr_f + + +def _check_callable(fun): + # In Python 3.10+, the only thing stopping us from supporting staticmethods + # is that we can't take weak references to them, which the C++ JIT requires. + if isinstance(fun, staticmethod): + raise TypeError(f"staticmethod arguments are not supported, got {fun}") + if not callable(fun): + raise TypeError(f"Expected a callable value, got {fun}") + if inspect.isgeneratorfunction(fun): + raise TypeError(f"Expected a function, got a generator function: {fun}") + + +def _broadcast_prefix( + prefix_tree: Any, + full_tree: Any, + is_leaf: Callable[[Any], bool] | None = None +) -> list[Any]: + # If prefix_tree is not a tree prefix of full_tree, this code can raise a + # ValueError; use prefix_errors to find disagreements and raise more precise + # error messages. + result = [] + num_leaves = lambda t: jax.tree.structure(t).num_leaves + add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree)) + jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) + return result + + +def _flat_axes_specs( + abstracted_axes, *args, **kwargs +) -> list[pe.AbstractedAxesSpec]: + if kwargs: + raise NotImplementedError + + def ax_leaf(l): + return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or + isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None)) + + return _broadcast_prefix(abstracted_axes, args, ax_leaf) + + +@transformation_with_aux +def _flatten_fun(in_tree, *args_flat): + py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat) + ans = yield py_args, py_kwargs + yield jax.tree.flatten(ans) + + +def _make_jaxpr( + fun: Callable, + static_argnums: int | Iterable[int] = (), + axis_env: Sequence[tuple[AxisName, int]] | None = None, + return_shape: bool = False, + abstracted_axes: Any | None = None, +) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]: + """Creates a function that produces its jaxpr given example args. + + Args: + fun: The function whose ``jaxpr`` is to be computed. Its positional + arguments and return value should be arrays, scalars, or standard Python + containers (tuple/list/dict) thereof. + static_argnums: See the :py:func:`jax.jit` docstring. + axis_env: Optional, a sequence of pairs where the first element is an axis + name and the second element is a positive integer representing the size of + the mapped axis with that name. This parameter is useful when lowering + functions that involve parallel communication collectives, and it + specifies the axis name/size environment that would be set up by + applications of :py:func:`jax.pmap`. + return_shape: Optional boolean, defaults to ``False``. If ``True``, the + wrapped function returns a pair where the first element is the + ``ClosedJaxpr`` representation of ``fun`` and the second element is a + pytree with the same structure as the output of ``fun`` and where the + leaves are objects with ``shape``, ``dtype``, and ``named_shape`` + attributes representing the corresponding types of the output leaves. + + Returns: + A wrapped version of ``fun`` that when applied to example arguments returns + a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the + argument ``return_shape`` is ``True``, then the returned function instead + returns a pair where the first element is the ``ClosedJaxpr`` + representation of ``fun`` and the second element is a pytree representing + the structure, shape, dtypes, and named shapes of the output of ``fun``. + + A ``jaxpr`` is JAX's intermediate representation for program traces. The + ``jaxpr`` language is based on the simply-typed first-order lambda calculus + with let-bindings. :py:func:`make_jaxpr` adapts a function to return its + ``jaxpr``, which we can inspect to understand what JAX is doing internally. + The ``jaxpr`` returned is a trace of ``fun`` abstracted to + :py:class:`ShapedArray` level. Other levels of abstraction exist internally. + + We do not describe the semantics of the ``jaxpr`` language in detail here, but + instead give a few examples. + + >>> import jax + >>> + >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) + >>> print(f(3.0)) + -0.83602 + >>> _make_jaxpr(f)(3.0) + { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } + >>> _make_jaxpr(jax.grad(f))(3.0) + { lambda ; a:f32[]. let + b:f32[] = cos a + c:f32[] = sin a + _:f32[] = sin b + d:f32[] = cos b + e:f32[] = mul 1.0 d + f:f32[] = neg e + g:f32[] = mul f c + in (g,) } + """ + _check_callable(fun) + static_argnums = _ensure_index_tuple(static_argnums) + + def _abstractify(args, kwargs): + flat_args, in_tree = jax.tree.flatten((args, kwargs)) + if abstracted_axes is None: + return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args) + else: + axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs) + in_type = pe.infer_lambda_input_type(axes_specs, flat_args) + in_avals, keep_inputs = jax.util.unzip2(in_type) + return in_avals, in_tree, keep_inputs + + @wraps(fun) + @api_boundary + def make_jaxpr_f(*args, **kwargs): + f = wrap_init(fun) + if static_argnums: + dyn_argnums = [i for i in range(len(args)) if i not in static_argnums] + f, args = jax.api_util.argnums_partial(f, dyn_argnums, args) + in_avals, in_tree, keep_inputs = _abstractify(args, kwargs) + in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs)) + f, out_tree = _flatten_fun(f, in_tree) + f = annotate(f, in_type) + debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr') + with ExitStack() as stack: + for axis_name, size in axis_env or []: + stack.enter_context(jax.core.extend_axis_env(axis_name, size, None)) + jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info) + closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts) + if return_shape: + out_avals, _ = jax.util.unzip2(out_type) + out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] + return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat) + return closed_jaxpr + + make_jaxpr_f.__module__ = "brainstate.transform" + if hasattr(fun, "__qualname__"): + make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})" + if hasattr(fun, "__name__"): + make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})" + return make_jaxpr_f From 96f9d24eaf20ccd79d33bb14b743ff2ca02cc8f9 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 11 Jun 2024 21:39:15 +0800 Subject: [PATCH 2/2] fix typing annotation error in python<=3.9 --- brainstate/functional/_others.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/brainstate/functional/_others.py b/brainstate/functional/_others.py index b652467..5ca52ce 100644 --- a/brainstate/functional/_others.py +++ b/brainstate/functional/_others.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + from functools import partial from typing import Any