Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LeakyParallel Neuron Implementation #266

Merged
merged 10 commits into from
Nov 19, 2023
9 changes: 9 additions & 0 deletions docs/snn.neurons_leakyparallel.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
===========================
snn.LeakyParallel
===========================


.. automodule:: snntorch._neurons.leakyparallel
:members:
:undoc-members:
:show-inheritance:
5 changes: 4 additions & 1 deletion docs/snntorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ At present, the neurons available in :mod:`snntorch` are variants of the Leaky I
* **Lapicque** - Lapicque's RC Neuron Model
* **Alpha** - Alpha Membrane Model

Neuron models that accelerate training require passing data in parallel. Available neurons include:
* **LeakyParallel** - 1st Order Leaky Integrate-and-Fire Neuron

Additional models include spiking-LSTMs and spiking-ConvLSTMs:

* **SLSTM** - Spiking long short-term memory cell with state-thresholding
Expand All @@ -35,7 +38,7 @@ Additional models include spiking-LSTMs and spiking-ConvLSTMs:
How to use snnTorch's neuron models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The following arguments are common across all neuron models:
The following arguments are common across most neuron models:

* **threshold** - firing threshold of the neuron
* **spike_grad** - surrogate gradient function (see :mod:`snntorch.surrogate`)
Expand Down
3 changes: 2 additions & 1 deletion snntorch/_neurons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"alpha",
"lapicque",
"leaky",
"leakyparallel",
"rleaky",
"rsynaptic",
"synaptic",
Expand All @@ -32,4 +33,4 @@
from .sconv2dlstm import SConv2dLSTM
from .slstm import SLSTM

# from .slstm import SLSTM
from .leakyparallel import LeakyParallel
306 changes: 306 additions & 0 deletions snntorch/_neurons/leakyparallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import torch
import torch.nn as nn

