-
Notifications
You must be signed in to change notification settings - Fork 1
/
base.py
144 lines (117 loc) · 4.98 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from enum import Enum, auto
import torch
import warnings
class TensorLayout(Enum):
Conv = auto()
FC = auto()
class DataType(Enum):
Spike = auto()
Dense = auto()
class MetaTensor:
def __init__(self, tensor: torch.Tensor, tensor_layout: TensorLayout, data_type: DataType):
self._data = tensor
self._tensor_layout = tensor_layout
self._data_type = data_type
if self.hasFCLayout():
assert tensor.shape[2:4] == (1, 1)
def getTensor(self):
return self._data.detach()
def getMeanNumSpikesPerNeuron(self):
"""
:return: For each sample in the batch, the average number of spikes per neuron in the given time.
Returns None if data does not contain spikes.
"""
if not self.isSpikeType():
warnings.warn("Data does not contain spikes. Returning None.", Warning)
return None
if self.hasConvLayout():
return torch.mean(self._data.sum(4), dim=(1, 2, 3)).detach()
assert self.hasFCLayout()
return torch.mean(self._data.squeeze().sum(2), dim=1).detach()
def getSpikeCounts(self):
"""
:return: For each sample in the batch, return the number of spikes, number of neurons and number of timesteps.
Returns None if data does not contain spikes.
"""
output = dict()
data = self._data.detach()
if not self.isSpikeType():
warnings.warn("Data does not contain spikes. Returning None.", Warning)
return None
if self.hasConvLayout():
num_spikes = torch.sum(data, dim=(1, 2, 3, 4))
num_neurons = data.shape[1] * data.shape[2] * data.shape[3] * data.shape[4]
num_steps = data.shape[4]
else:
assert self.hasFCLayout()
data = data.squeeze() # (batch, neurons, time)
num_spikes = torch.sum(data, dim=(1, 2)).detach()
num_neurons = data.shape[1]
num_steps = data.shape[2]
output['num_spikes'] = num_spikes
output['num_neurons'] = num_neurons
output['num_steps'] = num_steps
num_spikes_per_neuron = torch.sum(data, dim=-1).long()
is_spiking = num_spikes_per_neuron >= 1
num_steps_per_spike_per_neuron = num_steps / num_spikes_per_neuron[is_spiking].float()
fraction_spiking = torch.sum(is_spiking).float() / num_spikes_per_neuron.numel()
# steps/(spikes/neuron) as flattened vector over the whole batch:
output['fraction_spiking'] = fraction_spiking.item()
# spikes/neuron as flattened vector over the whole batch (only includes spiking neurons):
output['spikes_per_neuron'] = num_spikes_per_neuron[is_spiking]
# steps/(spikes/neuron) as flattened vector over the whole batch (only includes spiking neurons):
output['steps_in_batch'] = num_steps_per_spike_per_neuron
return output
def isSpikeType(self):
return self._data_type == DataType.Spike
def isDenseType(self):
return self._data_type == DataType.Dense
def hasConvLayout(self):
return self._tensor_layout == TensorLayout.Conv
def hasFCLayout(self):
return self._tensor_layout == TensorLayout.FC
class SpikeModule(torch.nn.Module):
_input_key = 'input'
_output_key = 'output'
def __init__(self):
super().__init__()
self._data = dict()
def getMetaTensorDict(self):
return self._data
def addMetaTensor(self, key: str, value: MetaTensor):
assert not key == self._output_key, 'Use addOutputMetaTensor function instead'
self._data[key] = value
def addInputMetaTensor(self, value: MetaTensor):
assert value.isSpikeType()
assert value.hasConvLayout(), 'Does not have to be but is reasonable for the moment'
self._data[self._input_key] = value
def addOutputMetaTensor(self, value: MetaTensor):
assert value.isDenseType()
assert value.hasFCLayout(), 'Does not have to be but is reasonable for the moment'
self._data[self._output_key] = value
def getNeuronConfig(type: str='SRMALPHA',
theta: float=10.,
tauSr: float=1.,
tauRef: float=1.,
scaleRef: float=2.,
tauRho: float=0.3, # Was set to 0.2 previously (e.g. for fullRes run)
scaleRho: float=1.):
"""
:param type: neuron type
:param theta: neuron threshold
:param tauSr: neuron time constant
:param tauRef: neuron refractory time constant
:param scaleRef: neuron refractory response scaling (relative to theta)
:param tauRho: spike function derivative time constant (relative to theta)
:param scaleRho: spike function derivative scale factor
:return: dictionary
"""
return {
'type': type,
'theta': theta,
'tauSr': tauSr,
'tauRef': tauRef,
'scaleRef': scaleRef,
'tauRho': tauRho,
'scaleRho': scaleRho,
}