From ee78e3945161cfae7da8f5ca88f016ec3dcfd188 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 18 Nov 2023 18:15:01 -0800 Subject: [PATCH 01/10] add LeakyParallel neuron --- snntorch/_neurons/__init__.py | 3 +- snntorch/_neurons/leakyparallel.py | 267 +++++++++++++++++++++++++++++ 2 files changed, 269 insertions(+), 1 deletion(-) create mode 100644 snntorch/_neurons/leakyparallel.py diff --git a/snntorch/_neurons/__init__.py b/snntorch/_neurons/__init__.py index 2d84b8bb..c2ec2ed4 100644 --- a/snntorch/_neurons/__init__.py +++ b/snntorch/_neurons/__init__.py @@ -12,6 +12,7 @@ "alpha", "lapicque", "leaky", + "leakyparallel", "rleaky", "rsynaptic", "synaptic", @@ -32,4 +33,4 @@ from .sconv2dlstm import SConv2dLSTM from .slstm import SLSTM -# from .slstm import SLSTM +from .leakyparallel import LeakyParallel \ No newline at end of file diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py new file mode 100644 index 00000000..37e23f83 --- /dev/null +++ b/snntorch/_neurons/leakyparallel.py @@ -0,0 +1,267 @@ +from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +import torch +import torch.nn as nn + +class LeakyParallel(nn.Module): + """ + A parallel implementation of the Leaky neuron with an input linear layer. + All time steps are passed to the input at once. + + First-order leaky integrate-and-fire neuron model. + Input is assumed to be a current injection. + Membrane potential decays exponentially with rate beta. + For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`. + + .. math:: + + U[t+1] = βU[t] + I_{\\rm in}[t+1] + + + * :math:`I_{\\rm in}` - Input current + * :math:`U` - Membrane potential + * :math:`U_{\\rm thr}` - Membrane threshold + * :math:`β` - Membrane potential decay rate + + Example:: + + import torch + import torch.nn as nn + import snntorch as snn + + beta = 0.5 + + # Define Network + class Net(nn.Module): + def __init__(self): + super().__init__() + + # initialize layers + self.lif1 = snn.ParallelLeaky(input_size=784, hidden_size=128) + self.lif2 = snn.ParallelLeaky(input_size=128, hidden_size=10, beta=beta) + + def forward(self, x): + spk1 = self.lif1(x) + spk2 = self.lif2(spk1) + return spk2 + + + :param input_size: The number of expected features in the input `x` + :type input_size: int + + :param hidden_size: The number of features in the hidden state `h` + :type hidden_size: int + + :param beta: membrane potential decay rate. Clipped between 0 and 1 + during the forward-pass. May be a single-valued tensor (i.e., equal + decay rate for all neurons in a layer), or multi-valued (one weight per + neuron). If left unspecified, then the decay rates will be randomly initialized based on PyTorch's initialization for RNN. Defaults to None + :type beta: float or torch.tensor, optional + + :param bias: If `False`, then the layer does not use bias weights `b_ih` and `b_hh`. Defaults to True + :type bias: Bool, optional + + :param threshold: Threshold for :math:`mem` to reach in order to + generate a spike `S=1`. Defaults to 1 + :type threshold: float, optional + + :param dropout: If non-zero, introduces a Dropout layer on the RNN output with dropout probability equal to dropout. Defaults to 0 + :type dropout: float, optional + + :param spike_grad: Surrogate gradient for the term dS/dU. Defaults to + None (corresponds to ATan surrogate gradient. See + `snntorch.surrogate` for more options) + :type spike_grad: surrogate gradient function from snntorch.surrogate, + optional + + :param surrogate_disable: Disables surrogate gradients regardless of + `spike_grad` argument. Useful for ONNX compatibility. Defaults + to False + :type surrogate_disable: bool, Optional + + :param init_hidden: Instantiates state variables as instance variables. + Defaults to False + :type init_hidden: bool, optional + + :param inhibition: If `True`, suppresses all spiking other than the + neuron with the highest state. Defaults to False + :type inhibition: bool, optional + + :param learn_beta: Option to enable learnable beta. Defaults to False + :type learn_beta: bool, optional + + :param learn_threshold: Option to enable learnable threshold. Defaults + to False + :type learn_threshold: bool, optional + + + + Inputs: \\input_ + - **input_** of shape of shape `(L, H_{in})` for unbatched input, + or `(L, N, H_{in})` containing the features of the input sequence. + + Outputs: spk + - **spk** of shape `(L, batch, input_size)`: tensor containing the + output spikes. + + where: + + `L = sequence length` + + `N = batch size` + + `H_{in} = input_size` + + `H_{out} = hidden_size` + + Learnable Parameters: + - **rnn.weight_ih_l** (torch.Tensor) - the learnable input-hidden weights of shape (hidden_size, input_size) + - **rnn.weight_hh_l** (torch.Tensor) - the learnable hidden-hidden weights of the k-th layer which are sampled from `beta` of shape (hidden_size, hidden_size) + - **bias_ih_l** - the learnable input-hidden bias of the k-th layer, of shape (hidden_size) + - **bias_hh_l** - the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size) + - **threshold** (torch.Tensor) - optional learnable thresholds + must be manually passed in, of shape `1` or`` (input_size). + - **graded_spikes_factor** (torch.Tensor) - optional learnable graded spike factor + + """ + + def __init__( + self, + input_size, + hidden_size, + beta=None, + bias=True, + threshold=1.0, + dropout=0.0, + spike_grad=None, + surrogate_disable=False, + learn_beta=False, + learn_threshold=False, + graded_spikes_factor=1.0, + learn_graded_spikes_factor=False, + device=None, + dtype=None, + ): + super(LeakyParallel, self).__init__() + + self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu', + bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) + + if beta is not None: + beta = beta.clamp(0, 1) + + if spike_grad is None: + self.spike_grad = self.ATan.apply + else: + self.spike_grad = spike_grad + + self._threshold_buffer(threshold, learn_threshold) + self._graded_spikes_buffer( + graded_spikes_factor, learn_graded_spikes_factor + ) + + self.surrogate_disable = surrogate_disable + if self.surrogate_disable: + self.spike_grad = self._surrogate_bypass + + with torch.no_grad(): + if beta is not None: + # Set all weights to the scalar value of beta + if isinstance(beta, float) or isinstance(beta, int): + self.rnn.weight_hh_10.fill_(beta) + elif isinstance(beta, torch.Tensor) or isinstance(beta, torch.FloatTensor): + if len(beta) == 1: + self.rnn.weight_hh_10.fill_(beta) + elif len(beta) == hidden_size: + # Replace each value with the corresponding value in beta + for i in range(hidden_size): + self.rnn.weight_hh_l0.data[i].fill_(beta[i]) + else: + raise ValueError("Beta must be either a single value or of length 'hidden_size'.") + + if not learn_beta: + # Make the weights non-learnable + self.rnn.weight_hh_l0.requires_grad_(False) + + + def forward(self, input_): + mem = self.rnn(input_) + # mem[0] contains relu'd outputs, mem[1] contains final hidden state + mem_shift = mem[0] - self.threshold + spk = self.spike_grad(mem_shift) + spk = spk * self.graded_spikes_factor + return spk + + @staticmethod + def _surrogate_bypass(input_): + return (input_ > 0).float() + + @staticmethod + class ATan(torch.autograd.Function): + """ + Surrogate gradient of the Heaviside step function. + + **Forward pass:** Heaviside step function shifted. + + .. math:: + + S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ + 0 & \\text{if U < U$_{\\rm thr}$} + \\end{cases} + + **Backward pass:** Gradient of shifted arc-tan function. + + .. math:: + + S&≈\\frac{1}{π}\\text{arctan}(πU \\frac{α}{2}) \\\\ + \\frac{∂S}{∂U}&=\\frac{1}{π}\ + \\frac{1}{(1+(πU\\frac{α}{2})^2)} + + + :math:`alpha` defaults to 2, and can be modified by calling + ``surrogate.atan(alpha=2)``. + + Adapted from: + + *W. Fang, Z. Yu, Y. Chen, T. Masquelier, T. Huang, Y. Tian (2021) + Incorporating Learnable Membrane Time Constants to Enhance Learning + of Spiking Neural Networks. Proc. IEEE/CVF Int. Conf. Computer + Vision (ICCV), pp. 2661-2671.*""" + + @staticmethod + def forward(ctx, input_, alpha=2.0): + ctx.save_for_backward(input_) + ctx.alpha = alpha + out = (input_ > 0).float() + return out + + @staticmethod + def backward(ctx, grad_output): + (input_,) = ctx.saved_tensors + grad_input = grad_output.clone() + grad = ( + ctx.alpha + / 2 + / (1 + (torch.pi / 2 * ctx.alpha * input_).pow_(2)) + * grad_input + ) + return grad, None + + + + def _graded_spikes_buffer( + self, graded_spikes_factor, learn_graded_spikes_factor + ): + if not isinstance(graded_spikes_factor, torch.Tensor): + graded_spikes_factor = torch.as_tensor(graded_spikes_factor) + if learn_graded_spikes_factor: + self.graded_spikes_factor = nn.Parameter(graded_spikes_factor) + else: + self.register_buffer("graded_spikes_factor", graded_spikes_factor) + + def _threshold_buffer(self, threshold, learn_threshold): + if not isinstance(threshold, torch.Tensor): + threshold = torch.as_tensor(threshold) + if learn_threshold: + self.threshold = nn.Parameter(threshold) + else: + self.register_buffer("threshold", threshold) \ No newline at end of file From 1872916181ee002bb95d80192bc89712478727ab Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 18 Nov 2023 18:22:50 -0800 Subject: [PATCH 02/10] add beta buffer function --- snntorch/_neurons/leakyparallel.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 37e23f83..466e85d0 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -145,7 +145,7 @@ def __init__( self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu', bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) - + self._beta_buffer if beta is not None: beta = beta.clamp(0, 1) @@ -247,7 +247,12 @@ def backward(ctx, grad_output): return grad, None - + def _beta_buffer(self, beta, learn_beta): + if not isinstance(beta, torch.Tensor): + beta = torch.as_tensor(beta) # TODO: or .tensor() if no copy + if not learn_beta: + self.register_buffer("beta", beta) + def _graded_spikes_buffer( self, graded_spikes_factor, learn_graded_spikes_factor ): From 4c8f9933a60577aae4b8a259f7016993ee2811a4 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 18 Nov 2023 18:25:30 -0800 Subject: [PATCH 03/10] fix self instance of beta --- snntorch/_neurons/leakyparallel.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 466e85d0..2e92ff9a 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -146,8 +146,8 @@ def __init__( self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu', bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) self._beta_buffer - if beta is not None: - beta = beta.clamp(0, 1) + if self.beta is not None: + self.beta = self.beta.clamp(0, 1) if spike_grad is None: self.spike_grad = self.ATan.apply @@ -164,17 +164,17 @@ def __init__( self.spike_grad = self._surrogate_bypass with torch.no_grad(): - if beta is not None: - # Set all weights to the scalar value of beta - if isinstance(beta, float) or isinstance(beta, int): - self.rnn.weight_hh_10.fill_(beta) - elif isinstance(beta, torch.Tensor) or isinstance(beta, torch.FloatTensor): - if len(beta) == 1: - self.rnn.weight_hh_10.fill_(beta) - elif len(beta) == hidden_size: - # Replace each value with the corresponding value in beta + if self.beta is not None: + # Set all weights to the scalar value of self.beta + if isinstance(self.beta, float) or isinstance(self.beta, int): + self.rnn.weight_hh_10.fill_(self.beta) + elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor): + if len(self.beta) == 1: + self.rnn.weight_hh_10.fill_(self.beta) + elif len(self.beta) == hidden_size: + # Replace each value with the corresponding value in self.beta for i in range(hidden_size): - self.rnn.weight_hh_l0.data[i].fill_(beta[i]) + self.rnn.weight_hh_l0.data[i].fill_(self.beta[i]) else: raise ValueError("Beta must be either a single value or of length 'hidden_size'.") From 29b7840d328ce8a4001b61ac0e92a3f84a8d3823 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 18 Nov 2023 18:29:10 -0800 Subject: [PATCH 04/10] self instance fix in beta_buffer --- snntorch/_neurons/leakyparallel.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 2e92ff9a..5781a4d1 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -145,6 +145,7 @@ def __init__( self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu', bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) + self.beta = beta self._beta_buffer if self.beta is not None: self.beta = self.beta.clamp(0, 1) @@ -247,11 +248,11 @@ def backward(ctx, grad_output): return grad, None - def _beta_buffer(self, beta, learn_beta): - if not isinstance(beta, torch.Tensor): - beta = torch.as_tensor(beta) # TODO: or .tensor() if no copy + def _beta_buffer(self, learn_beta): + if not isinstance(self.beta, torch.Tensor): + self.beta = torch.as_tensor(self.beta) # TODO: or .tensor() if no copy if not learn_beta: - self.register_buffer("beta", beta) + self.register_buffer("beta", self.beta) def _graded_spikes_buffer( self, graded_spikes_factor, learn_graded_spikes_factor From b30e3f345db9e662d6e59b108b216265b0009379 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 18 Nov 2023 18:49:29 -0800 Subject: [PATCH 05/10] parallel leaky neuron bug fixes --- snntorch/_neurons/leakyparallel.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 5781a4d1..8b82d969 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -145,8 +145,8 @@ def __init__( self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu', bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) - self.beta = beta - self._beta_buffer + + self._beta_buffer(beta, learn_beta) if self.beta is not None: self.beta = self.beta.clamp(0, 1) @@ -168,10 +168,10 @@ def __init__( if self.beta is not None: # Set all weights to the scalar value of self.beta if isinstance(self.beta, float) or isinstance(self.beta, int): - self.rnn.weight_hh_10.fill_(self.beta) + self.rnn.weight_hh_l0.fill_(self.beta) elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor): if len(self.beta) == 1: - self.rnn.weight_hh_10.fill_(self.beta) + self.rnn.weight_hh_l0.fill_(self.beta[0]) elif len(self.beta) == hidden_size: # Replace each value with the corresponding value in self.beta for i in range(hidden_size): @@ -248,11 +248,11 @@ def backward(ctx, grad_output): return grad, None - def _beta_buffer(self, learn_beta): - if not isinstance(self.beta, torch.Tensor): - self.beta = torch.as_tensor(self.beta) # TODO: or .tensor() if no copy - if not learn_beta: - self.register_buffer("beta", self.beta) + def _beta_buffer(self, beta, learn_beta): + if not isinstance(beta, torch.Tensor): + if beta is not None: + beta = torch.as_tensor([beta]) # TODO: or .tensor() if no copy + self.register_buffer("beta", beta) def _graded_spikes_buffer( self, graded_spikes_factor, learn_graded_spikes_factor From 0a426904c6f5932072bf44ec4e62f94737ee338f Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 18 Nov 2023 19:15:10 -0800 Subject: [PATCH 06/10] add gradient hook to clip learning non-diagonal values --- snntorch/_neurons/leakyparallel.py | 57 +++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 8b82d969..b7c5b838 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -138,6 +138,7 @@ def __init__( learn_threshold=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False, + diagonal_enable=False, device=None, dtype=None, ): @@ -147,8 +148,16 @@ def __init__( bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) self._beta_buffer(beta, learn_beta) + self.hidden_size = hidden_size + if self.beta is not None: - self.beta = self.beta.clamp(0, 1) + self.beta = self.beta.clamp(0, 1) + + if diagonal_enable is False: + # Initial gradient and weights of w_hh are made diagonal + self._diagonal_enable(diagonal_enable) + # Register a gradient hook to clamp out non-diagonal matrices in backward pass + self.rnn.weight_hh_l0.register_hook(self.grad_hook) if spike_grad is None: self.spike_grad = self.ATan.apply @@ -164,20 +173,7 @@ def __init__( if self.surrogate_disable: self.spike_grad = self._surrogate_bypass - with torch.no_grad(): - if self.beta is not None: - # Set all weights to the scalar value of self.beta - if isinstance(self.beta, float) or isinstance(self.beta, int): - self.rnn.weight_hh_l0.fill_(self.beta) - elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor): - if len(self.beta) == 1: - self.rnn.weight_hh_l0.fill_(self.beta[0]) - elif len(self.beta) == hidden_size: - # Replace each value with the corresponding value in self.beta - for i in range(hidden_size): - self.rnn.weight_hh_l0.data[i].fill_(self.beta[i]) - else: - raise ValueError("Beta must be either a single value or of length 'hidden_size'.") + self._beta_to_weight_hh() if not learn_beta: # Make the weights non-learnable @@ -247,12 +243,41 @@ def backward(ctx, grad_output): ) return grad, None + def _diagonal_enable(self, diagonal_enable): + if diagonal_enable is False: + for i in range(self.hidden_size): + for j in range(self.hidden_size): + if i != j: + self.rnn.weight_hh_l0.data[i, j] = 0 + # self.rnn.weight_hh_l0.grad[i, j] = 0 + def grad_hook(self, grad): + # Create a mask that is 1 on the diagonal and 0 elsewhere + mask = torch.eye(self.hidden_size, self.hidden_size) + # Use the mask to zero out non-diagonal elements of the gradient + return grad * mask + + def _beta_to_weight_hh(self): + with torch.no_grad(): + if self.beta is not None: + # Set all weights to the scalar value of self.beta + if isinstance(self.beta, float) or isinstance(self.beta, int): + self.rnn.weight_hh_l0.fill_(self.beta) + elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor): + if len(self.beta) == 1: + self.rnn.weight_hh_l0.fill_(self.beta[0]) + elif len(self.beta) == self.hidden_size: + # Replace each value with the corresponding value in self.beta + for i in range(self.hidden_size): + self.rnn.weight_hh_l0.data[i].fill_(self.beta[i]) + else: + raise ValueError("Beta must be either a single value or of length 'hidden_size'.") + def _beta_buffer(self, beta, learn_beta): if not isinstance(beta, torch.Tensor): if beta is not None: beta = torch.as_tensor([beta]) # TODO: or .tensor() if no copy - self.register_buffer("beta", beta) + self.register_buffer("beta", beta) def _graded_spikes_buffer( self, graded_spikes_factor, learn_graded_spikes_factor From 284ce8903cbcc818a51ccc7a142cd72979cdb5f4 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 18 Nov 2023 19:33:06 -0800 Subject: [PATCH 07/10] device fix --- snntorch/_neurons/leakyparallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index b7c5b838..35942b58 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -252,8 +252,9 @@ def _diagonal_enable(self, diagonal_enable): # self.rnn.weight_hh_l0.grad[i, j] = 0 def grad_hook(self, grad): + device = grad.device # Create a mask that is 1 on the diagonal and 0 elsewhere - mask = torch.eye(self.hidden_size, self.hidden_size) + mask = torch.eye(self.hidden_size, self.hidden_size, device=device) # Use the mask to zero out non-diagonal elements of the gradient return grad * mask From 67e5a8515accc2a42f26bf57e8fd17c7f7ad2e51 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sun, 19 Nov 2023 11:31:57 -0800 Subject: [PATCH 08/10] fix leakyparallel weight_hh_l diagonal bug --- snntorch/_neurons/leakyparallel.py | 67 +++++++++++++++--------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 35942b58..de81266d 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -1,11 +1,11 @@ -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF import torch import torch.nn as nn class LeakyParallel(nn.Module): """ - A parallel implementation of the Leaky neuron with an input linear layer. + A parallel implementation of the Leaky neuron with a fused input linear layer. All time steps are passed to the input at once. + This implementation uses `torch.nn.RNN` to accelerate the implementation. First-order leaky integrate-and-fire neuron model. Input is assumed to be a current injection. @@ -22,6 +22,15 @@ class LeakyParallel(nn.Module): * :math:`U_{\\rm thr}` - Membrane threshold * :math:`β` - Membrane potential decay rate + Several differences between `LeakyParallel` and `Leaky` include: + * Negative hidden states are clipped due to the forced ReLU operation in RNN + * Linear weights are included in addition to recurrent weights + * `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise + * There is no explicit reset mechanism + * Several functions such as `init_hidden`, `output`, `inhibition`, and `state_quant` are unavailable in `LeakyParallel` + * Only the output spike is returned. Membrane potential is not accessible by default + * RNN uses a hidden matrix of size (num_hidden, num_hidden) to transform the hidden state vector. This would 'leak' the membrane potential between LIF neurons, and so the hidden matrix is forced to a diagonal matrix by default. This can be disabled by setting `weight_hh_enable=True`. + Example:: import torch @@ -36,8 +45,8 @@ def __init__(self): super().__init__() # initialize layers - self.lif1 = snn.ParallelLeaky(input_size=784, hidden_size=128) - self.lif2 = snn.ParallelLeaky(input_size=128, hidden_size=10, beta=beta) + self.lif1 = snn.LeakyParallel(input_size=784, hidden_size=128) + self.lif2 = snn.LeakyParallel(input_size=128, hidden_size=10, beta=beta) def forward(self, x): spk1 = self.lif1(x) @@ -78,14 +87,6 @@ def forward(self, x): to False :type surrogate_disable: bool, Optional - :param init_hidden: Instantiates state variables as instance variables. - Defaults to False - :type init_hidden: bool, optional - - :param inhibition: If `True`, suppresses all spiking other than the - neuron with the highest state. Defaults to False - :type inhibition: bool, optional - :param learn_beta: Option to enable learnable beta. Defaults to False :type learn_beta: bool, optional @@ -93,6 +94,9 @@ def forward(self, x): to False :type learn_threshold: bool, optional + :param weight_hh_enable: Option to set the hidden matrix to be dense or diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. Dense (True) would allow the membrane potential of one LIF neuron to influence all others, and follow the RNN default implementation. Defaults to False + :type weight_hh_enable: bool, optional + Inputs: \\input_ @@ -138,7 +142,7 @@ def __init__( learn_threshold=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False, - diagonal_enable=False, + weight_hh_enable=False, device=None, dtype=None, ): @@ -152,18 +156,24 @@ def __init__( if self.beta is not None: self.beta = self.beta.clamp(0, 1) - - if diagonal_enable is False: - # Initial gradient and weights of w_hh are made diagonal - self._diagonal_enable(diagonal_enable) - # Register a gradient hook to clamp out non-diagonal matrices in backward pass - self.rnn.weight_hh_l0.register_hook(self.grad_hook) if spike_grad is None: self.spike_grad = self.ATan.apply else: self.spike_grad = spike_grad + self._beta_to_weight_hh() + if weight_hh_enable is False: + # Initial gradient and weights of w_hh are made diagonal + self.weight_hh_enable() + # Register a gradient hook to clamp out non-diagonal matrices in backward pass + if learn_beta: + self.rnn.weight_hh_l0.register_hook(self.grad_hook) + + if not learn_beta: + # Make the weights non-learnable + self.rnn.weight_hh_l0.requires_grad_(False) + self._threshold_buffer(threshold, learn_threshold) self._graded_spikes_buffer( graded_spikes_factor, learn_graded_spikes_factor @@ -173,17 +183,12 @@ def __init__( if self.surrogate_disable: self.spike_grad = self._surrogate_bypass - self._beta_to_weight_hh() - - if not learn_beta: - # Make the weights non-learnable - self.rnn.weight_hh_l0.requires_grad_(False) - - def forward(self, input_): mem = self.rnn(input_) # mem[0] contains relu'd outputs, mem[1] contains final hidden state mem_shift = mem[0] - self.threshold + # print(mem[0]) + # print(self.rnn.weight_hh_l0) spk = self.spike_grad(mem_shift) spk = spk * self.graded_spikes_factor return spk @@ -243,13 +248,9 @@ def backward(ctx, grad_output): ) return grad, None - def _diagonal_enable(self, diagonal_enable): - if diagonal_enable is False: - for i in range(self.hidden_size): - for j in range(self.hidden_size): - if i != j: - self.rnn.weight_hh_l0.data[i, j] = 0 - # self.rnn.weight_hh_l0.grad[i, j] = 0 + def weight_hh_enable(self): + mask = torch.eye(self.hidden_size, self.hidden_size) + self.rnn.weight_hh_l0.data = self.rnn.weight_hh_l0.data * mask def grad_hook(self, grad): device = grad.device From 21e93792c6e4d4f945c8845c75e8a1c50c361e9c Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sun, 19 Nov 2023 14:55:18 -0800 Subject: [PATCH 09/10] update leaky parallel docstrings --- snntorch/_neurons/leakyparallel.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index de81266d..c05627ca 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -23,6 +23,7 @@ class LeakyParallel(nn.Module): * :math:`β` - Membrane potential decay rate Several differences between `LeakyParallel` and `Leaky` include: + * Negative hidden states are clipped due to the forced ReLU operation in RNN * Linear weights are included in addition to recurrent weights * `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise @@ -38,6 +39,11 @@ class LeakyParallel(nn.Module): import snntorch as snn beta = 0.5 + num_inputs = 784 + num_hidden = 128 + num_outputs = 10 + batch_size = 128 + x = torch.rand((num_steps, batch_size, num_inputs)) # Define Network class Net(nn.Module): @@ -45,8 +51,8 @@ def __init__(self): super().__init__() # initialize layers - self.lif1 = snn.LeakyParallel(input_size=784, hidden_size=128) - self.lif2 = snn.LeakyParallel(input_size=128, hidden_size=10, beta=beta) + self.lif1 = snn.LeakyParallel(input_size=num_inputs, hidden_size=num_hidden) # randomly initialize recurrent weights + self.lif2 = snn.LeakyParallel(input_size=num_hidden, hidden_size=num_outputs, beta=beta, learn_beta=True) # learnable recurrent weights initialized at beta def forward(self, x): spk1 = self.lif1(x) @@ -94,11 +100,13 @@ def forward(self, x): to False :type learn_threshold: bool, optional - :param weight_hh_enable: Option to set the hidden matrix to be dense or diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. Dense (True) would allow the membrane potential of one LIF neuron to influence all others, and follow the RNN default implementation. Defaults to False + :param weight_hh_enable: Option to set the hidden matrix to be dense or + diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. + Dense (True) would allow the membrane potential of one LIF neuron to + influence all others, and follow the RNN default implementation. Defaults to False :type weight_hh_enable: bool, optional - Inputs: \\input_ - **input_** of shape of shape `(L, H_{in})` for unbatched input, or `(L, N, H_{in})` containing the features of the input sequence. @@ -186,9 +194,7 @@ def __init__( def forward(self, input_): mem = self.rnn(input_) # mem[0] contains relu'd outputs, mem[1] contains final hidden state - mem_shift = mem[0] - self.threshold - # print(mem[0]) - # print(self.rnn.weight_hh_l0) + mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0 spk = self.spike_grad(mem_shift) spk = spk * self.graded_spikes_factor return spk From 38e1cdd6682e87a8d9722697ecc95a5e6aa19abe Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sun, 19 Nov 2023 14:59:41 -0800 Subject: [PATCH 10/10] update docs for leakyparallel --- docs/snn.neurons_leakyparallel.rst | 9 +++++++++ docs/snntorch.rst | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 docs/snn.neurons_leakyparallel.rst diff --git a/docs/snn.neurons_leakyparallel.rst b/docs/snn.neurons_leakyparallel.rst new file mode 100644 index 00000000..87edbd08 --- /dev/null +++ b/docs/snn.neurons_leakyparallel.rst @@ -0,0 +1,9 @@ +=========================== +snn.LeakyParallel +=========================== + + +.. automodule:: snntorch._neurons.leakyparallel + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/snntorch.rst b/docs/snntorch.rst index 5f0d3828..8aa09acb 100644 --- a/docs/snntorch.rst +++ b/docs/snntorch.rst @@ -25,6 +25,9 @@ At present, the neurons available in :mod:`snntorch` are variants of the Leaky I * **Lapicque** - Lapicque's RC Neuron Model * **Alpha** - Alpha Membrane Model +Neuron models that accelerate training require passing data in parallel. Available neurons include: +* **LeakyParallel** - 1st Order Leaky Integrate-and-Fire Neuron + Additional models include spiking-LSTMs and spiking-ConvLSTMs: * **SLSTM** - Spiking long short-term memory cell with state-thresholding @@ -35,7 +38,7 @@ Additional models include spiking-LSTMs and spiking-ConvLSTMs: How to use snnTorch's neuron models ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The following arguments are common across all neuron models: +The following arguments are common across most neuron models: * **threshold** - firing threshold of the neuron * **spike_grad** - surrogate gradient function (see :mod:`snntorch.surrogate`)