diff --git a/brainstate/environ.py b/brainstate/environ.py index ff90cdb..2ffc7d9 100644 --- a/brainstate/environ.py +++ b/brainstate/environ.py @@ -31,7 +31,7 @@ from jax.typing import DTypeLike from .mixin import Mode -from .util import MemScaling, IdMemScaling +from .util import MemScaling __all__ = [ # functions for environment settings @@ -555,4 +555,5 @@ def register_default_behavior(key: str, behavior: Callable, replace_if_exist: bo DFAULT.functions[key] = behavior -set(dt=0.1, precision=32, mode=Mode(), mem_scaling=IdMemScaling()) +set(precision=32) + diff --git a/brainstate/environ_test.py b/brainstate/environ_test.py index ebda7c4..54f1cc0 100644 --- a/brainstate/environ_test.py +++ b/brainstate/environ_test.py @@ -42,6 +42,8 @@ def test_platform(self): self.assertEqual(a.device(), 'cpu') def test_register_default_behavior(self): + bst.environ.set(dt=0.1) + dt_ = 0.1 def dt_behavior(dt): diff --git a/brainstate/nn/_collective_ops.py b/brainstate/nn/_collective_ops.py index b2f9bcc..3a18443 100644 --- a/brainstate/nn/_collective_ops.py +++ b/brainstate/nn/_collective_ops.py @@ -118,7 +118,7 @@ def reset_all_states(target: Module, *args, **kwargs) -> Module: nodes_with_order = [] # reset node whose `init_state` has no `call_order` - for path, node in nodes(target).items(): + for path, node in nodes(target).filter(Module).items(): if hasattr(node.reset_state, 'call_order'): nodes_with_order.append(node) else: diff --git a/brainstate/nn/_dyn_impl/_dynamics_neuron_test.py b/brainstate/nn/_dyn_impl/_dynamics_neuron_test.py index 7cdbdf6..46a732d 100644 --- a/brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +++ b/brainstate/nn/_dyn_impl/_dynamics_neuron_test.py @@ -157,4 +157,5 @@ def test_keep_size(self): if __name__ == '__main__': - unittest.main() + with bst.environ.context(dt=0.1): + unittest.main() diff --git a/brainstate/nn/_dyn_impl/_dynamics_synapse_test.py b/brainstate/nn/_dyn_impl/_dynamics_synapse_test.py index dc1498f..8264791 100644 --- a/brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +++ b/brainstate/nn/_dyn_impl/_dynamics_synapse_test.py @@ -127,4 +127,5 @@ def test_keep_size(self): if __name__ == '__main__': - unittest.main() + with bst.environ.context(dt=0.1): + unittest.main() diff --git a/brainstate/nn/_dyn_impl/_readout_test.py b/brainstate/nn/_dyn_impl/_readout_test.py index daf63cf..4faa1f7 100644 --- a/brainstate/nn/_dyn_impl/_readout_test.py +++ b/brainstate/nn/_dyn_impl/_readout_test.py @@ -45,4 +45,5 @@ def test_LeakySpikeReadout(self): if __name__ == '__main__': - unittest.main() + with bst.environ.context(dt=0.1): + unittest.main() diff --git a/brainstate/nn/_exp_euler_test.py b/brainstate/nn/_exp_euler_test.py index b53c2eb..6634be2 100644 --- a/brainstate/nn/_exp_euler_test.py +++ b/brainstate/nn/_exp_euler_test.py @@ -26,8 +26,9 @@ def test1(self): def fun(x, tau): return -x / tau - with self.assertRaises(AssertionError): - r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms) + with bst.environ.context(dt=0.1): + with self.assertRaises(AssertionError): + r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms) with bst.environ.context(dt=1. * u.ms): r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms) diff --git a/brainstate/nn/_interaction/_connections.py b/brainstate/nn/_interaction/_connections.py index 4b3ef3d..f053e45 100644 --- a/brainstate/nn/_interaction/_connections.py +++ b/brainstate/nn/_interaction/_connections.py @@ -116,7 +116,7 @@ def update(self, x): weight = params['weight'] if self.w_mask is not None: weight = weight * self.w_mask - y = jnp.dot(x, weight) + y = u.math.dot(x, weight) if 'bias' in params: y = y + params['bias'] return y diff --git a/brainstate/nn/_module_test.py b/brainstate/nn/_module_test.py index 832eaa3..37e4c5d 100644 --- a/brainstate/nn/_module_test.py +++ b/brainstate/nn/_module_test.py @@ -202,3 +202,9 @@ def __init__(self): print(b.states()) print(b.states(level=0)) print(b.states(level=0)) + + +if __name__ == '__main__': + with bst.environ.context(dt=0.1): + unittest.main() + diff --git a/brainstate/optim/_sgd_optimizer.py b/brainstate/optim/_sgd_optimizer.py index aa9a9bc..5b581d4 100644 --- a/brainstate/optim/_sgd_optimizer.py +++ b/brainstate/optim/_sgd_optimizer.py @@ -252,13 +252,10 @@ def __init__( momentum: float = 0.9, weight_decay: Optional[float] = None, ): - super(Momentum, self).__init__(lr=lr, weight_decay=weight_decay) + super().__init__(lr=lr, weight_decay=weight_decay) self.momentum = fcast(momentum) self.momentum_states = StateDictManager() - def extra_repr(self) -> str: - return f", momentum={self.momentum}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -318,14 +315,11 @@ def __init__( weight_decay: Optional[float] = None, momentum: float = 0.9, ): - super(MomentumNesterov, self).__init__(lr=lr, weight_decay=weight_decay) + super().__init__(lr=lr, weight_decay=weight_decay) self.momentum = fcast(momentum) self.momentum_states = StateDictManager() - def extra_repr(self) -> str: - return f", momentum={self.momentum}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -390,9 +384,6 @@ def __init__( self.epsilon = fcast(epsilon) self.cache_states = StateDictManager() - def extra_repr(self) -> str: - return f", epsilon={self.epsilon}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -472,9 +463,6 @@ def __init__( self.cache_states = StateDictManager() self.delta_states = StateDictManager() - def extra_repr(self) -> str: - return f", epsilon={self.epsilon}, rho={self.rho}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -539,15 +527,12 @@ def __init__( epsilon: float = 1e-6, rho: float = 0.9, ): - super(RMSProp, self).__init__(lr=lr, weight_decay=weight_decay) + super().__init__(lr=lr, weight_decay=weight_decay) self.epsilon = fcast(epsilon) self.rho = fcast(rho) self.cache_states = StateDictManager() - def extra_repr(self) -> str: - return f", epsilon={self.epsilon}, rho={self.rho}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -604,7 +589,7 @@ def __init__( eps: float = 1e-8, weight_decay: Optional[float] = None, ): - super(Adam, self).__init__(lr=lr, weight_decay=weight_decay) + super().__init__(lr=lr, weight_decay=weight_decay) self.beta1 = fcast(beta1) self.beta2 = fcast(beta2) @@ -612,9 +597,6 @@ def __init__( self.m1_states = StateDictManager() self.m2_states = StateDictManager() - def extra_repr(self) -> str: - return f", beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -685,7 +667,7 @@ def __init__( tc: float = 1e-3, eps: float = 1e-5, ): - super(LARS, self).__init__(lr=lr, weight_decay=weight_decay) + super().__init__(lr=lr, weight_decay=weight_decay) assert self.weight_decay is None, 'LARS does not support weight decay.' self.momentum = fcast(momentum) @@ -693,9 +675,6 @@ def __init__( self.eps = fcast(eps) self.momentum_states = StateDictManager() - def extra_repr(self) -> str: - return f", momentum={self.momentum}, tc={self.tc}, eps={self.eps}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -777,7 +756,7 @@ def __init__( weight_decay: float = 0.02, no_prox: bool = False, ): - super(Adan, self).__init__(lr=lr, weight_decay=weight_decay) + super().__init__(lr=lr, weight_decay=weight_decay) assert len(betas) == 3 if eps < 0.: @@ -797,9 +776,6 @@ def __init__( self.exp_avg_diff_states = StateDictManager() self.pre_grad_states = StateDictManager() - def extra_repr(self) -> str: - return f", betas={self.betas}, eps={self.eps}, weight_decay={self.weight_decay}, no_prox={self.no_prox}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -925,7 +901,7 @@ def __init__( weight_decay: float = 1e-2, amsgrad: bool = False, ): - super(AdamW, self).__init__(lr=lr, weight_decay=weight_decay) + super().__init__(lr=lr, weight_decay=weight_decay) if eps < 0.: raise ValueError("Invalid epsilon value: {}".format(eps)) @@ -945,10 +921,6 @@ def __init__( if self.amsgrad: self.vmax_states = StateDictManager() - def extra_repr(self) -> str: - return (f", beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}" - f", weight_decay={self.weight_decay}, amsgrad={self.amsgrad}") - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' @@ -1043,7 +1015,7 @@ def __init__( eps: float = 1e-30, weight_decay: Optional[float] = None, ): - super(SM3, self).__init__(lr=lr, weight_decay=weight_decay) + super().__init__(lr=lr, weight_decay=weight_decay) if not 0.0 <= momentum < 1.0: raise ValueError("Invalid momentum: {0}".format(momentum)) @@ -1057,9 +1029,6 @@ def __init__( self.momentum = fcast(momentum) self.memory_states = StateDictManager() - def extra_repr(self) -> str: - return f", beta={self.beta}, momentum={self.momentum}, eps={self.eps}" - def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None): train_states = dict() if train_states is None else train_states assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.' diff --git a/docs/index.rst b/docs/index.rst index 36993b6..b4afc33 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -37,7 +37,7 @@ Features .. div:: sd-font-normal - ``BrainState`` enables `event-driven computation <./apis/event.rst>`__ for spiking neural networks, + ``BrainState`` enables `event-driven computation <./apis/event.html>`__ for spiking neural networks, and thus obtains unprecedented performance on CPU and GPU devices. @@ -52,7 +52,7 @@ Features .. div:: sd-font-normal - ``BrainState`` supports `program compilation <./apis/compile.rst>`__ (such as just-in-time compilation) with its `state-based <./apis/brainstate.rst>`__ IR construction. + ``BrainState`` supports `program compilation <./apis/compile.html>`__ (such as just-in-time compilation) with its `state-based <./apis/brainstate.html>`__ IR construction. @@ -66,7 +66,7 @@ Features .. div:: sd-font-normal - ``BrainState`` supports program `functionality augmentation <./apis/augment.rst>`__ (such batching) with its `graph-based <./apis/graph.rst>`__ Python objects. + ``BrainState`` supports program `functionality augmentation <./apis/augment.html>`__ (such batching) with its `graph-based <./apis/graph.html>`__ Python objects. diff --git a/examples/004_scan_over_layers.py b/examples/004_scan_over_layers.py index bae1ab3..9708fdd 100644 --- a/examples/004_scan_over_layers.py +++ b/examples/004_scan_over_layers.py @@ -60,10 +60,12 @@ def loop_fn(block_tree): # Feed the output of the previous layer to the next layer block: Block = bst.graph.treefy_merge(graphdef, block_tree) activation.value = block(activation.value) + return bst.graph.treefy_split(block)[1] # Loop over each layer in the block tree graphdef, statetree = bst.graph.treefy_split(self.layers) - bst.compile.for_loop(loop_fn, statetree) + block_trees = bst.compile.for_loop(loop_fn, statetree) + bst.graph.update_states(self.layers, block_trees) return activation.value diff --git a/examples/200_surrogate_grad_lif.py b/examples/200_surrogate_grad_lif.py new file mode 100644 index 0000000..4ff1642 --- /dev/null +++ b/examples/200_surrogate_grad_lif.py @@ -0,0 +1,159 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +""" +Reproduce the results of the``spytorch`` tutorial 1: + +- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial1.ipynb + +""" + +import time + +import braintools as bts +import brainunit as u +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +import brainstate as bst + + +class SNN(bst.nn.DynamicsGroup): + def __init__(self, num_in, num_rec, num_out): + super(SNN, self).__init__() + + # parameters + self.num_in = num_in + self.num_rec = num_rec + self.num_out = num_out + + # synapse: i->r + self.i2r = bst.nn.Sequential( + bst.nn.Linear( + num_in, num_rec, + w_init=bst.init.KaimingNormal(scale=2., unit=u.mA), + b_init=bst.init.ZeroInit(unit=u.mA) + ), + bst.nn.Expon(num_rec, tau=10. * u.ms, g_initializer=bst.init.Constant(0. * u.mA)) + ) + # recurrent: r + self.r = bst.nn.LIF( + num_rec, tau=20 * u.ms, V_reset=0 * u.mV, + V_rest=0 * u.mV, V_th=1. * u.mV, + spk_fun=bst.surrogate.ReluGrad() + ) + # synapse: r->o + self.r2o = bst.nn.Linear(num_rec, num_out, w_init=bst.init.KaimingNormal()) + # # output: o + self.o = bst.nn.Expon(num_out, tau=10. * u.ms, g_initializer=bst.init.Constant(0.)) + + def update(self, spike): + return self.o(self.r2o(self.r(self.i2r(spike)))) + + def predict(self, spike): + rec_spikes = self.r(self.i2r(spike)) + out = self.o(self.r2o(rec_spikes)) + return self.r.V.value, rec_spikes, out + + +def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5, show=True): + fig, gs = bts.visualize.get_figure(*dim, 3, 3) + if spk is not None: + mem[spk > 0.0] = spike_height + if isinstance(mem, u.Quantity): + mem = mem.to_decimal(u.mV) + for i in range(np.prod(dim)): + if i == 0: + a0 = ax = plt.subplot(gs[i]) + else: + ax = plt.subplot(gs[i], sharey=a0) + ax.plot(mem[:, i]) + if show: + plt.show() + + +def print_classification_accuracy(output, target): + """ Dirty little helper function to compute classification accuracy. """ + m = u.math.max(output, axis=0) # max over time + am = u.math.argmax(m, axis=1) # argmax over output units + acc = u.math.mean(target == am) # compare to labels + print("Accuracy %.3f" % acc) + + +def predict_and_visualize_net_activity(net): + bst.nn.init_all_states(net, batch_size=num_sample) + vs, spikes, outs = bst.compile.for_loop(net.predict, x_data, pbar=bst.compile.ProgressBar(10)) + plot_voltage_traces(vs, spikes, spike_height=5 * u.mV, show=False) + plot_voltage_traces(outs) + print_classification_accuracy(outs, y_data) + + +with bst.environ.context(dt=1.0 * u.ms): + # network + net = SNN(100, 20, 2) + + # dataset + num_step = 1000 + num_sample = 256 + freq = 20 * u.Hz + x_data = bst.random.rand(num_step, num_sample, net.num_in) < freq * bst.environ.get_dt() + y_data = u.math.asarray(bst.random.rand(num_sample) < 0.5, dtype=int) + + # Before training + predict_and_visualize_net_activity(net) + + # brainstate optimizer + optimizer = bst.optim.Adam(lr=1e-3) + optimizer.register_trainable_weights(net.states(bst.ParamState)) + + # # optax optimizer + # import optax + # optimizer = bst.optim.OptaxOptimizer(net.states(bst.ParamState), optax.adam(1e-3)) + + + def loss_fn(): + predictions = bst.compile.for_loop(net.update, x_data) + predictions = u.math.mean(predictions, axis=0) # [T, B, C] -> [B, C] + return bts.metric.softmax_cross_entropy_with_integer_labels(predictions, y_data).mean() + + + @bst.compile.jit + def train_fn(): + bst.nn.init_all_states(net, batch_size=num_sample) + grads, l = bst.augment.grad(loss_fn, net.states(bst.ParamState), return_value=True)() + optimizer.update(grads) + return l + + + # train the network + train_losses = [] + t0 = time.time() + for i in range(1, 1001): + loss = train_fn() + train_losses.append(loss) + if i % 100 == 0: + print(f'Train {i} epoch, loss = {loss:.4f}, used time {time.time() - t0:.4f} s') + t0 = time.time() + + # visualize the training losses + plt.plot(np.asarray(jnp.asarray(train_losses))) + plt.xlabel("Epoch") + plt.ylabel("Training Loss") + plt.title("Training Loss vs Epoch") + + # predict the output according to the input data + predict_and_visualize_net_activity(net) diff --git a/examples/201_surrogate_grad_lif_fashion_mnist.py b/examples/201_surrogate_grad_lif_fashion_mnist.py new file mode 100644 index 0000000..1dbc05d --- /dev/null +++ b/examples/201_surrogate_grad_lif_fashion_mnist.py @@ -0,0 +1,219 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Reproduce the results of the``spytorch`` tutorial 2 & 3: + +- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial2.ipynb +- https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial3.ipynb + +""" + +import time + +import braintools as bts +import brainunit as u +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from datasets import load_dataset + +import brainstate as bst + +dataset = load_dataset("zalando-datasets/fashion_mnist") + +# images +X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8) +X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8) +X_train = (X_train / 255).reshape(-1, 28 * 28).astype(jnp.float32) +X_test = (X_test / 255).reshape(-1, 28 * 28).astype(jnp.float32) +print(f'Training image shape: {X_train.shape}, testing image shape: {X_test.shape}') +# labels +Y_train = np.array(dataset['train']['label'], dtype=np.int32) +Y_test = np.array(dataset['test']['label'], dtype=np.int32) + + +class SNN(bst.nn.DynamicsGroup): + """ + This class implements a spiking neural network model with three layers: + + i >> r >> o + + Each two layers are connected through the exponential synapse model. + """ + + def __init__(self, num_in, num_rec, num_out): + super().__init__() + + # parameters + self.num_in = num_in + self.num_rec = num_rec + self.num_out = num_out + + # synapse: i->r + self.i2r = bst.nn.Sequential( + bst.nn.Linear(num_in, num_rec, w_init=bst.init.KaimingNormal(scale=40.)), + bst.nn.Expon(num_rec, tau=10. * u.ms, g_initializer=bst.init.ZeroInit()) + ) + # recurrent: r + self.r = bst.nn.LIF(num_rec, tau=10 * u.ms, V_reset=0 * u.mV, V_rest=0 * u.mV, V_th=1. * u.mV) + # synapse: r->o + self.r2o = bst.nn.Sequential( + bst.nn.Linear(num_rec, num_out, w_init=bst.init.KaimingNormal(scale=2.)), + bst.nn.Expon(num_out, tau=10. * u.ms, g_initializer=bst.init.ZeroInit()) + ) + + def update(self, spikes): + r_spikes = self.r(self.i2r(spikes) * u.mA) + out = self.r2o(r_spikes) + return out, r_spikes + + def predict(self, spikes): + r_spikes = self.r(self.i2r(spikes) * u.mA) + out = self.r2o(r_spikes) + return out, r_spikes, self.r.V.value + + +with bst.environ.context(dt=1.0 * u.ms): + # inputs + batch_size = 256 + + # spiking neural networks + net = SNN(num_in=X_train.shape[-1], num_rec=100, num_out=10) + + # encoding inputs as spikes + encoder = bts.LatencyEncoder(tau=100 * u.ms) + + + @bst.compile.jit + def predict(xs): + bst.nn.init_all_states(net, xs.shape[0]) + xs = encoder(xs) + outs, spikes, vs = bst.compile.for_loop(net.predict, xs) + return outs, spikes, vs + + + def visualize(xs): + # visualization function + outs, spikes, vs = predict(xs) + xs = np.asarray(encoder(xs)) + vs = np.asarray(vs.to_decimal(u.mV)) + # vs = np.where(spikes, vs, 5.0) + fig, gs = bts.visualize.get_figure(4, 4, 3., 4.) + for i in range(4): + ax = fig.add_subplot(gs[i, 0]) + i_indice, n_indices = np.where(xs[:, i]) + ax.plot(i_indice, n_indices, 'r.', markersize=1) + plt.title('Input spikes') + ax = fig.add_subplot(gs[i, 1]) + i_indice, n_indices = np.where(spikes[:, i]) + ax.plot(i_indice, n_indices, 'r.', markersize=1) + plt.title('Recurrent spikes') + ax = fig.add_subplot(gs[i, 2]) + ax.plot(vs[:, i]) + plt.title('Membrane potential') + ax = fig.add_subplot(gs[i, 3]) + ax.plot(outs[:, i]) + plt.title('Output') + plt.show() + + + # visualization of the spiking activity + visualize(X_test[:4]) + + # optimizer + optimizer = bst.optim.Adam(lr=1e-3) + optimizer.register_trainable_weights(net.states(bst.ParamState)) + + + def loss_fun(xs, ys): + # initialize states + bst.nn.init_all_states(net, xs.shape[0]) + + # encode inputs + xs = encoder(xs) + + # predictions + outs, r_spikes = bst.compile.for_loop(net.update, xs) + + # Here we set up our regularize loss + # The strength parameters here are merely a guess and there should be ample + # room for improvement by tuning these parameters. + # l1_loss = 1e-5 * u.math.sum(r_spikes) # L1 loss on total number of spikes + # l2_loss = 1e-5 * u.math.mean(u.math.sum(u.math.sum(r_spikes, axis=0), axis=0) ** 2) # L2 loss on spikes per neuron + + # predictions + predicts = u.math.max(outs, axis=0) # max over time, [T, B, C] -> [B, C] + loss = bts.metric.softmax_cross_entropy_with_integer_labels(predicts, ys).mean() + correct_n = u.math.sum(ys == u.math.argmax(predicts, axis=1)) # compare to labels + # return loss + l2_loss + l1_loss, acc + return loss, correct_n + + + @bst.compile.jit + def train_fn(xs, ys): + grads, loss, correct_n = bst.augment.grad(loss_fun, net.states(bst.ParamState), has_aux=True, return_value=True)(xs, ys) + optimizer.update(grads) + return loss, correct_n + + + n_epoch = 20 + train_losses, train_accs = [], [] + indices = np.arange(X_train.shape[0]) + + for epoch_i in range(n_epoch): + indices = bst.random.shuffle(indices) + + # training phase + t0 = time.time() + loss, train_acc = [], 0. + for i in range(0, X_train.shape[0], batch_size): + X = X_train[indices[i: i + batch_size]] + Y = Y_train[indices[i: i + batch_size]] + l, correct_num = train_fn(X, Y) + loss.append(l) + train_acc += correct_num + train_acc /= X_train.shape[0] + train_loss = jnp.mean(jnp.asarray(loss)) + optimizer.lr.step_epoch() + + # testing phase + loss, test_acc = [], 0. + for i in range(0, X_test.shape[0], batch_size): + X = X_test[i: i + batch_size] + Y = Y_test[i: i + batch_size] + l, correct_num = loss_fun(X, Y) + loss.append(l) + test_acc += correct_num + test_acc /= X_test.shape[0] + test_loss = jnp.mean(jnp.asarray(loss)) + + t = (time.time() - t0) / 60 + print(f"Epoch {epoch_i}: train_loss={train_loss:.3f}, train_acc={train_acc:.3f}, " + f"test_loss={test_loss:.3f}, test_acc={test_acc:.3f}, time={t:.2f} min") + train_losses.append(train_loss) + train_accs.append(train_acc) + + fig, gs = bts.visualize.get_figure(1, 2, 3, 4) + fig.add_subplot(gs[0]) + plt.plot(np.asarray(train_losses)) + plt.xlabel("Epoch") + plt.ylabel("Loss") + fig.add_subplot(gs[1]) + plt.plot(np.asarray(train_accs)) + plt.xlabel("Epoch") + plt.ylabel("Accuracy") + + visualize(X_test[:4]) diff --git a/examples/202_mnist_lif_readout.py b/examples/202_mnist_lif_readout.py new file mode 100644 index 0000000..3f55c80 --- /dev/null +++ b/examples/202_mnist_lif_readout.py @@ -0,0 +1,171 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import time + +import braintools as bts +import brainunit as u +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from datasets import load_dataset + +import brainstate as bst + +parser = argparse.ArgumentParser(description='LIF MNIST Training') +parser.add_argument('-T', default=100, type=int, help='simulating time-steps') +parser.add_argument('-platform', default='cpu', help='device') +parser.add_argument('-batch', default=64, type=int, help='batch size') +parser.add_argument('-epochs', default=15, type=int, metavar='N', help='number of total epochs to run') +parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint') +parser.add_argument('-lr', default=1e-3, type=float, help='learning rate') +parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron') +args = parser.parse_args() +print(args) + + +class SNN(bst.nn.DynamicsGroup): + def __init__(self, tau): + super().__init__() + self.l1 = bst.nn.Linear(28 * 28, 10, b_init=None, w_init=bst.init.LecunNormal(scale=10., unit=u.mA)) + self.l2 = bst.nn.LIF(10, V_rest=0. * u.mV, V_reset=0. * u.mV, V_th=1. * u.mV, tau=tau * u.ms) + + def update(self, x): + return self.l2(self.l1(x)) + + def predict(self, x): + spikes = self.l2(self.l1(x)) + return self.l2.V.value, spikes + + +with bst.environ.context(dt=1.0 * u.ms): + net = SNN(args.tau) + + dataset = load_dataset('mnist') + # images + X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8) + X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8) + X_train = (X_train / 255).reshape(-1, 28 * 28).astype(jnp.float32) + X_test = (X_test / 255).reshape(-1, 28 * 28).astype(jnp.float32) + # labels + Y_train = np.array(dataset['train']['label'], dtype=np.int32) + Y_test = np.array(dataset['test']['label'], dtype=np.int32) + + + @bst.compile.jit + def predict(xs): + bst.nn.init_all_states(net, xs.shape[0]) + xs = (xs + 0.02) + xs = bst.random.rand(args.T, *xs.shape) < xs + vs, outs = bst.compile.for_loop(net.predict, xs) + return vs, outs + + + def visualize(xs): + vs, outs = predict(xs) + vs = np.asarray(vs.to_decimal(u.mV)) + fig, gs = bts.visualize.get_figure(4, 2, 3., 6.) + for i in range(4): + ax = fig.add_subplot(gs[i, 0]) + i_indice, n_indices = np.where(outs[:, i]) + ax.plot(i_indice, n_indices, 'r.', markersize=1) + ax.set_xlim([0, args.T]) + ax.set_ylim([0, net.l2.num]) + ax = fig.add_subplot(gs[i, 1]) + ax.plot(vs[:, i]) + ax.set_xlim([0, args.T]) + plt.show() + + + # visualization of the spiking activity + visualize(X_test[:4]) + + + @bst.compile.jit + def loss_fun(xs, ys): + # initialize states + bst.nn.init_all_states(net, xs.shape[0]) + + # encoding inputs as spikes + xs = bst.random.rand(args.T, *xs.shape) < xs + + # shared arguments for looping over time + outs = bst.compile.for_loop(net.update, xs) + out_fr = u.math.mean(outs, axis=0) # [T, B, C] -> [B, C] + ys_onehot = bst.functional.one_hot(ys, 10, dtype=float) + l = bts.metric.squared_error(out_fr, ys_onehot).mean() + n = u.math.sum(out_fr.argmax(1) == ys) + return l, n + + + # gradient function + grad_fun = bst.augment.grad(loss_fun, net.states(bst.ParamState), has_aux=True, return_value=True) + + # optimizer + optimizer = bst.optim.Adam(lr=args.lr) + optimizer.register_trainable_weights(net.states(bst.ParamState)) + + + # train + @bst.compile.jit + def train(xs, ys): + grads, l, n = grad_fun(xs, ys) + optimizer.update(grads) + return l, n + + + # training loop + for epoch_i in range(args.epochs): + key = bst.random.split_key() + X_train = bst.random.shuffle(X_train, key=key) + Y_train = bst.random.shuffle(Y_train, key=key) + + # training phase + t0 = time.time() + loss, train_acc = [], 0. + for i in range(0, X_train.shape[0], args.batch): + X = X_train[i: i + args.batch] + Y = Y_train[i: i + args.batch] + l, correct_num = train(X, Y) + loss.append(l) + train_acc += correct_num + train_acc /= X_train.shape[0] + train_loss = jnp.mean(jnp.asarray(loss)) + optimizer.lr.step_epoch() + + # testing phase + loss, test_acc = [], 0. + for i in range(0, X_test.shape[0], args.batch): + X = X_test[i: i + args.batch] + Y = Y_test[i: i + args.batch] + l, correct_num = loss_fun(X, Y) + loss.append(l) + test_acc += correct_num + test_acc /= X_test.shape[0] + test_loss = jnp.mean(jnp.asarray(loss)) + + t = (time.time() - t0) / 60 + print(f'epoch {epoch_i}, used {t:.3f} min, ' + f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, ' + f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}') + + # inference + correct_num = 0. + for i in range(0, X_test.shape[0], 512): + X = X_test[i: i + 512] + Y = Y_test[i: i + 512] + correct_num += loss_fun(X, Y)[1] + print('Max test accuracy: ', correct_num / X_test.shape[0]) diff --git a/examples/300_integrator_rnn.py b/examples/300_integrator_rnn.py new file mode 100644 index 0000000..ceb1984 --- /dev/null +++ b/examples/300_integrator_rnn.py @@ -0,0 +1,160 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# -*- coding: utf-8 -*- + +from typing import Callable + +import braintools as bt +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +import brainstate as bst + +dt = 0.04 +num_step = int(1.0 / dt) +num_batch = 512 + + +@bst.compile.jit(static_argnums=2) +def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10): + # Create the white noise input + sample = bst.random.normal(size=(1, batch_size, 1)) + bias = mean * 2.0 * (sample - 0.5) + samples = bst.random.normal(size=(num_step, batch_size, 1)) + noise_t = scale / dt ** 0.5 * samples + inputs = bias + noise_t + targets = jnp.cumsum(inputs, axis=0) + return inputs, targets + + +def train_data(): + for _ in range(500): + yield build_inputs_and_targets(0.025, 0.01, num_batch) + + +class RNNCell(bst.nn.Module): + def __init__( + self, + num_in: int, + num_out: int, + state_initializer: Callable = bst.init.ZeroInit(), + w_initializer: Callable = bst.init.XavierNormal(), + b_initializer: Callable = bst.init.ZeroInit(), + activation: Callable = bst.functional.relu, + train_state: bool = False, + ): + super().__init__() + + # parameters + self.num_out = num_out + self.train_state = train_state + + # parameters + self.num_in = num_in + + # initializers + self._state_initializer = state_initializer + self._w_initializer = w_initializer + self._b_initializer = b_initializer + + # activation function + self.activation = activation + + # weights + W = bst.init.param(self._w_initializer, (num_in + num_out, self.num_out)) + b = bst.init.param(self._b_initializer, (self.num_out,)) + self.W = bst.ParamState(W) + self.b = None if (b is None) else bst.ParamState(b) + + # state + if train_state: + self.state2train = bst.ParamState(bst.init.param(bst.init.ZeroInit(), (self.num_out,), allow_none=False)) + + def init_state(self, batch_size=None, **kwargs): + self.state = bst.ShortTermState(bst.init.param(self._state_initializer, (self.num_out,), batch_size)) + if self.train_state: + self.state.value = jnp.repeat(jnp.expand_dims(self.state2train.value, axis=0), batch_size, axis=0) + + def update(self, x): + x = jnp.concat([x, self.state.value], axis=-1) + h = x @ self.W.value + if self.b is not None: + h += self.b.value + h = self.activation(h) + self.state.value = h + return h + + +class RNN(bst.nn.Module): + def __init__(self, num_in, num_hidden): + super().__init__() + self.rnn = RNNCell(num_in, num_hidden, train_state=True) + self.out = bst.nn.Linear(num_hidden, 1) + + def update(self, x): + return x >> self.rnn >> self.out + + +model = RNN(1, 100) +weights = model.states(bst.ParamState) + + +@bst.compile.jit +def f_predict(inputs): + bst.nn.init_all_states(model, batch_size=inputs.shape[1]) + return bst.compile.for_loop(model.update, inputs) + + +def f_loss(inputs, targets, l2_reg=2e-4): + predictions = f_predict(inputs) + mse = bt.metric.squared_error(predictions, targets).mean() + l2 = 0.0 + for weight in weights.values(): + for leaf in jax.tree.leaves(weight.value): + l2 += jnp.sum(leaf ** 2) + return mse + l2_reg * l2 + + +# define optimizer +lr = bst.optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975) +opt = bst.optim.Adam(lr=lr, eps=1e-1) +opt.register_trainable_weights(weights) + + +@bst.compile.jit +def f_train(inputs, targets): + grads, l = bst.augment.grad(f_loss, weights, return_value=True)(inputs, targets) + opt.update(grads) + return l + + +for i_epoch in range(5): + for i_batch, (inps, tars) in enumerate(train_data()): + loss = f_train(inps, tars) + if (i_batch + 1) % 100 == 0: + print(f'Epoch {i_epoch}, Batch {i_batch + 1:3d}, Loss {loss:.5f}') + +bst.nn.init_all_states(model, 1) +x, y = build_inputs_and_targets(0.025, 0.01, 1) +predicts = f_predict(x) + +plt.figure(figsize=(8, 2)) +plt.plot(np.asarray(y[:, 0]).flatten(), label='Ground Truth') +plt.plot(np.asarray(predicts[:, 0]).flatten(), label='Prediction') +plt.legend() +plt.show() diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..e9173f1 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,12 @@ +# Examples for ``brainstate`` library + + +We provide several kinds of examples to demonstrate the usage of the ``brainstate`` library. The examples are organized in the following categories: + +- The files with name started with ``0__`` are the examples for deep neural networks. +- The files with name started with ``1__`` are the examples for brain simulation models, especially spiking neural networks. +- The files with name started with ``2__`` are the examples for brain-inspired computing models, especially training spiking neural networks. +- The files with name started with ``3__`` are the examples for rate-based recurrent neural networks. + + +