class LeakyParallel(nn.Module):
"""
A parallel implementation of the Leaky neuron with a fused input linear layer.
All time steps are passed to the input at once.
This implementation uses `torch.nn.RNN` to accelerate the implementation.

First-order leaky integrate-and-fire neuron model.
Input is assumed to be a current injection.
Membrane potential decays exponentially with rate beta.
For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`.

.. math::

U[t+1] = βU[t] + I_{\\rm in}[t+1]


* :math:`I_{\\rm in}` - Input current
* :math:`U` - Membrane potential
* :math:`U_{\\rm thr}` - Membrane threshold
* :math:`β` - Membrane potential decay rate

Several differences between `LeakyParallel` and `Leaky` include:

* Negative hidden states are clipped due to the forced ReLU operation in RNN
* Linear weights are included in addition to recurrent weights
* `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise
* There is no explicit reset mechanism
* Several functions such as `init_hidden`, `output`, `inhibition`, and `state_quant` are unavailable in `LeakyParallel`
* Only the output spike is returned. Membrane potential is not accessible by default
* RNN uses a hidden matrix of size (num_hidden, num_hidden) to transform the hidden state vector. This would 'leak' the membrane potential between LIF neurons, and so the hidden matrix is forced to a diagonal matrix by default. This can be disabled by setting `weight_hh_enable=True`.

Example::

import torch
import torch.nn as nn
import snntorch as snn

beta = 0.5
num_inputs = 784
num_hidden = 128
num_outputs = 10
batch_size = 128
x = torch.rand((num_steps, batch_size, num_inputs))

# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()

# initialize layers
self.lif1 = snn.LeakyParallel(input_size=num_inputs, hidden_size=num_hidden) # randomly initialize recurrent weights
self.lif2 = snn.LeakyParallel(input_size=num_hidden, hidden_size=num_outputs, beta=beta, learn_beta=True) # learnable recurrent weights initialized at beta

def forward(self, x):
spk1 = self.lif1(x)
spk2 = self.lif2(spk1)
return spk2


:param input_size: The number of expected features in the input `x`
:type input_size: int

:param hidden_size: The number of features in the hidden state `h`
:type hidden_size: int

:param beta: membrane potential decay rate. Clipped between 0 and 1
during the forward-pass. May be a single-valued tensor (i.e., equal
decay rate for all neurons in a layer), or multi-valued (one weight per
neuron). If left unspecified, then the decay rates will be randomly initialized based on PyTorch's initialization for RNN. Defaults to None
:type beta: float or torch.tensor, optional

:param bias: If `False`, then the layer does not use bias weights `b_ih` and `b_hh`. Defaults to True
:type bias: Bool, optional

:param threshold: Threshold for :math:`mem` to reach in order to
generate a spike `S=1`. Defaults to 1
:type threshold: float, optional

:param dropout: If non-zero, introduces a Dropout layer on the RNN output with dropout probability equal to dropout. Defaults to 0
:type dropout: float, optional

:param spike_grad: Surrogate gradient for the term dS/dU. Defaults to
None (corresponds to ATan surrogate gradient. See
`snntorch.surrogate` for more options)
:type spike_grad: surrogate gradient function from snntorch.surrogate,
optional

:param surrogate_disable: Disables surrogate gradients regardless of
`spike_grad` argument. Useful for ONNX compatibility. Defaults
to False
:type surrogate_disable: bool, Optional

:param learn_beta: Option to enable learnable beta. Defaults to False
:type learn_beta: bool, optional

:param learn_threshold: Option to enable learnable threshold. Defaults
to False
:type learn_threshold: bool, optional

:param weight_hh_enable: Option to set the hidden matrix to be dense or
diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works.
Dense (True) would allow the membrane potential of one LIF neuron to
influence all others, and follow the RNN default implementation. Defaults to False
:type weight_hh_enable: bool, optional


Inputs: \\input_
- **input_** of shape of shape `(L, H_{in})` for unbatched input,
or `(L, N, H_{in})` containing the features of the input sequence.

Outputs: spk
- **spk** of shape `(L, batch, input_size)`: tensor containing the
output spikes.

where:

`L = sequence length`

`N = batch size`

`H_{in} = input_size`

`H_{out} = hidden_size`

Learnable Parameters:
- **rnn.weight_ih_l** (torch.Tensor) - the learnable input-hidden weights of shape (hidden_size, input_size)
- **rnn.weight_hh_l** (torch.Tensor) - the learnable hidden-hidden weights of the k-th layer which are sampled from `beta` of shape (hidden_size, hidden_size)
- **bias_ih_l** - the learnable input-hidden bias of the k-th layer, of shape (hidden_size)
- **bias_hh_l** - the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size)
- **threshold** (torch.Tensor) - optional learnable thresholds
must be manually passed in, of shape `1` or`` (input_size).
- **graded_spikes_factor** (torch.Tensor) - optional learnable graded spike factor

"""

def __init__(
self,
input_size,
hidden_size,
beta=None,
bias=True,
threshold=1.0,
dropout=0.0,
spike_grad=None,
surrogate_disable=False,
learn_beta=False,
learn_threshold=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
weight_hh_enable=False,
device=None,
dtype=None,
):
super(LeakyParallel, self).__init__()

self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu',
bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype)

self._beta_buffer(beta, learn_beta)
self.hidden_size = hidden_size

if self.beta is not None:
self.beta = self.beta.clamp(0, 1)

if spike_grad is None:
self.spike_grad = self.ATan.apply
else:
self.spike_grad = spike_grad

self._beta_to_weight_hh()
if weight_hh_enable is False:
# Initial gradient and weights of w_hh are made diagonal
self.weight_hh_enable()
# Register a gradient hook to clamp out non-diagonal matrices in backward pass
if learn_beta:
self.rnn.weight_hh_l0.register_hook(self.grad_hook)

if not learn_beta:
# Make the weights non-learnable
self.rnn.weight_hh_l0.requires_grad_(False)

self._threshold_buffer(threshold, learn_threshold)
self._graded_spikes_buffer(
graded_spikes_factor, learn_graded_spikes_factor
)

