Skip to content

Commit

Permalink
unify CSRLinear, CSCLinear, COOLinear using SparseLinear (#48)
Browse files Browse the repository at this point in the history
* csr sparse event

* fix

* unify `CSRLinear`, `CSCLinear`, `COOLinear` using `SparseLinear`

* update docs

* fix tests

* fix tests
  • Loading branch information
chaoming0625 authored Dec 14, 2024
1 parent 176ee7a commit 1c8446b
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 416 deletions.
6 changes: 3 additions & 3 deletions brainstate/compile/_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,17 @@ def __call__(self, iter_num, *args, **kwargs):

_ = jax.lax.cond(
iter_num == 0,
lambda: jax.debug.callback(self._define_tqdm),
lambda: jax.debug.callback(self._define_tqdm, ordered=True),
lambda: None,
)
_ = jax.lax.cond(
iter_num % self.print_freq == (self.print_freq - 1),
lambda: jax.debug.callback(self._update_tqdm),
lambda: jax.debug.callback(self._update_tqdm, ordered=True),
lambda: None,
)
_ = jax.lax.cond(
iter_num == self.n - 1,
lambda: jax.debug.callback(self._close_tqdm),
lambda: jax.debug.callback(self._close_tqdm, ordered=True),
lambda: None,
)

14 changes: 0 additions & 14 deletions brainstate/event/_csr_benchmark.py

This file was deleted.

37 changes: 12 additions & 25 deletions brainstate/event/_csr_mv.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
indices: ArrayLike,
weight: Union[Callable, ArrayLike],
name: Optional[str] = None,
grad_mode: str = 'vjp'
):
super().__init__(name=name)

Expand All @@ -68,17 +67,13 @@ def __init__(
self.n_pre = self.in_size[-1]
self.n_post = self.out_size[-1]

# gradient mode
assert grad_mode in ['vjp', 'jvp'], f"Unsupported grad_mode: {grad_mode}"
self.grad_mode = grad_mode

# CSR data structure
indptr = jnp.asarray(indptr)
indices = jnp.asarray(indices)
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
with jax.ensure_compile_time_eval():
indptr = jnp.asarray(indptr)
indices = jnp.asarray(indices)
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
assert indptr.size == self.n_pre + 1, f"indptr must have size {self.n_pre + 1}. Got: {indptr.size}"
self.indptr = u.math.asarray(indptr)
self.indices = u.math.asarray(indices)

Expand All @@ -101,21 +96,13 @@ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
device_kind = jax.devices()[0].platform # spk.device.device_kind

# CPU implementation
if device_kind == 'cpu':
return cpu_event_csr(
u.math.asarray(spk),
self.indptr,
self.indices,
u.math.asarray(weight),
n_post=self.n_post, grad_mode=self.grad_mode
)

# GPU/TPU implementation
elif device_kind in ['gpu', 'tpu']:
raise NotImplementedError()

else:
raise ValueError(f"Unsupported device: {device_kind}")
return cpu_event_csr(
u.math.asarray(spk),
self.indptr,
self.indices,
u.math.asarray(weight),
n_post=self.n_post,
)


@set_module_as('brainstate.event')
Expand Down
152 changes: 76 additions & 76 deletions brainstate/event/_csr_mv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,79 +40,79 @@ def true_fn(x, w, indices, indptr, n_out):
return post


class TestFixedProbCSR(parameterized.TestCase):
@parameterized.product(
homo_w=[True, False],
)
def test1(self, homo_w):
x = bst.random.rand(20) < 0.1
indptr, indices = _get_csr(20, 40, 0.1)
m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
y = m(x)
y2 = true_fn(x, m.weight.value, indices, indptr, 40)
self.assertTrue(jnp.allclose(y, y2))

@parameterized.product(
bool_x=[True, False],
homo_w=[True, False]
)
def test_vjp(self, bool_x, homo_w):
n_in = 20
n_out = 30
if bool_x:
x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
else:
x = bst.random.rand(n_in)

indptr, indices = _get_csr(n_in, n_out, 0.1)
fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
w = fn.weight.value

def f(x, w):
fn.weight.value = w
return fn(x).sum()

r = jax.grad(f, argnums=(0, 1))(x, w)

# -------------------
# TRUE gradients

def f2(x, w):
return true_fn(x, w, indices, indptr, n_out).sum()

r2 = jax.grad(f2, argnums=(0, 1))(x, w)
self.assertTrue(jnp.allclose(r[0], r2[0]))
self.assertTrue(jnp.allclose(r[1], r2[1]))

@parameterized.product(
bool_x=[True, False],
homo_w=[True, False]
)
def test_jvp(self, bool_x, homo_w):
n_in = 20
n_out = 30
if bool_x:
x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
else:
x = bst.random.rand(n_in)

indptr, indices = _get_csr(n_in, n_out, 0.1)
fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
w = fn.weight.value

def f(x, w):
fn.weight.value = w
return fn(x)

o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))

