Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plateau Neuron - Fixed Point Model #781

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
0263677
Merge pull request #6 from lava-nc/main
kds300 Aug 16, 2023
d70b6e7
plateau- fixed point implementation of Plateau neuron model
kds300 Sep 1, 2023
3f649d1
Merge branch 'main' into plateau-neuron
kds300 Sep 5, 2023
6647107
Merge branch 'main' into plateau-neuron
mgkwill Oct 18, 2023
5db90d7
Fix codacy issues
Oct 18, 2023
bab1fd1
fixed typo 'bgy' -> 'by' in PyPlateauModelFixed.run_spk docstring
kds300 Oct 19, 2023
5d8b10c
changed 'potential' to 'voltage' in the Plateau class docstring
kds300 Oct 19, 2023
1961780
changed variable types from float to int for dv_dend, dv_soma, vth_de…
kds300 Oct 19, 2023
4be4918
updated Plateau process unit test to specify integer inputs for dv_de…
kds300 Oct 19, 2023
70433d4
added input validation for dv_dend, dv_soma, vth_dend, vth_soma, up_dur
kds300 Oct 19, 2023
4b26480
fixed typos 'du_dend' -> 'dv_dend' and 'du_soma' -> 'dv_soma' in PyPl…
kds300 Oct 19, 2023
fb217ba
reduced decay_unity by 1 (4096 -> 4095) to agree with limits on dv_de…
kds300 Oct 19, 2023
d1c66f0
updated tests to agree with changes to dv_dend, dv_soma allowed range
kds300 Oct 19, 2023
c469bd8
Removed custom SpikeGen process and replaced with source RingBuffers
kds300 Oct 19, 2023
e8a8215
Merge branch 'plateau-neuron' into plateau-neuron
kds300 Oct 19, 2023
6ccbf74
Merge pull request #7 from mgkwill/plateau-neuron
kds300 Oct 19, 2023
cea6296
fixed codacy issues
kds300 Oct 22, 2023
6e3744d
Merge branch 'plateau-neuron' of https://github.com/kds300/lava into …
kds300 Oct 22, 2023
0c6bf24
Merge branch 'main' into plateau-neuron
mgkwill Nov 14, 2023
d2aed48
Merge branch 'main' into plateau-neuron
mgkwill Aug 8, 2024
6dd7e3b
Merge branch 'lava-nc:main' into plateau-neuron
kds300 Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions src/lava/proc/plateau/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (C) 2023 Intel Corporation
mgkwill marked this conversation as resolved.
Show resolved Hide resolved
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import numpy as np
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.proc.plateau.process import Plateau


@implements(proc=Plateau, protocol=LoihiProtocol)
@requires(CPU)
@tag("fixed_pt")
class PyPlateauModelFixed(PyLoihiProcessModel):
""" Implementation of Plateau neuron process in fixed point precision.

Precisions of state variables

- dv_dend : unsigned 12-bit integer (0 to 4095)
- dv_soma : unsigned 12-bit integer (0 to 4095)
- vth_dend : unsigned 17-bit integer (0 to 131071)
- vth_soma : unsigned 17-bit integer (0 to 131071)
kds300 marked this conversation as resolved.
Show resolved Hide resolved
- up_dur : unsigned 8-bit integer (0 to 255)
"""

a_dend_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16)
a_soma_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, np.int16, precision=16)
s_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=24)
v_dend: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
v_soma: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=24)
dv_dend: int = LavaPyType(int, np.uint16, precision=12)
dv_soma: int = LavaPyType(int, np.uint16, precision=12)
vth_dend: int = LavaPyType(int, np.int32, precision=17)
vth_soma: int = LavaPyType(int, np.int32, precision=17)
up_dur: int = LavaPyType(int, np.uint16, precision=8)
up_state: int = LavaPyType(np.ndarray, np.uint16, precision=8)

def __init__(self, proc_params):
super(PyPlateauModelFixed, self).__init__(proc_params)
self._validate_inputs(proc_params)
self.uv_bitwidth = 24
self.max_uv_val = 2 ** (self.uv_bitwidth - 1)
self.decay_shift = 12
self.decay_unity = 2 ** self.decay_shift - 1
self.vth_shift = 6
self.act_shift = 6
self.isthrscaled = False
self.effective_vth_dend = None
self.effective_vth_soma = None
self.s_out_buff = None

