Skip to content

Commit

Permalink
Performance and API enhancement for STTC (NeuralEnsemble#244)
Browse files Browse the repository at this point in the history
* Speedup for STTC
* Implemented spike search in other sptr with numpy search

Memory complexity: O(max(n1, n2)) (exactly max(n1, n2) is used
additionally at peak)
Runtime complexity: O(max(n1, n2)*log(max(n1, n2))), because search
happens n1 times with log(n2) for binary search (or n2 times with
log(n1))
  • Loading branch information
muellerbjoern authored and dizcza committed Oct 16, 2019
1 parent a833a96 commit ce9e7ac
Showing 1 changed file with 50 additions and 33 deletions.
83 changes: 50 additions & 33 deletions elephant/spike_train_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,51 +650,68 @@ def spike_time_tiling_coefficient(spiketrain_1, spiketrain_2, dt=0.005 * pq.s):
Study of Retinal Waves. Journal of Neuroscience, 34(43), 14288–14303.
"""

def run_P(spiketrain_1, spiketrain_2, N1, N2, dt):
def run_P(spiketrain_1, spiketrain_2):
"""
Check every spike in train 1 to see if there's a spike in train 2
within dt
"""
Nab = 0
j = 0
for i in range(N1):
while j < N2: # don't need to search all j each iteration
if np.abs(spiketrain_1[i] - spiketrain_2[j]) <= dt:
Nab = Nab + 1
break
elif spiketrain_2[j] > spiketrain_1[i]:
break
else:
j = j + 1
return Nab

def run_T(spiketrain, N, dt):
N2 = len(spiketrain_2)

# Search spikes of spiketrain_1 in spiketrain_2
# ind will contain index of
ind = np.searchsorted(spiketrain_2.times, spiketrain_1.times)

# To prevent IndexErrors
# If a spike of spiketrain_1 is after the last spike of spiketrain_2,
# the index is N2, however spiketrain_2[N2] raises an IndexError.
# By shifting this index, the spike of spiketrain_1 will be compared
# to the last 2 spikes of spiketrain_2 (negligible overhead).
# Note: Not necessary for index 0 that will be shifted to -1,
# because spiketrain_2[-1] is valid (additional negligible comparison)
ind[ind == N2] = N2 - 1

# Compare to nearest spike in spiketrain_2 BEFORE spike in spiketrain_1
close_left = np.abs(
spiketrain_2.times[ind - 1] - spiketrain_1.times) <= dt
# Compare to nearest spike in spiketrain_2 AFTER (or simultaneous)
# spike in spiketrain_2
close_right = np.abs(
spiketrain_2.times[ind] - spiketrain_1.times) <= dt

# spiketrain_2 spikes that are in [-dt, dt] range of spiketrain_1
# spikes are counted only ONCE (as per original implementation)
close = close_left + close_right

# Count how many spikes in spiketrain_1 have a "partner" in
# spiketrain_2
return np.count_nonzero(close)

def run_T(spiketrain):
"""
Calculate the proportion of the total recording time 'tiled' by spikes.
"""
N = len(spiketrain)
time_A = 2 * N * dt # maximum possible time

if N == 1: # for just one spike in train
if spiketrain[0] - spiketrain.t_start < dt:
time_A = time_A - dt + spiketrain[0] - spiketrain.t_start
time_A += -dt + spiketrain[0] - spiketrain.t_start
if spiketrain[0] + dt > spiketrain.t_stop:
time_A = time_A - dt - spiketrain[0] + spiketrain.t_stop

time_A += -dt - spiketrain[0] + spiketrain.t_stop
else: # if more than one spike in train
i = 0
while i < (N - 1):
diff = spiketrain[i + 1] - spiketrain[i]

if diff < (2 * dt): # subtract overlap
time_A = time_A - 2 * dt + diff
i += 1
# check if spikes are within dt of the start and/or end
# if so subtract overlap of first and/or last spike
# Vectorized loop of spike time differences
diff = np.diff(spiketrain)
diff_overlap = diff[diff < 2 * dt]
# Subtract overlap
time_A += -2 * dt * len(diff_overlap) + np.sum(diff_overlap)

# check if spikes are within dt of the start and/or end
# if so subtract overlap of first and/or last spike
if (spiketrain[0] - spiketrain.t_start) < dt:
time_A = time_A + spiketrain[0] - dt - spiketrain.t_start
time_A += spiketrain[0] - dt - spiketrain.t_start

if (spiketrain.t_stop - spiketrain[N - 1]) < dt:
time_A = time_A - spiketrain[-1] - dt + spiketrain.t_stop
time_A += -spiketrain[-1] - dt + spiketrain.t_stop

T = time_A / (spiketrain.t_stop - spiketrain.t_start)
return T.simplified.item() # enforce simplification, strip units
Expand All @@ -705,11 +722,11 @@ def run_T(spiketrain, N, dt):
if N1 == 0 or N2 == 0:
index = np.nan
else:
TA = run_T(spiketrain_1, N1, dt)
TB = run_T(spiketrain_2, N2, dt)
PA = run_P(spiketrain_1, spiketrain_2, N1, N2, dt)
TA = run_T(spiketrain_1)
TB = run_T(spiketrain_2)
PA = run_P(spiketrain_1, spiketrain_2)
PA = PA / N1
PB = run_P(spiketrain_2, spiketrain_1, N2, N1, dt)
PB = run_P(spiketrain_2, spiketrain_1)
PB = PB / N2
# check if the P and T values are 1 to avoid division by zero
# This only happens for TA = PB = 1 and/or TB = PA = 1,
Expand Down

0 comments on commit ce9e7ac

Please sign in to comment.