Skip to content

Commit

Permalink
Resfire (#787)
Browse files Browse the repository at this point in the history
* resfire process and fixed process model

* changed vth->uth in RFZero. Added tests.

* removed unused imports

* unused imports, copyright statement.

* bsd license on resfire models.py
  • Loading branch information
epaxon committed Oct 24, 2023
1 parent 98e16c7 commit 9b7eb1e
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/lava/proc/resfire/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import numpy as np
from lava.proc.resfire.process import RFZero

from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel


@implements(proc=RFZero, protocol=LoihiProtocol)
@requires(CPU)
@tag('fixed_pt')
class PyRFZeroModelFixed(PyLoihiProcessModel):
"""Fixed point implementation of RFZero"""
u_in = LavaPyType(PyInPort.VEC_DENSE, np.int32, precision=24)
v_in = LavaPyType(PyInPort.VEC_DENSE, np.int32, precision=24)
s_out = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24)

uth: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)

u: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
v: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)

lst: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
lct: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)

def run_spk(self) -> None:
u_in = self.u_in.recv()
v_in = self.v_in.recv()

new_u = ((self.u * self.lct) // (2**15)
- (self.v * self.lst) // (2**15) + u_in)

new_v = ((self.v * self.lct) // (2**15)
+ (self.u * self.lst) // (2**15) + v_in)

s_out = new_u * (new_u > self.uth) * (new_v >= 0) * (self.v < 0)

self.u = new_u
self.v = new_v

self.s_out.send(s_out)
65 changes: 65 additions & 0 deletions src/lava/proc/resfire/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (C) 2022-23 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np

import typing as ty

from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort


class RFZero(AbstractProcess):
def __init__(self,
shape: ty.Tuple[int, ...],
freqs: np.ndarray, # Hz
decay_tau: np.ndarray, # seconds
dt: ty.Optional[float] = 0.001, # seconds/timestep
uth: ty.Optional[int] = 1) -> None:
"""
RFZero
Resonate and fire neuron with spike trigger of threshold and
0-phase crossing. Graded spikes carry amplitude of oscillation.
Parameters
----------
shape : tuple(int)
Number and topology of RF neurons.
freqs : numpy.ndarray
Frequency for each neuron (Hz).
decay_tau : numpy.ndarray
Decay time constant (s).
dt : float, optional
Time per timestep. Default is 0.001 seconds.
uth : float, optional
Neuron threshold voltage.
Currently, only a single threshold can be set for the entire
population of neurons.
"""
super().__init__(shape=shape)

self.u_in = InPort(shape=shape)
self.v_in = InPort(shape=shape)

self.s_out = OutPort(shape=shape)

self.u = Var(shape=shape, init=0)
self.v = Var(shape=shape, init=0)

ll = -1 / decay_tau

lct = np.array(np.exp(dt * ll) * np.cos(dt * freqs * np.pi * 2))
lst = np.array(np.exp(dt * ll) * np.sin(dt * freqs * np.pi * 2))
lct = (lct * 2**15).astype(np.int32)
lst = (lst * 2**15).astype(np.int32)

self.lct = Var(shape=shape, init=lct)
self.lst = Var(shape=shape, init=lst)
self.uth = Var(shape=(1,), init=uth)

@property
def shape(self) -> ty.Tuple[int, ...]:
"""Return shape of the Process."""
return self.proc_params['shape']
Empty file.
84 changes: 84 additions & 0 deletions tests/lava/proc/resfire/test_resfire.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import unittest
import numpy as np

from lava.proc.resfire.process import RFZero
from lava.proc.dense.process import Dense
from lava.proc import io

from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi2SimCfg


class TestRFZeroProc(unittest.TestCase):
"""Tests for RFZero"""

def test_rfzero_impulse(self):
"""Tests for correct behavior of RFZero neurons from impulse input"""

num_steps = 50
num_neurons = 4
num_inputs = 1

# Create some weights
weightr = np.zeros((num_neurons, num_inputs))
weighti = np.zeros((num_neurons, num_inputs))

weightr[0, 0] = 50
weighti[0, 0] = -50

weightr[1, 0] = -70
weighti[1, 0] = -70

weightr[2, 0] = -90
weighti[2, 0] = 90

weightr[3, 0] = 110
weighti[3, 0] = 110

# Create inputs
inp_shape = (num_inputs,)
out_shape = (num_neurons,)

inp_data = np.zeros((inp_shape[0], num_steps))
inp_data[:, 3] = 10

# Create the procs
denser = Dense(weights=weightr, num_message_bits=24)
densei = Dense(weights=weighti, num_message_bits=24)

vec = RFZero(shape=out_shape, uth=1,
decay_tau=0.1, freqs=20)

generator1 = io.source.RingBuffer(data=inp_data)
generator2 = io.source.RingBuffer(data=inp_data)
logger = io.sink.RingBuffer(shape=out_shape, buffer=num_steps)

# Connect the procs
generator1.s_out.connect(denser.s_in)
generator2.s_out.connect(densei.s_in)

denser.a_out.connect(vec.u_in)
densei.a_out.connect(vec.v_in)

vec.s_out.connect(logger.a_in)

# Run
try:
vec.run(condition=RunSteps(num_steps=num_steps),
run_cfg=Loihi2SimCfg())
out_data = logger.data.get().astype(np.int32)
finally:
vec.stop()

expected_out = np.array([661, 833, 932, 1007])

self.assertTrue(
np.all(expected_out == out_data[[0, 1, 2, 3], [11, 23, 36, 48]]))


if __name__ == '__main__':
unittest.main()

0 comments on commit 9b7eb1e

Please sign in to comment.