Skip to content

Commit

Permalink
Graded relu (#860)
Browse files Browse the repository at this point in the history
* GradedReluVec process and tests.

* changed test to use thresh not 0.

* removed duplicate docstring line.

* Bump tornado from 6.4 to 6.4.1 (#863)

Bumps [tornado](https://github.com/tornadoweb/tornado) from 6.4 to 6.4.1.
- [Changelog](https://github.com/tornadoweb/tornado/blob/master/docs/releases.rst)
- [Commits](tornadoweb/tornado@v6.4.0...v6.4.1)

---
updated-dependencies:
- dependency-name: tornado
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Fix: subthreshold dynamics equation of refractory lif (#842)

* Fix: subthreshold dynamics equation of refractory lif

* Fix: RefractoryLIF unit test to test the voltage dynamics

* Bump urllib3 from 2.2.1 to 2.2.2 (#865)

Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.2.1 to 2.2.2.
- [Release notes](https://github.com/urllib3/urllib3/releases)
- [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst)
- [Commits](urllib3/urllib3@2.2.1...2.2.2)

---
updated-dependencies:
- dependency-name: urllib3
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: PhilippPlank <[email protected]>

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: João Gil <[email protected]>
Co-authored-by: PhilippPlank <[email protected]>
Co-authored-by: Marcus G K Williams <[email protected]>
  • Loading branch information
5 people authored Aug 5, 2024
1 parent a82abc1 commit d22e829
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 3 deletions.
40 changes: 39 additions & 1 deletion src/lava/proc/graded/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel

from lava.proc.graded.process import GradedVec, NormVecDelay, InvSqrt
from lava.proc.graded.process import (GradedVec, GradedReluVec,
NormVecDelay, InvSqrt)


class AbstractGradedVecModel(PyLoihiProcessModel):
Expand Down Expand Up @@ -51,6 +52,43 @@ class PyGradedVecModelFixed(AbstractGradedVecModel):
exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)


class AbstractGradedReluVecModel(PyLoihiProcessModel):
"""Implementation of GradedReluVec"""

a_in = None
s_out = None
v = None
vth = None
exp = None

def run_spk(self) -> None:
"""The run function that performs the actual computation during
execution orchestrated by a PyLoihiProcessModel using the
LoihiProtocol.
"""
a_in_data = self.a_in.recv()
self.v += a_in_data

is_spike = self.v > self.vth
sp_out = self.v * is_spike

self.v[:] = 0

self.s_out.send(sp_out)


@implements(proc=GradedReluVec, protocol=LoihiProtocol)
@requires(CPU)
@tag('fixed_pt')
class PyGradedReluVecModelFixed(AbstractGradedReluVecModel):
"""Fixed point implementation of GradedVec"""
a_in = LavaPyType(PyInPort.VEC_DENSE, np.int32, precision=24)
s_out = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24)
vth: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
v: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)


@implements(proc=NormVecDelay, protocol=LoihiProtocol)
@requires(CPU)
@tag('fixed_pt')
Expand Down
39 changes: 39 additions & 0 deletions src/lava/proc/graded/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,45 @@ class GradedVec(AbstractProcess):
Graded spike vector layer. Transmits accumulated input as
graded spike with no dynamics.
v[t] = a_in
s_out = v[t] * (|v[t]| > vth)
Parameters
----------
shape: tuple(int)
number and topology of neurons
vth: int
threshold for spiking
exp: int
fixed point base
"""

def __init__(
self,
shape: ty.Tuple[int, ...],
vth: ty.Optional[int] = 1,
exp: ty.Optional[int] = 0) -> None:

super().__init__(shape=shape)

self.a_in = InPort(shape=shape)
self.s_out = OutPort(shape=shape)

self.v = Var(shape=shape, init=0)
self.vth = Var(shape=(1,), init=vth)
self.exp = Var(shape=(1,), init=exp)

@property
def shape(self) -> ty.Tuple[int, ...]:
"""Return shape of the Process."""
return self.proc_params['shape']


class GradedReluVec(AbstractProcess):
"""GradedReluVec
Graded spike vector layer. Transmits accumulated input as
graded spike with no dynamics.
v[t] = a_in
s_out = v[t] * (v[t] > vth)
Expand Down
90 changes: 88 additions & 2 deletions tests/lava/proc/graded/test_graded.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np
from scipy.sparse import csr_matrix

from lava.proc.graded.process import GradedVec, NormVecDelay, InvSqrt
from lava.proc.graded.process import (GradedVec, GradedReluVec,
NormVecDelay, InvSqrt)
from lava.proc.graded.models import inv_sqrt
from lava.proc.dense.process import Dense
from lava.proc.sparse.process import Sparse
Expand Down Expand Up @@ -59,7 +60,7 @@ def test_gradedvec_dot_dense(self):
self.assertTrue(np.all(out_data[:, (3, 7)] == expected_out[:, (2, 6)]))

def test_gradedvec_dot_sparse(self):
"""Tests that GradedVec and Dense computes dot product."""
"""Tests that GradedVec and Sparse computes dot product"""
num_steps = 10
v_thresh = 1

Expand Down Expand Up @@ -99,6 +100,91 @@ def test_gradedvec_dot_sparse(self):
self.assertTrue(np.all(out_data[:, (3, 7)] == expected_out[:, (2, 6)]))


class TestGradedReluVecProc(unittest.TestCase):
"""Tests for GradedReluVec"""

def test_gradedreluvec_dot_dense(self):
"""Tests that GradedReluVec and Dense computes dot product"""
num_steps = 10
v_thresh = 1

weights1 = np.zeros((10, 1))
weights1[:, 0] = (np.arange(10) - 5) * 0.2

inp_data = np.zeros((weights1.shape[1], num_steps))
inp_data[:, 2] = 1000
inp_data[:, 6] = 20000

weight_exp = 7
weights1 *= 2**weight_exp
weights1 = weights1.astype('int')

dense1 = Dense(weights=weights1, num_message_bits=24,
weight_exp=-weight_exp)
vec1 = GradedReluVec(shape=(weights1.shape[0],),
vth=v_thresh)

generator = io.source.RingBuffer(data=inp_data)
logger = io.sink.RingBuffer(shape=(weights1.shape[0],),
buffer=num_steps)

generator.s_out.connect(dense1.s_in)
dense1.a_out.connect(vec1.a_in)
vec1.s_out.connect(logger.a_in)

vec1.run(condition=RunSteps(num_steps=num_steps),
run_cfg=Loihi2SimCfg(select_tag='fixed_pt'))
out_data = logger.data.get().astype('int')
vec1.stop()

ww = np.floor(weights1 / 2) * 2
expected_out = np.floor((ww @ inp_data) / 2**weight_exp)
expected_out *= expected_out > v_thresh

self.assertTrue(np.all(out_data[:, (3, 7)] == expected_out[:, (2, 6)]))

def test_gradedreluvec_dot_sparse(self):
"""Tests that GradedReluVec and Sparse computes dot product"""
num_steps = 10
v_thresh = 1

weights1 = np.zeros((10, 1))
weights1[:, 0] = (np.arange(10) - 5) * 0.2

inp_data = np.zeros((weights1.shape[1], num_steps))
inp_data[:, 2] = 1000
inp_data[:, 6] = 20000

weight_exp = 7
weights1 *= 2**weight_exp
weights1 = weights1.astype('int')

sparse1 = Sparse(weights=csr_matrix(weights1),
num_message_bits=24,
weight_exp=-weight_exp)
vec1 = GradedReluVec(shape=(weights1.shape[0],),
vth=v_thresh)

generator = io.source.RingBuffer(data=inp_data)
logger = io.sink.RingBuffer(shape=(weights1.shape[0],),
buffer=num_steps)

generator.s_out.connect(sparse1.s_in)
sparse1.a_out.connect(vec1.a_in)
vec1.s_out.connect(logger.a_in)

vec1.run(condition=RunSteps(num_steps=num_steps),
run_cfg=Loihi2SimCfg(select_tag='fixed_pt'))
out_data = logger.data.get().astype('int')
vec1.stop()

ww = np.floor(weights1 / 2) * 2
expected_out = np.floor((ww @ inp_data) / 2**weight_exp)
expected_out *= expected_out > v_thresh

self.assertTrue(np.all(out_data[:, (3, 7)] == expected_out[:, (2, 6)]))


class TestInvSqrtProc(unittest.TestCase):
"""Tests for inverse square process."""

Expand Down

0 comments on commit d22e829

Please sign in to comment.