def _validate_var(self, var, var_type, min_val, max_val, var_name):
if not isinstance(var, var_type):
raise ValueError(f"'{var_name}' must have type {var_type}")
if var < min_val or var > max_val:
raise ValueError(
f"'{var_name}' must be in range [{min_val}, {max_val}]"
)

def _validate_inputs(self, proc_params):
self._validate_var(proc_params['dv_dend'], int, 0, 4095, 'dv_dend')
self._validate_var(proc_params['dv_soma'], int, 0, 4095, 'dv_soma')
self._validate_var(proc_params['vth_dend'], int, 0, 131071, 'vth_dend')
self._validate_var(proc_params['vth_soma'], int, 0, 131071, 'vth_soma')
self._validate_var(proc_params['up_dur'], int, 0, 255, 'up_dur')

def scale_threshold(self):
self.effective_vth_dend = np.left_shift(self.vth_dend, self.vth_shift)
self.effective_vth_soma = np.left_shift(self.vth_soma, self.vth_shift)
self.isthrscaled = True

def subthr_dynamics(
self,
activation_dend_in: np.ndarray,
activation_soma_in: np.ndarray
):
"""Run the sub-threshold dynamics for both the dendrite and soma of the
neuron. Both use 'leaky integration'.
"""
for v, dv, a_in in [
(self.v_dend, self.dv_dend, activation_dend_in),
(self.v_soma, self.dv_soma, activation_soma_in),
]:
decayed_volt = np.int64(v) * (self.decay_unity - dv)
decayed_volt = np.sign(decayed_volt) * np.right_shift(
np.abs(decayed_volt), 12
)
decayed_volt = np.int32(decayed_volt)
updated_volt = decayed_volt + np.left_shift(a_in, self.act_shift)

neg_voltage_limit = -np.int32(self.max_uv_val) + 1
pos_voltage_limit = np.int32(self.max_uv_val) - 1

v[:] = np.clip(
updated_volt, neg_voltage_limit, pos_voltage_limit
)

def update_up_state(self):
"""Decrements the up state (if necessary) and checks v_dend to see if
up state needs to be (re)set. If up state is (re)set, then v_dend is
reset to 0.
"""
self.up_state[self.up_state > 0] -= 1
self.up_state[self.v_dend > self.effective_vth_dend] = self.up_dur
self.v_dend[self.v_dend > self.effective_vth_dend] = 0

def soma_spike_and_reset(self):
"""Check the spiking conditions for the plateau soma. Checks if:
v_soma > v_th_soma
up_state > 0

For any neurons n that satisfy both conditions, sets:
s_out_buff[n] = True
v_soma = 0
"""
s_out_buff = np.logical_and(
self.v_soma > self.effective_vth_soma,
self.up_state > 0
)
self.v_soma[s_out_buff] = 0

return s_out_buff

def run_spk(self):
"""The run function that performs the actual computation during
execution orchestrated by a PyLoihiProcessModel using the
LoihiProtocol.
"""

# Receive synaptic input
a_dend_in_data = self.a_dend_in.recv()
a_soma_in_data = self.a_soma_in.recv()

# Check threshold scaling
if not self.isthrscaled:
self.scale_threshold()

self.subthr_dynamics(a_dend_in_data, a_soma_in_data)

self.update_up_state()

self.s_out_buff = self.soma_spike_and_reset()

self.s_out.send(self.s_out_buff)
70 changes: 70 additions & 0 deletions src/lava/proc/plateau/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import typing as ty
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort


class Plateau(AbstractProcess):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor point. Since this is a LIF neuron, did you consider to let it inherit from the AbstractLIF process, and adding a 'LIF' in the class name? Not sure if it makes sense in this specific example, though. Up to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally didn't do this since I thought of it as two combined LIF neurons instead of a modified LIF neuron. I'll look over the AbstractLIF process and see if it makes sense to inherit from that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the class should inherit from AbstrictLIF, since it does not have current or bias vars.

