From 1f7c789c374f8e309dca0ffbeed0cd6ad076bffb Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 8 Jun 2024 19:46:20 +0800 Subject: [PATCH] updates --- brainstate/environ.py | 3 +- brainstate/functional/_activations.py | 2 +- brainstate/functional/_normalization.py | 3 ++ brainstate/optim/_lr_scheduler_test.py | 13 +++++ brainstate/transform/_jit.py | 68 +++++++++++++++++-------- 5 files changed, 66 insertions(+), 23 deletions(-) diff --git a/brainstate/environ.py b/brainstate/environ.py index 0047c3a..60f011e 100644 --- a/brainstate/environ.py +++ b/brainstate/environ.py @@ -18,7 +18,8 @@ __all__ = [ 'set', 'context', 'get', 'all', 'set_host_device_count', 'set_platform', - 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision', + 'get_host_device_count', 'get_platform', + 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision', 'tolerance', 'dftype', 'ditype', 'dutype', 'dctype', ] diff --git a/brainstate/functional/_activations.py b/brainstate/functional/_activations.py index bafe5c6..39fa118 100644 --- a/brainstate/functional/_activations.py +++ b/brainstate/functional/_activations.py @@ -27,7 +27,7 @@ from jax.scipy.special import logsumexp from jax.typing import ArrayLike -from brainstate import math, random +from .. import math, random __all__ = [ "tanh", diff --git a/brainstate/functional/_normalization.py b/brainstate/functional/_normalization.py index 39b23e9..7c52610 100644 --- a/brainstate/functional/_normalization.py +++ b/brainstate/functional/_normalization.py @@ -20,11 +20,14 @@ import jax import jax.numpy as jnp +from .._utils import set_module_as + __all__ = [ 'weight_standardization', ] +@set_module_as('brainstate.functional') def weight_standardization( w: jax.typing.ArrayLike, eps: float = 1e-4, diff --git a/brainstate/optim/_lr_scheduler_test.py b/brainstate/optim/_lr_scheduler_test.py index 01a9988..10cdbed 100644 --- a/brainstate/optim/_lr_scheduler_test.py +++ b/brainstate/optim/_lr_scheduler_test.py @@ -34,3 +34,16 @@ def test1(self): self.assertTrue(jnp.allclose(r, 0.001)) else: self.assertTrue(jnp.allclose(r, 0.0001)) + + def test2(self): + lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)) + for i in range(40): + r = lr(i) + if i < 10: + self.assertEqual(r, 0.1) + elif i < 20: + self.assertTrue(jnp.allclose(r, 0.01)) + elif i < 30: + self.assertTrue(jnp.allclose(r, 0.001)) + else: + self.assertTrue(jnp.allclose(r, 0.0001)) diff --git a/brainstate/transform/_jit.py b/brainstate/transform/_jit.py index 3330eaf..7573309 100644 --- a/brainstate/transform/_jit.py +++ b/brainstate/transform/_jit.py @@ -23,8 +23,8 @@ from jax._src import sharding_impls from jax.lib import xla_client as xc -from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values from brainstate._utils import set_module_as +from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values __all__ = ['jit'] @@ -33,10 +33,13 @@ class JittedFunction(Callable): """ A wrapped version of ``fun``, set up for just-in-time compilation. """ - origin_fun: Callable # the original function + origin_fun: Callable # the original function stateful_fun: StatefulFunction # the stateful function for extracting states jitted_fun: jax.stages.Wrapped # the jitted function - clear_cache: Callable # clear the cache of the jitted function + clear_cache: Callable # clear the cache of the jitted function + + def __call__(self, *args, **kwargs): + pass def _get_jitted_fun( @@ -85,12 +88,16 @@ def clear_cache(): jit_fun.clear_cache() jitted_fun: JittedFunction + # the original function jitted_fun.origin_fun = fun.fun + # the stateful function for extracting states jitted_fun.stateful_fun = fun + # the jitted function jitted_fun.jitted_fun = jit_fun + # clear cache jitted_fun.clear_cache = clear_cache @@ -99,18 +106,18 @@ def clear_cache(): @set_module_as('brainstate.transform') def jit( - fun: Callable = None, - in_shardings=sharding_impls.UNSPECIFIED, - out_shardings=sharding_impls.UNSPECIFIED, - static_argnums: int | Sequence[int] | None = None, - donate_argnums: int | Sequence[int] | None = None, - donate_argnames: str | Iterable[str] | None = None, - keep_unused: bool = False, - device: xc.Device | None = None, - backend: str | None = None, - inline: bool = False, - abstracted_axes: Any | None = None, - **kwargs + fun: Callable = None, + in_shardings=sharding_impls.UNSPECIFIED, + out_shardings=sharding_impls.UNSPECIFIED, + static_argnums: int | Sequence[int] | None = None, + donate_argnums: int | Sequence[int] | None = None, + donate_argnames: str | Iterable[str] | None = None, + keep_unused: bool = False, + device: xc.Device | None = None, + backend: str | None = None, + inline: bool = False, + abstracted_axes: Any | None = None, + **kwargs ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]: """ Sets up ``fun`` for just-in-time compilation with XLA. @@ -228,12 +235,31 @@ def jit( if fun is None: def wrapper(fun_again: Callable) -> JittedFunction: - return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums, - donate_argnums, donate_argnames, keep_unused, - device, backend, inline, abstracted_axes, **kwargs) + return _get_jitted_fun(fun_again, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + donate_argnames, + keep_unused, + device, + backend, + inline, + abstracted_axes, + **kwargs) + return wrapper else: - return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums, - donate_argnums, donate_argnames, keep_unused, - device, backend, inline, abstracted_axes, **kwargs) + return _get_jitted_fun(fun, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + donate_argnames, + keep_unused, + device, + backend, + inline, + abstracted_axes, + **kwargs)