Skip to content

Commit

Permalink
Update documtation (#1)
Browse files Browse the repository at this point in the history
* updates

* update documentation

* update readme
  • Loading branch information
chaoming0625 authored Jun 9, 2024
1 parent ab7b5c2 commit 2a1164d
Show file tree
Hide file tree
Showing 21 changed files with 533 additions and 92 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt

## See also the BDP ecosystem

- [``brainpy``](https://github.com/brainpy/BrainPy): The solution for the general-purpose brain dynamics programming.
- [``brainstate``](https://github.com/brainpy/brainstate): A ``State``-based transformation system for brain dynamics programming.

- [``brainstate``](https://github.com/brainpy/brainstate): The core system for the next generation of BrainPy framework.
- [``brainunit``](https://github.com/brainpy/brainunit): The unit system for brain dynamics programming.

- [``braintools``](https://github.com/brainpy/braintools): The tools for the brain dynamics simulation and analysis.
- [``braintaichi``](https://github.com/brainpy/braintaichi): Leveraging Taichi Lang to customize brain dynamics operators.

- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning for biological spiking neural networks.
- [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning framework for biological neural networks.

- [``braintools``](https://github.com/brainpy/braintools): The toolbox for the brain dynamics simulation, training and analysis.


3 changes: 2 additions & 1 deletion brainstate/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
__all__ = [
'set', 'context', 'get', 'all',
'set_host_device_count', 'set_platform',
'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
'get_host_device_count', 'get_platform',
'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
'tolerance',
'dftype', 'ditype', 'dutype', 'dctype',
]
Expand Down
2 changes: 1 addition & 1 deletion brainstate/functional/_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from brainstate import math, random
from .. import math, random

__all__ = [
"tanh",
Expand Down
3 changes: 3 additions & 0 deletions brainstate/functional/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import jax
import jax.numpy as jnp

from .._utils import set_module_as

__all__ = [
'weight_standardization',
]


@set_module_as('brainstate.functional')
def weight_standardization(
w: jax.typing.ArrayLike,
eps: float = 1e-4,
Expand Down
13 changes: 13 additions & 0 deletions brainstate/optim/_lr_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,16 @@ def test1(self):
self.assertTrue(jnp.allclose(r, 0.001))
else:
self.assertTrue(jnp.allclose(r, 0.0001))

def test2(self):
lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
for i in range(40):
r = lr(i)
if i < 10:
self.assertEqual(r, 0.1)
elif i < 20:
self.assertTrue(jnp.allclose(r, 0.01))
elif i < 30:
self.assertTrue(jnp.allclose(r, 0.001))
else:
self.assertTrue(jnp.allclose(r, 0.0001))
68 changes: 47 additions & 21 deletions brainstate/transform/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from jax._src import sharding_impls
from jax.lib import xla_client as xc

from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
from brainstate._utils import set_module_as
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values

__all__ = ['jit']

Expand All @@ -33,10 +33,13 @@ class JittedFunction(Callable):
"""
A wrapped version of ``fun``, set up for just-in-time compilation.
"""
origin_fun: Callable # the original function
origin_fun: Callable # the original function
stateful_fun: StatefulFunction # the stateful function for extracting states
jitted_fun: jax.stages.Wrapped # the jitted function
clear_cache: Callable # clear the cache of the jitted function
clear_cache: Callable # clear the cache of the jitted function

def __call__(self, *args, **kwargs):
pass


def _get_jitted_fun(
Expand Down Expand Up @@ -85,12 +88,16 @@ def clear_cache():
jit_fun.clear_cache()

jitted_fun: JittedFunction

# the original function
jitted_fun.origin_fun = fun.fun

# the stateful function for extracting states
jitted_fun.stateful_fun = fun

# the jitted function
jitted_fun.jitted_fun = jit_fun

# clear cache
jitted_fun.clear_cache = clear_cache

Expand All @@ -99,18 +106,18 @@ def clear_cache():

@set_module_as('brainstate.transform')
def jit(
fun: Callable = None,
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
**kwargs
fun: Callable = None,
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
**kwargs
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
"""
Sets up ``fun`` for just-in-time compilation with XLA.
Expand Down Expand Up @@ -228,12 +235,31 @@ def jit(

if fun is None:
def wrapper(fun_again: Callable) -> JittedFunction:
return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums,
donate_argnums, donate_argnames, keep_unused,
device, backend, inline, abstracted_axes, **kwargs)
return _get_jitted_fun(fun_again,
in_shardings,
out_shardings,
static_argnums,
donate_argnums,
donate_argnames,
keep_unused,
device,
backend,
inline,
abstracted_axes,
**kwargs)

return wrapper

else:
return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums,
donate_argnums, donate_argnames, keep_unused,
device, backend, inline, abstracted_axes, **kwargs)
return _get_jitted_fun(fun,
in_shardings,
out_shardings,
static_argnums,
donate_argnums,
donate_argnames,
keep_unused,
device,
backend,
inline,
abstracted_axes,
**kwargs)
Binary file removed docs/_static/braincore.jpg
Binary file not shown.
6 changes: 5 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ API Documentation
:maxdepth: 1

apis/changelog.md
apis/braincore.rst
apis/brainstate.rst
apis/init.rst
apis/math.rst
apis/mixin.rst
apis/optim.rst
apis/random.rst
apis/surrogate.rst
apis/transform.rst
apis/util.rst
apis/brainstate.nn.rst
apis/brainstate.functional.rst
76 changes: 76 additions & 0 deletions docs/apis/brainstate.functional.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
``brainstate.functional`` module
================================

.. currentmodule:: brainstate.functional
.. automodule:: brainstate.functional

Activation Functions
--------------------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

tanh
relu
squareplus
softplus
soft_sign
sigmoid
silu
swish
log_sigmoid
elu
leaky_relu
hard_tanh
celu
selu
gelu
glu
logsumexp
log_softmax
softmax
standardize
one_hot
relu6
hard_sigmoid
hard_silu
hard_swish
hard_shrink
rrelu
mish
soft_shrink
prelu
tanh_shrink
softmin


Normalization Functions
-----------------------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

weight_standardization


Spiking Operations
------------------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

spike_bitwise_or
spike_bitwise_and
spike_bitwise_iand
spike_bitwise_not
spike_bitwise_xor
spike_bitwise_ixor
spike_bitwise


Loading

0 comments on commit 2a1164d

Please sign in to comment.