"""Plateau Neuron Process.

Couples two modified LIF dynamics. The neuron posesses two voltages,
v_dend and v_soma. Both follow sub-threshold LIF dynamics. When v_dend
crosses v_th_dend, it resets and sets the up_state to the value up_dur.
The supra-threshold behavior of v_soma depends on up_state:
if up_state == 0:
v_soma follows sub-threshold dynamics
if up_state > 0:
v_soma resets and the neuron sends out a spike

Parameters
----------
shape : tuple(int)
Number and topology of Plateau neurons.
dv_dend : int
Inverse of the decay time-constant for the dendrite voltage.
dv_soma : int
Inverse of the decay time-constant for the soma voltage.
vth_dend : int
Dendrite threshold voltage, exceeding which, the neuron will enter the
UP state.
vth_soma : int
Soma threshold voltage, exceeding which, the neuron will spike if it is
also in the UP state.
up_dur : int
The duration, in timesteps, of the UP state.
"""
def __init__(
self,
shape: ty.Tuple[int, ...],
dv_dend: int,
dv_soma: int,
vth_dend: int,
vth_soma: int,
up_dur: int,
name: ty.Optional[str] = None,
):
super().__init__(
shape=shape,
dv_dend=dv_dend,
dv_soma=dv_soma,
name=name,
up_dur=up_dur,
vth_dend=vth_dend,
vth_soma=vth_soma
)
self.a_dend_in = InPort(shape=shape)
self.a_soma_in = InPort(shape=shape)
self.s_out = OutPort(shape=shape)
self.v_dend = Var(shape=shape, init=0)
self.v_soma = Var(shape=shape, init=0)
self.dv_dend = Var(shape=(1,), init=dv_dend)
self.dv_soma = Var(shape=(1,), init=dv_soma)
self.vth_dend = Var(shape=(1,), init=vth_dend)
self.vth_soma = Var(shape=(1,), init=vth_soma)
self.up_dur = Var(shape=(1,), init=up_dur)
self.up_state = Var(shape=shape, init=0)
137 changes: 137 additions & 0 deletions tests/lava/proc/plateau/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (C) 2023 Intel Corporation
kds300 marked this conversation as resolved.
Show resolved Hide resolved
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import unittest
import numpy as np
from lava.proc.plateau.process import Plateau
from lava.proc.dense.process import Dense
from lava.proc.io.source import RingBuffer as Source
from lava.magma.core.run_configs import Loihi2SimCfg
from lava.magma.core.run_conditions import RunSteps
from lava.tests.lava.proc.lif.test_models import VecRecvProcess


def create_spike_source(spike_list, n_indices, n_timesteps):
"""Use list of spikes [(idx, timestep), ...] to create a RingBuffer source
with data shape (n_indices, n_timesteps) and spikes at all specified points
in the spike_list.
"""
data = np.zeros(shape=(n_indices, n_timesteps))
for idx, timestep in spike_list:
data[idx, timestep - 1] = 1
return Source(data=data)


class TestPlateauProcessModelsFixed(unittest.TestCase):
"""Tests for the fixed point Plateau process models."""
def test_fixed_max_decay(self):
"""
Tests fixed point Plateau with max voltage decays.
"""
shape = (3,)
num_steps = 20
spikes_in_dend = [(0, 5), (1, 5), (2, 5)]
spikes_in_soma = [(0, 3), (1, 10), (2, 17)]
sg_dend = create_spike_source(spikes_in_dend, shape[0], num_steps)
sg_soma = create_spike_source(spikes_in_soma, shape[0], num_steps)
dense_dend = Dense(weights=2 * np.diag(np.ones(shape=shape)))
dense_soma = Dense(weights=2 * np.diag(np.ones(shape=shape)))
plat = Plateau(
shape=shape,
dv_dend=4095,
dv_soma=4095,
vth_soma=1,
vth_dend=1,
up_dur=10
)
vr = VecRecvProcess(shape=(num_steps, shape[0]))
sg_dend.s_out.connect(dense_dend.s_in)
sg_soma.s_out.connect(dense_soma.s_in)
dense_dend.a_out.connect(plat.a_dend_in)
dense_soma.a_out.connect(plat.a_soma_in)
plat.s_out.connect(vr.s_in)
# run model
plat.run(RunSteps(num_steps), Loihi2SimCfg(select_tag='fixed_pt'))
test_spk_data = vr.spk_data.get()
plat.stop()
# Gold standard for the test
expected_spk_data = np.zeros((num_steps, shape[0]))
# Neuron 2 should spike when receiving soma input
expected_spk_data[10, 1] = 1
self.assertTrue(np.all(expected_spk_data == test_spk_data))