self.surrogate_disable = surrogate_disable
if self.surrogate_disable:
self.spike_grad = self._surrogate_bypass

def forward(self, input_):
mem = self.rnn(input_)
# mem[0] contains relu'd outputs, mem[1] contains final hidden state
mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0
spk = self.spike_grad(mem_shift)
spk = spk * self.graded_spikes_factor
return spk

@staticmethod
def _surrogate_bypass(input_):
return (input_ > 0).float()

@staticmethod
class ATan(torch.autograd.Function):
"""
Surrogate gradient of the Heaviside step function.

**Forward pass:** Heaviside step function shifted.

.. math::

S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\
0 & \\text{if U < U$_{\\rm thr}$}
\\end{cases}

**Backward pass:** Gradient of shifted arc-tan function.

.. math::

S&≈\\frac{1}{π}\\text{arctan}(πU \\frac{α}{2}) \\\\
\\frac{∂S}{∂U}&=\\frac{1}{π}\
\\frac{1}{(1+(πU\\frac{α}{2})^2)}


:math:`alpha` defaults to 2, and can be modified by calling
``surrogate.atan(alpha=2)``.

Adapted from:

*W. Fang, Z. Yu, Y. Chen, T. Masquelier, T. Huang, Y. Tian (2021)
Incorporating Learnable Membrane Time Constants to Enhance Learning
of Spiking Neural Networks. Proc. IEEE/CVF Int. Conf. Computer
Vision (ICCV), pp. 2661-2671.*"""

@staticmethod
def forward(ctx, input_, alpha=2.0):
ctx.save_for_backward(input_)
ctx.alpha = alpha
out = (input_ > 0).float()
return out

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad = (
ctx.alpha
/ 2
/ (1 + (torch.pi / 2 * ctx.alpha * input_).pow_(2))
* grad_input
)
return grad, None

def weight_hh_enable(self):
mask = torch.eye(self.hidden_size, self.hidden_size)
self.rnn.weight_hh_l0.data = self.rnn.weight_hh_l0.data * mask

def grad_hook(self, grad):
device = grad.device
# Create a mask that is 1 on the diagonal and 0 elsewhere
mask = torch.eye(self.hidden_size, self.hidden_size, device=device)
# Use the mask to zero out non-diagonal elements of the gradient
return grad * mask

def _beta_to_weight_hh(self):
with torch.no_grad():
if self.beta is not None:
# Set all weights to the scalar value of self.beta
if isinstance(self.beta, float) or isinstance(self.beta, int):
self.rnn.weight_hh_l0.fill_(self.beta)
elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor):
if len(self.beta) == 1:
self.rnn.weight_hh_l0.fill_(self.beta[0])
elif len(self.beta) == self.hidden_size:
# Replace each value with the corresponding value in self.beta
for i in range(self.hidden_size):
self.rnn.weight_hh_l0.data[i].fill_(self.beta[i])
else:
raise ValueError("Beta must be either a single value or of length 'hidden_size'.")

def _beta_buffer(self, beta, learn_beta):
if not isinstance(beta, torch.Tensor):
if beta is not None:
beta = torch.as_tensor([beta]) # TODO: or .tensor() if no copy
self.register_buffer("beta", beta)

def _graded_spikes_buffer(
self, graded_spikes_factor, learn_graded_spikes_factor
):
if not isinstance(graded_spikes_factor, torch.Tensor):
graded_spikes_factor = torch.as_tensor(graded_spikes_factor)
if learn_graded_spikes_factor:
self.graded_spikes_factor = nn.Parameter(graded_spikes_factor)
else:
self.register_buffer("graded_spikes_factor", graded_spikes_factor)

def _threshold_buffer(self, threshold, learn_threshold):
if not isinstance(threshold, torch.Tensor):
threshold = torch.as_tensor(threshold)
if learn_threshold:
self.threshold = nn.Parameter(threshold)
else:
self.register_buffer("threshold", threshold)
Loading