# -------------------
# TRUE gradients

def f2(x, w):
return true_fn(x, w, indices, indptr, n_out)

o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
self.assertTrue(jnp.allclose(r1, r2))
self.assertTrue(jnp.allclose(o1, o2))
# class TestFixedProbCSR(parameterized.TestCase):
# @parameterized.product(
# homo_w=[True, False],
# )
# def test1(self, homo_w):
# x = bst.random.rand(20) < 0.1
# indptr, indices = _get_csr(20, 40, 0.1)
# m = bst.event.CSRLinear(20, 40, indptr, indices, 1.5 if homo_w else bst.init.Normal())
# y = m(x)
# y2 = true_fn(x, m.weight.value, indices, indptr, 40)
# self.assertTrue(jnp.allclose(y, y2))
#
# @parameterized.product(
# bool_x=[True, False],
# homo_w=[True, False]
# )
# def test_vjp(self, bool_x, homo_w):
# n_in = 20
# n_out = 30
# if bool_x:
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
# else:
# x = bst.random.rand(n_in)
#
# indptr, indices = _get_csr(n_in, n_out, 0.1)
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
# w = fn.weight.value
#
# def f(x, w):
# fn.weight.value = w
# return fn(x).sum()
#
# r = jax.grad(f, argnums=(0, 1))(x, w)
#
# # -------------------
# # TRUE gradients
#
# def f2(x, w):
# return true_fn(x, w, indices, indptr, n_out).sum()
#
# r2 = jax.grad(f2, argnums=(0, 1))(x, w)
# self.assertTrue(jnp.allclose(r[0], r2[0]))
# self.assertTrue(jnp.allclose(r[1], r2[1]))
#
# @parameterized.product(
# bool_x=[True, False],
# homo_w=[True, False]
# )
# def test_jvp(self, bool_x, homo_w):
# n_in = 20
# n_out = 30
# if bool_x:
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
# else:
# x = bst.random.rand(n_in)
#
# indptr, indices = _get_csr(n_in, n_out, 0.1)
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
# 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
# w = fn.weight.value
#
# def f(x, w):
# fn.weight.value = w
# return fn(x)
#
# o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
#
# # -------------------
# # TRUE gradients
#
# def f2(x, w):
# return true_fn(x, w, indices, indptr, n_out)
#
# o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
# self.assertTrue(jnp.allclose(r1, r2))
# self.assertTrue(jnp.allclose(o1, o2))
76 changes: 48 additions & 28 deletions brainstate/event/_fixedprob_mv.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,44 +85,52 @@ def __init__(
self.in_size = in_size
self.out_size = out_size
self.n_conn = int(self.out_size[-1] * prob)
if self.n_conn < 1:
raise ValueError(f"The number of connections must be at least 1. "
f"Got: int({self.out_size[-1]} * {prob}) = {self.n_conn}")
self.float_as_event = float_as_event
self.block_size = block_size

# indices of post connected neurons
with jax.ensure_compile_time_eval():
if allow_multi_conn:
rng = np.random.RandomState(seed)
self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
else:
rng = RandomState(seed)
if self.n_conn > 1:
# indices of post connected neurons
with jax.ensure_compile_time_eval():
if allow_multi_conn:
rng = np.random.RandomState(seed)
self.indices = rng.randint(0, self.out_size[-1], size=(self.in_size[-1], self.n_conn))
else:
rng = RandomState(seed)

@vmap(rngs=rng)
def rand_indices(key):
rng.set_key(key)
return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)
@vmap(rngs=rng)
def rand_indices(key):
rng.set_key(key)
return rng.choice(self.out_size[-1], size=(self.n_conn,), replace=False)

