Skip to content

Commit

Permalink
Merge branch 'develop' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Felix Bauer committed Mar 14, 2024
2 parents b367535 + 6d18542 commit 2618b84
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 121 deletions.
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
CHANGES
=======

v1.2.1
------
* Support Sinabs 2.0
* Support individual time constants

v1.1.2
------
* Update README
Expand Down
14 changes: 10 additions & 4 deletions cuda/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ void lifForward(
const torch::Tensor& vmem,
const torch::Tensor& input,
const torch::Tensor& vmemPostInitial,
const torch::Tensor& alpha,
const torch::Tensor& alpha,
const torch::Tensor& membrSubtract,
const float theta,
const float thetaLow,
const torch::Tensor& theta,
const torch::Tensor& thetaLow,
const bool applyThetaLow,
const int maxNumSpikes)
{
Expand All @@ -29,13 +29,17 @@ void lifForward(
CHECK_INPUT(vmemPostInitial);
CHECK_INPUT(alpha);
CHECK_INPUT(membrSubtract);
CHECK_INPUT(theta);
CHECK_INPUT(thetaLow);

// check if tensors are on same device
CHECK_DEVICE(input, vmem);
CHECK_DEVICE(input, outputSpikes);
CHECK_DEVICE(input, vmemPostInitial);
CHECK_DEVICE(input, alpha);
CHECK_DEVICE(input, membrSubtract);
CHECK_DEVICE(input, theta);
CHECK_DEVICE(input, thetaLow);

// set the current cuda device to wherever the tensor input resides
cudaSetDevice(input.device().index());
Expand All @@ -58,7 +62,9 @@ void lifForward(
vmemPostInitial.data_ptr<float>(),
alpha.data_ptr<float>(),
membrSubtract.data_ptr<float>(),
theta, thetaLow, applyThetaLow, maxNumSpikesU, nNeurons, nTimesteps);
theta.data_ptr<float>(),
thetaLow.data_ptr<float>(),
applyThetaLow, maxNumSpikesU, nNeurons, nTimesteps);

return;
}
Expand Down
40 changes: 20 additions & 20 deletions cuda/lif_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
* @param vmemPostInitial 1D-tensor (nNeurons) with the initial membrane potentials (after reset)
* @param alhpa 1D-tensor with decay factor of the neuron states (exp(-dt/tau)).
* For IAF neurons set to 1.
* @param theta Firing threshold
* @param theta 1D-tensor of firing thresholds
* @param membrSubtract 1D tensor with values that are subtracted from the membrane
* potential when spiking
* @param thetaLow Lower bound to vmem
* @param thetaLow 1D-tensor of lower bounds to vmem
* @param applyThetaLow Flag whether vmem is lower bounded
* @param maxNumSpikes Maximum number of spikes a neuron can emit per time step
* @param nNeurons Number of neurons/batches
Expand All @@ -45,8 +45,8 @@ __global__ void lifForwardKernel(
const scalarType* __restrict__ vmemPostInitial,
const scalarType* __restrict__ alpha,
const scalarType* __restrict__ membrSubtract,
float theta,
float thetaLow,
const scalarType* __restrict__ theta,
const scalarType* __restrict__ thetaLow,
bool applyThetaLow,
unsigned maxNumSpikes,
unsigned nNeurons,
Expand All @@ -73,13 +73,13 @@ __global__ void lifForwardKernel(
vmemCurr += input[linearID];

// Apply lower threshold
if (applyThetaLow && (vmemCurr < thetaLow)){
vmemCurr = thetaLow;
if (applyThetaLow && (vmemCurr < thetaLow[neuronID])){
vmemCurr = thetaLow[neuronID];
}

// Generate spikes
if(vmemCurr >= theta){
activation = min(unsigned(vmemCurr / theta), maxNumSpikes);
if(vmemCurr >= theta[neuronID]){
activation = min(unsigned(vmemCurr / theta[neuronID]), maxNumSpikes);
} else {
activation = 0;
}
Expand Down Expand Up @@ -271,8 +271,8 @@ __global__ void lifBackwardAlphaKernel(
* For IAF neurons set to 1.
* @param membrSubtract 1D-tensor of value that is subtracted from the membrane potential
* when spiking
* @param theta Firing threshold
* @param thetaLow Lower bound to vmem
* @param theta 1D-tensor of firing thresholds
* @param thetaLow 1D-tensor of lower bounds to vmem
* @param applyThetaLow Flag whether vmem is lower bounded
* @param multipleSpikes Flag whether multiple spikes can be emitted in a single time step
* @param nNeurons Number of neurons/batches
Expand All @@ -286,8 +286,8 @@ void lifForwardCuda(
const scalarType* vmemPostInitial,
const scalarType* alpha,
const scalarType* membrSubtract,
const float theta,
const float thetaLow,
const scalarType* theta,
const scalarType* thetaLow,
const bool applyThetaLow,
const unsigned maxNumSpikes,
const unsigned nNeurons,
Expand All @@ -302,14 +302,14 @@ void lifForwardCuda(
vmem,
input,
vmemPostInitial,
alpha,
membrSubtract,
theta,
thetaLow,
applyThetaLow,
maxNumSpikes,
nNeurons,
nTimesteps);
alpha,
membrSubtract,
theta,
thetaLow,
applyThetaLow,
maxNumSpikes,
nNeurons,
nTimesteps);
}


Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def run(self):

# Handle versions
version = versioneer.get_version()
version_major = version.split(".")[0]

# Install
setup(
Expand All @@ -63,6 +62,6 @@ def run(self):
)
],
cmdclass=cmdclass,
install_requires=["torch", "sinabs"],
install_requires=["torch", f"sinabs >= 1.2.9"],
)

8 changes: 4 additions & 4 deletions sinabs/exodus/layers/iaf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable, Optional, Union
import torch
import numpy as np
from sinabs.layers import SqueezeMixin
Expand Down Expand Up @@ -54,12 +54,12 @@ class IAF(LIF):

def __init__(
self,
spike_threshold: float = 1.0,
spike_threshold: Optional[Union[float, torch.Tensor]] = 1.0,
spike_fn: Callable = MultiSpike,
reset_fn: Callable = MembraneSubtract(),
surrogate_grad_fn: Callable = SingleExponential(),
tau_syn: Optional[float] = None,
min_v_mem: Optional[float] = None,
tau_syn: Optional[Union[float, torch.Tensor]] = None,
min_v_mem: Optional[Union[float, torch.Tensor]] = None,
shape: Optional[torch.Size] = None,
record_states: bool = False,
decay_early: bool = True,
Expand Down
67 changes: 56 additions & 11 deletions sinabs/exodus/layers/lif.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import Callable, Optional, Tuple, Union

import torch
from sinabs.layers import SqueezeMixin
Expand All @@ -16,6 +16,32 @@

__all__ = ["LIF", "LIFSqueeze"]

def expand_to_1d_contiguous(
value: Union[float, torch.Tensor], shape: Tuple[int]
) -> torch.Tensor:
"""
Expand tensor to tensor of given shape.
Then flatten and make contiguous.
Not flattening immediately ensures that non-scalar tensors are
broadcast correctly to the required shape.
Parameters
----------
value: float or torch
Tensor to be expanded
shape: tuple of ints
shape to expand to
Returns
-------
torch.Tensor
Contiguous 1D-tensor from expanded `value`
"""

expanded_tensor = torch.as_tensor(value).expand(shape)
return expanded_tensor.float().flatten().contiguous()


class LIF(LIFSinabs):
"""
Expand Down Expand Up @@ -70,11 +96,11 @@ def __init__(
self,
tau_mem: Union[float, torch.Tensor],
tau_syn: Optional[Union[float, torch.Tensor]] = None,
spike_threshold: float = 1.0,
spike_threshold: Optional[Union[float, torch.Tensor]] = 1.0,
spike_fn: Callable = MultiSpike,
reset_fn: Callable = MembraneSubtract(),
surrogate_grad_fn: Callable = SingleExponential(),
min_v_mem: Optional[float] = None,
min_v_mem: Optional[Union[float, torch.Tensor]] = None,
train_alphas: bool = False,
shape: Optional[torch.Size] = None,
norm_input: bool = True,
Expand Down Expand Up @@ -158,24 +184,27 @@ def _prepare_input(self, input_data: torch.Tensor):
def _forward_synaptic(self, input_2d: torch.Tensor):
"""Evolve synaptic dynamics"""

alpha_syn = self.alpha_syn_calculated.expand(self.v_mem.shape).flatten()
alpha_syn = expand_to_1d_contiguous(
self.alpha_syn_calculated, self.v_mem.shape
)

if self.decay_early:
input_2d = input_2d * alpha_syn.unsqueeze(1)

# Apply exponential filter to input
return LeakyIntegrator.apply(
input_2d.contiguous(), # Input data
alpha_syn.contiguous(), # Synaptic alpha
alpha_syn, # Synaptic alpha
self.i_syn.flatten().contiguous(), # Initial synaptic states
)

def _forward_membrane(self, i_syn_2d: torch.Tensor):
"""Evolve membrane dynamics"""

# Broadcast alpha to number of neurons (x batches)
alpha_mem = self.alpha_mem_calculated.expand(self.v_mem.shape)
alpha_mem = alpha_mem.flatten().contiguous()
alpha_mem = expand_to_1d_contiguous(
self.alpha_mem_calculated, self.v_mem.shape
)

if self.norm_input:
# Rescale input with 1 - alpha (based on approximation that
Expand All @@ -196,19 +225,35 @@ def _forward_membrane(self, i_syn_2d: torch.Tensor):

return v_mem, v_mem

# Expand spike threshold
spike_threshold = expand_to_1d_contiguous(
self.spike_threshold, self.v_mem.shape
)

# Expand min_v_mem
if self.min_v_mem is None:
min_v_mem = None
else:
min_v_mem = expand_to_1d_contiguous(
self.min_v_mem, self.v_mem.shape
)

# Expand membrane subtract
membrane_subtract = self.reset_fn.subtract_value
if membrane_subtract is None:
membrane_subtract = self.spike_threshold
membrane_subtract = torch.full_like(alpha_mem, membrane_subtract)
membrane_subtract = spike_threshold
else:
membrane_subtract = expand_to_1d_contiguous(
membrane_subtract, self.v_mem.shape
)

output_2d, v_mem_2d = IntegrateAndFire.apply(
i_syn_2d.contiguous(), # Input data
alpha_mem, # Alphas
self.v_mem.flatten().contiguous(), # Initial vmem
self.spike_threshold, # Spike threshold
spike_threshold, # Spike threshold
membrane_subtract, # Membrane subtract
self.min_v_mem, # Lower bound on vmem
min_v_mem, # Lower bound on vmem
self.surrogate_grad_fn, # Surrogate gradient
self.max_num_spikes_per_bin, # Max. number of spikes per bin
)
Expand Down
Loading

0 comments on commit 2618b84

Please sign in to comment.