From 78fc1a9ab4be5c3430d7215d8ec531a5c24de682 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 22 Nov 2024 00:33:35 +0800 Subject: [PATCH] Event-driven operator updates (#36) * csr benchmark * fix xla custom op bugs * update examples * fix memory access error when n_conn is small * fix hashable bug * update examples * fix bug --- brainstate/event/_csr_benchmark.py | 14 +++++ brainstate/event/_fixed_probability.py | 53 +++++++++++-------- brainstate/event/_xla_custom_op.py | 11 ++-- examples/102_EI_net_1996.py | 2 +- examples/103_COBA_2005.py | 2 +- examples/104_CUBA_2005.py | 2 +- examples/105_COBA_HH_2007.py | 2 +- examples/106_COBA_HH_2007.py | 2 +- examples/107_gamma_oscillation_1996.py | 2 +- examples/108_synfire_chains_199.py | 2 +- examples/109_fast_global_oscillation.py | 2 +- ...usin_Destexhe_2021_gamma_oscillation_AI.py | 9 ++-- ...n_Destexhe_2021_gamma_oscillation_CHING.py | 4 -- ...sin_Destexhe_2021_gamma_oscillation_ING.py | 4 -- ...in_Destexhe_2021_gamma_oscillation_PING.py | 4 -- examples/200_surrogate_grad_lif.py | 3 +- .../Susin_Destexhe_2021_gamma_oscillation.py | 2 +- 17 files changed, 66 insertions(+), 54 deletions(-) create mode 100644 brainstate/event/_csr_benchmark.py diff --git a/brainstate/event/_csr_benchmark.py b/brainstate/event/_csr_benchmark.py new file mode 100644 index 0000000..23b09eb --- /dev/null +++ b/brainstate/event/_csr_benchmark.py @@ -0,0 +1,14 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/brainstate/event/_fixed_probability.py b/brainstate/event/_fixed_probability.py index ca81920..3b0cd30 100644 --- a/brainstate/event/_fixed_probability.py +++ b/brainstate/event/_fixed_probability.py @@ -253,7 +253,6 @@ def gpu_kernel_generator( weight_info: jax.ShapeDtypeStruct, **kwargs ): - # 对于具有形状 [n_event] 的 spikes 向量,以及形状 [n_event, n_conn] 的 indices 和 weights 矩阵, # 这个算子的计算逻辑为: # @@ -273,16 +272,27 @@ def _ell_mv_kernel_homo( mask = jnp.arange(block_size) + c_start < n_conn def body_fn(j, _): - def true_fn(): - ind = pl.load(ind_ref, (j, pl.dslice(c_start, block_size)), mask=mask) - y_ref[ind] += 1.0 - # ind = ind_ref[j, ...] - # pl.store(y_ref, ind, 1.0, mask=mask) - if sp_ref.dtype == jnp.bool_: + def true_fn(): + ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask) + pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask) + # y_ref[ind] += 1.0 + # ind = ind_ref[j, ...] + # pl.store(y_ref, ind, 1.0, mask=mask) + jax.lax.cond(sp_ref[j], true_fn, lambda: None) + + else: - jax.lax.cond(sp_ref[j] != 0., true_fn, lambda: None) + def true_fn(sp): + ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask) + if float_as_event: + pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype), mask=mask) + else: + pl.atomic_add(y_ref, ind, jnp.ones(block_size, dtype=weight_info.dtype) * sp, mask=mask) + + sp_ = sp_ref[j] + jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_) jax.lax.fori_loop(0, row_length, body_fn, None) @@ -305,7 +315,7 @@ def true_fn(): interpret=False ) return (lambda spikes, weight, indices: - kernel(spikes, indices, jnp.zeros(n_post, dtype=weight.dtype)) * weight) + [kernel(spikes, indices, jnp.zeros(n_post, dtype=weight.dtype))[0] * weight]) else: def _ell_mv_kernel_heter( @@ -323,19 +333,18 @@ def _ell_mv_kernel_heter( def body_fn(j, _): if sp_ref.dtype == jnp.bool_: def true_fn(): - ind = ind_ref[j, ...] - w = w_ref[j, ...] - pl.store(y_ref, ind, w, mask=mask) + ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask) + w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask) + pl.atomic_add(y_ref, ind, w, mask=mask) jax.lax.cond(sp_ref[j], true_fn, lambda: None) else: def true_fn(spk): - ind = ind_ref[j, ...] - if float_as_event: - w = w_ref[j, ...] - else: - w = w_ref[j, ...] * spk - pl.store(y_ref, ind, w, mask=mask) + ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask) + w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask) + if not float_as_event: + w = w * spk + pl.atomic_add(y_ref, ind, w, mask=mask) sp_ = sp_ref[j] jax.lax.cond(sp_ != 0., true_fn, lambda _: None, sp_) @@ -540,8 +549,8 @@ def _kernel( row_length = jnp.minimum(n_pre - r_pid * block_size, block_size) def body_fn(j, _): - y = vec_ref[j] * jnp.ones(block_size) - ind = ind_ref[j, ...] + y = vec_ref[j] * jnp.ones(block_size, dtype=weight_info.dtype) + ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask) pl.atomic_add(out_ref, ind, y, mask=mask) jax.lax.fori_loop(0, row_length, body_fn, None) @@ -585,9 +594,9 @@ def _kernel( row_length = jnp.minimum(n_pre - r_pid * block_size, block_size) def body_fn(j, _): - w = w_ref[j, ...] + w = pl.load(w_ref, (j, pl.dslice(None)), mask=mask) y = w * vec_ref[j] - ind = ind_ref[j, ...] + ind = pl.load(ind_ref, (j, pl.dslice(None)), mask=mask) pl.atomic_add(out_ref, ind, y, mask=mask) jax.lax.fori_loop(0, row_length, body_fn, None) diff --git a/brainstate/event/_xla_custom_op.py b/brainstate/event/_xla_custom_op.py index e3a051f..750c3ce 100644 --- a/brainstate/event/_xla_custom_op.py +++ b/brainstate/event/_xla_custom_op.py @@ -209,6 +209,7 @@ def __init__( # abstract evaluation self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) + self.primitive.def_abstract_eval(self._abstract_eval) # cpu kernel if cpu_kernel_generator is not None: @@ -228,11 +229,13 @@ def __init__( if transpose_translation is not None: ad.primitive_transposes[self.primitive] = transpose_translation + def _abstract_eval(self, *ins, outs: Sequence[ShapeDtype], **kwargs): + return tuple(outs) + def __call__(self, *ins, outs: Sequence[ShapeDtype], **kwargs): assert isinstance(outs, (tuple, list)), 'The `outs` should be a tuple or list of shape-dtype pairs.' outs = jax.tree.map(_transform_to_shapedarray, outs) - self.primitive.def_abstract_eval(functools.partial(_abstract_eval, outs)) - return self.primitive.bind(*ins, **kwargs) + return self.primitive.bind(*ins, **kwargs, outs=tuple(outs)) def def_cpu_kernel(self, kernel_generator: Callable): """ @@ -305,9 +308,5 @@ def def_mlir_lowering(self, platform, fun): mlir.register_lowering(self.primitive, fun, platform) -def _abstract_eval(outs, *args, **kwargs): - return [jax.core.ShapedArray(out.shape, out.dtype) for out in outs] - - def _transform_to_shapedarray(a): return jax.core.ShapedArray(a.shape, a.dtype) diff --git a/examples/102_EI_net_1996.py b/examples/102_EI_net_1996.py index 4ead905..b79740a 100644 --- a/examples/102_EI_net_1996.py +++ b/examples/102_EI_net_1996.py @@ -96,7 +96,7 @@ def update(self, inp): # visualization t_indices, n_indices = u.math.where(spikes) -plt.plot(times[t_indices], n_indices, 'k.', markersize=1) +plt.scatter(times[t_indices], n_indices, s=1) plt.xlabel('Time (ms)') plt.ylabel('Neuron index') plt.show() diff --git a/examples/103_COBA_2005.py b/examples/103_COBA_2005.py index a1d442e..b2dfd70 100644 --- a/examples/103_COBA_2005.py +++ b/examples/103_COBA_2005.py @@ -73,7 +73,7 @@ def update(self, t, inp): # visualization t_indices, n_indices = u.math.where(spikes) -plt.plot(times[t_indices], n_indices, 'k.', markersize=1) +plt.scatter(times[t_indices], n_indices, s=1) plt.xlabel('Time (ms)') plt.ylabel('Neuron index') plt.show() diff --git a/examples/104_CUBA_2005.py b/examples/104_CUBA_2005.py index 6079b18..d4e2168 100644 --- a/examples/104_CUBA_2005.py +++ b/examples/104_CUBA_2005.py @@ -75,7 +75,7 @@ def update(self, t, inp): # visualization t_indices, n_indices = u.math.where(spikes) -plt.plot(times[t_indices], n_indices, 'k.', markersze=1) +plt.scatter(times[t_indices], n_indices, s=1) plt.xlabel('Time (ms)') plt.ylabel('Neuron index') plt.show() diff --git a/examples/105_COBA_HH_2007.py b/examples/105_COBA_HH_2007.py index 8a242b7..9620fa5 100644 --- a/examples/105_COBA_HH_2007.py +++ b/examples/105_COBA_HH_2007.py @@ -96,7 +96,7 @@ def update(self, t): plt.show() t_indices, n_indices = u.math.where(spikes) -plt.plot(times[t_indices], n_indices, 'k.', markersize=1) +plt.scatter(times[t_indices], n_indices, s=1) plt.xlabel('Time (ms)') plt.ylabel('Neuron index') plt.show() diff --git a/examples/106_COBA_HH_2007.py b/examples/106_COBA_HH_2007.py index 42f223c..dc2621a 100644 --- a/examples/106_COBA_HH_2007.py +++ b/examples/106_COBA_HH_2007.py @@ -166,7 +166,7 @@ def update(self, t): # visualization t_indices, n_indices = u.math.where(spikes) -plt.plot(times[t_indices], n_indices, 'k.', markersize=1) +plt.scatter(times[t_indices], n_indices, s=1) plt.xlabel('Time (ms)') plt.ylabel('Neuron index') plt.show() diff --git a/examples/107_gamma_oscillation_1996.py b/examples/107_gamma_oscillation_1996.py index c8a0c57..4336b40 100644 --- a/examples/107_gamma_oscillation_1996.py +++ b/examples/107_gamma_oscillation_1996.py @@ -145,7 +145,7 @@ def update(self, t): fig.add_subplot(gs[0, 1]) t_indices, n_indices = u.math.where(spikes) -plt.plot(times[t_indices], n_indices, 'k.', markersize=1) +plt.plot(times[t_indices], n_indices, 'k.') plt.xlabel('Time (ms)') plt.ylabel('Neuron index') plt.show() diff --git a/examples/108_synfire_chains_199.py b/examples/108_synfire_chains_199.py index c562b6e..321d8d7 100644 --- a/examples/108_synfire_chains_199.py +++ b/examples/108_synfire_chains_199.py @@ -150,7 +150,7 @@ def run_network(spike_num: int, ax): # visualization times = times.to_decimal(u.ms) t_indices, n_indices = u.math.where(spikes) - ax.plot(times[t_indices], n_indices, 'k.', markersize=1) + ax.scatter(times[t_indices], n_indices, s=1) ax.set_xlabel('Time (ms)') ax.set_ylabel('Neuron index') diff --git a/examples/109_fast_global_oscillation.py b/examples/109_fast_global_oscillation.py index 7b23a36..38266ed 100644 --- a/examples/109_fast_global_oscillation.py +++ b/examples/109_fast_global_oscillation.py @@ -99,7 +99,7 @@ def update(self, t, i): # visualization times = times.to_decimal(u.ms) t_indices, n_indices = u.math.where(spikes) -plt.plot(times[t_indices], n_indices, 'k.', markersize=1) +plt.scatter(times[t_indices], n_indices, s=1) plt.xlabel('Time (ms)') plt.ylabel('Neuron index') plt.xlim([0, duration.to_decimal(u.ms)]) diff --git a/examples/110_Susin_Destexhe_2021_gamma_oscillation_AI.py b/examples/110_Susin_Destexhe_2021_gamma_oscillation_AI.py index 859bf43..37f620f 100644 --- a/examples/110_Susin_Destexhe_2021_gamma_oscillation_AI.py +++ b/examples/110_Susin_Destexhe_2021_gamma_oscillation_AI.py @@ -24,7 +24,8 @@ import os -os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' +os.environ['JAX_TRACEBACK_FILTERING'] = 'off' + import brainstate as bst import braintools as bts @@ -93,8 +94,8 @@ def __init__(self): FS_par_ = FS_par.copy() RS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV) FS_par_.update(Vth=-50 * u.mV, V_sp_th=-40 * u.mV) - self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **RS_par_) self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **FS_par_) + self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, **RS_par_) self.ext_pop = bst.nn.PoissonEncoder(self.num_exc) # Poisson inputs @@ -140,12 +141,12 @@ def update(self, i, t, freq): ext_spikes = self.ext_pop(freq) self.ext_to_FS(ext_spikes) self.ext_to_RS(ext_spikes) - self.RS_to_FS() self.RS_to_RS() + self.RS_to_FS() self.FS_to_FS() self.FS_to_RS() - self.fs_pop() self.rs_pop() + self.fs_pop() return { 'FS.V0': self.fs_pop.V.value[0], 'RS.V0': self.rs_pop.V.value[0], diff --git a/examples/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py b/examples/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py index 426d426..4e4fb99 100644 --- a/examples/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py +++ b/examples/111_Susin_Destexhe_2021_gamma_oscillation_CHING.py @@ -22,10 +22,6 @@ # CHING Network for Generating Gamma Oscillation -import os - -os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' - import brainunit as u import brainstate as bst diff --git a/examples/112_Susin_Destexhe_2021_gamma_oscillation_ING.py b/examples/112_Susin_Destexhe_2021_gamma_oscillation_ING.py index 559ce98..d1e0e16 100644 --- a/examples/112_Susin_Destexhe_2021_gamma_oscillation_ING.py +++ b/examples/112_Susin_Destexhe_2021_gamma_oscillation_ING.py @@ -23,10 +23,6 @@ # ING Network for Generating Gamma Oscillation -import os - -os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' - import brainunit as u import brainstate as bst diff --git a/examples/113_Susin_Destexhe_2021_gamma_oscillation_PING.py b/examples/113_Susin_Destexhe_2021_gamma_oscillation_PING.py index 310b339..412a40e 100644 --- a/examples/113_Susin_Destexhe_2021_gamma_oscillation_PING.py +++ b/examples/113_Susin_Destexhe_2021_gamma_oscillation_PING.py @@ -23,10 +23,6 @@ # PING Network for Generating Gamma Oscillation -import os - -os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' - import brainunit as u import brainstate as bst diff --git a/examples/200_surrogate_grad_lif.py b/examples/200_surrogate_grad_lif.py index 22a8201..baae188 100644 --- a/examples/200_surrogate_grad_lif.py +++ b/examples/200_surrogate_grad_lif.py @@ -42,10 +42,11 @@ def __init__(self, num_in, num_rec, num_out): self.num_out = num_out # synapse: i->r + scale = 7 * (1 - (u.math.exp(-bst.environ.get_dt() / (1 * u.ms)))) self.i2r = bst.nn.Sequential( bst.nn.Linear( num_in, num_rec, - w_init=bst.init.KaimingNormal(scale=7*(1-(u.math.exp(-bst.environ.get_dt()/(1*u.ms)))), unit=u.mA), + w_init=bst.init.KaimingNormal(scale=scale, unit=u.mA), b_init=bst.init.ZeroInit(unit=u.mA) ), bst.nn.Expon(num_rec, tau=5. * u.ms, g_initializer=bst.init.Constant(0. * u.mA)) diff --git a/examples/Susin_Destexhe_2021_gamma_oscillation.py b/examples/Susin_Destexhe_2021_gamma_oscillation.py index c1ef6e8..a1de1f2 100644 --- a/examples/Susin_Destexhe_2021_gamma_oscillation.py +++ b/examples/Susin_Destexhe_2021_gamma_oscillation.py @@ -227,7 +227,7 @@ def visualize_simulation_results( for key, (sp_matrix, sp_type) in spikes.items(): iis, sps = np.where(sp_matrix) tts = times[iis] - plt.plot(tts, sps + i, '.', markersize=1, label=key) + plt.scatter(tts, sps + i, s=1, label=key) y_ticks[0].append(i + sp_matrix.shape[1] / 2) y_ticks[1].append(key) i += sp_matrix.shape[1]