self.indices = rand_indices(rng.split_key(self.in_size[-1]))
self.indices = u.math.asarray(self.indices)
self.indices = rand_indices(rng.split_key(self.in_size[-1]))
self.indices = u.math.asarray(self.indices)

# maximum synaptic conductance
weight = param(weight, (self.in_size[-1], self.n_conn), allow_none=False)
self.weight = ParamState(weight)

def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
return event_fixed_prob(
spk,
self.weight.value,
self.indices,
n_post=self.out_size[-1],
block_size=self.block_size,
float_as_event=self.float_as_event
)
if self.n_conn > 1:
return event_fixed_prob(
spk,
self.weight.value,
self.indices,
n_post=self.out_size[-1],
block_size=self.block_size,
float_as_event=self.float_as_event
)
else:
weight = self.weight.value
unit = u.get_unit(weight)
r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
return u.maybe_decimal(u.Quantity(r, unit=unit))


def event_fixed_prob(spk, weight, indices, *, n_post, block_size, float_as_event):
def event_fixed_prob(
spk, weight, indices,
*,
n_post, block_size, float_as_event
):
"""
The FixedProb module implements a fixed probability connection with CSR sparse data structure.
Expand Down Expand Up @@ -374,7 +382,11 @@ def true_fn(spk):
kernel(spikes, indices, weight, jnp.zeros(n_post, dtype=weight_info.dtype)))


def jvp_spikes(spk_dot, spikes, weights, indices, *, n_post, block_size, **kwargs):
def jvp_spikes(
spk_dot, spikes, weights, indices,
*,
n_post, block_size, **kwargs
):
return ellmv_p_call(
spk_dot,
weights,
Expand All @@ -384,7 +396,11 @@ def jvp_spikes(spk_dot, spikes, weights, indices, *, n_post, block_size, **kwarg
)


def jvp_weights(w_dot, spikes, weights, indices, *, float_as_event, block_size, n_post, **kwargs):
def jvp_weights(
w_dot, spikes, weights, indices,
*,
float_as_event, block_size, n_post, **kwargs
):
return event_ellmv_p_call(
spikes,
w_dot,
Expand Down Expand Up @@ -464,7 +480,11 @@ def map_fn(one_spk, one_ind):
event_ellmv_p.def_transpose_rule(transpose_rule)


def event_ellmv_p_call(spikes, weights, indices, *, n_post, block_size, float_as_event):
def event_ellmv_p_call(
spikes, weights, indices,
*,
n_post, block_size, float_as_event
):
n_conn = indices.shape[1]
if block_size is None:
if n_conn <= 16:
Expand Down
9 changes: 9 additions & 0 deletions brainstate/nn/_dynamics/_projection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,21 @@ def __init__(
# checking synapse and output models
if is_instance(syn, ParamDescriber[AlignPost]):
if not is_instance(out, ParamDescriber[SynOut]):
if is_instance(out, ParamDescriber):
raise TypeError(
f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
f'the synapse is an instance of {AlignPost}, but got {out}.'
)
raise TypeError(
f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
f'the synapse is a describer, but we got {out}.'
)
merging = True
else:
if is_instance(syn, ParamDescriber):
raise TypeError(
f'The synapse should be an instance of describer {ParamDescriber[AlignPost]}, but got {syn}.'
)
if not is_instance(out, SynOut):
raise TypeError(
f'The output should be an instance of {SynOut} when the synapse is '
Expand Down
Loading

0 comments on commit 1c8446b

Please sign in to comment.