Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Synchrofact Detection #322

Merged
merged 65 commits into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
4ab6ff8
Add first version of synchrofact detection
Kleinjohann May 11, 2020
7b85e3d
Implement TODOs, adapt tests accordingly
Kleinjohann May 15, 2020
da26ace
Add tests for raised errors, cleanup
Kleinjohann May 15, 2020
68b2ae5
Change spread to spread - 1
Kleinjohann May 15, 2020
3b822d8
remove unused imports
morales-gregorio May 18, 2020
c0891f4
add check for list instance in input and rewrite a bit error messages
morales-gregorio May 18, 2020
f0c8fa3
refactor docstring for synchrofact detection function. Reorder kwargs…
morales-gregorio May 18, 2020
771632c
refactor docstring for complexity function, rename function to shorte…
morales-gregorio May 18, 2020
bd531b8
relocate line of code that came too early
morales-gregorio May 18, 2020
035a29f
remove TODO flags
morales-gregorio May 18, 2020
c750386
add the module to the documentation
morales-gregorio May 18, 2020
bf5ca69
make helper function private
morales-gregorio May 18, 2020
d910dc9
refactor documentation for nice formatting after build
morales-gregorio May 18, 2020
7b769a7
Refactor and cleanup
Kleinjohann May 20, 2020
2fd59eb
Add wrapper to calculate the complexity histogram
Kleinjohann May 20, 2020
87919a5
move check consistency of spiketrain list to utils
morales-gregorio May 20, 2020
386682b
refactor docs
morales-gregorio May 21, 2020
791b30a
update checking of spiketrainlist in detect_synchrofacts
morales-gregorio May 21, 2020
8b69e2d
add test for utils spiketrainlist checking
morales-gregorio May 21, 2020
034d1d2
create complexity class, that can replace complexity_pdf in statistics
morales-gregorio Jun 10, 2020
399e7c4
Work in progress, Synchrotool class. Will have to be changed eventually.
morales-gregorio Jun 10, 2020
b002cd5
Move complexity class to statistics
Kleinjohann Jun 19, 2020
6e360bb
Test whether spiketrains is a list
Kleinjohann Jun 24, 2020
1386799
Finish complexity class (minus docs)
Kleinjohann Jun 24, 2020
35c081f
create list of raised errors in util
morales-gregorio Jul 1, 2020
95189e7
Update docstrings
morales-gregorio Jul 1, 2020
385ed2a
Fix time_histogram not being returned as list of int (by not multiply…
morales-gregorio Jul 1, 2020
db9a299
enforce t_start and t_stop checks
morales-gregorio Jul 1, 2020
e82399d
rename variables to avoid overwriting time_histogram function call
morales-gregorio Jul 1, 2020
2c6b02e
include tolerance to binning in spread > 0 case
morales-gregorio Jul 1, 2020
7b0c541
Apply bin shifting to epoch without spread
Kleinjohann Jul 1, 2020
e36b74a
Implement synchrofact detection as child class
Kleinjohann Jul 1, 2020
d46cbff
Merge branch 'branchout_synchrofact' of github.com:INM-6/elephant int…
Kleinjohann Jul 1, 2020
39bb674
Fix overlooked occurrence of renamed variable
Kleinjohann Jul 1, 2020
9758ab9
style changes to docstring
morales-gregorio Jul 4, 2020
41a252c
Merge branch 'branchout_synchrofact' of github.com:INM-6/elephant int…
morales-gregorio Jul 4, 2020
411ac50
Account for rounding errors in input check
Kleinjohann Jul 15, 2020
2dc37fe
Test for spread<=1 error raising
Kleinjohann Jul 15, 2020
1264a80
Test for no sampling rate warning
Kleinjohann Jul 15, 2020
c48df7f
Test for num_bins rounding error
Kleinjohann Jul 15, 2020
764544b
Cleanup
Kleinjohann Jul 15, 2020
6b4cd4f
Cleanup
Kleinjohann Jul 15, 2020
dc7b9f4
Remove prints
Kleinjohann Jul 15, 2020
934ebc3
update docs, rename `invert` kwarg to `mode`
morales-gregorio Jul 15, 2020
e8c52fe
Account for renaming of invert_delete in tests
Kleinjohann Jul 20, 2020
f6a9d20
Only account for shift when it actually happens
Kleinjohann Jul 20, 2020
f333a6d
Merge branch 'master' into synchrofact_detection
dizcza Aug 11, 2020
ba0eb64
decorative refactoring
dizcza Aug 11, 2020
f8761e8
Merge branch 'master' into synchrofact_detection
dizcza Sep 30, 2020
92292ff
python 2 issue; Capitalize class names; fixed TODOs
dizcza Sep 30, 2020
2285cd2
Merge branch 'master' into synchrofact_detection
dizcza Oct 5, 2020
b47fb3d
merged spike_train_processing into spike_train_synchrony
dizcza Oct 5, 2020
194aedb
converted a string to a comment
dizcza Oct 5, 2020
727bbf0
Faster algorithm for calculating complexity epochs
Kleinjohann Nov 11, 2020
865a75d
Merge branch 'master' into synchrofact_detection
dizcza Nov 13, 2020
0b2dbe6
check_neo_consistency test
dizcza Nov 13, 2020
a094f7a
updated _epoch_with_spread according to the recent changes in master
dizcza Nov 13, 2020
43e5522
Merge branch 'master' into synchrofact_detection
dizcza Nov 13, 2020
819e159
optimized quantity calls
dizcza Nov 30, 2020
e20e8ce
overwrite autosummary docs
dizcza Nov 30, 2020
7233520
round_binning_errors utility function
dizcza Nov 30, 2020
e62f2e1
Support groups only, remove unit and channelindex
Kleinjohann Dec 3, 2020
06e5739
Fix issues in docstrings spotted by @kohlerca
Kleinjohann Dec 3, 2020
ffd54fb
pep8
Kleinjohann Dec 3, 2020
4103045
simplified
dizcza Dec 3, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@
# the autosummary fields of each module.
autosummary_generate = True

# don't overwrite our custom toctree/*.rst
autosummary_generate_overwrite = False
# Set to False to not overwrite our custom toctree/*.rst
autosummary_generate_overwrite = True

# -- Options for HTML output ---------------------------------------------

Expand Down Expand Up @@ -344,7 +344,8 @@

# configuration for intersphinx: refer to Viziphant
intersphinx_mapping = {
'viziphant': ('https://viziphant.readthedocs.io/en/latest/', None)
'viziphant': ('https://viziphant.readthedocs.io/en/latest/', None),
'numpy': ('https://numpy.org/doc/stable', None)
}

# Use more reliable mathjax source
Expand Down
1 change: 1 addition & 0 deletions doc/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Function Reference by Module
reference/sta
reference/statistics
reference/unitary_event_analysis
reference/utils
reference/waveform_features


Expand Down
5 changes: 5 additions & 0 deletions doc/reference/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
=================
Utility functions
=================

.. automodule:: elephant.utils
34 changes: 4 additions & 30 deletions elephant/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import scipy.sparse as sps

from elephant.utils import is_binary, deprecated_alias, \
check_neo_consistency, get_common_start_stop_times
check_neo_consistency, get_common_start_stop_times, round_binning_errors

__all__ = [
"binarize",
Expand Down Expand Up @@ -185,18 +185,6 @@ def binarize(spiketrain, sampling_rate=None, t_start=None, t_stop=None,
###########################################################################


def _detect_rounding_errors(values, tolerance):
"""
Finds rounding errors in values that will be cast to int afterwards.
Returns True for values that are within tolerance of the next integer.
Works for both scalars and numpy arrays.
"""
if tolerance is None or tolerance == 0:
return np.zeros_like(values, dtype=bool)
# same as '1 - (values % 1) <= tolerance' but faster
return 1 - tolerance <= values % 1


class BinnedSpikeTrain(object):
"""
Class which calculates a binned spike train and provides methods to
Expand Down Expand Up @@ -417,12 +405,8 @@ def get_n_bins():
n_bins = (self._t_stop - self._t_start) / self._bin_size
if isinstance(n_bins, pq.Quantity):
n_bins = n_bins.simplified.item()
if _detect_rounding_errors(n_bins, tolerance=tolerance):
warnings.warn('Correcting a rounding error in the calculation '
'of n_bins by increasing n_bins by 1. '
'You can set tolerance=None to disable this '
'behaviour.')
return int(n_bins)
n_bins = round_binning_errors(n_bins, tolerance=tolerance)
return n_bins

def check_n_bins_consistency():
if self.n_bins != get_n_bins():
Expand Down Expand Up @@ -825,17 +809,7 @@ def _create_sparse_matrix(self, spiketrains, tolerance):

# shift spikes that are very close
# to the right edge into the next bin
rounding_error_indices = _detect_rounding_errors(
bins, tolerance=tolerance)
num_rounding_corrections = rounding_error_indices.sum()
if num_rounding_corrections > 0:
warnings.warn('Correcting {} rounding errors by shifting '
'the affected spikes into the following bin. '
'You can set tolerance=None to disable this '
'behaviour.'.format(num_rounding_corrections))
bins[rounding_error_indices] += .5

bins = bins.astype(np.int32)
bins = round_binning_errors(bins, tolerance=tolerance)
valid_bins = bins[bins < self.n_bins]
n_discarded += len(bins) - len(valid_bins)
f, c = np.unique(valid_bins, return_counts=True)
Expand Down
175 changes: 175 additions & 0 deletions elephant/spike_train_synchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
:toctree: toctree/spike_train_synchrony/

spike_contrast
Synchrotool


:copyright: Copyright 2015-2020 by the Elephant team, see `doc/authors.rst`.
Expand All @@ -18,17 +19,26 @@
from __future__ import division, print_function, unicode_literals

from collections import namedtuple
from copy import deepcopy

import neo
import numpy as np
import quantities as pq

from elephant.statistics import Complexity
from elephant.utils import is_time_quantity

SpikeContrastTrace = namedtuple("SpikeContrastTrace", (
"contrast", "active_spiketrains", "synchrony"))


__all__ = [
"SpikeContrastTrace",
"spike_contrast",
"Synchrotool"
]


def _get_theta_and_n_per_bin(spiketrains, t_start, t_stop, bin_size):
"""
Calculates theta (amount of spikes per bin) and the amount of active spike
Expand Down Expand Up @@ -218,3 +228,168 @@ def spike_contrast(spiketrains, t_start=None, t_stop=None,
return synchrony, spike_contrast_trace

return synchrony


class Synchrotool(Complexity):
"""
Tool class to find, remove and/or annotate the presence of synchronous
spiking events across multiple spike trains.

The complexity is used to characterize synchronous events within the same
spike train and across different spike trains in the `spiketrains` list.
This way synchronous events can be found both in multi-unit and
single-unit spike trains.

This class inherits from :class:`elephant.statistics.Complexity`, see its
documentation for more details and input parameters description.

See also
--------
elephant.statistics.Complexity

"""

def __init__(self, spiketrains,
sampling_rate,
bin_size=None,
binary=True,
spread=0,
tolerance=1e-8):

self.annotated = False

super(Synchrotool, self).__init__(spiketrains=spiketrains,
bin_size=bin_size,
sampling_rate=sampling_rate,
binary=binary,
spread=spread,
tolerance=tolerance)

def delete_synchrofacts(self, threshold, in_place=False, mode='delete'):
"""
Kleinjohann marked this conversation as resolved.
Show resolved Hide resolved
Delete or extract synchronous spiking events.

Parameters
----------
threshold : int
Threshold value for the deletion of spikes engaged in synchronous
activity.
* `deletion_threshold >= 2` leads to all spikes with a larger or
equal complexity value to be deleted/extracted.
* `deletion_threshold <= 1` leads to a ValueError, since this
would delete/extract all spikes and there are definitely more
efficient ways of doing so.
in_place : bool, optional
Determines whether the modification are made in place
on ``self.input_spiketrains``.
Default: False
mode : {'delete', 'extract'}, optional
Inversion of the mask for deletion of synchronous events.
* ``'delete'`` leads to the deletion of all spikes with
complexity >= `threshold`,
i.e. deletes synchronous spikes.
* ``'extract'`` leads to the deletion of all spikes with
complexity < `threshold`, i.e. extracts synchronous spikes.
Default: 'delete'

Raises
------
ValueError
If `mode` is not one in {'delete', 'extract'}.

If `threshold <= 1`.

Returns
-------
list of neo.SpikeTrain
List of spiketrains where the spikes with
``complexity >= threshold`` have been deleted/extracted.
* If ``in_place`` is True, the returned list is the same as
``self.input_spiketrains``.
* If ``in_place`` is False, the returned list is a deepcopy of
``self.input_spiketrains``.

"""

if not self.annotated:
self.annotate_synchrofacts()

if mode not in ['delete', 'extract']:
raise ValueError(f"Invalid mode '{mode}'. Valid modes are: "
f"'delete', 'extract'")

if threshold <= 1:
raise ValueError('A deletion threshold <= 1 would result '
'in the deletion of all spikes.')

if in_place:
spiketrain_list = self.input_spiketrains
else:
spiketrain_list = deepcopy(self.input_spiketrains)

for idx, st in enumerate(spiketrain_list):
mask = st.array_annotations['complexity'] < threshold
if mode == 'extract':
mask = np.invert(mask)
new_st = st[mask]
spiketrain_list[idx] = new_st
if in_place:
segment = st.segment
if segment is None:
continue

# replace link to spiketrain in segment
new_index = self._get_spiketrain_index(
segment.spiketrains, st)
segment.spiketrains[new_index] = new_st

block = segment.block
if block is None:
continue

# replace link to spiketrain in groups
for group in block.groups:
try:
idx = self._get_spiketrain_index(
group.spiketrains,
st)
except ValueError:
# st is not in this group, move to next group
continue

# st found in group, replace with new_st
group.spiketrains[idx] = new_st

return spiketrain_list

def annotate_synchrofacts(self):
"""
Annotate the complexity of each spike in the
``self.epoch.array_annotations`` *in-place*.
"""
epoch_complexities = self.epoch.array_annotations['complexity']
right_edges = (
self.epoch.times.magnitude.flatten()
+ self.epoch.durations.rescale(
self.epoch.times.units).magnitude.flatten()
)

for idx, st in enumerate(self.input_spiketrains):

# all indices of spikes that are within the half-open intervals
# defined by the boundaries
# note that every second entry in boundaries is an upper boundary
spike_to_epoch_idx = np.searchsorted(
right_edges,
st.times.rescale(self.epoch.times.units).magnitude.flatten())
complexity_per_spike = epoch_complexities[spike_to_epoch_idx]

st.array_annotate(complexity=complexity_per_spike)

self.annotated = True

def _get_spiketrain_index(self, spiketrain_list, spiketrain):
for index, item in enumerate(spiketrain_list):
if item is spiketrain:
return index
raise ValueError("Spiketrain is not found in the list")
Loading