-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathspikeFRInder.py
207 lines (158 loc) · 6.63 KB
/
spikeFRInder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""
Functions for using frequency-domain FRI method for spike inference.
Authors: Benjamin Bejar, Gavin Mischler
"""
import numpy as np
from numpy.fft import fft, fftshift, ifft
from math import pi
from scipy.linalg import toeplitz, svd, pinv, inv
from scipy.optimize import golden
from scipy.signal import hamming
from scipy.ndimage import gaussian_filter1d
import matplotlib.pyplot as plt
from utils import F, Cadzow
# ============================================================================
# Calcium signal sliding window estimation
# ============================================================================
def sliding_window_predict(signal,
Fs,
K,
alpha_grid_ends=[7, 1.75],
window_lengths=[301, 601, 801, 1101],
jump_size=30,
OF=4,
smoothing_sigma=7.5):
"""
Use multiple sliding windows to estimate spikes. Returns dictionary
containing aggregated array of spike likelihoods, same length as
input signal.
Parameters
----------
signal : numpy 1-D array
Calcium signal (preprocessed).
Fs : int
Sampling frequency of the signal
K : int
Estimate of the number of spikes in the entire signal
alpha_grid_ends : list (length=2), default=[7, 1.75]
Endpoints of the grid within which to search for the time constant,
alpha, for decaying exponentials of the form exp(-t*alpha).
window_lengths : list, default=[301, 601, 801, 1101]
The window lengths to use. Each length is used with a sliding
window approach and the results are summed together.
jump_size : int, default=25
Jump size of the sliding window approach (number of samples)
OF : int, default=4
Oversampling factor, used to truncate the Fourier series.
smoothing_sigma: float, default=5
Sigma parameter for final smoothing of the joint histogram using
scipy.ndimage.gaussian_filter1d. If `None`, no smoothing is performed
and the joint histogram is returned.
Returns
-------
final_histogram : numpy 1-D array with same shape as signal
Joint histogram spiking estimate, same shape as input signal.
"""
likelihood_counts = np.zeros_like(signal)
for window_len in window_lengths:
# estimate K for this window
K = max(round(window_len / len(signal) * K), 1)
N = len(signal)
start_pt = 0
end_pt = window_len
while end_pt < N:
sig = signal[start_pt:np.clip(end_pt, 0, N-1)]
d = estimate_tk_ak(signal=sig, Fs=Fs, K=K, alpha_grid_ends=alpha_grid_ends, OF=OF)
tk_indices = d['tk_indices']
ak_hat = d['ak_hat']
relative_indices = tk_indices + start_pt
# add to likelihood counts
likelihood_counts[relative_indices] += np.sqrt(np.maximum(ak_hat, 0))
start_pt = start_pt + jump_size
end_pt = end_pt + jump_size
if smoothing_sigma is not None:
return gaussian_filter1d(likelihood_counts, sigma=smoothing_sigma)
return likelihood_counts
def estimate_tk_ak(signal, Fs, K, alpha_grid_ends, OF=4):
"""
Estimate the indices of most likely spikes in the signal using the FRI method.
Parameters
----------
signal : signal to process, shape=(N,)
Signal to estimate K spikes within.
Fs : float
Sampling frequency of signal
K : int
Number of spikes assumed to be in the signal
alpha_grid_ends : list (length=2), default=[7, 1.75]
Endpoints of the grid within which to search for the time constant,
alpha, for decaying exponentials of the form exp(-t*alpha).
OF : int, positive
oversampling factor to define the number of Fourier coefficients
to use from the signal, by L = (OF * K * 2) + 1
Returns
-------
d : Dictionary of outputs containing the following
'tk_indices' : array of indices corresponding to the predicted spike locations, shape=(K,)
'ak_hat' : estimated amplitudes of detected spikes, shape=(K,)
"""
# scale the alpha estimate to the length of the signal, since the model
# assumes the signal is 1 second long
N = len(signal)
N_full = N
seconds = N / Fs # duration of signal in seconds
alpha_grid_ends = [alpha * seconds for alpha in alpha_grid_ends]
# compute Fourier series
Zk_tilde = fft(signal)
if OF is None:
L = (N - 1) / 2
else:
L = OF * K * 2 + 1
N = 2 * L + 1
Zk_tilde = np.concatenate((Zk_tilde[:L+1], Zk_tilde[-L-1:-1]))
# grid search for candidate solutions
alpha_grid = np.linspace(alpha_grid_ends[0], alpha_grid_ends[1], 6)
# frequencies used for estimation
wk = 2*pi*np.arange(-L,L+1)
# error variable
e = np.zeros(alpha_grid.size)
for ii, alpha in enumerate(alpha_grid):
# estimated Fourier coefficients
Sk = fftshift(Zk_tilde) * (alpha + 1j*wk)
Sk = Cadzow(Sk,K,N)
# find annihilating filter
s = svd( toeplitz(Sk[K:],Sk[np.arange(K,-1,-1)]) )[1]
# error computation
e[ii] = s[-1]
# minindex = 0
minindx = np.argmin(e)
if minindx == 0 or minindx == len(e)-1:
bracket = (alpha_grid[0],alpha_grid[0])
else:
bracket = tuple(alpha_grid[minindx-1:minindx+2])
# refine estimate with golden search, using Cadzow
alpha_hat = golden(F,(Zk_tilde,wk,K,True),bracket)
alpha_hat = np.clip(alpha_hat,a_min=alpha_grid[0],a_max=alpha_grid[-1])
# estimated Fourier coefficients
Sk = fftshift(Zk_tilde) * (alpha_hat + 1j*wk)
Sk = Cadzow(Sk,K,N)
## Prony's method to estimate tk
# find annihilating filter
Qh = svd( toeplitz(Sk[K:],Sk[np.arange(K,-1,-1)]) )[2]
h = Qh[-1,:].conj()
h = h/h[0]
# estimate time locations from the roots of the polynomial
tk_hat = np.sort( np.mod( np.angle( np.roots( h[::-1] ) ) / 2.0 / pi, 1 ) ).reshape((K,1))
# estimate amplitudes by solving linear system
uk = np.exp( -1j * 2 * np.pi * tk_hat ).reshape((K,))
V = np.flipud(np.vander(uk, len(Sk)//2+1).transpose())[1:,:]
if len(Sk)%2==0:
Z = np.concatenate((np.flipud(V.conj()),np.ones((1,K)),V[:-1,:]),axis=0)
else:
Z = np.concatenate((np.flipud(V.conj()),np.ones((1,K)),V),axis=0)
# least-squares estimate
ak_hat = np.real(np.dot(pinv(Z),Sk)).reshape((K,))
tk_hat_indices = (tk_hat * N_full).astype(int).squeeze()
d = {'tk_indices': tk_hat_indices,
'ak_hat' : ak_hat}
return d