Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement of Spike Time Tiling Coefficient (STTC) #438

Merged
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions elephant/spike_train_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,26 +926,34 @@ def run_T(spiketrain):
N = len(spiketrain)
time_A = 2 * N * dt # maximum possible time

if N == 1: # for just one spike in train
if N == 1: # for only a single spike in the train

# Check difference between start of recording and single spike
if spiketrain[0] - spiketrain.t_start < dt:
time_A += -dt + spiketrain[0] - spiketrain.t_start
if spiketrain[0] + dt > spiketrain.t_stop:
time_A += -dt - spiketrain[0] + spiketrain.t_stop
else: # if more than one spike in train
# Vectorized loop of spike time differences
time_A += spiketrain[0] - spiketrain.t_start
Moritz-Alexander-Kern marked this conversation as resolved.
Show resolved Hide resolved

# Check difference between single spike and end of recording
elif spiketrain[0] + dt > spiketrain.t_stop:
time_A += - dt - spiketrain[0] + spiketrain.t_stop

else: # if more than a single spike in the train

# Calculate difference between consecutive spikes
diff = np.diff(spiketrain)
diff_overlap = diff[diff < 2 * dt]
# Subtract overlap
time_A += -2 * dt * len(diff_overlap) + np.sum(diff_overlap)

# check if spikes are within dt of the start and/or end
# if so subtract overlap of first and/or last spike
# Find spikes whose tiles overlap
idx = np.where(diff < 2 * dt)[0]
# Subtract overlapping "2*dt" tiles and add differences instead
time_A += - 2 * dt * len(idx) + diff[idx].sum()

# Check if spikes are within +/-dt of the start and/or end
# if so, subtract overlap of first and/or last spike
if (spiketrain[0] - spiketrain.t_start) < dt:
time_A += spiketrain[0] - dt - spiketrain.t_start

if (spiketrain.t_stop - spiketrain[N - 1]) < dt:
time_A += -spiketrain[-1] - dt + spiketrain.t_stop
time_A += - spiketrain[-1] - dt + spiketrain.t_stop

# Calculate the proportion of total recorded time to "tiled" time
T = time_A / (spiketrain.t_stop - spiketrain.t_start)
return T.simplified.item() # enforce simplification, strip units

Expand Down Expand Up @@ -1066,4 +1074,4 @@ def spike_train_timescale(binned_spiketrain, max_tau):
# Calculate the timescale using trapezoidal integration
integr = (corrfct / corrfct[0]) ** 2
timescale = 2 * integrate.trapz(integr, dx=bin_size)
return pq.Quantity(timescale, units=binned_spiketrain.units, copy=False)
return pq.Quantity(timescale, units=binned_spiketrain.units, copy=False)