Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update documtation #1

Merged
merged 3 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt

## See also the BDP ecosystem

- [``brainpy``](https://github.com/brainpy/BrainPy): The solution for the general-purpose brain dynamics programming.
- [``brainstate``](https://github.com/brainpy/brainstate): A ``State``-based transformation system for brain dynamics programming.

- [``brainstate``](https://github.com/brainpy/brainstate): The core system for the next generation of BrainPy framework.
- [``brainunit``](https://github.com/brainpy/brainunit): The unit system for brain dynamics programming.

- [``braintools``](https://github.com/brainpy/braintools): The tools for the brain dynamics simulation and analysis.
- [``braintaichi``](https://github.com/brainpy/braintaichi): Leveraging Taichi Lang to customize brain dynamics operators.

- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning for biological spiking neural networks.
- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning framework for biological neural networks.

- [``braintools``](https://github.com/brainpy/braintools): The toolbox for the brain dynamics simulation, training and analysis.


3 changes: 2 additions & 1 deletion brainstate/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
__all__ = [
'set', 'context', 'get', 'all',
'set_host_device_count', 'set_platform',
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
'get_host_device_count', 'get_platform',
'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
'tolerance',
'dftype', 'ditype', 'dutype', 'dctype',
]
Expand Down
2 changes: 1 addition & 1 deletion brainstate/functional/_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from brainstate import math, random
from .. import math, random

__all__ = [
"tanh",
Expand Down
3 changes: 3 additions & 0 deletions brainstate/functional/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import jax
import jax.numpy as jnp

from .._utils import set_module_as

__all__ = [
'weight_standardization',
]


@set_module_as('brainstate.functional')
def weight_standardization(
w: jax.typing.ArrayLike,
eps: float = 1e-4,
Expand Down
13 changes: 13 additions & 0 deletions brainstate/optim/_lr_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,16 @@ def test1(self):
self.assertTrue(jnp.allclose(r, 0.001))
else:
self.assertTrue(jnp.allclose(r, 0.0001))

def test2(self):
lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
for i in range(40):
r = lr(i)
if i < 10:
self.assertEqual(r, 0.1)
elif i < 20:
self.assertTrue(jnp.allclose(r, 0.01))
elif i < 30:
self.assertTrue(jnp.allclose(r, 0.001))
else:
self.assertTrue(jnp.allclose(r, 0.0001))
68 changes: 47 additions & 21 deletions brainstate/transform/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from jax._src import sharding_impls
from jax.lib import xla_client as xc

from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
from brainstate._utils import set_module_as
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values

__all__ = ['jit']

Expand All @@ -33,10 +33,13 @@ class JittedFunction(Callable):
"""
A wrapped version of ``fun``, set up for just-in-time compilation.
"""
origin_fun: Callable # the original function
origin_fun: Callable # the original function
stateful_fun: StatefulFunction # the stateful function for extracting states
jitted_fun: jax.stages.Wrapped # the jitted function
clear_cache: Callable # clear the cache of the jitted function
clear_cache: Callable # clear the cache of the jitted function

def __call__(self, *args, **kwargs):
pass


def _get_jitted_fun(
Expand Down Expand Up @@ -85,12 +88,16 @@ def clear_cache():
jit_fun.clear_cache()

jitted_fun: JittedFunction

# the original function
jitted_fun.origin_fun = fun.fun

# the stateful function for extracting states
jitted_fun.stateful_fun = fun

# the jitted function
jitted_fun.jitted_fun = jit_fun

# clear cache
jitted_fun.clear_cache = clear_cache

Expand All @@ -99,18 +106,18 @@ def clear_cache():

@set_module_as('brainstate.transform')
def jit(
fun: Callable = None,
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
**kwargs
fun: Callable = None,
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
**kwargs
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
"""
Sets up ``fun`` for just-in-time compilation with XLA.
Expand Down Expand Up @@ -228,12 +235,31 @@ def jit(

if fun is None:
def wrapper(fun_again: Callable) -> JittedFunction:
return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums,
donate_argnums, donate_argnames, keep_unused,
device, backend, inline, abstracted_axes, **kwargs)
return _get_jitted_fun(fun_again,
in_shardings,
out_shardings,
static_argnums,
donate_argnums,
donate_argnames,
keep_unused,
device,
backend,
inline,
abstracted_axes,
**kwargs)

return wrapper

else:
return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums,
donate_argnums, donate_argnames, keep_unused,
device, backend, inline, abstracted_axes, **kwargs)
return _get_jitted_fun(fun,
in_shardings,
out_shardings,
static_argnums,
donate_argnums,
donate_argnames,
keep_unused,
device,
backend,
inline,
abstracted_axes,
**kwargs)
Binary file removed docs/_static/braincore.jpg
Binary file not shown.
6 changes: 5 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ API Documentation
:maxdepth: 1

apis/changelog.md
apis/braincore.rst
apis/brainstate.rst
apis/init.rst
apis/math.rst
apis/mixin.rst
apis/optim.rst
apis/random.rst
apis/surrogate.rst
apis/transform.rst
apis/util.rst
apis/brainstate.nn.rst
apis/brainstate.functional.rst
76 changes: 76 additions & 0 deletions docs/apis/brainstate.functional.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
``brainstate.functional`` module
================================

.. currentmodule:: brainstate.functional
.. automodule:: brainstate.functional

Activation Functions
--------------------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

tanh
relu
squareplus
softplus
soft_sign
sigmoid
silu
swish
log_sigmoid
elu
leaky_relu
hard_tanh
celu
selu
gelu
glu
logsumexp
log_softmax
softmax
standardize
one_hot
relu6
hard_sigmoid
hard_silu
hard_swish
hard_shrink
rrelu
mish
soft_shrink
prelu
tanh_shrink
softmin


Normalization Functions
-----------------------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

weight_standardization


Spiking Operations
------------------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

spike_bitwise_or
spike_bitwise_and
spike_bitwise_iand
spike_bitwise_not
spike_bitwise_xor
spike_bitwise_ixor
spike_bitwise


Loading
Loading