diff --git a/ChangeLog b/ChangeLog index 94481da..6a5e325 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,6 +1,11 @@ CHANGES ======= +v1.2.1 +------ +* Support Sinabs 2.0 +* Support individual time constants + v1.1.2 ------ * Update README diff --git a/cuda/bindings.cu b/cuda/bindings.cu index 40e18ac..31d1059 100644 --- a/cuda/bindings.cu +++ b/cuda/bindings.cu @@ -16,10 +16,10 @@ void lifForward( const torch::Tensor& vmem, const torch::Tensor& input, const torch::Tensor& vmemPostInitial, - const torch::Tensor& alpha, + const torch::Tensor& alpha, const torch::Tensor& membrSubtract, - const float theta, - const float thetaLow, + const torch::Tensor& theta, + const torch::Tensor& thetaLow, const bool applyThetaLow, const int maxNumSpikes) { @@ -29,6 +29,8 @@ void lifForward( CHECK_INPUT(vmemPostInitial); CHECK_INPUT(alpha); CHECK_INPUT(membrSubtract); + CHECK_INPUT(theta); + CHECK_INPUT(thetaLow); // check if tensors are on same device CHECK_DEVICE(input, vmem); @@ -36,6 +38,8 @@ void lifForward( CHECK_DEVICE(input, vmemPostInitial); CHECK_DEVICE(input, alpha); CHECK_DEVICE(input, membrSubtract); + CHECK_DEVICE(input, theta); + CHECK_DEVICE(input, thetaLow); // set the current cuda device to wherever the tensor input resides cudaSetDevice(input.device().index()); @@ -58,7 +62,9 @@ void lifForward( vmemPostInitial.data_ptr(), alpha.data_ptr(), membrSubtract.data_ptr(), - theta, thetaLow, applyThetaLow, maxNumSpikesU, nNeurons, nTimesteps); + theta.data_ptr(), + thetaLow.data_ptr(), + applyThetaLow, maxNumSpikesU, nNeurons, nTimesteps); return; } diff --git a/cuda/lif_kernels.h b/cuda/lif_kernels.h index 8000f64..1e3fe28 100644 --- a/cuda/lif_kernels.h +++ b/cuda/lif_kernels.h @@ -28,10 +28,10 @@ * @param vmemPostInitial 1D-tensor (nNeurons) with the initial membrane potentials (after reset) * @param alhpa 1D-tensor with decay factor of the neuron states (exp(-dt/tau)). * For IAF neurons set to 1. - * @param theta Firing threshold + * @param theta 1D-tensor of firing thresholds * @param membrSubtract 1D tensor with values that are subtracted from the membrane * potential when spiking - * @param thetaLow Lower bound to vmem + * @param thetaLow 1D-tensor of lower bounds to vmem * @param applyThetaLow Flag whether vmem is lower bounded * @param maxNumSpikes Maximum number of spikes a neuron can emit per time step * @param nNeurons Number of neurons/batches @@ -45,8 +45,8 @@ __global__ void lifForwardKernel( const scalarType* __restrict__ vmemPostInitial, const scalarType* __restrict__ alpha, const scalarType* __restrict__ membrSubtract, - float theta, - float thetaLow, + const scalarType* __restrict__ theta, + const scalarType* __restrict__ thetaLow, bool applyThetaLow, unsigned maxNumSpikes, unsigned nNeurons, @@ -73,13 +73,13 @@ __global__ void lifForwardKernel( vmemCurr += input[linearID]; // Apply lower threshold - if (applyThetaLow && (vmemCurr < thetaLow)){ - vmemCurr = thetaLow; + if (applyThetaLow && (vmemCurr < thetaLow[neuronID])){ + vmemCurr = thetaLow[neuronID]; } // Generate spikes - if(vmemCurr >= theta){ - activation = min(unsigned(vmemCurr / theta), maxNumSpikes); + if(vmemCurr >= theta[neuronID]){ + activation = min(unsigned(vmemCurr / theta[neuronID]), maxNumSpikes); } else { activation = 0; } @@ -271,8 +271,8 @@ __global__ void lifBackwardAlphaKernel( * For IAF neurons set to 1. * @param membrSubtract 1D-tensor of value that is subtracted from the membrane potential * when spiking - * @param theta Firing threshold - * @param thetaLow Lower bound to vmem + * @param theta 1D-tensor of firing thresholds + * @param thetaLow 1D-tensor of lower bounds to vmem * @param applyThetaLow Flag whether vmem is lower bounded * @param multipleSpikes Flag whether multiple spikes can be emitted in a single time step * @param nNeurons Number of neurons/batches @@ -286,8 +286,8 @@ void lifForwardCuda( const scalarType* vmemPostInitial, const scalarType* alpha, const scalarType* membrSubtract, - const float theta, - const float thetaLow, + const scalarType* theta, + const scalarType* thetaLow, const bool applyThetaLow, const unsigned maxNumSpikes, const unsigned nNeurons, @@ -302,14 +302,14 @@ void lifForwardCuda( vmem, input, vmemPostInitial, - alpha, - membrSubtract, - theta, - thetaLow, - applyThetaLow, - maxNumSpikes, - nNeurons, - nTimesteps); + alpha, + membrSubtract, + theta, + thetaLow, + applyThetaLow, + maxNumSpikes, + nNeurons, + nTimesteps); } diff --git a/setup.py b/setup.py index b07a88f..b2d892b 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,6 @@ def run(self): # Handle versions version = versioneer.get_version() -version_major = version.split(".")[0] # Install setup( @@ -63,6 +62,6 @@ def run(self): ) ], cmdclass=cmdclass, - install_requires=["torch", "sinabs"], + install_requires=["torch", f"sinabs >= 1.2.9"], ) diff --git a/sinabs/exodus/layers/iaf.py b/sinabs/exodus/layers/iaf.py index e54ae89..88f6303 100644 --- a/sinabs/exodus/layers/iaf.py +++ b/sinabs/exodus/layers/iaf.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch import numpy as np from sinabs.layers import SqueezeMixin @@ -54,12 +54,12 @@ class IAF(LIF): def __init__( self, - spike_threshold: float = 1.0, + spike_threshold: Optional[Union[float, torch.Tensor]] = 1.0, spike_fn: Callable = MultiSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), - tau_syn: Optional[float] = None, - min_v_mem: Optional[float] = None, + tau_syn: Optional[Union[float, torch.Tensor]] = None, + min_v_mem: Optional[Union[float, torch.Tensor]] = None, shape: Optional[torch.Size] = None, record_states: bool = False, decay_early: bool = True, diff --git a/sinabs/exodus/layers/lif.py b/sinabs/exodus/layers/lif.py index 79609da..365bcf3 100644 --- a/sinabs/exodus/layers/lif.py +++ b/sinabs/exodus/layers/lif.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Union +from typing import Callable, Optional, Tuple, Union import torch from sinabs.layers import SqueezeMixin @@ -16,6 +16,32 @@ __all__ = ["LIF", "LIFSqueeze"] +def expand_to_1d_contiguous( + value: Union[float, torch.Tensor], shape: Tuple[int] +) -> torch.Tensor: + """ + Expand tensor to tensor of given shape. + Then flatten and make contiguous. + + Not flattening immediately ensures that non-scalar tensors are + broadcast correctly to the required shape. + + Parameters + ---------- + value: float or torch + Tensor to be expanded + shape: tuple of ints + shape to expand to + + Returns + ------- + torch.Tensor + Contiguous 1D-tensor from expanded `value` + """ + + expanded_tensor = torch.as_tensor(value).expand(shape) + return expanded_tensor.float().flatten().contiguous() + class LIF(LIFSinabs): """ @@ -70,11 +96,11 @@ def __init__( self, tau_mem: Union[float, torch.Tensor], tau_syn: Optional[Union[float, torch.Tensor]] = None, - spike_threshold: float = 1.0, + spike_threshold: Optional[Union[float, torch.Tensor]] = 1.0, spike_fn: Callable = MultiSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), - min_v_mem: Optional[float] = None, + min_v_mem: Optional[Union[float, torch.Tensor]] = None, train_alphas: bool = False, shape: Optional[torch.Size] = None, norm_input: bool = True, @@ -158,7 +184,9 @@ def _prepare_input(self, input_data: torch.Tensor): def _forward_synaptic(self, input_2d: torch.Tensor): """Evolve synaptic dynamics""" - alpha_syn = self.alpha_syn_calculated.expand(self.v_mem.shape).flatten() + alpha_syn = expand_to_1d_contiguous( + self.alpha_syn_calculated, self.v_mem.shape + ) if self.decay_early: input_2d = input_2d * alpha_syn.unsqueeze(1) @@ -166,7 +194,7 @@ def _forward_synaptic(self, input_2d: torch.Tensor): # Apply exponential filter to input return LeakyIntegrator.apply( input_2d.contiguous(), # Input data - alpha_syn.contiguous(), # Synaptic alpha + alpha_syn, # Synaptic alpha self.i_syn.flatten().contiguous(), # Initial synaptic states ) @@ -174,8 +202,9 @@ def _forward_membrane(self, i_syn_2d: torch.Tensor): """Evolve membrane dynamics""" # Broadcast alpha to number of neurons (x batches) - alpha_mem = self.alpha_mem_calculated.expand(self.v_mem.shape) - alpha_mem = alpha_mem.flatten().contiguous() + alpha_mem = expand_to_1d_contiguous( + self.alpha_mem_calculated, self.v_mem.shape + ) if self.norm_input: # Rescale input with 1 - alpha (based on approximation that @@ -196,19 +225,35 @@ def _forward_membrane(self, i_syn_2d: torch.Tensor): return v_mem, v_mem + # Expand spike threshold + spike_threshold = expand_to_1d_contiguous( + self.spike_threshold, self.v_mem.shape + ) + + # Expand min_v_mem + if self.min_v_mem is None: + min_v_mem = None + else: + min_v_mem = expand_to_1d_contiguous( + self.min_v_mem, self.v_mem.shape + ) + # Expand membrane subtract membrane_subtract = self.reset_fn.subtract_value if membrane_subtract is None: - membrane_subtract = self.spike_threshold - membrane_subtract = torch.full_like(alpha_mem, membrane_subtract) + membrane_subtract = spike_threshold + else: + membrane_subtract = expand_to_1d_contiguous( + membrane_subtract, self.v_mem.shape + ) output_2d, v_mem_2d = IntegrateAndFire.apply( i_syn_2d.contiguous(), # Input data alpha_mem, # Alphas self.v_mem.flatten().contiguous(), # Initial vmem - self.spike_threshold, # Spike threshold + spike_threshold, # Spike threshold membrane_subtract, # Membrane subtract - self.min_v_mem, # Lower bound on vmem + min_v_mem, # Lower bound on vmem self.surrogate_grad_fn, # Surrogate gradient self.max_num_spikes_per_bin, # Max. number of spikes per bin ) diff --git a/sinabs/exodus/spike.py b/sinabs/exodus/spike.py index a0d19ae..a73d3c4 100644 --- a/sinabs/exodus/spike.py +++ b/sinabs/exodus/spike.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch import exodus_cuda @@ -52,7 +52,7 @@ def forward( raise ValueError("'v_mem' has to be contiguous.") if not v_mem.ndim == 2: raise ValueError("'v_mem' must be 2D, (N, Time)") - if min_v_mem is not None and (threshold <= min_v_mem): + if min_v_mem is not None and (threshold <= min_v_mem).any(): raise ValueError("`threshold` must be greater than `min_v_mem`.") spikes = exodus_cuda.spikeForward( @@ -111,9 +111,9 @@ def forward( inp: torch.tensor, alpha: torch.tensor, v_mem_init: torch.tensor, - threshold: float, + threshold: torch.tensor, membrane_subtract: torch.tensor, - min_v_mem: float, + min_v_mem: Union[torch.tensor, None], surrogate_grad_fn: Callable, max_num_spikes_per_bin: Optional[int] = None, ): @@ -134,12 +134,12 @@ def forward( activations : torch.tensor 1D, shape (N,). Activations from previous time step. Has to be contiguous. - threshold: float - Firing threshold + threshold: torch.tensor + 1D, shape (N,). Firing thresholds membrane_subtract: torch.Tensor 1D, shape (N,). Value that is subracted from membrane potential after spike - min_v_mem: float - Lower limit for v_mem + min_v_mem: torch.Tensor or None + 1D, shape (N,). Lower limits for v_mem. If 'None', don't apply limits surrogate_grad_fn: Callable Calculates surrogate gradients as function of v_mem max_num_spikes_per_bin: int @@ -156,6 +156,9 @@ def forward( if membrane_subtract is None: membrane_subtract = torch.ones_like(alpha) * threshold + if not (apply_min_v_mem := (min_v_mem is not None)): + # Pass some empty tensor to match CUDA function signature + min_v_mem = torch.empty_like(threshold) if not inp.ndim == 2: raise ValueError("'inp' must be 2D, (N, Time)") @@ -173,11 +176,19 @@ def forward( raise ValueError("'v_mem_init' must be 1D, (N,)") if not v_mem_init.is_contiguous(): raise ValueError("'v_mem_init' has to be contiguous.") - - if min_v_mem is not None and threshold <= min_v_mem: - raise ValueError("`threshold` must be greater than `min_v_mem`.") + if not threshold.ndim == 1: + raise ValueError("'threshold' must be 1D, (N,)") + if not threshold.is_contiguous(): + raise ValueError("'threshold' has to be contiguous.") if (alpha < 0).any() or (alpha > 1).any(): raise ValueError("'alpha' must be between 0 and 1.") + if apply_min_v_mem: + if not min_v_mem.ndim == 1: + raise ValueError("'min_v_mem' must be 1D, (N,)") + if not min_v_mem.is_contiguous(): + raise ValueError("'min_v_mem' has to be contiguous.") + if (threshold <= min_v_mem).any(): + raise ValueError("`threshold` must be greater than `min_v_mem`.") v_mem = torch.empty_like(inp).contiguous() output_spikes = torch.empty_like(inp).contiguous() @@ -190,20 +201,21 @@ def forward( alpha, membrane_subtract, threshold, - min_v_mem if min_v_mem is not None else 0, - min_v_mem is not None, + min_v_mem, + apply_min_v_mem, -1 if max_num_spikes_per_bin is None else max_num_spikes_per_bin, ) - ctx.threshold = threshold - ctx.min_v_mem = min_v_mem ctx.surrogate_grad_fn = surrogate_grad_fn + ctx.apply_min_v_mem = apply_min_v_mem # vmem is stored before reset (to calculate surrogate gradients in backward) # however, vmem_initial should already have reset applied if alpha.requires_grad: - ctx.save_for_backward(output_spikes, v_mem, v_mem_init, alpha, membrane_subtract) + ctx.save_for_backward( + output_spikes, v_mem, v_mem_init, alpha, membrane_subtract, threshold, min_v_mem + ) else: - ctx.save_for_backward(v_mem, alpha, membrane_subtract) + ctx.save_for_backward(v_mem, alpha, membrane_subtract, threshold, min_v_mem) ctx.get_alpha_grads = alpha.requires_grad return output_spikes, v_mem @@ -217,18 +229,18 @@ def backward(ctx, grad_output, grad_v_mem): # ) if ctx.get_alpha_grads: - (output_spikes, v_mem, v_mem_init, alpha, membrane_subtract) = ctx.saved_tensors + (output_spikes, v_mem, v_mem_init, alpha, membrane_subtract, threshold, min_v_mem) = ctx.saved_tensors else: - (v_mem, alpha, membrane_subtract) = ctx.saved_tensors + (v_mem, alpha, membrane_subtract, threshold, min_v_mem) = ctx.saved_tensors # Surrogate gradients - surrogates = ctx.surrogate_grad_fn(v_mem, ctx.threshold) + surrogates = ctx.surrogate_grad_fn(v_mem, threshold.unsqueeze(1)) # Gradient becomes 0 where v_mem is clipped to lower threshold - if ctx.min_v_mem is None: - not_clipped = torch.ones_like(surrogates) + if ctx.apply_min_v_mem: + not_clipped = (v_mem > min_v_mem.unsqueeze(1)).float() else: - not_clipped = (v_mem > ctx.min_v_mem).float() + not_clipped = torch.ones_like(surrogates) # Gradient wrt. input # Scaling membrane_subtract with alpha compensates for different execution order diff --git a/tests/test_spike_functions.py b/tests/test_spike_functions.py index 256ac54..bdba070 100644 --- a/tests/test_spike_functions.py +++ b/tests/test_spike_functions.py @@ -1,66 +1,76 @@ import pytest +from itertools import product import torch from sinabs.exodus.spike import IntegrateAndFire from sinabs import activation as sa from sinabs.layers.functional.lif import lif_forward -def test_integratefire(): - inp = torch.rand((2, 10), requires_grad=True, device="cuda") - v_mem_initial = torch.zeros(2, device="cuda") - surrogate_gradient_fn = sa.Heaviside(0) - - thr = 1.0 - alpha = torch.as_tensor(0.9).expand(v_mem_initial.shape).contiguous().cuda() - membrane_subtract = torch.as_tensor(thr).expand(v_mem_initial.shape) - - out, v_mem = IntegrateAndFire.apply( - inp, - alpha, - v_mem_initial, - thr, - membrane_subtract.contiguous().float().cuda(), - -thr, - surrogate_gradient_fn, - ) - - out.sum().backward() +v_mem_initials = (torch.zeros(2).cuda(), torch.rand(2).cuda() - 0.5) +alphas = (torch.ones(2).cuda() * 0.9, torch.rand(2).cuda()) +thresholds = (torch.ones(2).cuda(), torch.tensor([0.3, 0.9]).cuda()) +min_v_mem = (None, -torch.ones(2).cuda(), torch.tensor([-0.3, -0.4]).cuda(), torch.tensor([1.2, 0]).cuda()) +membrane_subtract = (None, torch.tensor([0.1, 0.2]).cuda()) +argvals = (v_mem_initials, alphas, thresholds, min_v_mem, membrane_subtract) +combined_args = product(*argvals) +argnames = "v_mem_initial,alpha,threshold,min_v_mem,membrane_subtract" -def test_integratefire_backprop_vmem(): +@pytest.mark.parametrize(argnames, combined_args) +def test_integratefire( + v_mem_initial, alpha, threshold, min_v_mem, membrane_subtract +): inp = torch.rand((2, 10), requires_grad=True, device="cuda") - v_mem_initial = torch.zeros(2, device="cuda") surrogate_gradient_fn = sa.Heaviside(0) - thr = 1 - alpha = torch.as_tensor(0.9).expand(v_mem_initial.shape).contiguous().cuda() - membrane_subtract = torch.as_tensor(thr).expand(v_mem_initial.shape) - - out, v_mem = IntegrateAndFire.apply( - inp, - alpha, - v_mem_initial, - thr, - membrane_subtract.contiguous().float().cuda(), - -thr, - surrogate_gradient_fn, - ) - - -args = ("spikes", "vmem", "sum") + if membrane_subtract is None: + membrane_subtract = threshold + def apply(): + return IntegrateAndFire.apply( + inp, + alpha, + v_mem_initial, + threshold, + membrane_subtract, + min_v_mem, + surrogate_gradient_fn, + ) -@pytest.mark.parametrize("backward_var", args) -def test_compare_integratefire(backward_var): + if min_v_mem is not None and (min_v_mem >= threshold).any(): + with pytest.raises(ValueError): + apply() + else: + # Test forward pass and backpropagation through output + out, v_mem = apply() + out.sum().backward() + + # Test forward pass and backpropagation through v_mem + out, v_mem = apply() + v_mem.sum().backward() + + +backward_varnames = ("spikes", "vmem", "sum") +argvals_ext = ( + v_mem_initials, + alphas, + thresholds, + min_v_mem[:-1], # Avoid min_v_mem > thr error + membrane_subtract, + backward_varnames, +) +combined_args_ext = product(*argvals_ext) +@pytest.mark.parametrize(argnames + ",backward_var", combined_args_ext) +def test_compare_integratefire( + v_mem_initial, alpha, threshold, min_v_mem, membrane_subtract, backward_var +): torch.manual_seed(1) time_steps = 100 batchsize = 10 - n_neurons = 8 + n_neurons = 2 num_epochs = 3 - thr = 1 - min_v_mem = None # -1 surrogate_grad_fn = sa.PeriodicExponential() max_num_spikes_per_bin = 2 @@ -70,34 +80,34 @@ def test_compare_integratefire(backward_var): requires_grad=True, device="cuda", ) - v_mem_init_sinabs = torch.rand( - batchsize, n_neurons, requires_grad=True, device="cuda" - ) - alpha_sinabs = torch.rand(n_neurons, requires_grad=True, device="cuda") + v_mem_init_sinabs = v_mem_initial.clone().requires_grad_(True) + alpha_sinabs = alpha.clone().requires_grad_(True) # Copy data without connecting gradients input_exodus = input_sinabs.clone().detach().requires_grad_(True) - v_mem_init_exodus = v_mem_init_sinabs.clone().detach().requires_grad_(True) - alpha_exodus = alpha_sinabs.clone().detach().requires_grad_(True) + v_mem_init_exodus = v_mem_initial.clone().requires_grad_(True) + alpha_exodus = alpha.clone().requires_grad_(True) out_exodus, vmem_exodus = evolve_exodus( data=input_exodus, alpha=alpha_exodus, v_mem_init=v_mem_init_exodus, - threshold=thr, + threshold=threshold, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_grad_fn, max_num_spikes_per_bin=max_num_spikes_per_bin, + membrane_subtract=membrane_subtract, ) out_sinabs, vmem_sinabs = evolve_sinabs( data=input_sinabs, alpha=alpha_sinabs, v_mem_init=v_mem_init_sinabs, - threshold=thr, + threshold=threshold, min_v_mem=min_v_mem, surrogate_grad_fn=surrogate_grad_fn, max_num_spikes_per_bin=max_num_spikes_per_bin, + membrane_subtract=membrane_subtract, ) assert torch.allclose(out_exodus, out_sinabs) @@ -132,15 +142,26 @@ def evolve_exodus( data: torch.tensor, alpha: torch.tensor, v_mem_init: torch.tensor, - threshold: float, - min_v_mem: float, + threshold: torch.tensor, + min_v_mem: torch.tensor, surrogate_grad_fn, max_num_spikes_per_bin=None, + membrane_subtract=None, ): - alpha = alpha.expand(v_mem_init.shape).flatten().contiguous() - v_mem_init = v_mem_init.flatten().contiguous() - membrane_subtract = torch.full_like(alpha, threshold) + # This normally happens inside the Exodus LIF layer, + # which is being circumvented here batchsize, timesteps, *trailing_dim = data.shape[1:] + expanded_shape = (batchsize, *trailing_dim) + alpha = alpha.expand(expanded_shape).flatten().contiguous() + v_mem_init = v_mem_init.expand(expanded_shape).flatten().contiguous() + threshold = threshold.expand(expanded_shape).flatten().contiguous() + if min_v_mem is not None: + min_v_mem = min_v_mem.expand(expanded_shape).flatten().contiguous() + + if membrane_subtract is None: + membrane_subtract = threshold + else: + membrane_subtract = membrane_subtract.expand(expanded_shape).flatten().contiguous() for inp in data: inp = inp.movedim(1, -1).reshape(-1, timesteps) @@ -172,6 +193,7 @@ def evolve_sinabs( min_v_mem: float, surrogate_grad_fn, max_num_spikes_per_bin=None, + membrane_subtract=None, ): if max_num_spikes_per_bin is not None: spike_fn = sa.MaxSpike(max_num_spikes_per_bin) @@ -190,7 +212,7 @@ def evolve_sinabs( state=state, spike_threshold=threshold, spike_fn=spike_fn, - reset_fn=sa.MembraneSubtract(), + reset_fn=sa.MembraneSubtract(subtract_value=membrane_subtract), surrogate_grad_fn=surrogate_grad_fn, min_v_mem=min_v_mem, norm_input=False,