def test_up_dur(self):
"""
Tests that the UP state lasts for the time specified by the model.
Checks that up_state decreases by one each time step after activation.
"""
shape = (1,)
num_steps = 10
spikes_in_dend = [(0, 3)]
sg_dend = create_spike_source(spikes_in_dend, shape[0], num_steps)
dense_dend = Dense(weights=2 * (np.diag(np.ones(shape=shape))))
plat = Plateau(
shape=shape,
dv_dend=4095,
dv_soma=4095,
vth_soma=1,
vth_dend=1,
up_dur=5
)
sg_dend.s_out.connect(dense_dend.s_in)
dense_dend.a_out.connect(plat.a_dend_in)
# run model
test_up_state = []
for _ in range(num_steps):
plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt'))
test_up_state.append(plat.up_state.get().astype(int)[0])
plat.stop()
# Gold standard for the test
# UP state active time steps 4 - 9 (5 timesteps)
# this is delayed by one b.c. of the Dense process
expected_up_state = [0, 0, 0, 5, 4, 3, 2, 1, 0, 0]
self.assertListEqual(expected_up_state, test_up_state)

def test_fixed_dvs(self):
"""
Tests fixed point Plateau voltage decays.
"""
shape = (1,)
num_steps = 10
spikes_in = [(0, 1)]
sg_dend = create_spike_source(spikes_in, shape[0], num_steps)
sg_soma = create_spike_source(spikes_in, shape[0], num_steps)
dense_dend = Dense(weights=100 * np.diag(np.ones(shape=shape)))
dense_soma = Dense(weights=100 * np.diag(np.ones(shape=shape)))
plat = Plateau(
shape=shape,
dv_dend=2048,
dv_soma=1024,
vth_soma=100,
vth_dend=100,
up_dur=10
)
sg_dend.s_out.connect(dense_dend.s_in)
sg_soma.s_out.connect(dense_soma.s_in)
dense_dend.a_out.connect(plat.a_dend_in)
dense_soma.a_out.connect(plat.a_soma_in)
# run model
test_v_dend = []
test_v_soma = []
for _ in range(num_steps):
plat.run(RunSteps(1), Loihi2SimCfg(select_tag='fixed_pt'))
test_v_dend.append(plat.v_dend.get().astype(int)[0])
test_v_soma.append(plat.v_soma.get().astype(int)[0])
plat.stop()
# Gold standard for the test
# 100<<6 = 6400 -- initial value at time step 2
expected_v_dend = [
0, 6400, 3198, 1598, 798, 398, 198, 98, 48, 23
]
expected_v_soma = [
0, 6400, 4798, 3597, 2696, 2021, 1515, 1135, 850, 637
]
self.assertListEqual(expected_v_dend, test_v_dend)
self.assertListEqual(expected_v_soma, test_v_soma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check if lintin passes?
'flakeheaven lint src/lava tests'

My guess is that there are a few points that must change, including missing lines at the end of files. Not functionally relevant, but important to keep a clean code base :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've run the linting, (flakeheaven and bandit) and they pass on my local copy of the code. Also, I have the empty line at the end of the files locally, but it doesn't seem to show up on the github versions. Does github just not show the empty line at the end?

29 changes: 29 additions & 0 deletions tests/lava/proc/plateau/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/


import unittest
from lava.proc.plateau.process import Plateau


class TestPlateauProcess(unittest.TestCase):
"""Tests for Plateau class"""
def test_init(self):
"""Tests instantiation of Plateau"""
plat = Plateau(
shape=(100,),
dv_dend=100,
dv_soma=1,
vth_dend=10,
vth_soma=1,
up_dur=10,
name="Plat"
)

self.assertEqual(plat.name, "Plat")
self.assertEqual(plat.dv_dend.init, 100)
self.assertEqual(plat.dv_soma.init, 1)
self.assertEqual(plat.vth_dend.init, 10)
self.assertEqual(plat.vth_soma.init, 1)
self.assertEqual(plat.up_dur.init, 10)