diff --git a/brainstate/augment/_autograd.py b/brainstate/augment/_autograd.py index de89979..96eb049 100644 --- a/brainstate/augment/_autograd.py +++ b/brainstate/augment/_autograd.py @@ -45,7 +45,7 @@ from brainstate.util import PrettyType, PrettyAttr, PrettyRepr __all__ = [ - 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian', + 'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian', ] A = TypeVar('A') @@ -159,6 +159,9 @@ def jacfun(*args, **kwargs): return jacfun +TransformFn = Callable + + class GradientTransform(PrettyRepr): """ Automatic Differentiation Transformations for the ``State`` system. @@ -168,11 +171,11 @@ class GradientTransform(PrettyRepr): def __init__( self, target: Callable, - transform: Callable, - grad_states: Any, - argnums: Optional[Union[int, Sequence[int]]], - return_value: bool, - has_aux: bool, + transform: TransformFn, + grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None, + argnums: Optional[Union[int, Sequence[int]]] = None, + return_value: bool = False, + has_aux: bool = False, transform_params: Optional[Dict[str, Any]] = None, ): # gradient variables diff --git a/brainstate/event/__init__.py b/brainstate/event/__init__.py index b6f9a22..e2e7150 100644 --- a/brainstate/event/__init__.py +++ b/brainstate/event/__init__.py @@ -14,14 +14,14 @@ # ============================================================================== -from ._csr import * -from ._csr import __all__ as __all_csr -from ._fixed_probability import * -from ._fixed_probability import __all__ as __all_fixed_probability -from ._linear import * +from ._csr_mv import * +from ._csr_mv import __all__ as __all_csr +from ._fixedprob_mv import * +from ._fixedprob_mv import __all__ as __all_fixed_probability +from ._linear_mv import * from ._xla_custom_op import * from ._xla_custom_op import __all__ as __all_xla_custom_op -from ._linear import __all__ as __all_linear +from ._linear_mv import __all__ as __all_linear __all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op diff --git a/brainstate/event/_csr.py b/brainstate/event/_csr_mv.py similarity index 100% rename from brainstate/event/_csr.py rename to brainstate/event/_csr_mv.py diff --git a/brainstate/event/_csr_benchmark.py b/brainstate/event/_csr_mv_benchmark.py similarity index 100% rename from brainstate/event/_csr_benchmark.py rename to brainstate/event/_csr_mv_benchmark.py diff --git a/brainstate/event/_csr_test.py b/brainstate/event/_csr_mv_test.py similarity index 100% rename from brainstate/event/_csr_test.py rename to brainstate/event/_csr_mv_test.py diff --git a/brainstate/event/_fixed_probability.py b/brainstate/event/_fixedprob_mv.py similarity index 100% rename from brainstate/event/_fixed_probability.py rename to brainstate/event/_fixedprob_mv.py diff --git a/brainstate/event/_fixed_probability_benchmark.py b/brainstate/event/_fixedprob_mv_benchmark.py similarity index 100% rename from brainstate/event/_fixed_probability_benchmark.py rename to brainstate/event/_fixedprob_mv_benchmark.py diff --git a/brainstate/event/_fixed_probability_test.py b/brainstate/event/_fixedprob_mv_test.py similarity index 100% rename from brainstate/event/_fixed_probability_test.py rename to brainstate/event/_fixedprob_mv_test.py diff --git a/brainstate/event/_linear.py b/brainstate/event/_linear_mv.py similarity index 100% rename from brainstate/event/_linear.py rename to brainstate/event/_linear_mv.py diff --git a/brainstate/event/_linear_benckmark.py b/brainstate/event/_linear_mv_benckmark.py similarity index 100% rename from brainstate/event/_linear_benckmark.py rename to brainstate/event/_linear_mv_benckmark.py diff --git a/brainstate/event/_linear_test.py b/brainstate/event/_linear_mv_test.py similarity index 98% rename from brainstate/event/_linear_test.py rename to brainstate/event/_linear_mv_test.py index 2618d12..09b9b7e 100644 --- a/brainstate/event/_linear_test.py +++ b/brainstate/event/_linear_mv_test.py @@ -20,7 +20,7 @@ from absl.testing import parameterized import brainstate as bst -from brainstate.event._linear import Linear +from brainstate.event._linear_mv import Linear class TestEventLinear(parameterized.TestCase): diff --git a/brainstate/event/_xla_custom_op.py b/brainstate/event/_xla_custom_op.py index 750c3ce..6706dd4 100644 --- a/brainstate/event/_xla_custom_op.py +++ b/brainstate/event/_xla_custom_op.py @@ -17,14 +17,8 @@ numba_installed = importlib.util.find_spec('numba') is not None -if numba_installed: - import numba # pylint: disable=import-error - from numba import types, carray, cfunc # pylint: disable=import-error - from numba.core.dispatcher import Dispatcher # pylint: disable=import-error -else: - numba = None - __all__ = [ + 'defjvp', 'XLACustomOp', ] @@ -93,9 +87,12 @@ def _numba_mlir_cpu_translation_rule( *ins, **kwargs ): - if numba is None: + if not numba_installed: raise ImportError('Numba is required to compile the CPU kernel for the custom operator.') + from numba import types, carray, cfunc # pylint: disable=import-error + from numba.core.dispatcher import Dispatcher # pylint: disable=import-error + if not isinstance(kernel, Dispatcher): kernel = kernel(**kwargs) assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}' diff --git a/brainstate/optim/_optax_optimizer.py b/brainstate/optim/_optax_optimizer.py index 8575d1b..8c1e8ff 100644 --- a/brainstate/optim/_optax_optimizer.py +++ b/brainstate/optim/_optax_optimizer.py @@ -27,6 +27,7 @@ __all__ = [ 'OptaxOptimizer', + 'LBFGS', ] @@ -133,3 +134,21 @@ def update(self, grads: Dict[Hashable, PyTree]): for k, v in self.param_states.items(): v.value = new_params[k] self.opt_state.value = new_opt_state + + +class LBFGS(OptaxOptimizer): + def __init__( + self, + lr: float, + memory_size: int = 10, + scale_init_precond: bool = True, + ): + import optax # type: ignore[import-not-found,import-untyped] + super().__init__( + optax.lbfgs( + lr, + memory_size=memory_size, + scale_init_precond=scale_init_precond, + linesearch=None, + ) + )