Skip to content

Commit

Permalink
rewrite _make_jaxpr() function to be compatible with ``jax==0.4.2…
Browse files Browse the repository at this point in the history
…9`` (#3)

* rewrite ``_make_jaxpr()`` function to be compatible with ``jax==0.4.29``

* fix typing annotation error in python<=3.9
  • Loading branch information
chaoming0625 authored Jun 11, 2024
1 parent b430dea commit aea868f
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 5 deletions.
5 changes: 3 additions & 2 deletions brainstate/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ._normalization import __all__ as __others_all__
from ._spikes import *
from ._spikes import __all__ as __spikes_all__
from ._others import *
from ._others import __all__ as __others_all__

__all__ = __spikes_all__ + __others_all__ + __activations_all__

__all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__
49 changes: 49 additions & 0 deletions brainstate/functional/_others.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from functools import partial
from typing import Any

import jax
import jax.numpy as jnp

PyTree = Any

__all__ = [
'clip_grad_norm',
]


def clip_grad_norm(
grad: PyTree,
max_norm: float | jax.Array,
norm_type: int | str | None = None
):
"""
Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
max_norm (float): max norm of the gradients.
norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
"""
norm_fn = partial(jnp.linalg.norm, ord=norm_type)
norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
168 changes: 165 additions & 3 deletions brainstate/transform/_make_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,27 @@
from __future__ import annotations

import functools
import inspect
import operator
from collections.abc import Hashable, Iterable, Sequence
from contextlib import ExitStack
from typing import Any, Callable, Tuple, Union, Dict, Optional

import jax
from jax._src import source_info_util
from jax._src.linear_util import annotate
from jax._src.traceback_util import api_boundary
from jax.extend.linear_util import transformation_with_aux, wrap_init
from jax.interpreters import partial_eval as pe
from jax.util import wraps
from jax.interpreters.xla import abstractify
from jax.util import wraps

from brainstate._state import State, StateTrace
from brainstate._utils import set_module_as

PyTree = Any
AxisName = Hashable


__all__ = [
"StatefulFunction",
Expand Down Expand Up @@ -393,7 +400,8 @@ def make_jaxpr(self, *args, **kwargs):
if cache_key not in self._state_trace:
try:
# jaxpr
jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
# jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
functools.partial(self._wrapped_fun_to_eval, cache_key),
static_argnums=self.static_argnums,
axis_env=self.axis_env,
Expand Down Expand Up @@ -474,7 +482,8 @@ def make_jaxpr(
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
"""Creates a function that produces its jaxpr given example args.
"""
Creates a function that produces its jaxpr given example args.
Args:
fun: The function whose ``jaxpr`` is to be computed. Its positional
Expand Down Expand Up @@ -571,3 +580,156 @@ def make_jaxpr_f(*args, **kwargs):
if hasattr(fun, "__name__"):
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
return make_jaxpr_f


def _check_callable(fun):
# In Python 3.10+, the only thing stopping us from supporting staticmethods
# is that we can't take weak references to them, which the C++ JIT requires.
if isinstance(fun, staticmethod):
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
if not callable(fun):
raise TypeError(f"Expected a callable value, got {fun}")
if inspect.isgeneratorfunction(fun):
raise TypeError(f"Expected a function, got a generator function: {fun}")


def _broadcast_prefix(
prefix_tree: Any,
full_tree: Any,
is_leaf: Callable[[Any], bool] | None = None
) -> list[Any]:
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
# ValueError; use prefix_errors to find disagreements and raise more precise
# error messages.
result = []
num_leaves = lambda t: jax.tree.structure(t).num_leaves
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
return result


def _flat_axes_specs(
abstracted_axes, *args, **kwargs
) -> list[pe.AbstractedAxesSpec]:
if kwargs:
raise NotImplementedError

def ax_leaf(l):
return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))

return _broadcast_prefix(abstracted_axes, args, ax_leaf)


@transformation_with_aux
def _flatten_fun(in_tree, *args_flat):
py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield jax.tree.flatten(ans)


def _make_jaxpr(
fun: Callable,
static_argnums: int | Iterable[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
return_shape: bool = False,
abstracted_axes: Any | None = None,
) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
"""Creates a function that produces its jaxpr given example args.
Args:
fun: The function whose ``jaxpr`` is to be computed. Its positional
arguments and return value should be arrays, scalars, or standard Python
containers (tuple/list/dict) thereof.
static_argnums: See the :py:func:`jax.jit` docstring.
axis_env: Optional, a sequence of pairs where the first element is an axis
name and the second element is a positive integer representing the size of
the mapped axis with that name. This parameter is useful when lowering
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the
``ClosedJaxpr`` representation of ``fun`` and the second element is a
pytree with the same structure as the output of ``fun`` and where the
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
attributes representing the corresponding types of the output leaves.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
argument ``return_shape`` is ``True``, then the returned function instead
returns a pair where the first element is the ``ClosedJaxpr``
representation of ``fun`` and the second element is a pytree representing
the structure, shape, dtypes, and named shapes of the output of ``fun``.
A ``jaxpr`` is JAX's intermediate representation for program traces. The
``jaxpr`` language is based on the simply-typed first-order lambda calculus
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
``jaxpr``, which we can inspect to understand what JAX is doing internally.
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
We do not describe the semantics of the ``jaxpr`` language in detail here, but
instead give a few examples.
>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602
>>> _make_jaxpr(f)(3.0)
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
>>> _make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a:f32[]. let
b:f32[] = cos a
c:f32[] = sin a
_:f32[] = sin b
d:f32[] = cos b
e:f32[] = mul 1.0 d
f:f32[] = neg e
g:f32[] = mul f c
in (g,) }
"""
_check_callable(fun)
static_argnums = _ensure_index_tuple(static_argnums)

def _abstractify(args, kwargs):
flat_args, in_tree = jax.tree.flatten((args, kwargs))
if abstracted_axes is None:
return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
else:
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
in_avals, keep_inputs = jax.util.unzip2(in_type)
return in_avals, in_tree, keep_inputs

@wraps(fun)
@api_boundary
def make_jaxpr_f(*args, **kwargs):
f = wrap_init(fun)
if static_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
f, out_tree = _flatten_fun(f, in_tree)
f = annotate(f, in_type)
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
if return_shape:
out_avals, _ = jax.util.unzip2(out_type)
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
return closed_jaxpr

make_jaxpr_f.__module__ = "brainstate.transform"
if hasattr(fun, "__qualname__"):
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
if hasattr(fun, "__name__"):
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
return make_jaxpr_f

0 comments on commit aea868f

Please sign in to comment.