From 3b1efa33fce3629125a1bf55df511e8084ec032a Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Mon, 9 Dec 2024 10:46:00 +0800 Subject: [PATCH 1/2] compatible with jax==0.4.36 --- brainstate/compile/_loop_collect_return.py | 6 ++- brainstate/compile/_make_jaxpr.py | 55 ++++++++++++---------- brainstate/compile/_progress_bar.py | 42 ++++++++++++----- brainstate/nn/_collective_ops.py | 33 ++++++------- brainstate/util/_tracers.py | 7 --- examples/100_hh_neuron_model.py | 11 +++-- examples/106_COBA_HH_2007.py | 3 +- 7 files changed, 88 insertions(+), 69 deletions(-) diff --git a/brainstate/compile/_loop_collect_return.py b/brainstate/compile/_loop_collect_return.py index 68658d4..151c344 100644 --- a/brainstate/compile/_loop_collect_return.py +++ b/brainstate/compile/_loop_collect_return.py @@ -211,7 +211,11 @@ def scan(f, init, xs, length=None): # scan init = (all_writen_state_vals, init) - (all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f, init, xs, length=length, reverse=reverse, + (all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f, + init, + xs, + length=length, + reverse=reverse, unroll=unroll) # assign the written state values and restore the read state values write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals) diff --git a/brainstate/compile/_make_jaxpr.py b/brainstate/compile/_make_jaxpr.py index 742a0c7..aa10496 100644 --- a/brainstate/compile/_make_jaxpr.py +++ b/brainstate/compile/_make_jaxpr.py @@ -72,7 +72,6 @@ from brainstate._state import State, StateTraceStack from brainstate._utils import set_module_as from brainstate.typing import PyTree -from brainstate.util._tracers import new_jax_trace AxisName = Hashable @@ -112,28 +111,27 @@ def _new_arg_fn(frame, trace, aval): return tracer -def _init_state_trace() -> StateTraceStack: - # Should be within the calling of ``jax.make_jaxpr()`` - frame, trace = new_jax_trace() +def _new_jax_trace(): + main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1] + frame = main.jaxpr_stack[-1] + trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel()) + return frame, trace + + +def _init_state_trace_stack() -> StateTraceStack: state_trace: StateTraceStack = StateTraceStack() - # Set the function to transform the new argument to a tracer - state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace)) - return state_trace + if jax.__version_info__ < (0, 4, 36): + # Should be within the calling of ``jax.make_jaxpr()`` + frame, trace = _new_jax_trace() + # Set the function to transform the new argument to a tracer + state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace)) + return state_trace -# def wrapped_abstractify(x: Any) -> Any: -# """ -# Abstractify the input. -# -# Args: -# x: The input. -# -# Returns: -# The abstractified input. -# """ -# if isinstance(x, pe.DynamicJaxprTracer): -# return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type) -# return shaped_abstractify(x) + else: + trace = jax.core.trace_ctx.trace + state_trace.set_new_arg(trace.new_arg) + return state_trace class StatefulFunction(object): @@ -383,12 +381,15 @@ def _wrapped_fun_to_eval( A tuple of the states that are read and written by the function and the output of the function. """ # state trace - state_trace = _init_state_trace() + state_trace = _init_state_trace_stack() self._cached_state_trace[cache_key] = state_trace with state_trace: out = self.fun(*args, **kwargs) - state_values = state_trace.get_write_state_values( - True) if return_only_write else state_trace.get_state_values() + state_values = ( + state_trace.get_write_state_values(True) + if return_only_write else + state_trace.get_state_values() + ) state_trace.recovery_original_values() # State instance as functional returns is not allowed. @@ -419,17 +420,21 @@ def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs): try: # jaxpr jaxpr, (out_shapes, state_shapes) = _make_jaxpr( - functools.partial(self._wrapped_fun_to_eval, cache_key, return_only_write=return_only_write), + functools.partial( + self._wrapped_fun_to_eval, + cache_key, + return_only_write=return_only_write + ), static_argnums=self.static_argnums, axis_env=self.axis_env, return_shape=True, abstracted_axes=self.abstracted_axes )(*args, **kwargs) - # returns self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes)) self._cached_out_shapes[cache_key] = (out_shapes, state_shapes) self._cached_jaxpr[cache_key] = jaxpr + except Exception as e: try: self._cached_state_trace.pop(cache_key) diff --git a/brainstate/compile/_progress_bar.py b/brainstate/compile/_progress_bar.py index 39cb572..4c74e0c 100644 --- a/brainstate/compile/_progress_bar.py +++ b/brainstate/compile/_progress_bar.py @@ -93,19 +93,37 @@ def _close_tqdm(self): self.tqdm_bars[0].update(self.remainder) self.tqdm_bars[0].close() + def _tqdm(self, is_init, is_print, is_final): + if is_init: + self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs) + self.tqdm_bars[0].set_description(self.message, refresh=False) + if is_print: + self.tqdm_bars[0].update(self.print_freq) + if is_final: + if self.remainder > 0: + self.tqdm_bars[0].update(self.remainder) + self.tqdm_bars[0].close() + def __call__(self, iter_num, *args, **kwargs): - _ = jax.lax.cond( + jax.debug.callback( + self._tqdm, iter_num == 0, - lambda: jax.debug.callback(self._define_tqdm), - lambda: None, - ) - _ = jax.lax.cond( (iter_num + 1) % self.print_freq == 0, - lambda: jax.debug.callback(self._update_tqdm), - lambda: None, - ) - _ = jax.lax.cond( - iter_num == self.n - 1, - lambda: jax.debug.callback(self._close_tqdm), - lambda: None, + iter_num == self.n - 1 ) + + # _ = jax.lax.cond( + # iter_num == 0, + # lambda: jax.debug.callback(self._define_tqdm, ordered=True), + # lambda: None, + # ) + # _ = jax.lax.cond( + # (iter_num + 1) % self.print_freq == 0, + # lambda: jax.debug.callback(self._update_tqdm, ordered=True), + # lambda: None, + # ) + # _ = jax.lax.cond( + # iter_num == self.n - 1, + # lambda: jax.debug.callback(self._close_tqdm, ordered=True), + # lambda: None, + # ) diff --git a/brainstate/nn/_collective_ops.py b/brainstate/nn/_collective_ops.py index 76653f7..e7e3309 100644 --- a/brainstate/nn/_collective_ops.py +++ b/brainstate/nn/_collective_ops.py @@ -80,7 +80,6 @@ def init_all_states( target: T, *args, exclude: Filter = None, - tag: str = None, **kwargs ) -> T: """ @@ -99,27 +98,25 @@ def init_all_states( The target Module. """ - with catch_new_states(tag=tag): - - # node that has `call_order` decorated - nodes_with_order = [] - - nodes_ = nodes(target).filter(Module) - if exclude is not None: - nodes_ = nodes_ - nodes_.filter(exclude) + # node that has `call_order` decorated + nodes_with_order = [] - # reset node whose `init_state` has no `call_order` - for node in list(nodes_.values()): - if hasattr(node.init_state, 'call_order'): - nodes_with_order.append(node) - else: - node.init_state(*args, **kwargs) + nodes_ = nodes(target).filter(Module) + if exclude is not None: + nodes_ = nodes_ - nodes_.filter(exclude) - # reset the node's states with `call_order` - for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order): + # reset node whose `init_state` has no `call_order` + for node in list(nodes_.values()): + if hasattr(node.init_state, 'call_order'): + nodes_with_order.append(node) + else: node.init_state(*args, **kwargs) - return target + # reset the node's states with `call_order` + for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order): + node.init_state(*args, **kwargs) + + return target @set_module_as('brainstate.nn') diff --git a/brainstate/util/_tracers.py b/brainstate/util/_tracers.py index 5b79de8..7488047 100644 --- a/brainstate/util/_tracers.py +++ b/brainstate/util/_tracers.py @@ -16,7 +16,6 @@ import jax import jax.core -from jax.interpreters import partial_eval as pe from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr @@ -25,12 +24,6 @@ ] -def new_jax_trace(): - main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1] - frame = main.jaxpr_stack[-1] - trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel()) - return frame, trace - def current_jax_trace(): """Returns the Jax tracing state.""" diff --git a/examples/100_hh_neuron_model.py b/examples/100_hh_neuron_model.py index 2f6a4ed..68a0f80 100644 --- a/examples/100_hh_neuron_model.py +++ b/examples/100_hh_neuron_model.py @@ -92,6 +92,7 @@ def update(self, x=0. * u.mA / u.cm ** 2): return spike + hh = HH(10) bst.nn.init_all_states(hh) dt = 0.01 * u.ms @@ -104,10 +105,12 @@ def run(t, inp): times = u.math.arange(0. * u.ms, 100. * u.ms, dt) -vs = bst.compile.for_loop(run, - # times, random inputs - times, bst.random.uniform(1., 10., times.shape) * u.uA / u.cm ** 2, - pbar=bst.compile.ProgressBar(count=100)) +vs = bst.compile.for_loop( + run, + # times, random inputs + times, bst.random.uniform(1., 10., times.shape) * u.uA / u.cm ** 2, + pbar=bst.compile.ProgressBar(count=100) +) plt.plot(times.to_decimal(u.ms), vs.to_decimal(u.mV)) plt.show() diff --git a/examples/106_COBA_HH_2007.py b/examples/106_COBA_HH_2007.py index dc2621a..9f5aabf 100644 --- a/examples/106_COBA_HH_2007.py +++ b/examples/106_COBA_HH_2007.py @@ -21,7 +21,6 @@ # import brainunit as u -import dendritex as dx import matplotlib.pyplot as plt import brainstate as bst @@ -157,7 +156,7 @@ def update(self, t): # network net = EINet() -bst.nn.init_all_states(net, exclude=dx.IonChannel) +bst.nn.init_all_states(net) # simulation with bst.environ.context(dt=0.1 * u.ms): From d391ed7cc363f5bba2886d68f4eb51159489160e Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Mon, 9 Dec 2024 10:52:27 +0800 Subject: [PATCH 2/2] fix tests --- brainstate/functional/_activations.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/brainstate/functional/_activations.py b/brainstate/functional/_activations.py index 6aeb061..f1e7c3d 100644 --- a/brainstate/functional/_activations.py +++ b/brainstate/functional/_activations.py @@ -588,8 +588,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]: def log_softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]: + where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -604,8 +603,6 @@ def log_softmax(x: ArrayLike, axis: the axis or axes along which the :code:`log_softmax` should be computed. Either an integer or a tuple of integers. where: Elements to include in the :code:`log_softmax`. - initial: The minimum value used to shift the input array. Must be present - when :code:`where` is not None. Returns: An array. @@ -613,15 +610,12 @@ def log_softmax(x: ArrayLike, See also: :func:`softmax` """ - if initial is not None: - initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa - return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial) + return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where) def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]: + where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -645,9 +639,7 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ - if initial is not None: - initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa - return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial) + return _keep_unit(jax.nn.softmax, x, axis=axis, where=where) def standardize(x: ArrayLike,