Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 8, 2024
1 parent ab7b5c2 commit 1f7c789
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 23 deletions.
3 changes: 2 additions & 1 deletion brainstate/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
__all__ = [
'set', 'context', 'get', 'all',
'set_host_device_count', 'set_platform',
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
'get_host_device_count', 'get_platform',
'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
'tolerance',
'dftype', 'ditype', 'dutype', 'dctype',
]
Expand Down
2 changes: 1 addition & 1 deletion brainstate/functional/_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from brainstate import math, random
from .. import math, random

__all__ = [
"tanh",
Expand Down
3 changes: 3 additions & 0 deletions brainstate/functional/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import jax
import jax.numpy as jnp

from .._utils import set_module_as

__all__ = [
'weight_standardization',
]


@set_module_as('brainstate.functional')
def weight_standardization(
w: jax.typing.ArrayLike,
eps: float = 1e-4,
Expand Down
13 changes: 13 additions & 0 deletions brainstate/optim/_lr_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,16 @@ def test1(self):
self.assertTrue(jnp.allclose(r, 0.001))
else:
self.assertTrue(jnp.allclose(r, 0.0001))

def test2(self):
lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
for i in range(40):
r = lr(i)
if i < 10:
self.assertEqual(r, 0.1)
elif i < 20:
self.assertTrue(jnp.allclose(r, 0.01))
elif i < 30:
self.assertTrue(jnp.allclose(r, 0.001))
else:
self.assertTrue(jnp.allclose(r, 0.0001))
68 changes: 47 additions & 21 deletions brainstate/transform/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from jax._src import sharding_impls
from jax.lib import xla_client as xc

from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
from brainstate._utils import set_module_as
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values

__all__ = ['jit']

Expand All @@ -33,10 +33,13 @@ class JittedFunction(Callable):
"""
A wrapped version of ``fun``, set up for just-in-time compilation.
"""
origin_fun: Callable # the original function
origin_fun: Callable # the original function
stateful_fun: StatefulFunction # the stateful function for extracting states
jitted_fun: jax.stages.Wrapped # the jitted function
clear_cache: Callable # clear the cache of the jitted function
clear_cache: Callable # clear the cache of the jitted function

def __call__(self, *args, **kwargs):
pass


def _get_jitted_fun(
Expand Down Expand Up @@ -85,12 +88,16 @@ def clear_cache():
jit_fun.clear_cache()

jitted_fun: JittedFunction

# the original function
jitted_fun.origin_fun = fun.fun

# the stateful function for extracting states
jitted_fun.stateful_fun = fun

# the jitted function
jitted_fun.jitted_fun = jit_fun

# clear cache
jitted_fun.clear_cache = clear_cache

Expand All @@ -99,18 +106,18 @@ def clear_cache():

@set_module_as('brainstate.transform')
def jit(
fun: Callable = None,
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
**kwargs
fun: Callable = None,
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
**kwargs
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
"""
Sets up ``fun`` for just-in-time compilation with XLA.
Expand Down Expand Up @@ -228,12 +235,31 @@ def jit(

if fun is None:
def wrapper(fun_again: Callable) -> JittedFunction:
return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums,
donate_argnums, donate_argnames, keep_unused,
device, backend, inline, abstracted_axes, **kwargs)
return _get_jitted_fun(fun_again,
in_shardings,
out_shardings,
static_argnums,
donate_argnums,
donate_argnames,
keep_unused,
device,
backend,
inline,
abstracted_axes,
**kwargs)

return wrapper

else:
return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums,
donate_argnums, donate_argnames, keep_unused,
device, backend, inline, abstracted_axes, **kwargs)
return _get_jitted_fun(fun,
in_shardings,
out_shardings,
static_argnums,
donate_argnums,
donate_argnames,
keep_unused,
device,
backend,
inline,
abstracted_axes,
**kwargs)

0 comments on commit 1f7c789

Please sign in to comment.