Skip to content

Commit

Permalink
update file names
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Nov 22, 2024
1 parent 78fc1a9 commit 7a8f8fe
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 21 deletions.
15 changes: 9 additions & 6 deletions brainstate/augment/_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from brainstate.util import PrettyType, PrettyAttr, PrettyRepr

__all__ = [
'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
]

A = TypeVar('A')
Expand Down Expand Up @@ -159,6 +159,9 @@ def jacfun(*args, **kwargs):
return jacfun


TransformFn = Callable


class GradientTransform(PrettyRepr):
"""
Automatic Differentiation Transformations for the ``State`` system.
Expand All @@ -168,11 +171,11 @@ class GradientTransform(PrettyRepr):
def __init__(
self,
target: Callable,
transform: Callable,
grad_states: Any,
argnums: Optional[Union[int, Sequence[int]]],
return_value: bool,
has_aux: bool,
transform: TransformFn,
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
return_value: bool = False,
has_aux: bool = False,
transform_params: Optional[Dict[str, Any]] = None,
):
# gradient variables
Expand Down
12 changes: 6 additions & 6 deletions brainstate/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# ==============================================================================


from ._csr import *
from ._csr import __all__ as __all_csr
from ._fixed_probability import *
from ._fixed_probability import __all__ as __all_fixed_probability
from ._linear import *
from ._csr_mv import *
from ._csr_mv import __all__ as __all_csr
from ._fixedprob_mv import *
from ._fixedprob_mv import __all__ as __all_fixed_probability
from ._linear_mv import *
from ._xla_custom_op import *
from ._xla_custom_op import __all__ as __all_xla_custom_op
from ._linear import __all__ as __all_linear
from ._linear_mv import __all__ as __all_linear

__all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from absl.testing import parameterized

import brainstate as bst
from brainstate.event._linear import Linear
from brainstate.event._linear_mv import Linear


class TestEventLinear(parameterized.TestCase):
Expand Down
13 changes: 5 additions & 8 deletions brainstate/event/_xla_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,8 @@

numba_installed = importlib.util.find_spec('numba') is not None

if numba_installed:
import numba # pylint: disable=import-error
from numba import types, carray, cfunc # pylint: disable=import-error
from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
else:
numba = None

__all__ = [
'defjvp',
'XLACustomOp',
]

Expand Down Expand Up @@ -93,9 +87,12 @@ def _numba_mlir_cpu_translation_rule(
*ins,
**kwargs
):
if numba is None:
if not numba_installed:
raise ImportError('Numba is required to compile the CPU kernel for the custom operator.')

from numba import types, carray, cfunc # pylint: disable=import-error
from numba.core.dispatcher import Dispatcher # pylint: disable=import-error

if not isinstance(kernel, Dispatcher):
kernel = kernel(**kwargs)
assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}'
Expand Down
19 changes: 19 additions & 0 deletions brainstate/optim/_optax_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

__all__ = [
'OptaxOptimizer',
'LBFGS',
]


Expand Down Expand Up @@ -133,3 +134,21 @@ def update(self, grads: Dict[Hashable, PyTree]):
for k, v in self.param_states.items():
v.value = new_params[k]
self.opt_state.value = new_opt_state


class LBFGS(OptaxOptimizer):
def __init__(
self,
lr: float,
memory_size: int = 10,
scale_init_precond: bool = True,
):
import optax # type: ignore[import-not-found,import-untyped]
super().__init__(
optax.lbfgs(
lr,
memory_size=memory_size,
scale_init_precond=scale_init_precond,
linesearch=None,
)
)

0 comments on commit 7a8f8fe

Please sign in to comment.