From 4ab6ff85844b7e4eb2fccb0217c1fdcf1b3c0f12 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Mon, 11 May 2020 11:08:53 +0200 Subject: [PATCH 01/58] Add first version of synchrofact detection --- elephant/spike_train_processing.py | 261 ++++++++++++++ elephant/test/test_spike_train_processing.py | 346 +++++++++++++++++++ 2 files changed, 607 insertions(+) create mode 100644 elephant/spike_train_processing.py create mode 100644 elephant/test/test_spike_train_processing.py diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py new file mode 100644 index 000000000..8ca9dd432 --- /dev/null +++ b/elephant/spike_train_processing.py @@ -0,0 +1,261 @@ +from __future__ import division + +import neo +import elephant.conversion as conv +import quantities as pq +import numpy as np +import warnings + + +def get_index(lst, obj): + for index, item in enumerate(lst): + if item is obj: + return index + return None + + +def detect_synchrofacts(block, sampling_rate, segment='all', n=2, spread=2, + invert=False, delete=False, unit_type='all'): + """ + Given block with spike trains, find all spikes engaged + in synchronous events of size *n* or higher. Two events are considered + synchronous if they occur within spread/sampling_rate of one another. + + *Args* + ------ + block [list]: + a block containing neo spike trains + + segment [int or iterable or str. Default: 1]: + indices of segments in the block. Can be an integer, an iterable object + or a string containing 'all'. Indicates on which segments of the block + the synchrofact removal should be performed. + + n [int. Default: 2]: + minimum number of coincident spikes to report synchrony + + spread [int. Default: 2]: + number of bins of size 1/sampling_rate in which to check for + synchronous spikes. *n* spikes within *spread* consecutive bins are + considered synchronous. + + sampling_rate [quantity. Default: 30000/s]: + Sampling rate of the spike trains. The spike trains are binned with + binsize dt = 1/sampling_rate and *n* spikes within *spread* consecutive + bins are considered synchronous. + Groups of *n* or more synchronous spikes are deleted/annotated. + + invert [bool. Default: True]: + invert the mask for annotation/deletion (Default:False). + False annotates synchrofacts with False and other spikes with True or + deletes everything except for synchrofacts for delete = True. + + delete [bool. Default: False]: + delete spikes engaged in synchronous activity. If set to False the + spiketrains are array-annotated and the spike times are kept unchanged. + + unit_type [list of strings. Default 'all']: + selects only spiketrain of certain units / channels for synchrofact + extraction. unit_type = 'all' considers all provided spiketrains + Accepted unit types: 'sua', 'mua', 'idX' + (where X is the id number requested) + """ + # TODO: refactor docs, correct description of spread parameter + + if isinstance(segment, str): + if 'all' in segment.lower(): + segment = range(len(block.segments)) + else: + raise ValueError('Input parameter segment not understood.') + + elif isinstance(segment, int): + segment = [segment] + + # make sure all quantities have units s + binsize = (1 / sampling_rate).rescale(pq.s) + + for seg in segment: + # data check + if len(block.segments[seg].spiketrains) == 0: + warnings.warn( + 'Segment {0} does not contain any spiketrains!'.format(seg)) + continue + + selected_sts, index = [], [] + + # considering all spiketrains for unit_type == 'all' + if isinstance(unit_type, str): + if 'all' in unit_type.lower(): + selected_sts = block.segments[seg].spiketrains + index = range(len(block.segments[seg].spiketrains)) + + else: + # extracting spiketrains which should be used for synchrofact + # extraction based on given unit type + # possible improvement by using masks for different conditions + # and adding them up + for i, st in enumerate(block.segments[seg].spiketrains): + take_it = False + for utype in unit_type: + if (utype[:2] == 'id' and + st.annotations['unit_id'] == int( + utype[2:])): + take_it = True + elif ((utype == 'sua' or utype == 'mua') + and utype in st.annotations + and st.annotations[utype]): + take_it = True + if take_it: + selected_sts.append(st) + index.append(i) + + # if no spiketrains were selected + if len(selected_sts) == 0: + warnings.warn( + 'No matching spike trains for given unit selection' + 'criteria %s found' % unit_type) + # we can skip to the next segment immediately since there are no + # spiketrains to perform synchrofact detection on + continue + else: + # find times of synchrony of size >=n + bst = conv.BinnedSpikeTrain(selected_sts, + binsize=binsize) + # TODO: adapt everything below, find_complexity_intervals should + # return a neo.Epoch instead + # TODO: we can probably clean up all implicit units once we use + # neo.Epoch for intervals + # TODO: use conversion._detect_rounding_errors to ensure that + # there are no rounding errors + complexity_intervals = find_complexity_intervals(bst, + min_complexity=n, + spread=spread) + # get a sorted flattened array of the interval edges + boundaries = complexity_intervals[1:].flatten(order='F') + + # j = index of pre-selected sts in selected_sts + # idx = index of pre-selected sts in original + # block.segments[seg].spiketrains + for j, idx in enumerate(index): + + # all indices of spikes that are within the half-open intervals + # defined by the boundaries + # note that every second entry in boundaries is an upper boundary + mask = np.array( + np.searchsorted(boundaries, + selected_sts[j].times.rescale(pq.s).magnitude, + side='right') % 2, + dtype=np.bool) + if invert: + mask = np.invert(mask) + + if delete: + old_st = selected_sts[j] + new_st = old_st[np.logical_not(mask)] + block.segments[seg].spiketrains[idx] = new_st + unit = old_st.unit + if unit is not None: + unit.spiketrains[get_index(unit.spiketrains, + old_st)] = new_st + del old_st + else: + block.segments[seg].spiketrains[idx].array_annotate( + synchrofacts=mask) + + +def find_complexity_intervals(bst, min_complexity=2, spread=1): + """ + Calculate the complexity (i.e. number of synchronous spikes) + for each bin. + + For `spread = 1` this corresponds to a simple bincount. + + For `spread > 1` jittered synchrony is included, then spikes within + `spread` consecutive bins are considered to be synchronous. + Every bin of such a jittered synchronous event is assigned the + complexity of the whole event, see example below. + + Parameters + ---------- + min_complexity : int, optional + Minimum complexity to report + Default: 2. + spread : int, optional + Number of bins in which to check for synchronous spikes. + Spikes within `spread` consecutive bins are considered synchronous. + Default: 2. + + Returns + ------- + complexity_intervals : np.ndarray + An array containing complexity values, left and right edges of all + intervals with at least `min_complexity` spikes separated by fewer + than `spread - 1` empty bins. + Output shape (3, num_complexity_intervals) + + Raises + ------ + ValueError + When `t_stop` is smaller than `t_start`. + + Examples + -------- + >>> import elephant.conversion as conv + >>> import neo + >>> import quantities as pq + >>> st1 = neo.SpikeTrain([1, 6] * pq.ms, + ... t_stop=10.0 * pq.ms) + >>> st2 = neo.SpikeTrain([1, 7] * pq.ms, + ... t_stop=10.0 * pq.ms) + >>> bst = conv.BinnedSpikeTrain([st1, st2], num_bins=10, + ... binsize=1 * pq.ms, + ... t_start=0 * pq.ms) + >>> print(bst.complexity().magnitude.flatten()) + [0. 2. 0. 0. 0. 0. 1. 1. 0. 0.] + >>> print(bst.complexity(spread=2).magnitude.flatten()) + [0. 2. 0. 0. 0. 0. 2. 2. 0. 0.] + """ + # TODO: documentation, example + bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() + + if spread == 1: + bin_indices = np.where(bincount >= min_complexity)[0] + complexities = bincount[bin_indices] + left_edges = bst.bin_edges[bin_indices].rescale(pq.s).magnitude + right_edges = bst.bin_edges[bin_indices + 1].rescale(pq.s).magnitude + else: + i = 0 + complexities = [] + left_edges = [] + right_edges = [] + while i < len(bincount): + current_bincount = bincount[i] + if current_bincount == 0: + i += 1 + else: + last_window_sum = current_bincount + last_nonzero_index = 0 + current_window = bincount[i:i+spread] + window_sum = current_window.sum() + while window_sum > last_window_sum: + last_nonzero_index = np.nonzero(current_window)[0][-1] + current_window = bincount[i: + i + last_nonzero_index + + spread] + last_window_sum = window_sum + window_sum = current_window.sum() + if window_sum >= min_complexity: + complexities.append(window_sum) + left_edges.append( + bst.bin_edges[i].rescale(pq.s).magnitude.item()) + right_edges.append( + bst.bin_edges[ + i + last_nonzero_index + 1 + ].rescale(pq.s).magnitude.item()) + i += last_nonzero_index + 1 + + # TODO: return a neo.Epoch instead + complexity_intervals = np.vstack((complexities, left_edges, right_edges)) + + return complexity_intervals + diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py new file mode 100644 index 000000000..0874e92c5 --- /dev/null +++ b/elephant/test/test_spike_train_processing.py @@ -0,0 +1,346 @@ +# -*- coding: utf-8 -*- +""" +Unit tests for the synchrofact detection app +""" + +import unittest + +import neo +import numpy as np +from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_array_equal +import quantities as pq + +from elephant import spike_train_processing + + +def generate_block(spike_times, segment_edges=[0, 10, 20]*pq.s): + """ + Generate a block with segments with start and end times given by segment_edges + and with spike trains given by spike_times. + """ + n_segments = len(segment_edges) - 1 + + # Create Block to contain all generated data + block = neo.Block() + + # Create multiple Segments + block.segments = [neo.Segment(index=i, + t_start=segment_edges[i], + t_stop=segment_edges[i+1]) + for i in range(n_segments)] + + # Create multiple ChannelIndexes + block.channel_indexes = [neo.ChannelIndex(name='C%d' % i, index=i) + for i in range(len(spike_times[0]))] + + # Attach multiple Units to each ChannelIndex + for i, channel_idx in enumerate(block.channel_indexes): + channel_idx.units = [neo.Unit('U1')] + for seg_idx, seg in enumerate(block.segments): + train = neo.SpikeTrain(spike_times[seg_idx][i], + t_start=segment_edges[seg_idx], + t_stop=segment_edges[seg_idx+1]) + seg.spiketrains.append(train) + channel_idx.units[0].spiketrains.append(train) + + block.create_many_to_one_relationship() + return block + + +class SynchrofactDetectionTestCase(unittest.TestCase): + + def test_no_synchrofacts(self): + + # nothing to find here + # there was an error for spread > 1 when nothing was found + # since boundaries is then set to [] and we later check boundaries.shape + # fixed by skipping the interval merge step when there are no intervals + + sampling_rate = 1 / pq.s + + spike_times = np.array([[[1, 9], [3, 7]], [[12, 19], [15, 17]]]) * pq.s + + block = generate_block(spike_times) + + # test annotation + spike_train_processing.detect_synchrofacts(block, segment='all', n=2, spread=2, + sampling_rate=sampling_rate, + invert=False, delete=False, + unit_type='all') + + correct_annotations = [[np.array([False, False]), np.array([False, False])], + [np.array([False, False]), np.array([False, False])]] + + annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] + for seg in block.segments] + + assert_array_equal(annotations, correct_annotations) + + # test deletion + spike_train_processing.detect_synchrofacts(block, segment='all', n=2, spread=2, + sampling_rate=sampling_rate, + invert=False, delete=True, + unit_type='all') + + correct_spike_times = np.array( + [[spikes[mask] for spikes, mask in zip(seg_spike_times, seg_mask)] + for seg_spike_times, seg_mask in zip(spike_times, + np.logical_not(correct_annotations) + ) + ]) + + cleaned_spike_times = np.array( + [[st.times for st in seg.spiketrains] for seg in block.segments]) + + for correct_seg, cleaned_seg in zip(correct_spike_times, cleaned_spike_times): + for correct_st, cleaned_st in zip(correct_seg, cleaned_seg): + assert_array_almost_equal(cleaned_st, correct_st) + + def test_spread_1(self): + + # basic test with a minimum number of two spikes per synchrofact + # only taking into account multiple spikes + # within one bin of size 1 / sampling_rate + + sampling_rate = 1 / pq.s + + spike_times = np.array([[[1, 5, 9], [1, 4, 8]], + [[11, 16, 19], [12, 16, 18]]]) * pq.s + + block = generate_block(spike_times) + + # test annotation + spike_train_processing.detect_synchrofacts(block, segment='all', n=2, + spread=1, + sampling_rate=sampling_rate, + invert=False, delete=False, + unit_type='all') + + correct_annotations = np.array([[[True, False, False], [True, False, False]], + [[False, True, False], [False, True, False]]]) + + annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] + for seg in block.segments] + + assert_array_equal(annotations, correct_annotations) + + # test deletion + spike_train_processing.detect_synchrofacts(block, segment='all', n=2, spread=1, + sampling_rate=sampling_rate, + invert=False, delete=True, + unit_type='all') + + correct_spike_times = np.array([[spikes[mask] + for spikes, mask in zip(seg_spike_times, + seg_mask)] + for seg_spike_times, seg_mask + in zip(spike_times, + np.logical_not(correct_annotations))]) + + cleaned_spike_times = np.array([[st.times for st in seg.spiketrains] + for seg in block.segments]) + + assert_array_almost_equal(cleaned_spike_times, correct_spike_times) + + def test_spread_2(self): + + # test synchrofact search taking into account adjacent bins + # this requires an additional loop with shifted binning + + sampling_rate = 1 / pq.s + + spike_times = np.array([[[1, 5, 9], [1, 4, 7]], + [[10, 12, 19], [11, 15, 17]]]) * pq.s + + block = generate_block(spike_times) + + # test annotation + spike_train_processing.detect_synchrofacts(block, segment='all', + n=2, spread=2, + sampling_rate=sampling_rate, + invert=False, delete=False, + unit_type='all') + + correct_annotations = [[np.array([True, True, False]), + np.array([True, True, False])], + [np.array([True, True, False]), + np.array([True, False, False])]] + + annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] + for seg in block.segments] + + assert_array_equal(annotations, correct_annotations) + + # test deletion + spike_train_processing.detect_synchrofacts(block, segment='all', n=2, spread=2, + sampling_rate=sampling_rate, + invert=False, delete=True, + unit_type='all') + + correct_spike_times = np.array([[spikes[mask] for spikes, mask in + zip(seg_spike_times, seg_mask)] + for seg_spike_times, seg_mask in + zip(spike_times, + np.logical_not(correct_annotations))]) + + cleaned_spike_times = np.array([[st.times for st in seg.spiketrains] + for seg in block.segments]) + + for correct_seg, cleaned_seg in zip(correct_spike_times, cleaned_spike_times): + for correct_st, cleaned_st in zip(correct_seg, cleaned_seg): + assert_array_almost_equal(cleaned_st, correct_st) + + def test_n_equals_3(self): + + # test synchrofact detection with a minimum number of + # three spikes per synchrofact + + sampling_rate = 1 / pq.s + + spike_times = np.array([[[1, 1, 5, 10], [1, 4, 7, 9]], + [[12, 15, 16, 18], [11, 13, 15, 19]]]) * pq.s + + block = generate_block(spike_times) + + # test annotation + spike_train_processing.detect_synchrofacts(block, segment='all', n=3, spread=2, + sampling_rate=sampling_rate, + invert=False, delete=False, + unit_type='all') + + correct_annotations = [[np.array([True, True, False, False]), + np.array([True, False, False, False])], + [np.array([True, True, True, False]), + np.array([True, True, True, False])]] + + annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] + for seg in block.segments] + + assert_array_equal(annotations, correct_annotations) + + # test deletion + spike_train_processing.detect_synchrofacts(block, segment='all', n=3, spread=2, + sampling_rate=sampling_rate, + invert=False, delete=True, + unit_type='all') + + correct_spike_times = np.array([[spikes[mask] for spikes, mask in + zip(seg_spike_times, seg_mask)] + for seg_spike_times, seg_mask in + zip(spike_times, + np.logical_not(correct_annotations))]) + + cleaned_spike_times = np.array([[st.times for st in seg.spiketrains] + for seg in block.segments]) + + for correct_seg, cleaned_seg in zip(correct_spike_times, cleaned_spike_times): + for correct_st, cleaned_st in zip(correct_seg, cleaned_seg): + assert_array_almost_equal(cleaned_st, correct_st) + + def test_binning_for_input_with_rounding_errors(self): + + # redo the test_n_equals_3 with inputs divided by 30000 + # which leads to rounding errors + # these errors have to be accounted for by proper binning; + # check if we still get the correct result + + sampling_rate = 30000. / pq.s + + spike_times = np.array([[[1, 1, 5, 10], [1, 4, 7, 9]], + [[12, 15, 16, 18], [11, 13, 15, 19]]]) / 30000. * pq.s + + block = generate_block(spike_times, + segment_edges=[0./30000., 10./30000., 20./30000.]*pq.s) + + # test annotation + spike_train_processing.detect_synchrofacts(block, segment='all', n=3, spread=2, + sampling_rate=sampling_rate, + invert=False, delete=False, + unit_type='all') + + correct_annotations = [[np.array([True, True, False, False]), + np.array([True, False, False, False])], + [np.array([True, True, True, False]), + np.array([True, True, True, False])]] + + annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] + for seg in block.segments] + + assert_array_equal(annotations, correct_annotations) + + # test deletion + spike_train_processing.detect_synchrofacts(block, segment='all', n=3, spread=2, + sampling_rate=sampling_rate, + invert=False, delete=True, + unit_type='all') + + correct_spike_times = np.array([[spikes[mask] for spikes, mask in + zip(seg_spike_times, seg_mask)] + for seg_spike_times, seg_mask in + zip(spike_times, + np.logical_not(correct_annotations))]) + + cleaned_spike_times = np.array([[st.times for st in seg.spiketrains] + for seg in block.segments]) + + for correct_seg, cleaned_seg in zip(correct_spike_times, cleaned_spike_times): + for correct_st, cleaned_st in zip(correct_seg, cleaned_seg): + assert_array_almost_equal(cleaned_st, correct_st) + + def test_correct_transfer_of_spiketrain_attributes(self): + + # for delete=True the spiketrains in the block are changed, + # test if their attributes remain correct + + sampling_rate = 1 / pq.s + + spike_times = np.array([[[1, 1, 5, 9]]]) * pq.s + + block = generate_block(spike_times, segment_edges=[0, 10]*pq.s) + + block.segments[0].spiketrains[0].annotate(cool_spike_train=True) + block.segments[0].spiketrains[0].array_annotate( + spike_number=np.arange(len( + block.segments[0].spiketrains[0].times.magnitude))) + block.segments[0].spiketrains[0].waveforms = np.sin( + np.arange(len( + block.segments[0].spiketrains[0].times.magnitude))[:, np.newaxis] + + np.arange(len( + block.segments[0].spiketrains[0].times.magnitude))[np.newaxis, :]) + + correct_mask = np.array([False, False, True, True]) + + # store the correct attributes + correct_annotations = block.segments[0].spiketrains[0].annotations.copy() + correct_waveforms = block.segments[0].spiketrains[0].waveforms[ + correct_mask].copy() + correct_array_annotations = { + key: value[correct_mask] for key, value in + block.segments[0].spiketrains[0].array_annotations.items()} + + # perform a synchrofact search with delete=True + spike_train_processing.detect_synchrofacts(block, segment='all', + n=2, spread=1, + sampling_rate=sampling_rate, + invert=False, delete=True, + unit_type='all') + + # Ensure that the spiketrain was not duplicated + self.assertEqual(len(block.filter(objects=neo.SpikeTrain)), 1) + + cleaned_annotations = block.segments[0].spiketrains[0].annotations + cleaned_waveforms = block.segments[0].spiketrains[0].waveforms + cleaned_array_annotations = block.segments[0].spiketrains[0].array_annotations + + self.assertDictEqual(correct_annotations, cleaned_annotations) + assert_array_almost_equal(cleaned_waveforms, correct_waveforms) + self.assertTrue(len(cleaned_array_annotations) + == len(correct_array_annotations)) + for key, value in correct_array_annotations.items(): + self.assertTrue(key in cleaned_array_annotations.keys()) + assert_array_almost_equal(value, cleaned_array_annotations[key]) + + +if __name__ == '__main__': + unittest.main() From 7b85e3d6361bf0c032e90483605d98094ab6bd77 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Fri, 15 May 2020 22:08:58 +0200 Subject: [PATCH 02/58] Implement TODOs, adapt tests accordingly --- elephant/spike_train_processing.py | 226 ++++++------ elephant/test/test_spike_train_processing.py | 357 ++++++------------- 2 files changed, 220 insertions(+), 363 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 8ca9dd432..48791dec3 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -14,8 +14,19 @@ def get_index(lst, obj): return None -def detect_synchrofacts(block, sampling_rate, segment='all', n=2, spread=2, - invert=False, delete=False, unit_type='all'): +def _check_spiketrains(spiketrains): + if len(spiketrains) == 0: + raise ValueError('The spiketrains should not be empty!') + + # check that all list elements are spike trains + for spiketrain in spiketrains: + if not isinstance(spiketrain, neo.SpikeTrain): + raise TypeError('not all elements of spiketrains are' + 'neo.SpikeTrain objects') + + +def detect_synchrofacts(spiketrains, sampling_rate, spread=2, + invert=False, deletion_threshold=None): """ Given block with spike trains, find all spikes engaged in synchronous events of size *n* or higher. Two events are considered @@ -41,7 +52,7 @@ def detect_synchrofacts(block, sampling_rate, segment='all', n=2, spread=2, sampling_rate [quantity. Default: 30000/s]: Sampling rate of the spike trains. The spike trains are binned with - binsize dt = 1/sampling_rate and *n* spikes within *spread* consecutive + bin_size dt = 1/sampling_rate and *n* spikes within *spread* consecutive bins are considered synchronous. Groups of *n* or more synchronous spikes are deleted/annotated. @@ -62,118 +73,64 @@ def detect_synchrofacts(block, sampling_rate, segment='all', n=2, spread=2, """ # TODO: refactor docs, correct description of spread parameter - if isinstance(segment, str): - if 'all' in segment.lower(): - segment = range(len(block.segments)) - else: - raise ValueError('Input parameter segment not understood.') - - elif isinstance(segment, int): - segment = [segment] - - # make sure all quantities have units s - binsize = (1 / sampling_rate).rescale(pq.s) - - for seg in segment: - # data check - if len(block.segments[seg].spiketrains) == 0: - warnings.warn( - 'Segment {0} does not contain any spiketrains!'.format(seg)) - continue - - selected_sts, index = [], [] - - # considering all spiketrains for unit_type == 'all' - if isinstance(unit_type, str): - if 'all' in unit_type.lower(): - selected_sts = block.segments[seg].spiketrains - index = range(len(block.segments[seg].spiketrains)) - - else: - # extracting spiketrains which should be used for synchrofact - # extraction based on given unit type - # possible improvement by using masks for different conditions - # and adding them up - for i, st in enumerate(block.segments[seg].spiketrains): - take_it = False - for utype in unit_type: - if (utype[:2] == 'id' and - st.annotations['unit_id'] == int( - utype[2:])): - take_it = True - elif ((utype == 'sua' or utype == 'mua') - and utype in st.annotations - and st.annotations[utype]): - take_it = True - if take_it: - selected_sts.append(st) - index.append(i) - - # if no spiketrains were selected - if len(selected_sts) == 0: - warnings.warn( - 'No matching spike trains for given unit selection' - 'criteria %s found' % unit_type) - # we can skip to the next segment immediately since there are no - # spiketrains to perform synchrofact detection on - continue - else: - # find times of synchrony of size >=n - bst = conv.BinnedSpikeTrain(selected_sts, - binsize=binsize) - # TODO: adapt everything below, find_complexity_intervals should - # return a neo.Epoch instead - # TODO: we can probably clean up all implicit units once we use - # neo.Epoch for intervals - # TODO: use conversion._detect_rounding_errors to ensure that - # there are no rounding errors - complexity_intervals = find_complexity_intervals(bst, - min_complexity=n, - spread=spread) - # get a sorted flattened array of the interval edges - boundaries = complexity_intervals[1:].flatten(order='F') - - # j = index of pre-selected sts in selected_sts - # idx = index of pre-selected sts in original - # block.segments[seg].spiketrains - for j, idx in enumerate(index): - - # all indices of spikes that are within the half-open intervals - # defined by the boundaries - # note that every second entry in boundaries is an upper boundary - mask = np.array( - np.searchsorted(boundaries, - selected_sts[j].times.rescale(pq.s).magnitude, - side='right') % 2, - dtype=np.bool) - if invert: - mask = np.invert(mask) + if deletion_threshold is not None and deletion_threshold <= 1: + raise ValueError('A deletion_threshold <= 1 would result' + 'in deletion of all spikes.') - if delete: - old_st = selected_sts[j] - new_st = old_st[np.logical_not(mask)] - block.segments[seg].spiketrains[idx] = new_st - unit = old_st.unit - if unit is not None: - unit.spiketrains[get_index(unit.spiketrains, - old_st)] = new_st - del old_st - else: - block.segments[seg].spiketrains[idx].array_annotate( - synchrofacts=mask) + _check_spiketrains(spiketrains) + + # find times of synchrony of size >=n + complexity_epoch = find_complexity_intervals(spiketrains, + sampling_rate, + spread=spread) + complexity = complexity_epoch.array_annotations['complexity'] + right_edges = complexity_epoch.times + complexity_epoch.durations + + # j = index of pre-selected sts in spiketrains + # idx = index of pre-selected sts in original + # block.segments[seg].spiketrains + for idx, st in enumerate(spiketrains): + + # all indices of spikes that are within the half-open intervals + # defined by the boundaries + # note that every second entry in boundaries is an upper boundary + spike_to_epoch_idx = np.searchsorted(right_edges, + st.times.rescale( + right_edges.units)) + complexity_per_spike = complexity[spike_to_epoch_idx] + st.array_annotate(complexity=complexity_per_spike) -def find_complexity_intervals(bst, min_complexity=2, spread=1): + if deletion_threshold is not None: + mask = complexity_per_spike < deletion_threshold + if invert: + mask = np.invert(mask) + old_st = st + new_st = old_st[mask] + spiketrains[idx] = new_st + unit = old_st.unit + segment = old_st.segment + if unit is not None: + unit.spiketrains[get_index(unit.spiketrains, + old_st)] = new_st + if segment is not None: + segment.spiketrains[get_index(segment.spiketrains, + old_st)] = new_st + del old_st + + return complexity_epoch + + +def find_complexity_intervals(spiketrains, sampling_rate, + bin_size=None, spread=1): """ Calculate the complexity (i.e. number of synchronous spikes) for each bin. For `spread = 1` this corresponds to a simple bincount. - For `spread > 1` jittered synchrony is included, then spikes within - `spread` consecutive bins are considered to be synchronous. - Every bin of such a jittered synchronous event is assigned the - complexity of the whole event, see example below. + For `spread > 1` spikes separated by fewer than `spread - 1` + empty bins are considered synchronous. Parameters ---------- @@ -208,21 +165,34 @@ def find_complexity_intervals(bst, min_complexity=2, spread=1): >>> st2 = neo.SpikeTrain([1, 7] * pq.ms, ... t_stop=10.0 * pq.ms) >>> bst = conv.BinnedSpikeTrain([st1, st2], num_bins=10, - ... binsize=1 * pq.ms, + ... bin_size=1 * pq.ms, ... t_start=0 * pq.ms) >>> print(bst.complexity().magnitude.flatten()) [0. 2. 0. 0. 0. 0. 1. 1. 0. 0.] >>> print(bst.complexity(spread=2).magnitude.flatten()) [0. 2. 0. 0. 0. 0. 2. 2. 0. 0.] """ + _check_spiketrains(spiketrains) + + if bin_size is None: + bin_size = 1 / sampling_rate + elif bin_size < 1 / sampling_rate: + raise ValueError('The bin size should be at least' + '1 / sampling_rate (which is the' + 'default).') + # TODO: documentation, example + min_t_start = min([st.t_start for st in spiketrains]) + + bst = conv.BinnedSpikeTrain(spiketrains, + binsize=bin_size) bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() if spread == 1: - bin_indices = np.where(bincount >= min_complexity)[0] + bin_indices = np.nonzero(bincount)[0] complexities = bincount[bin_indices] - left_edges = bst.bin_edges[bin_indices].rescale(pq.s).magnitude - right_edges = bst.bin_edges[bin_indices + 1].rescale(pq.s).magnitude + left_edges = bst.bin_edges[bin_indices] + right_edges = bst.bin_edges[bin_indices + 1] else: i = 0 complexities = [] @@ -244,18 +214,32 @@ def find_complexity_intervals(bst, min_complexity=2, spread=1): + spread] last_window_sum = window_sum window_sum = current_window.sum() - if window_sum >= min_complexity: - complexities.append(window_sum) - left_edges.append( - bst.bin_edges[i].rescale(pq.s).magnitude.item()) - right_edges.append( - bst.bin_edges[ - i + last_nonzero_index + 1 - ].rescale(pq.s).magnitude.item()) + complexities.append(window_sum) + left_edges.append( + bst.bin_edges[i].magnitude.item()) + right_edges.append( + bst.bin_edges[ + i + last_nonzero_index + 1 + ].magnitude.item()) i += last_nonzero_index + 1 - # TODO: return a neo.Epoch instead - complexity_intervals = np.vstack((complexities, left_edges, right_edges)) + # we dropped units above, neither concatenate nor append works with + # arrays of quantities + left_edges *= bst.bin_edges.units + right_edges *= bst.bin_edges.units + + # ensure that spikes are not on the bin edges + bin_shift = .5 / sampling_rate + left_edges -= bin_shift + right_edges -= bin_shift + + # ensure that epoch does not start before the minimum t_start + left_edges[0] = min(min_t_start, left_edges[0]) + + complexity_epoch = neo.Epoch(times=left_edges, + durations=right_edges - left_edges, + array_annotations={'complexity': + complexities}) - return complexity_intervals + return complexity_epoch diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index 0874e92c5..e912a1e33 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -14,88 +14,60 @@ from elephant import spike_train_processing -def generate_block(spike_times, segment_edges=[0, 10, 20]*pq.s): - """ - Generate a block with segments with start and end times given by segment_edges - and with spike trains given by spike_times. - """ - n_segments = len(segment_edges) - 1 - - # Create Block to contain all generated data - block = neo.Block() - - # Create multiple Segments - block.segments = [neo.Segment(index=i, - t_start=segment_edges[i], - t_stop=segment_edges[i+1]) - for i in range(n_segments)] - - # Create multiple ChannelIndexes - block.channel_indexes = [neo.ChannelIndex(name='C%d' % i, index=i) - for i in range(len(spike_times[0]))] - - # Attach multiple Units to each ChannelIndex - for i, channel_idx in enumerate(block.channel_indexes): - channel_idx.units = [neo.Unit('U1')] - for seg_idx, seg in enumerate(block.segments): - train = neo.SpikeTrain(spike_times[seg_idx][i], - t_start=segment_edges[seg_idx], - t_stop=segment_edges[seg_idx+1]) - seg.spiketrains.append(train) - channel_idx.units[0].spiketrains.append(train) - - block.create_many_to_one_relationship() - return block - - class SynchrofactDetectionTestCase(unittest.TestCase): - def test_no_synchrofacts(self): + def _test_template(self, spiketrains, correct_complexities, sampling_rate, + spread, deletion_threshold=2, invert=False): + # test annotation + spike_train_processing.detect_synchrofacts( + spiketrains, + spread=spread, + sampling_rate=sampling_rate, + invert=invert, + deletion_threshold=None) - # nothing to find here - # there was an error for spread > 1 when nothing was found - # since boundaries is then set to [] and we later check boundaries.shape - # fixed by skipping the interval merge step when there are no intervals + annotations = [st.array_annotations['complexity'] + for st in spiketrains] - sampling_rate = 1 / pq.s + assert_array_equal(annotations, correct_complexities) - spike_times = np.array([[[1, 9], [3, 7]], [[12, 19], [15, 17]]]) * pq.s + correct_spike_times = np.array( + [spikes[mask] for spikes, mask + in zip(spiketrains, correct_complexities < deletion_threshold) + ]) - block = generate_block(spike_times) + # test deletion + spike_train_processing.detect_synchrofacts( + spiketrains, + spread=spread, + sampling_rate=sampling_rate, + invert=invert, + deletion_threshold=deletion_threshold) - # test annotation - spike_train_processing.detect_synchrofacts(block, segment='all', n=2, spread=2, - sampling_rate=sampling_rate, - invert=False, delete=False, - unit_type='all') + cleaned_spike_times = np.array( + [st.times for st in spiketrains]) - correct_annotations = [[np.array([False, False]), np.array([False, False])], - [np.array([False, False]), np.array([False, False])]] + for correct_st, cleaned_st in zip(correct_spike_times, + cleaned_spike_times): + assert_array_almost_equal(cleaned_st, correct_st) - annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] - for seg in block.segments] + def test_no_synchrofacts(self): - assert_array_equal(annotations, correct_annotations) + # nothing to find here + # there used to be an error for spread > 1 when nothing was found - # test deletion - spike_train_processing.detect_synchrofacts(block, segment='all', n=2, spread=2, - sampling_rate=sampling_rate, - invert=False, delete=True, - unit_type='all') + sampling_rate = 1 / pq.s - correct_spike_times = np.array( - [[spikes[mask] for spikes, mask in zip(seg_spike_times, seg_mask)] - for seg_spike_times, seg_mask in zip(spike_times, - np.logical_not(correct_annotations) - ) - ]) + spiketrains = [neo.SpikeTrain([1, 9, 12, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([3, 7, 15, 17] * pq.s, + t_stop=20*pq.s)] - cleaned_spike_times = np.array( - [[st.times for st in seg.spiketrains] for seg in block.segments]) + correct_annotations = np.array([[1, 1, 1, 1], + [1, 1, 1, 1]]) - for correct_seg, cleaned_seg in zip(correct_spike_times, cleaned_spike_times): - for correct_st, cleaned_st in zip(correct_seg, cleaned_seg): - assert_array_almost_equal(cleaned_st, correct_st) + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=2, invert=False, deletion_threshold=2) def test_spread_1(self): @@ -105,43 +77,16 @@ def test_spread_1(self): sampling_rate = 1 / pq.s - spike_times = np.array([[[1, 5, 9], [1, 4, 8]], - [[11, 16, 19], [12, 16, 18]]]) * pq.s - - block = generate_block(spike_times) - - # test annotation - spike_train_processing.detect_synchrofacts(block, segment='all', n=2, - spread=1, - sampling_rate=sampling_rate, - invert=False, delete=False, - unit_type='all') - - correct_annotations = np.array([[[True, False, False], [True, False, False]], - [[False, True, False], [False, True, False]]]) - - annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] - for seg in block.segments] + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, + t_stop=20*pq.s)] - assert_array_equal(annotations, correct_annotations) - - # test deletion - spike_train_processing.detect_synchrofacts(block, segment='all', n=2, spread=1, - sampling_rate=sampling_rate, - invert=False, delete=True, - unit_type='all') + correct_annotations = np.array([[2, 1, 1, 1, 2, 1], + [2, 1, 1, 1, 2, 1]]) - correct_spike_times = np.array([[spikes[mask] - for spikes, mask in zip(seg_spike_times, - seg_mask)] - for seg_spike_times, seg_mask - in zip(spike_times, - np.logical_not(correct_annotations))]) - - cleaned_spike_times = np.array([[st.times for st in seg.spiketrains] - for seg in block.segments]) - - assert_array_almost_equal(cleaned_spike_times, correct_spike_times) + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, invert=False, deletion_threshold=2) def test_spread_2(self): @@ -150,46 +95,16 @@ def test_spread_2(self): sampling_rate = 1 / pq.s - spike_times = np.array([[[1, 5, 9], [1, 4, 7]], - [[10, 12, 19], [11, 15, 17]]]) * pq.s - - block = generate_block(spike_times) + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] - # test annotation - spike_train_processing.detect_synchrofacts(block, segment='all', - n=2, spread=2, - sampling_rate=sampling_rate, - invert=False, delete=False, - unit_type='all') - - correct_annotations = [[np.array([True, True, False]), - np.array([True, True, False])], - [np.array([True, True, False]), - np.array([True, False, False])]] - - annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] - for seg in block.segments] - - assert_array_equal(annotations, correct_annotations) - - # test deletion - spike_train_processing.detect_synchrofacts(block, segment='all', n=2, spread=2, - sampling_rate=sampling_rate, - invert=False, delete=True, - unit_type='all') - - correct_spike_times = np.array([[spikes[mask] for spikes, mask in - zip(seg_spike_times, seg_mask)] - for seg_spike_times, seg_mask in - zip(spike_times, - np.logical_not(correct_annotations))]) + correct_annotations = np.array([[2, 2, 1, 3, 3, 1], + [2, 2, 1, 3, 1, 1]]) - cleaned_spike_times = np.array([[st.times for st in seg.spiketrains] - for seg in block.segments]) - - for correct_seg, cleaned_seg in zip(correct_spike_times, cleaned_spike_times): - for correct_st, cleaned_st in zip(correct_seg, cleaned_seg): - assert_array_almost_equal(cleaned_st, correct_st) + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=2, invert=False, deletion_threshold=2) def test_n_equals_3(self): @@ -198,95 +113,41 @@ def test_n_equals_3(self): sampling_rate = 1 / pq.s - spike_times = np.array([[[1, 1, 5, 10], [1, 4, 7, 9]], - [[12, 15, 16, 18], [11, 13, 15, 19]]]) * pq.s - - block = generate_block(spike_times) - - # test annotation - spike_train_processing.detect_synchrofacts(block, segment='all', n=3, spread=2, - sampling_rate=sampling_rate, - invert=False, delete=False, - unit_type='all') - - correct_annotations = [[np.array([True, True, False, False]), - np.array([True, False, False, False])], - [np.array([True, True, True, False]), - np.array([True, True, True, False])]] - - annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] - for seg in block.segments] - - assert_array_equal(annotations, correct_annotations) + spiketrains = [neo.SpikeTrain([1, 1, 5, 10, 13, 16, 17, 19] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 9, 12, 14, 16, 20] * pq.s, + t_stop=21*pq.s)] - # test deletion - spike_train_processing.detect_synchrofacts(block, segment='all', n=3, spread=2, - sampling_rate=sampling_rate, - invert=False, delete=True, - unit_type='all') - - correct_spike_times = np.array([[spikes[mask] for spikes, mask in - zip(seg_spike_times, seg_mask)] - for seg_spike_times, seg_mask in - zip(spike_times, - np.logical_not(correct_annotations))]) + correct_annotations = np.array([[3, 3, 2, 2, 3, 3, 3, 2], + [3, 2, 1, 2, 3, 3, 3, 2]]) - cleaned_spike_times = np.array([[st.times for st in seg.spiketrains] - for seg in block.segments]) - - for correct_seg, cleaned_seg in zip(correct_spike_times, cleaned_spike_times): - for correct_st, cleaned_st in zip(correct_seg, cleaned_seg): - assert_array_almost_equal(cleaned_st, correct_st) + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=2, invert=False, deletion_threshold=3) def test_binning_for_input_with_rounding_errors(self): - # redo the test_n_equals_3 with inputs divided by 30000 - # which leads to rounding errors + # a test with inputs divided by 30000 which leads to rounding errors # these errors have to be accounted for by proper binning; # check if we still get the correct result - sampling_rate = 30000. / pq.s - - spike_times = np.array([[[1, 1, 5, 10], [1, 4, 7, 9]], - [[12, 15, 16, 18], [11, 13, 15, 19]]]) / 30000. * pq.s - - block = generate_block(spike_times, - segment_edges=[0./30000., 10./30000., 20./30000.]*pq.s) - - # test annotation - spike_train_processing.detect_synchrofacts(block, segment='all', n=3, spread=2, - sampling_rate=sampling_rate, - invert=False, delete=False, - unit_type='all') - - correct_annotations = [[np.array([True, True, False, False]), - np.array([True, False, False, False])], - [np.array([True, True, True, False]), - np.array([True, True, True, False])]] - - annotations = [[st.array_annotations['synchrofacts'] for st in seg.spiketrains] - for seg in block.segments] + sampling_rate = 30000 / pq.s - assert_array_equal(annotations, correct_annotations) + spiketrains = [neo.SpikeTrain(np.arange(1000) * pq.s / 30000, + t_stop=.1 * pq.s), + neo.SpikeTrain(np.arange(2000, step=2) * pq.s / 30000, + t_stop=.1 * pq.s)] - # test deletion - spike_train_processing.detect_synchrofacts(block, segment='all', n=3, spread=2, - sampling_rate=sampling_rate, - invert=False, delete=True, - unit_type='all') + first_annotations = np.ones(1000) + first_annotations[::2] = 2 - correct_spike_times = np.array([[spikes[mask] for spikes, mask in - zip(seg_spike_times, seg_mask)] - for seg_spike_times, seg_mask in - zip(spike_times, - np.logical_not(correct_annotations))]) + second_annotations = np.ones(1000) + second_annotations[:500] = 2 - cleaned_spike_times = np.array([[st.times for st in seg.spiketrains] - for seg in block.segments]) + correct_annotations = np.array([first_annotations, + second_annotations]) - for correct_seg, cleaned_seg in zip(correct_spike_times, cleaned_spike_times): - for correct_st, cleaned_st in zip(correct_seg, cleaned_seg): - assert_array_almost_equal(cleaned_st, correct_st) + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, invert=False, deletion_threshold=2) def test_correct_transfer_of_spiketrain_attributes(self): @@ -295,43 +156,55 @@ def test_correct_transfer_of_spiketrain_attributes(self): sampling_rate = 1 / pq.s - spike_times = np.array([[[1, 1, 5, 9]]]) * pq.s + spiketrain = neo.SpikeTrain([1, 1, 5, 0] * pq.s, + t_stop=10 * pq.s) + + block = neo.Block() - block = generate_block(spike_times, segment_edges=[0, 10]*pq.s) + channel_index = neo.ChannelIndex(name='Channel 1', index=1) + block.channel_indexes.append(channel_index) - block.segments[0].spiketrains[0].annotate(cool_spike_train=True) - block.segments[0].spiketrains[0].array_annotate( - spike_number=np.arange(len( - block.segments[0].spiketrains[0].times.magnitude))) - block.segments[0].spiketrains[0].waveforms = np.sin( - np.arange(len( - block.segments[0].spiketrains[0].times.magnitude))[:, np.newaxis] + - np.arange(len( - block.segments[0].spiketrains[0].times.magnitude))[np.newaxis, :]) + unit = neo.Unit('Unit 1') + channel_index.units.append(unit) + unit.spiketrains.append(spiketrain) + spiketrain.unit = unit + + segment = neo.Segment() + block.segments.append(segment) + segment.spiketrains.append(spiketrain) + spiketrain.segment = segment + + spiketrain.annotate(cool_spike_train=True) + spiketrain.array_annotate( + spike_number=np.arange(len(spiketrain.times.magnitude))) + spiketrain.waveforms = np.sin( + np.arange(len(spiketrain.times.magnitude))[:, np.newaxis] + + np.arange(len(spiketrain.times.magnitude))[np.newaxis, :]) correct_mask = np.array([False, False, True, True]) # store the correct attributes - correct_annotations = block.segments[0].spiketrains[0].annotations.copy() - correct_waveforms = block.segments[0].spiketrains[0].waveforms[ - correct_mask].copy() - correct_array_annotations = { - key: value[correct_mask] for key, value in - block.segments[0].spiketrains[0].array_annotations.items()} + correct_annotations = spiketrain.annotations.copy() + correct_waveforms = spiketrain.waveforms[correct_mask].copy() + correct_array_annotations = {key: value[correct_mask] for key, value in + spiketrain.array_annotations.items()} # perform a synchrofact search with delete=True - spike_train_processing.detect_synchrofacts(block, segment='all', - n=2, spread=1, - sampling_rate=sampling_rate, - invert=False, delete=True, - unit_type='all') + spike_train_processing.detect_synchrofacts([spiketrain], + spread=1, + sampling_rate=sampling_rate, + invert=False, + deletion_threshold=2) # Ensure that the spiketrain was not duplicated self.assertEqual(len(block.filter(objects=neo.SpikeTrain)), 1) - cleaned_annotations = block.segments[0].spiketrains[0].annotations - cleaned_waveforms = block.segments[0].spiketrains[0].waveforms - cleaned_array_annotations = block.segments[0].spiketrains[0].array_annotations + cleaned_spiketrain = segment.spiketrains[0] + + cleaned_annotations = cleaned_spiketrain.annotations + cleaned_waveforms = cleaned_spiketrain.waveforms + cleaned_array_annotations = cleaned_spiketrain.array_annotations + cleaned_array_annotations.pop('complexity') self.assertDictEqual(correct_annotations, cleaned_annotations) assert_array_almost_equal(cleaned_waveforms, correct_waveforms) From da26aceaca1d64f2b0dda8d7e2952ff00f8b3500 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Fri, 15 May 2020 22:28:15 +0200 Subject: [PATCH 03/58] Add tests for raised errors, cleanup --- elephant/spike_train_processing.py | 5 ++-- elephant/test/test_spike_train_processing.py | 28 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 48791dec3..46da79b01 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -11,7 +11,6 @@ def get_index(lst, obj): for index, item in enumerate(lst): if item is obj: return index - return None def _check_spiketrains(spiketrains): @@ -52,8 +51,8 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=2, sampling_rate [quantity. Default: 30000/s]: Sampling rate of the spike trains. The spike trains are binned with - bin_size dt = 1/sampling_rate and *n* spikes within *spread* consecutive - bins are considered synchronous. + bin_size dt = 1/sampling_rate and *n* spikes within *spread* + consecutive bins are considered synchronous. Groups of *n* or more synchronous spikes are deleted/annotated. invert [bool. Default: True]: diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index e912a1e33..c9d9b565e 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -214,6 +214,34 @@ def test_correct_transfer_of_spiketrain_attributes(self): self.assertTrue(key in cleaned_array_annotations.keys()) assert_array_almost_equal(value, cleaned_array_annotations[key]) + def test_wrong_input_errors(self): + self.assertRaises(ValueError, + spike_train_processing.detect_synchrofacts, + [], 1 / pq.s) + self.assertRaises(TypeError, + spike_train_processing.detect_synchrofacts, + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), + np.arange(2)], + 1 / pq.s) + self.assertRaises(ValueError, + spike_train_processing.detect_synchrofacts, + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], + 1 / pq.s, + deletion_threshold=-1) + self.assertRaises(ValueError, + spike_train_processing.find_complexity_intervals, + [], 1 / pq.s) + self.assertRaises(TypeError, + spike_train_processing.find_complexity_intervals, + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), + np.arange(2)], + 1 / pq.s) + self.assertRaises(ValueError, + spike_train_processing.find_complexity_intervals, + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], + sampling_rate=1 / pq.s, + bin_size=.5 * pq.s) + if __name__ == '__main__': unittest.main() From 68b2ae5d51ba12067424904a20a82702d0224f4c Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Fri, 15 May 2020 22:38:48 +0200 Subject: [PATCH 04/58] Change spread to spread - 1 --- elephant/spike_train_processing.py | 30 +++++++++----------- elephant/test/test_spike_train_processing.py | 18 ++++++------ 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 46da79b01..93194fdb0 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -24,7 +24,7 @@ def _check_spiketrains(spiketrains): 'neo.SpikeTrain objects') -def detect_synchrofacts(spiketrains, sampling_rate, spread=2, +def detect_synchrofacts(spiketrains, sampling_rate, spread=1, invert=False, deletion_threshold=None): """ Given block with spike trains, find all spikes engaged @@ -44,10 +44,9 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=2, n [int. Default: 2]: minimum number of coincident spikes to report synchrony - spread [int. Default: 2]: - number of bins of size 1/sampling_rate in which to check for - synchronous spikes. *n* spikes within *spread* consecutive bins are - considered synchronous. + spread [int. Default: 1]: + the number of bins to look ahead of each spike for more spikes to add + to the same synchronous event sampling_rate [quantity. Default: 30000/s]: Sampling rate of the spike trains. The spike trains are binned with @@ -70,7 +69,7 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=2, Accepted unit types: 'sua', 'mua', 'idX' (where X is the id number requested) """ - # TODO: refactor docs, correct description of spread parameter + # TODO: refactor docs if deletion_threshold is not None and deletion_threshold <= 1: raise ValueError('A deletion_threshold <= 1 would result' @@ -121,15 +120,15 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=2, def find_complexity_intervals(spiketrains, sampling_rate, - bin_size=None, spread=1): + bin_size=None, spread=0): """ Calculate the complexity (i.e. number of synchronous spikes) for each bin. - For `spread = 1` this corresponds to a simple bincount. + For `spread = 0` this corresponds to a simple bincount. - For `spread > 1` spikes separated by fewer than `spread - 1` - empty bins are considered synchronous. + For `spread > 0` spikes within `spread` bins of one another are considered + synchronous. Parameters ---------- @@ -139,15 +138,14 @@ def find_complexity_intervals(spiketrains, sampling_rate, spread : int, optional Number of bins in which to check for synchronous spikes. Spikes within `spread` consecutive bins are considered synchronous. - Default: 2. + Default: 0. Returns ------- complexity_intervals : np.ndarray An array containing complexity values, left and right edges of all intervals with at least `min_complexity` spikes separated by fewer - than `spread - 1` empty bins. - Output shape (3, num_complexity_intervals) + than `spread` empty bins. Raises ------ @@ -187,7 +185,7 @@ def find_complexity_intervals(spiketrains, sampling_rate, binsize=bin_size) bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() - if spread == 1: + if spread == 0: bin_indices = np.nonzero(bincount)[0] complexities = bincount[bin_indices] left_edges = bst.bin_edges[bin_indices] @@ -204,13 +202,13 @@ def find_complexity_intervals(spiketrains, sampling_rate, else: last_window_sum = current_bincount last_nonzero_index = 0 - current_window = bincount[i:i+spread] + current_window = bincount[i:i + spread + 1] window_sum = current_window.sum() while window_sum > last_window_sum: last_nonzero_index = np.nonzero(current_window)[0][-1] current_window = bincount[i: i + last_nonzero_index - + spread] + + spread + 1] last_window_sum = window_sum window_sum = current_window.sum() complexities.append(window_sum) diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index c9d9b565e..df0d12de4 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -54,7 +54,7 @@ def _test_template(self, spiketrains, correct_complexities, sampling_rate, def test_no_synchrofacts(self): # nothing to find here - # there used to be an error for spread > 1 when nothing was found + # there used to be an error for spread > 0 when nothing was found sampling_rate = 1 / pq.s @@ -67,9 +67,9 @@ def test_no_synchrofacts(self): [1, 1, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=2, invert=False, deletion_threshold=2) + spread=1, invert=False, deletion_threshold=2) - def test_spread_1(self): + def test_spread_0(self): # basic test with a minimum number of two spikes per synchrofact # only taking into account multiple spikes @@ -86,9 +86,9 @@ def test_spread_1(self): [2, 1, 1, 1, 2, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert=False, deletion_threshold=2) + spread=0, invert=False, deletion_threshold=2) - def test_spread_2(self): + def test_spread_1(self): # test synchrofact search taking into account adjacent bins # this requires an additional loop with shifted binning @@ -104,7 +104,7 @@ def test_spread_2(self): [2, 2, 1, 3, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=2, invert=False, deletion_threshold=2) + spread=1, invert=False, deletion_threshold=2) def test_n_equals_3(self): @@ -122,7 +122,7 @@ def test_n_equals_3(self): [3, 2, 1, 2, 3, 3, 3, 2]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=2, invert=False, deletion_threshold=3) + spread=1, invert=False, deletion_threshold=3) def test_binning_for_input_with_rounding_errors(self): @@ -147,7 +147,7 @@ def test_binning_for_input_with_rounding_errors(self): second_annotations]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert=False, deletion_threshold=2) + spread=0, invert=False, deletion_threshold=2) def test_correct_transfer_of_spiketrain_attributes(self): @@ -191,7 +191,7 @@ def test_correct_transfer_of_spiketrain_attributes(self): # perform a synchrofact search with delete=True spike_train_processing.detect_synchrofacts([spiketrain], - spread=1, + spread=0, sampling_rate=sampling_rate, invert=False, deletion_threshold=2) From 3b822d83e9b21679693269be02cfad3f9bff59ee Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 13:26:19 +0200 Subject: [PATCH 05/58] remove unused imports --- elephant/spike_train_processing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 93194fdb0..af87fe19f 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -2,9 +2,7 @@ import neo import elephant.conversion as conv -import quantities as pq import numpy as np -import warnings def get_index(lst, obj): From c0891f457d3c5aafd6f66314ec2f810a76bb5c37 Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 13:27:44 +0200 Subject: [PATCH 06/58] add check for list instance in input and rewrite a bit error messages --- elephant/spike_train_processing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index af87fe19f..eba008cbb 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -12,13 +12,16 @@ def get_index(lst, obj): def _check_spiketrains(spiketrains): + if not isinstance(spiketrains, list): + raise TypeError('spiketrains should be a list of neo.SpikeTrain') + if len(spiketrains) == 0: raise ValueError('The spiketrains should not be empty!') # check that all list elements are spike trains for spiketrain in spiketrains: if not isinstance(spiketrain, neo.SpikeTrain): - raise TypeError('not all elements of spiketrains are' + raise TypeError('not all elements in the spiketrains list are' 'neo.SpikeTrain objects') From f0c8fa34fc354d3c79769ce3653c5c00c7c1740b Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 13:29:08 +0200 Subject: [PATCH 07/58] refactor docstring for synchrofact detection function. Reorder kwargs to a more meaninful way and rename invert with invert_delete --- elephant/spike_train_processing.py | 106 +++++++++++++++++++---------- 1 file changed, 69 insertions(+), 37 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index eba008cbb..5f6ec1cb0 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -26,49 +26,81 @@ def _check_spiketrains(spiketrains): def detect_synchrofacts(spiketrains, sampling_rate, spread=1, - invert=False, deletion_threshold=None): + deletion_threshold=None, invert_delete=False): """ - Given block with spike trains, find all spikes engaged - in synchronous events of size *n* or higher. Two events are considered - synchronous if they occur within spread/sampling_rate of one another. + Given a list of neo.Spiketrain objects, calculate the number of synchronous + spikes found and optionally delete or extract them from the given list + *in-place*. - *Args* - ------ - block [list]: - a block containing neo spike trains + The spike trains are binned at sampling precission + (i.e. bin_size = 1 / `sampling_rate`) + + Two spikes are considered synchronous if they occur separated by strictly + fewer than `spread - 1` empty bins from one another. See + `elephant.spike_train_processing.complexity_intervals` for a detailed + description of how synchronous events are counted. + + Synchronous events are considered within the same spike train and across + different spike trains in the `spiketrains` list. Such that, synchronous + events can be found both in multi-unit and single-unit spike trains. - segment [int or iterable or str. Default: 1]: - indices of segments in the block. Can be an integer, an iterable object - or a string containing 'all'. Indicates on which segments of the block - the synchrofact removal should be performed. + The spike trains in the `spiketrains` list are annotated with the + complexity value of each spike in their :attr:`array_annotations`. - n [int. Default: 2]: - minimum number of coincident spikes to report synchrony - spread [int. Default: 1]: - the number of bins to look ahead of each spike for more spikes to add - to the same synchronous event + Parameters + ---------- + spiketrains: list of neo.SpikeTrains + a list of neo.SpikeTrains objects. These spike trains should have been + recorded simultaneously. - sampling_rate [quantity. Default: 30000/s]: + sampling_rate: pq.Quantity Sampling rate of the spike trains. The spike trains are binned with - bin_size dt = 1/sampling_rate and *n* spikes within *spread* - consecutive bins are considered synchronous. - Groups of *n* or more synchronous spikes are deleted/annotated. - - invert [bool. Default: True]: - invert the mask for annotation/deletion (Default:False). - False annotates synchrofacts with False and other spikes with True or - deletes everything except for synchrofacts for delete = True. - - delete [bool. Default: False]: - delete spikes engaged in synchronous activity. If set to False the - spiketrains are array-annotated and the spike times are kept unchanged. - - unit_type [list of strings. Default 'all']: - selects only spiketrain of certain units / channels for synchrofact - extraction. unit_type = 'all' considers all provided spiketrains - Accepted unit types: 'sua', 'mua', 'idX' - (where X is the id number requested) + bin_size = 1 / `sampling_rate`. + + spread: int + Number of bins in which to check for synchronous spikes. + Spikes that occur separated by `spread - 1` or less empty bins are + considered synchronous. + Default: 1 + + deletion_threshold: int, optional + Threshold value for the deletion of spikes engaged in synchronous + activity. + `deletion_threshold = None` leads to no spikes being deleted, spike + trains are array-annotated and the spike times are kept unchanged. + `deletion_threshold >= 2` leads to all spikes with a larger or equal + complexity value to be deleted *in-place*. + `deletion_threshold` cannot be set to 1 (this would delete all spikes + and there are definitely more efficient ways of doing this) + `deletion_threshold <= 0` leads to a ValueError. + Default: None + + invert_delete: bool + Inversion of the mask for deletion of synchronous events. + `invert_delete = False` leads to the deletion of all spikes with + complexity >= `deletion_threshold`, i.e. deletes synchronous spikes. + `invert_delete = True` leads to the deletion of all spikes with + complexity < `deletion_threshold`, i.e. returns synchronous spikes. + Default: False + + Returns + ------- + complexity_epoch: neo.Epoch + An epoch object containing complexity values, left edges and durations + of all intervals with at least one spike. + Calculated with `elephant.spike_train_processing.complexity_intervals`. + Complexity values per spike can be accessed with: + >>> complexity_epoch.array_annotations['complexity'] + The left edges of the intervals with: + >>> complexity_epoch.times + And the durations with: + >>> complexity_epoch.durations + + See also + -------- + elephant.spike_train_processing.complexity_intervals + """ # TODO: refactor docs @@ -102,7 +134,7 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, if deletion_threshold is not None: mask = complexity_per_spike < deletion_threshold - if invert: + if invert_delete: mask = np.invert(mask) old_st = st new_st = old_st[mask] From 771632cc5edf03ea9d732733298dd03c175cffbb Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 13:30:17 +0200 Subject: [PATCH 08/58] refactor docstring for complexity function, rename function to shorter name --- elephant/spike_train_processing.py | 99 ++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 31 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 5f6ec1cb0..919c54f06 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -111,9 +111,9 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, _check_spiketrains(spiketrains) # find times of synchrony of size >=n - complexity_epoch = find_complexity_intervals(spiketrains, - sampling_rate, - spread=spread) + complexity_epoch = complexity_intervals(spiketrains, + sampling_rate, + spread=spread) complexity = complexity_epoch.array_annotations['complexity'] right_edges = complexity_epoch.times + complexity_epoch.durations @@ -152,33 +152,47 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, return complexity_epoch -def find_complexity_intervals(spiketrains, sampling_rate, - bin_size=None, spread=0): +def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): """ - Calculate the complexity (i.e. number of synchronous spikes) - for each bin. + Calculate the complexity (i.e. number of synchronous spikes found) + for each bin interval in a list of spiketrains. - For `spread = 0` this corresponds to a simple bincount. + Complexity is calculated by counting the number of spikes (i.e. non-empty + bins) that occur separated by `spread - 1` or less empty bins, within and + across spike trains in the `spiketrains` list. - For `spread > 0` spikes within `spread` bins of one another are considered - synchronous. Parameters ---------- - min_complexity : int, optional - Minimum complexity to report - Default: 2. + spiketrains: list of neo.SpikeTrains + a list of neo.SpikeTrains objects. These spike trains should have been + recorded simultaneously. + + sampling_rate: pq.Quantity + Sampling rate of the spike trains. + + bin_size: pq.Quantity + Bin size for calculating the complexity values. If `bin_size = None` + the spike trains are binned with `bin_size = 1 / sampling_rate`. + Default: None + spread : int, optional Number of bins in which to check for synchronous spikes. - Spikes within `spread` consecutive bins are considered synchronous. - Default: 0. + Spikes that occur separated by `spread - 1` or less empty bins are + considered synchronous. + `spread = 0` corresponds to a bincount accross spike trains. + `spread = 1` corresponds to counting consecutive spikes. + `spread = 2` corresponds to counting consecutive spikes and spikes + separated by exactly 1 empty bin. + `spread = n` corresponds to counting spikes separated by exactly or + less than `n - 1` empty bins. + Default: 0 Returns ------- - complexity_intervals : np.ndarray - An array containing complexity values, left and right edges of all - intervals with at least `min_complexity` spikes separated by fewer - than `spread` empty bins. + complexity_intervals : neo.Epoch + An epoch object containing complexity values, left edges and durations + of all intervals with at least one spike. Raises ------ @@ -187,20 +201,43 @@ def find_complexity_intervals(spiketrains, sampling_rate, Examples -------- - >>> import elephant.conversion as conv + Here the behavior of + `elephant.spike_train_processing.complexity_intervals` is shown, by + applying the function to some sample spiketrains. >>> import neo >>> import quantities as pq - >>> st1 = neo.SpikeTrain([1, 6] * pq.ms, - ... t_stop=10.0 * pq.ms) - >>> st2 = neo.SpikeTrain([1, 7] * pq.ms, - ... t_stop=10.0 * pq.ms) - >>> bst = conv.BinnedSpikeTrain([st1, st2], num_bins=10, - ... bin_size=1 * pq.ms, - ... t_start=0 * pq.ms) - >>> print(bst.complexity().magnitude.flatten()) - [0. 2. 0. 0. 0. 0. 1. 1. 0. 0.] - >>> print(bst.complexity(spread=2).magnitude.flatten()) - [0. 2. 0. 0. 0. 0. 2. 2. 0. 0.] + ... + >>> sampling_rate = 1/pq.ms + >>> st1 = neo.SpikeTrain([1, 4, 6] * pq.ms, t_stop=10.0 * pq.ms) + >>> st2 = neo.SpikeTrain([1, 5, 8] * pq.ms, t_stop=10.0 * pq.ms) + ... + >>> # spread = 0, a simple bincount + >>> ep1 = complexity_intervals([st1, st2], sampling_rate) + >>> print(ep1.array_annotations['complexity'].flatten()) + [2, 1, 1, 1, 1] + >>> print(ep1.times) + [0. 3.5 4.5 5.5 7.5] ms + >>> print(ep1.durations) + [1.5, 1. , 1. , 1. , 1. ] ms + ... + >>> # spread = 1, consecutive spikes + >>> ep2 = complexity_intervals([st1, st2], sampling_rate, spread=1) + >>> print(ep2.array_annotations['complexity'].flatten()) + [2, 3, 1] + >>> print(ep2.times) + [0. 3.5 7.5] ms + >>> print(ep2.durations) + [1.5 3. 1. ] ms + ... + >>> # spread = 2, consecutive spikes and separated by 1 empty bin + >>> ep3 = complexity_intervals([st1, st2], sampling_rate, spread=2) + >>> print(ep3.array_annotations['complexity'].flatten()) + [2, 4] + >>> print(ep3.times) + [0. 3.5] ms + >>> print(ep3.durations) + [1.5 5. ] ms + """ _check_spiketrains(spiketrains) From bd531b824a3914e77090a8aa9f88081e06bd8e98 Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 13:31:20 +0200 Subject: [PATCH 09/58] relocate line of code that came too early --- elephant/spike_train_processing.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 919c54f06..5521005fe 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -249,8 +249,6 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): 'default).') # TODO: documentation, example - min_t_start = min([st.t_start for st in spiketrains]) - bst = conv.BinnedSpikeTrain(spiketrains, binsize=bin_size) bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() @@ -290,8 +288,8 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): ].magnitude.item()) i += last_nonzero_index + 1 - # we dropped units above, neither concatenate nor append works with - # arrays of quantities + # we dropped units above, because neither concatenate nor append works + # with arrays of quantities left_edges *= bst.bin_edges.units right_edges *= bst.bin_edges.units @@ -300,7 +298,8 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): left_edges -= bin_shift right_edges -= bin_shift - # ensure that epoch does not start before the minimum t_start + # ensure that an epoch does not start before the minimum t_start + min_t_start = min([st.t_start for st in spiketrains]) left_edges[0] = min(min_t_start, left_edges[0]) complexity_epoch = neo.Epoch(times=left_edges, @@ -309,4 +308,3 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): complexities}) return complexity_epoch - From 035a29fcbf5f7a8500f2e9eb3db90acfed8e4572 Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 14:35:53 +0200 Subject: [PATCH 10/58] remove TODO flags --- elephant/spike_train_processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 5521005fe..9515d2b45 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -102,8 +102,6 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, elephant.spike_train_processing.complexity_intervals """ - # TODO: refactor docs - if deletion_threshold is not None and deletion_threshold <= 1: raise ValueError('A deletion_threshold <= 1 would result' 'in deletion of all spikes.') @@ -248,7 +246,6 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): '1 / sampling_rate (which is the' 'default).') - # TODO: documentation, example bst = conv.BinnedSpikeTrain(spiketrains, binsize=bin_size) bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() From c7503862916e2df4edc61a9039154eda331720b4 Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 15:45:51 +0200 Subject: [PATCH 11/58] add the module to the documentation --- doc/modules.rst | 1 + doc/reference/spike_train_processing.rst | 12 ++++++++++++ 2 files changed, 13 insertions(+) create mode 100644 doc/reference/spike_train_processing.rst diff --git a/doc/modules.rst b/doc/modules.rst index e89b27a18..9780e1225 100644 --- a/doc/modules.rst +++ b/doc/modules.rst @@ -23,6 +23,7 @@ Function Reference by Module reference/spike_train_correlation reference/spike_train_dissimilarity reference/spike_train_generation + reference/spike_train_processing reference/spike_train_surrogates reference/sta reference/statistics diff --git a/doc/reference/spike_train_processing.rst b/doc/reference/spike_train_processing.rst new file mode 100644 index 000000000..aca083c97 --- /dev/null +++ b/doc/reference/spike_train_processing.rst @@ -0,0 +1,12 @@ +====================== +Spike train processing +====================== + + +.. testsetup:: + + from elephant.spike_train_processing import synchrofacts, complexity_intervals + + +.. automodule:: elephant.spike_train_processing + :members: From bf5ca69ad46e9e3a12d22d4d4fb44cf8a254f2d1 Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 15:46:16 +0200 Subject: [PATCH 12/58] make helper function private --- elephant/spike_train_processing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 9515d2b45..5d109cb76 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -5,7 +5,7 @@ import numpy as np -def get_index(lst, obj): +def _get_index(lst, obj): for index, item in enumerate(lst): if item is obj: return index @@ -140,10 +140,10 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, unit = old_st.unit segment = old_st.segment if unit is not None: - unit.spiketrains[get_index(unit.spiketrains, + unit.spiketrains[_get_index(unit.spiketrains, old_st)] = new_st if segment is not None: - segment.spiketrains[get_index(segment.spiketrains, + segment.spiketrains[_get_index(segment.spiketrains, old_st)] = new_st del old_st From d910dc95d60392db599a13670f0c569dd00d2647 Mon Sep 17 00:00:00 2001 From: Aitor Date: Mon, 18 May 2020 15:47:35 +0200 Subject: [PATCH 13/58] refactor documentation for nice formatting after build --- elephant/spike_train_processing.py | 49 ++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 5d109cb76..363cc1f46 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -1,3 +1,11 @@ +# -*- coding: utf-8 -*- +""" +Module for spike train processing + +:copyright: Copyright 2014-2020 by the Elephant team, see `doc/authors.rst`. +:license: Modified BSD, see LICENSE.txt for details. +""" + from __future__ import division import neo @@ -50,51 +58,67 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, Parameters ---------- - spiketrains: list of neo.SpikeTrains + spiketrains : list of neo.SpikeTrains a list of neo.SpikeTrains objects. These spike trains should have been recorded simultaneously. - sampling_rate: pq.Quantity + sampling_rate : pq.Quantity Sampling rate of the spike trains. The spike trains are binned with bin_size = 1 / `sampling_rate`. - spread: int + spread : int Number of bins in which to check for synchronous spikes. Spikes that occur separated by `spread - 1` or less empty bins are considered synchronous. + Default: 1 - deletion_threshold: int, optional + deletion_threshold : int, optional Threshold value for the deletion of spikes engaged in synchronous activity. + `deletion_threshold = None` leads to no spikes being deleted, spike trains are array-annotated and the spike times are kept unchanged. + `deletion_threshold >= 2` leads to all spikes with a larger or equal complexity value to be deleted *in-place*. + `deletion_threshold` cannot be set to 1 (this would delete all spikes and there are definitely more efficient ways of doing this) + `deletion_threshold <= 0` leads to a ValueError. + Default: None - invert_delete: bool + invert_delete : bool Inversion of the mask for deletion of synchronous events. + `invert_delete = False` leads to the deletion of all spikes with complexity >= `deletion_threshold`, i.e. deletes synchronous spikes. + `invert_delete = True` leads to the deletion of all spikes with complexity < `deletion_threshold`, i.e. returns synchronous spikes. + Default: False Returns ------- - complexity_epoch: neo.Epoch + complexity_epoch : neo.Epoch An epoch object containing complexity values, left edges and durations of all intervals with at least one spike. + Calculated with `elephant.spike_train_processing.complexity_intervals`. + Complexity values per spike can be accessed with: + >>> complexity_epoch.array_annotations['complexity'] + The left edges of the intervals with: + >>> complexity_epoch.times + And the durations with: + >>> complexity_epoch.durations See also @@ -162,28 +186,34 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): Parameters ---------- - spiketrains: list of neo.SpikeTrains + spiketrains : list of neo.SpikeTrains a list of neo.SpikeTrains objects. These spike trains should have been recorded simultaneously. - sampling_rate: pq.Quantity + sampling_rate : pq.Quantity Sampling rate of the spike trains. - bin_size: pq.Quantity + bin_size : pq.Quantity Bin size for calculating the complexity values. If `bin_size = None` the spike trains are binned with `bin_size = 1 / sampling_rate`. + Default: None spread : int, optional Number of bins in which to check for synchronous spikes. Spikes that occur separated by `spread - 1` or less empty bins are considered synchronous. + `spread = 0` corresponds to a bincount accross spike trains. + `spread = 1` corresponds to counting consecutive spikes. + `spread = 2` corresponds to counting consecutive spikes and spikes separated by exactly 1 empty bin. + `spread = n` corresponds to counting spikes separated by exactly or less than `n - 1` empty bins. + Default: 0 Returns @@ -202,6 +232,7 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): Here the behavior of `elephant.spike_train_processing.complexity_intervals` is shown, by applying the function to some sample spiketrains. + >>> import neo >>> import quantities as pq ... From 7b769a7b886e4a132878d01904e9acf22facfdf5 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 20 May 2020 12:54:17 +0200 Subject: [PATCH 14/58] Refactor and cleanup --- elephant/spike_train_processing.py | 144 ++++++++----------- elephant/test/test_spike_train_processing.py | 66 ++++++--- 2 files changed, 106 insertions(+), 104 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 363cc1f46..cab523523 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -19,22 +19,28 @@ def _get_index(lst, obj): return index -def _check_spiketrains(spiketrains): - if not isinstance(spiketrains, list): - raise TypeError('spiketrains should be a list of neo.SpikeTrain') - +def _check_consistency_of_spiketrainlist(spiketrains, t_start=None, + t_stop=None): if len(spiketrains) == 0: raise ValueError('The spiketrains should not be empty!') - - # check that all list elements are spike trains for spiketrain in spiketrains: if not isinstance(spiketrain, neo.SpikeTrain): - raise TypeError('not all elements in the spiketrains list are' - 'neo.SpikeTrain objects') + raise TypeError( + "spike train must be instance of :class:`SpikeTrain` of Neo!\n" + " Found: %s, value %s" % ( + type(spiketrain), str(spiketrain))) + if t_start is None and not spiketrain.t_start == spiketrains[ + 0].t_start: + raise ValueError( + "the spike trains must have the same t_start!") + if t_stop is None and not spiketrain.t_stop == spiketrains[ + 0].t_stop: + raise ValueError( + "the spike trains must have the same t_stop!") def detect_synchrofacts(spiketrains, sampling_rate, spread=1, - deletion_threshold=None, invert_delete=False): + deletion_threshold=None, invert_delete=False): """ Given a list of neo.Spiketrain objects, calculate the number of synchronous spikes found and optionally delete or extract them from the given list @@ -45,8 +51,8 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, Two spikes are considered synchronous if they occur separated by strictly fewer than `spread - 1` empty bins from one another. See - `elephant.spike_train_processing.complexity_intervals` for a detailed - description of how synchronous events are counted. + `elephant.statistics.complexity_intervals` for a detailed description of + how synchronous events are counted. Synchronous events are considered within the same spike train and across different spike trains in the `spiketrains` list. Such that, synchronous @@ -61,44 +67,32 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, spiketrains : list of neo.SpikeTrains a list of neo.SpikeTrains objects. These spike trains should have been recorded simultaneously. - sampling_rate : pq.Quantity Sampling rate of the spike trains. The spike trains are binned with bin_size = 1 / `sampling_rate`. - spread : int Number of bins in which to check for synchronous spikes. Spikes that occur separated by `spread - 1` or less empty bins are considered synchronous. - Default: 1 - deletion_threshold : int, optional Threshold value for the deletion of spikes engaged in synchronous activity. - - `deletion_threshold = None` leads to no spikes being deleted, spike - trains are array-annotated and the spike times are kept unchanged. - - `deletion_threshold >= 2` leads to all spikes with a larger or equal - complexity value to be deleted *in-place*. - - `deletion_threshold` cannot be set to 1 (this would delete all spikes - and there are definitely more efficient ways of doing this) - - `deletion_threshold <= 0` leads to a ValueError. - + * `deletion_threshold = None` leads to no spikes being deleted, spike + trains are array-annotated and the spike times are kept unchanged. + * `deletion_threshold >= 2` leads to all spikes with a larger or + equal complexity value to be deleted *in-place*. + * `deletion_threshold` cannot be set to 1 (this would delete all + spikes and there are definitely more efficient ways of doing this) + * `deletion_threshold <= 0` leads to a ValueError. Default: None - invert_delete : bool Inversion of the mask for deletion of synchronous events. - - `invert_delete = False` leads to the deletion of all spikes with - complexity >= `deletion_threshold`, i.e. deletes synchronous spikes. - - `invert_delete = True` leads to the deletion of all spikes with - complexity < `deletion_threshold`, i.e. returns synchronous spikes. - + * `invert_delete = False` leads to the deletion of all spikes with + complexity >= `deletion_threshold`, + i.e. deletes synchronous spikes. + * `invert_delete = True` leads to the deletion of all spikes with + complexity < `deletion_threshold`, i.e. returns synchronous spikes. Default: False Returns @@ -107,7 +101,8 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, An epoch object containing complexity values, left edges and durations of all intervals with at least one spike. - Calculated with `elephant.spike_train_processing.complexity_intervals`. + Calculated with + `elephant.spike_train_processing.precise_complexity_intervals`. Complexity values per spike can be accessed with: @@ -123,19 +118,17 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, See also -------- - elephant.spike_train_processing.complexity_intervals + elephant.statistics.precise_complexity_intervals """ if deletion_threshold is not None and deletion_threshold <= 1: raise ValueError('A deletion_threshold <= 1 would result' 'in deletion of all spikes.') - _check_spiketrains(spiketrains) - # find times of synchrony of size >=n - complexity_epoch = complexity_intervals(spiketrains, - sampling_rate, - spread=spread) + complexity_epoch = precise_complexity_intervals(spiketrains, + sampling_rate, + spread=spread) complexity = complexity_epoch.array_annotations['complexity'] right_edges = complexity_epoch.times + complexity_epoch.durations @@ -165,19 +158,19 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, segment = old_st.segment if unit is not None: unit.spiketrains[_get_index(unit.spiketrains, - old_st)] = new_st + old_st)] = new_st if segment is not None: segment.spiketrains[_get_index(segment.spiketrains, - old_st)] = new_st + old_st)] = new_st del old_st return complexity_epoch -def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): +def precise_complexity_intervals(spiketrains, sampling_rate, spread=0): """ Calculate the complexity (i.e. number of synchronous spikes found) - for each bin interval in a list of spiketrains. + at `sampling_rate` precision in a list of spiketrains. Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur separated by `spread - 1` or less empty bins, within and @@ -189,31 +182,18 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): spiketrains : list of neo.SpikeTrains a list of neo.SpikeTrains objects. These spike trains should have been recorded simultaneously. - sampling_rate : pq.Quantity Sampling rate of the spike trains. - - bin_size : pq.Quantity - Bin size for calculating the complexity values. If `bin_size = None` - the spike trains are binned with `bin_size = 1 / sampling_rate`. - - Default: None - spread : int, optional Number of bins in which to check for synchronous spikes. Spikes that occur separated by `spread - 1` or less empty bins are considered synchronous. - - `spread = 0` corresponds to a bincount accross spike trains. - - `spread = 1` corresponds to counting consecutive spikes. - - `spread = 2` corresponds to counting consecutive spikes and spikes - separated by exactly 1 empty bin. - - `spread = n` corresponds to counting spikes separated by exactly or - less than `n - 1` empty bins. - + * `spread = 0` corresponds to a bincount accross spike trains. + * `spread = 1` corresponds to counting consecutive spikes. + * `spread = 2` corresponds to counting consecutive spikes and + spikes separated by exactly 1 empty bin. + * `spread = n` corresponds to counting spikes separated by exactly + or less than `n - 1` empty bins. Default: 0 Returns @@ -230,55 +210,51 @@ def complexity_intervals(spiketrains, sampling_rate, bin_size=None, spread=0): Examples -------- Here the behavior of - `elephant.spike_train_processing.complexity_intervals` is shown, by + `elephant.spike_train_processing.precise_complexity_intervals` is shown, by applying the function to some sample spiketrains. >>> import neo >>> import quantities as pq - ... + >>> sampling_rate = 1/pq.ms >>> st1 = neo.SpikeTrain([1, 4, 6] * pq.ms, t_stop=10.0 * pq.ms) >>> st2 = neo.SpikeTrain([1, 5, 8] * pq.ms, t_stop=10.0 * pq.ms) - ... + >>> # spread = 0, a simple bincount - >>> ep1 = complexity_intervals([st1, st2], sampling_rate) + >>> ep1 = precise_complexity_intervals([st1, st2], sampling_rate) >>> print(ep1.array_annotations['complexity'].flatten()) - [2, 1, 1, 1, 1] + [2 1 1 1 1] >>> print(ep1.times) [0. 3.5 4.5 5.5 7.5] ms >>> print(ep1.durations) - [1.5, 1. , 1. , 1. , 1. ] ms - ... + [1.5 1. 1. 1. 1. ] ms + >>> # spread = 1, consecutive spikes - >>> ep2 = complexity_intervals([st1, st2], sampling_rate, spread=1) + >>> ep2 = precise_complexity_intervals([st1, st2], sampling_rate, spread=1) >>> print(ep2.array_annotations['complexity'].flatten()) - [2, 3, 1] + [2 3 1] >>> print(ep2.times) [0. 3.5 7.5] ms >>> print(ep2.durations) [1.5 3. 1. ] ms - ... + >>> # spread = 2, consecutive spikes and separated by 1 empty bin - >>> ep3 = complexity_intervals([st1, st2], sampling_rate, spread=2) + >>> ep3 = precise_complexity_intervals([st1, st2], sampling_rate, spread=2) >>> print(ep3.array_annotations['complexity'].flatten()) - [2, 4] + [2 4] >>> print(ep3.times) [0. 3.5] ms >>> print(ep3.durations) [1.5 5. ] ms """ - _check_spiketrains(spiketrains) + _check_consistency_of_spiketrainlist(spiketrains) - if bin_size is None: - bin_size = 1 / sampling_rate - elif bin_size < 1 / sampling_rate: - raise ValueError('The bin size should be at least' - '1 / sampling_rate (which is the' - 'default).') + bin_size = 1 / sampling_rate bst = conv.BinnedSpikeTrain(spiketrains, binsize=bin_size) + bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() if spread == 0: diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index df0d12de4..3e6b69ef5 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -17,13 +17,13 @@ class SynchrofactDetectionTestCase(unittest.TestCase): def _test_template(self, spiketrains, correct_complexities, sampling_rate, - spread, deletion_threshold=2, invert=False): + spread, deletion_threshold=2, invert_delete=False): # test annotation spike_train_processing.detect_synchrofacts( spiketrains, spread=spread, sampling_rate=sampling_rate, - invert=invert, + invert_delete=invert_delete, deletion_threshold=None) annotations = [st.array_annotations['complexity'] @@ -31,17 +31,24 @@ def _test_template(self, spiketrains, correct_complexities, sampling_rate, assert_array_equal(annotations, correct_complexities) - correct_spike_times = np.array( - [spikes[mask] for spikes, mask - in zip(spiketrains, correct_complexities < deletion_threshold) - ]) + if invert_delete: + correct_spike_times = np.array( + [spikes[mask] for spikes, mask + in zip(spiketrains, + correct_complexities >= deletion_threshold) + ]) + else: + correct_spike_times = np.array( + [spikes[mask] for spikes, mask + in zip(spiketrains, correct_complexities < deletion_threshold) + ]) # test deletion spike_train_processing.detect_synchrofacts( spiketrains, spread=spread, sampling_rate=sampling_rate, - invert=invert, + invert_delete=invert_delete, deletion_threshold=deletion_threshold) cleaned_spike_times = np.array( @@ -67,7 +74,8 @@ def test_no_synchrofacts(self): [1, 1, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert=False, deletion_threshold=2) + spread=1, invert_delete=False, + deletion_threshold=2) def test_spread_0(self): @@ -86,7 +94,8 @@ def test_spread_0(self): [2, 1, 1, 1, 2, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, invert=False, deletion_threshold=2) + spread=0, invert_delete=False, + deletion_threshold=2) def test_spread_1(self): @@ -104,7 +113,8 @@ def test_spread_1(self): [2, 2, 1, 3, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert=False, deletion_threshold=2) + spread=1, invert_delete=False, + deletion_threshold=2) def test_n_equals_3(self): @@ -122,7 +132,27 @@ def test_n_equals_3(self): [3, 2, 1, 2, 3, 3, 3, 2]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert=False, deletion_threshold=3) + spread=1, invert_delete=False, + deletion_threshold=3) + + def test_invert_delete(self): + + # test synchrofact search taking into account adjacent bins + # this requires an additional loop with shifted binning + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_annotations = np.array([[2, 2, 1, 3, 3, 1], + [2, 2, 1, 3, 1, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, invert_delete=True, + deletion_threshold=2) def test_binning_for_input_with_rounding_errors(self): @@ -147,7 +177,8 @@ def test_binning_for_input_with_rounding_errors(self): second_annotations]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, invert=False, deletion_threshold=2) + spread=0, invert_delete=False, + deletion_threshold=2) def test_correct_transfer_of_spiketrain_attributes(self): @@ -193,7 +224,7 @@ def test_correct_transfer_of_spiketrain_attributes(self): spike_train_processing.detect_synchrofacts([spiketrain], spread=0, sampling_rate=sampling_rate, - invert=False, + invert_delete=False, deletion_threshold=2) # Ensure that the spiketrain was not duplicated @@ -229,18 +260,13 @@ def test_wrong_input_errors(self): 1 / pq.s, deletion_threshold=-1) self.assertRaises(ValueError, - spike_train_processing.find_complexity_intervals, + spike_train_processing.precise_complexity_intervals, [], 1 / pq.s) self.assertRaises(TypeError, - spike_train_processing.find_complexity_intervals, + spike_train_processing.precise_complexity_intervals, [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), np.arange(2)], 1 / pq.s) - self.assertRaises(ValueError, - spike_train_processing.find_complexity_intervals, - [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], - sampling_rate=1 / pq.s, - bin_size=.5 * pq.s) if __name__ == '__main__': From 2fd59eb46cab566acd30f4fb70d9b94aa19c7f0c Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 20 May 2020 12:54:48 +0200 Subject: [PATCH 15/58] Add wrapper to calculate the complexity histogram --- elephant/spike_train_processing.py | 27 +++++++++++++++ elephant/test/test_spike_train_processing.py | 36 ++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index cab523523..d9fdf1959 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -312,3 +312,30 @@ def precise_complexity_intervals(spiketrains, sampling_rate, spread=0): complexities}) return complexity_epoch + + +def precise_complexity_histogram(spiketrains, **kwargs): + """ + This is a wrapper for `precise_complexity_intervals` which calculates + the complexity histogram; the number of occurrences of events of + different complexities. + + Parameters + ---------- + spiketrains : list of neo.SpikeTrains + a list of neo.SpikeTrains objects. These spike trains should have been + recorded simultaneously. + **kwargs + Additional keyword arguments passed to + `precise_complexity_intervals`. + + Returns + ------- + complexity_histogram : np.ndarray + A histogram of complexities. `complexity_histogram[i]` corresponds + to the number of events of complexity `i` for `i > 0`. + """ + complexity_epoch = precise_complexity_intervals(spiketrains, **kwargs) + complexity_histogram = np.bincount( + complexity_epoch.array_annotations['complexity']) + return complexity_histogram diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index 3e6b69ef5..dbfdc372b 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -245,6 +245,42 @@ def test_correct_transfer_of_spiketrain_attributes(self): self.assertTrue(key in cleaned_array_annotations.keys()) assert_array_almost_equal(value, cleaned_array_annotations[key]) + def test_complexity_histogram_spread_0(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, + t_stop=20*pq.s)] + + correct_histogram = np.array([0, 8, 2]) + + histogram = spike_train_processing.precise_complexity_histogram( + spiketrains, + sampling_rate=sampling_rate, + spread=0) + + assert_array_equal(histogram, correct_histogram) + + def test_complexity_histogram_spread_1(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_histogram = np.array([0, 5, 2, 1]) + + histogram = spike_train_processing.precise_complexity_histogram( + spiketrains, + sampling_rate=sampling_rate, + spread=1) + + assert_array_equal(histogram, correct_histogram) + def test_wrong_input_errors(self): self.assertRaises(ValueError, spike_train_processing.detect_synchrofacts, From 87919a51e6967f1484b152147f8728d8f99d8af2 Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 20 May 2020 18:25:26 +0200 Subject: [PATCH 16/58] move check consistency of spiketrain list to utils --- elephant/spike_train_processing.py | 21 +-------------------- elephant/statistics.py | 28 +++++----------------------- elephant/utils.py | 27 +++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 43 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index d9fdf1959..12aa7f1cb 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -11,6 +11,7 @@ import neo import elephant.conversion as conv import numpy as np +from .utils import _check_consistency_of_spiketrainlist def _get_index(lst, obj): @@ -19,26 +20,6 @@ def _get_index(lst, obj): return index -def _check_consistency_of_spiketrainlist(spiketrains, t_start=None, - t_stop=None): - if len(spiketrains) == 0: - raise ValueError('The spiketrains should not be empty!') - for spiketrain in spiketrains: - if not isinstance(spiketrain, neo.SpikeTrain): - raise TypeError( - "spike train must be instance of :class:`SpikeTrain` of Neo!\n" - " Found: %s, value %s" % ( - type(spiketrain), str(spiketrain))) - if t_start is None and not spiketrain.t_start == spiketrains[ - 0].t_start: - raise ValueError( - "the spike trains must have the same t_start!") - if t_stop is None and not spiketrain.t_stop == spiketrains[ - 0].t_stop: - raise ValueError( - "the spike trains must have the same t_stop!") - - def detect_synchrofacts(spiketrains, sampling_rate, spread=1, deletion_threshold=None, invert_delete=False): """ diff --git a/elephant/statistics.py b/elephant/statistics.py index e569cd0e6..44956c975 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -70,6 +70,7 @@ import elephant.conversion as conv import elephant.kernels as kernels import warnings +from .utils import _check_consistency_of_spiketrainlist cv = scipy.stats.variation @@ -515,8 +516,10 @@ def instantaneous_rate(spiketrain, sampling_period, kernel='auto', """ # Merge spike trains if list of spike trains given: if isinstance(spiketrain, list): - _check_consistency_of_spiketrainlist( - spiketrain, t_start=t_start, t_stop=t_stop) + _check_consistency_of_spiketrainlist(spiketrain, + same_t_start=t_start, + same_t_stop=t_stop, + same_units=True) if t_start is None: t_start = spiketrain[0].t_start if t_stop is None: @@ -1071,24 +1074,3 @@ def sskernel(spiketimes, tin=None, w=None, bootstrap=False): 'C': C, 'confb95': confb95, 'yb': yb} - - -def _check_consistency_of_spiketrainlist(spiketrainlist, t_start=None, - t_stop=None): - for spiketrain in spiketrainlist: - if not isinstance(spiketrain, SpikeTrain): - raise TypeError( - "spike train must be instance of :class:`SpikeTrain` of Neo!\n" - " Found: %s, value %s" % ( - type(spiketrain), str(spiketrain))) - if t_start is None and not spiketrain.t_start == spiketrainlist[ - 0].t_start: - raise ValueError( - "the spike trains must have the same t_start!") - if t_stop is None and not spiketrain.t_stop == spiketrainlist[ - 0].t_stop: - raise ValueError( - "the spike trains must have the same t_stop!") - if not spiketrain.units == spiketrainlist[0].units: - raise ValueError( - "the spike trains must have the same units!") diff --git a/elephant/utils.py b/elephant/utils.py index 156f9c8f7..9f60d27e1 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -2,6 +2,8 @@ import numpy as np +from neo import SpikeTrain + def is_binary(array): """ @@ -17,3 +19,28 @@ def is_binary(array): """ array = np.asarray(array) return ((array == 0) | (array == 1)).all() + + +def _check_consistency_of_spiketrainlist(spiketrains, + same_t_start=None, + same_t_stop=None, + same_units=False): + """ + Private function to check lists of spiketrains. + """ + if len(spiketrains) == 0: + raise ValueError('The spiketrains list is empty!') + for st in spiketrains: + if not isinstance(st, SpikeTrain): + raise TypeError( + 'elements in spiketrains list must be instances of ' + ':class:`SpikeTrain` of Neo!' + 'Found: %s, value %s' % (type(st), str(st))) + if same_t_start and not st.t_start == spiketrains[0].t_start: + raise ValueError( + "the spike trains must have the same t_start!") + if same_t_stop and not st.t_stop == spiketrains[0].t_stop: + raise ValueError( + "the spike trains must have the same t_stop!") + if same_units and not st.units == st[0].units: + raise ValueError('The spike trains must have the same units!') From 386682baffb94a2465d3f84b1aed6788a3d56cf2 Mon Sep 17 00:00:00 2001 From: Aitor Date: Thu, 21 May 2020 14:02:13 +0200 Subject: [PATCH 17/58] refactor docs --- elephant/spike_train_processing.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 12aa7f1cb..2ab27ab50 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -32,8 +32,8 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, Two spikes are considered synchronous if they occur separated by strictly fewer than `spread - 1` empty bins from one another. See - `elephant.statistics.complexity_intervals` for a detailed description of - how synchronous events are counted. + `elephant.statistics.precise_complexity_intervals` for a detailed + description of how synchronous events are counted. Synchronous events are considered within the same spike train and across different spike trains in the `spiketrains` list. Such that, synchronous @@ -81,25 +81,16 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, complexity_epoch : neo.Epoch An epoch object containing complexity values, left edges and durations of all intervals with at least one spike. - Calculated with `elephant.spike_train_processing.precise_complexity_intervals`. - - Complexity values per spike can be accessed with: - - >>> complexity_epoch.array_annotations['complexity'] - - The left edges of the intervals with: - - >>> complexity_epoch.times - - And the durations with: - - >>> complexity_epoch.durations + * ``complexity_epoch.array_annotations['complexity']`` contains the + complexity values per spike. + * ``complexity_epoch.times`` contains the left edges. + * ``complexity_epoch.durations`` contains the durations. See also -------- - elephant.statistics.precise_complexity_intervals + elephant.spike_train_processing.precise_complexity_intervals """ if deletion_threshold is not None and deletion_threshold <= 1: @@ -157,7 +148,6 @@ def precise_complexity_intervals(spiketrains, sampling_rate, spread=0): bins) that occur separated by `spread - 1` or less empty bins, within and across spike trains in the `spiketrains` list. - Parameters ---------- spiketrains : list of neo.SpikeTrains From 791b30aeee3a4c820ef9ae494fe01ae5a318f334 Mon Sep 17 00:00:00 2001 From: Aitor Date: Thu, 21 May 2020 14:02:41 +0200 Subject: [PATCH 18/58] update checking of spiketrainlist in detect_synchrofacts --- elephant/spike_train_processing.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 2ab27ab50..bc5343dc4 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -94,9 +94,14 @@ def detect_synchrofacts(spiketrains, sampling_rate, spread=1, """ if deletion_threshold is not None and deletion_threshold <= 1: - raise ValueError('A deletion_threshold <= 1 would result' + raise ValueError('A deletion_threshold <= 1 would result ' 'in deletion of all spikes.') + if isinstance(spiketrains, list): + _check_consistency_of_spiketrainlist(spiketrains) + else: + raise TypeError('spiketrains should be a list of neo.SpikeTrain') + # find times of synchrony of size >=n complexity_epoch = precise_complexity_intervals(spiketrains, sampling_rate, From 8b69e2d7d0add26473bd4e74e4d05c2f5dd16535 Mon Sep 17 00:00:00 2001 From: Aitor Date: Thu, 21 May 2020 14:11:18 +0200 Subject: [PATCH 19/58] add test for utils spiketrainlist checking --- elephant/test/test_utils.py | 48 +++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 elephant/test/test_utils.py diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py new file mode 100644 index 000000000..c6063963e --- /dev/null +++ b/elephant/test/test_utils.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +Unit tests for the synchrofact detection app +""" + +import unittest + +import neo +import numpy as np +import quantities as pq + +from elephant import utils + + +class checkSpiketrainTestCase(unittest.TestCase): + + def test_wrong_input_errors(self): + self.assertRaises(ValueError, + utils._check_consistency_of_spiketrains, + [], 1 / pq.s) + self.assertRaises(TypeError, + utils._check_consistency_of_spiketrains, + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), + np.arange(2)], + 1 / pq.s) + self.assertRaises(ValueError, + utils._check_consistency_of_spiketrains, + [neo.SpikeTrain([1]*pq.s, + t_start=1*pq.s, + t_stop=2*pq.s), + neo.SpikeTrain([1]*pq.s, + t_start=0*pq.s, + t_stop=2*pq.s)], + same_t_start=True) + self.assertRaises(ValueError, + utils._check_consistency_of_spiketrains, + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), + neo.SpikeTrain([1]*pq.s, t_stop=3*pq.s)], + same_t_stop=True) + self.assertRaises(ValueError, + utils._check_consistency_of_spiketrains, + [neo.SpikeTrain([1]*pq.ms, t_stop=2*pq.s), + neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], + same_units=True) + + +if __name__ == '__main__': + unittest.main() From 034d1d220b312fd986e77c8f91c150d453217fa7 Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 10 Jun 2020 13:13:13 +0200 Subject: [PATCH 20/58] create complexity class, that can replace complexity_pdf in statistics --- elephant/spike_train_processing.py | 325 +++++++++++++---------------- elephant/statistics.py | 58 ----- 2 files changed, 149 insertions(+), 234 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index bc5343dc4..467bb004e 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -11,7 +11,10 @@ import neo import elephant.conversion as conv import numpy as np +import quantities as pq from .utils import _check_consistency_of_spiketrainlist +from elephant.statistics import time_histogram +import warnings def _get_index(lst, obj): @@ -22,130 +25,11 @@ def _get_index(lst, obj): def detect_synchrofacts(spiketrains, sampling_rate, spread=1, deletion_threshold=None, invert_delete=False): +class complexity: """ - Given a list of neo.Spiketrain objects, calculate the number of synchronous - spikes found and optionally delete or extract them from the given list - *in-place*. + docstring TODO - The spike trains are binned at sampling precission - (i.e. bin_size = 1 / `sampling_rate`) - - Two spikes are considered synchronous if they occur separated by strictly - fewer than `spread - 1` empty bins from one another. See - `elephant.statistics.precise_complexity_intervals` for a detailed - description of how synchronous events are counted. - - Synchronous events are considered within the same spike train and across - different spike trains in the `spiketrains` list. Such that, synchronous - events can be found both in multi-unit and single-unit spike trains. - - The spike trains in the `spiketrains` list are annotated with the - complexity value of each spike in their :attr:`array_annotations`. - - - Parameters - ---------- - spiketrains : list of neo.SpikeTrains - a list of neo.SpikeTrains objects. These spike trains should have been - recorded simultaneously. - sampling_rate : pq.Quantity - Sampling rate of the spike trains. The spike trains are binned with - bin_size = 1 / `sampling_rate`. - spread : int - Number of bins in which to check for synchronous spikes. - Spikes that occur separated by `spread - 1` or less empty bins are - considered synchronous. - Default: 1 - deletion_threshold : int, optional - Threshold value for the deletion of spikes engaged in synchronous - activity. - * `deletion_threshold = None` leads to no spikes being deleted, spike - trains are array-annotated and the spike times are kept unchanged. - * `deletion_threshold >= 2` leads to all spikes with a larger or - equal complexity value to be deleted *in-place*. - * `deletion_threshold` cannot be set to 1 (this would delete all - spikes and there are definitely more efficient ways of doing this) - * `deletion_threshold <= 0` leads to a ValueError. - Default: None - invert_delete : bool - Inversion of the mask for deletion of synchronous events. - * `invert_delete = False` leads to the deletion of all spikes with - complexity >= `deletion_threshold`, - i.e. deletes synchronous spikes. - * `invert_delete = True` leads to the deletion of all spikes with - complexity < `deletion_threshold`, i.e. returns synchronous spikes. - Default: False - - Returns - ------- - complexity_epoch : neo.Epoch - An epoch object containing complexity values, left edges and durations - of all intervals with at least one spike. - Calculated with - `elephant.spike_train_processing.precise_complexity_intervals`. - * ``complexity_epoch.array_annotations['complexity']`` contains the - complexity values per spike. - * ``complexity_epoch.times`` contains the left edges. - * ``complexity_epoch.durations`` contains the durations. - - See also - -------- - elephant.spike_train_processing.precise_complexity_intervals - - """ - if deletion_threshold is not None and deletion_threshold <= 1: - raise ValueError('A deletion_threshold <= 1 would result ' - 'in deletion of all spikes.') - - if isinstance(spiketrains, list): - _check_consistency_of_spiketrainlist(spiketrains) - else: - raise TypeError('spiketrains should be a list of neo.SpikeTrain') - - # find times of synchrony of size >=n - complexity_epoch = precise_complexity_intervals(spiketrains, - sampling_rate, - spread=spread) - complexity = complexity_epoch.array_annotations['complexity'] - right_edges = complexity_epoch.times + complexity_epoch.durations - - # j = index of pre-selected sts in spiketrains - # idx = index of pre-selected sts in original - # block.segments[seg].spiketrains - for idx, st in enumerate(spiketrains): - - # all indices of spikes that are within the half-open intervals - # defined by the boundaries - # note that every second entry in boundaries is an upper boundary - spike_to_epoch_idx = np.searchsorted(right_edges, - st.times.rescale( - right_edges.units)) - complexity_per_spike = complexity[spike_to_epoch_idx] - - st.array_annotate(complexity=complexity_per_spike) - - if deletion_threshold is not None: - mask = complexity_per_spike < deletion_threshold - if invert_delete: - mask = np.invert(mask) - old_st = st - new_st = old_st[mask] - spiketrains[idx] = new_st - unit = old_st.unit - segment = old_st.segment - if unit is not None: - unit.spiketrains[_get_index(unit.spiketrains, - old_st)] = new_st - if segment is not None: - segment.spiketrains[_get_index(segment.spiketrains, - old_st)] = new_st - del old_st - - return complexity_epoch - - -def precise_complexity_intervals(spiketrains, sampling_rate, spread=0): - """ + COPIED FROM PREVIOUS GET EPOCHS AS IS: Calculate the complexity (i.e. number of synchronous spikes found) at `sampling_rate` precision in a list of spiketrains. @@ -222,23 +106,132 @@ def precise_complexity_intervals(spiketrains, sampling_rate, spread=0): [0. 3.5] ms >>> print(ep3.durations) [1.5 5. ] ms - """ - _check_consistency_of_spiketrainlist(spiketrains) - bin_size = 1 / sampling_rate + def __init__(self, spiketrains, + sampling_rate=None, + bin_size=None, + binary=True, + spread=0): + + if isinstance(spiketrains, list): + _check_consistency_of_spiketrainlist(spiketrains) + else: + raise TypeError('spiketrains should be a list of neo.SpikeTrain') + self.input_spiketrains = spiketrains + self.sampling_rate = sampling_rate + self.bin_size = bin_size + self.binary = binary + self.spread = spread + + if bin_size is None and sampling_rate is None: + raise ValueError('No bin_size or sampling_rate was speficied!') + elif bin_size is None and sampling_rate is not None: + self.bin_size = 1 / self.sampling_rate + + if spread < 0: + raise ValueError('Spread must be >=0') + elif spread == 0: + self.time_histogram, self.histogram = self._histogram_no_spread() + else: + print('Complexity calculated at sampling rate precision') + # self.epoch = self.precise_complexity_intervals() + self.epoch = self.get_epoch() + self.histogram = self._histogram_with_spread() + + return self + + @property + def pdf(self): + """ + Normalization of the Complexity Histogram (probabilty distribution) + """ + norm_hist = self.histogram / self.histogram.sum() + # Convert the Complexity pdf to an neo.AnalogSignal + pdf = neo.AnalogSignal( + np.array(norm_hist).reshape(len(norm_hist), 1) * + pq.dimensionless, t_start=0 * pq.dimensionless, + sampling_period=1 * pq.dimensionless) + return pdf + + # @property + # def epoch(self): + # if self.spread == 0: + # warnings.warn('No epoch for cases with spread = 0') + # return None + # else: + # return self._epoch + + def _histogram_no_spread(self): + """ + Complexity Distribution of a list of `neo.SpikeTrain` objects. + + Probability density computed from the complexity histogram which is the + histogram of the entries of the population histogram of clipped + (binary) spike trains computed with a bin width of `binsize`. + It provides for each complexity (== number of active neurons per bin) + the number of occurrences. The normalization of that histogram to 1 is + the probability density. + + Implementation is based on [1]_. + + Returns + ------- + complexity_distribution : neo.AnalogSignal + A `neo.AnalogSignal` object containing the histogram values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + + See also + -------- + elephant.conversion.BinnedSpikeTrain + + References + ---------- + .. [1] S. Gruen, M. Abeles, & M. Diesmann, "Impact of higher-order + correlations on coincidence distributions of massively parallel + data," In "Dynamic Brain - from Neural Spikes to Behaviors", + pp. 96-114, Springer Berlin Heidelberg, 2008. + + """ + # Computing the population histogram with parameter binary=True to + # clip the spike trains before summing + pophist = time_histogram(self.input_spiketrains, + self.bin_size, + binary=self.binary) + + # Computing the histogram of the entries of pophist + complexity_hist = np.histogram( + pophist.magnitude, + bins=range(0, len(self.input_spiketrains) + 2))[0] + + return pophist, complexity_hist + + def _histogram_with_spread(self): + """ + Calculate the complexity histogram; + the number of occurrences of events of different complexities. + + Returns + ------- + complexity_histogram : np.ndarray + A histogram of complexities. `complexity_histogram[i]` corresponds + to the number of events of complexity `i` for `i > 0`. + """ + complexity_histogram = np.bincount( + self.epoch.array_annotations['complexity']) + return complexity_histogram + + def get_epoch(self): + bst = conv.BinnedSpikeTrain(self.input_spiketrains, + binsize=self.bin_size) + + if self.binary: + binarized = bst.to_sparse_bool_array() + bincount = np.array(binarized.sum(axis=0)).squeeze() + else: + bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() - bst = conv.BinnedSpikeTrain(spiketrains, - binsize=bin_size) - - bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() - - if spread == 0: - bin_indices = np.nonzero(bincount)[0] - complexities = bincount[bin_indices] - left_edges = bst.bin_edges[bin_indices] - right_edges = bst.bin_edges[bin_indices + 1] - else: i = 0 complexities = [] left_edges = [] @@ -250,13 +243,13 @@ def precise_complexity_intervals(spiketrains, sampling_rate, spread=0): else: last_window_sum = current_bincount last_nonzero_index = 0 - current_window = bincount[i:i + spread + 1] + current_window = bincount[i:i + self.spread + 1] window_sum = current_window.sum() while window_sum > last_window_sum: last_nonzero_index = np.nonzero(current_window)[0][-1] current_window = bincount[i: i + last_nonzero_index - + spread + 1] + + self.spread + 1] last_window_sum = window_sum window_sum = current_window.sum() complexities.append(window_sum) @@ -268,50 +261,30 @@ def precise_complexity_intervals(spiketrains, sampling_rate, spread=0): ].magnitude.item()) i += last_nonzero_index + 1 - # we dropped units above, because neither concatenate nor append works + # we dropped units above, neither concatenate nor append works # with arrays of quantities left_edges *= bst.bin_edges.units right_edges *= bst.bin_edges.units - # ensure that spikes are not on the bin edges - bin_shift = .5 / sampling_rate - left_edges -= bin_shift - right_edges -= bin_shift - - # ensure that an epoch does not start before the minimum t_start - min_t_start = min([st.t_start for st in spiketrains]) - left_edges[0] = min(min_t_start, left_edges[0]) - - complexity_epoch = neo.Epoch(times=left_edges, - durations=right_edges - left_edges, - array_annotations={'complexity': - complexities}) + if self.sampling_rate: + # ensure that spikes are not on the bin edges + bin_shift = .5 / self.sampling_rate + left_edges -= bin_shift + right_edges -= bin_shift + else: + warnings.warn('No sampling rate specified. ' + 'Note that using the complexity epoch to get ' + 'precise spike times can lead to rounding errors.') - return complexity_epoch + # ensure that an epoch does not start before the minimum t_start + min_t_start = min([st.t_start for st in self.input_spiketrains]) + left_edges[0] = min(min_t_start, left_edges[0]) + complexity_epoch = neo.Epoch(times=left_edges, + durations=right_edges - left_edges, + array_annotations={'complexity': + complexities}) -def precise_complexity_histogram(spiketrains, **kwargs): - """ - This is a wrapper for `precise_complexity_intervals` which calculates - the complexity histogram; the number of occurrences of events of - different complexities. + return complexity_epoch - Parameters - ---------- - spiketrains : list of neo.SpikeTrains - a list of neo.SpikeTrains objects. These spike trains should have been - recorded simultaneously. - **kwargs - Additional keyword arguments passed to - `precise_complexity_intervals`. - Returns - ------- - complexity_histogram : np.ndarray - A histogram of complexities. `complexity_histogram[i]` corresponds - to the number of events of complexity `i` for `i > 0`. - """ - complexity_epoch = precise_complexity_intervals(spiketrains, **kwargs) - complexity_histogram = np.bincount( - complexity_epoch.array_annotations['complexity']) - return complexity_histogram diff --git a/elephant/statistics.py b/elephant/statistics.py index 44956c975..70dc03197 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -763,64 +763,6 @@ def time_histogram(spiketrains, binsize, t_start=None, t_stop=None, t_start=t_start) -def complexity_pdf(spiketrains, binsize): - """ - Complexity Distribution of a list of `neo.SpikeTrain` objects. - - Probability density computed from the complexity histogram which is the - histogram of the entries of the population histogram of clipped (binary) - spike trains computed with a bin width of `binsize`. - It provides for each complexity (== number of active neurons per bin) the - number of occurrences. The normalization of that histogram to 1 is the - probability density. - - Implementation is based on [1]_. - - Parameters - ---------- - spiketrains : list of neo.SpikeTrain - Spike trains with a common time axis (same `t_start` and `t_stop`) - binsize : pq.Quantity - Width of the histogram's time bins. - - Returns - ------- - complexity_distribution : neo.AnalogSignal - A `neo.AnalogSignal` object containing the histogram values. - `neo.AnalogSignal[j]` is the histogram computed between - `t_start + j * binsize` and `t_start + (j + 1) * binsize`. - - See also - -------- - elephant.conversion.BinnedSpikeTrain - - References - ---------- - .. [1] S. Gruen, M. Abeles, & M. Diesmann, "Impact of higher-order - correlations on coincidence distributions of massively parallel - data," In "Dynamic Brain - from Neural Spikes to Behaviors", - pp. 96-114, Springer Berlin Heidelberg, 2008. - - """ - # Computing the population histogram with parameter binary=True to clip the - # spike trains before summing - pophist = time_histogram(spiketrains, binsize, binary=True) - - # Computing the histogram of the entries of pophist (=Complexity histogram) - complexity_hist = np.histogram( - pophist.magnitude, bins=range(0, len(spiketrains) + 2))[0] - - # Normalization of the Complexity Histogram to 1 (probabilty distribution) - complexity_hist = complexity_hist / complexity_hist.sum() - # Convert the Complexity pdf to an neo.AnalogSignal - complexity_distribution = neo.AnalogSignal( - np.array(complexity_hist).reshape(len(complexity_hist), 1) * - pq.dimensionless, t_start=0 * pq.dimensionless, - sampling_period=1 * pq.dimensionless) - - return complexity_distribution - - """ Kernel Bandwidth Optimization. From 399e7c49dc8941a9f662599c5c2ab4f5a155e8d0 Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 10 Jun 2020 13:13:47 +0200 Subject: [PATCH 21/58] Work in progress, Synchrotool class. Will have to be changed eventually. --- elephant/spike_train_processing.py | 164 +++++++++++++++++++++++++++-- 1 file changed, 156 insertions(+), 8 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 467bb004e..368f4d228 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -17,14 +17,6 @@ import warnings -def _get_index(lst, obj): - for index, item in enumerate(lst): - if item is obj: - return index - - -def detect_synchrofacts(spiketrains, sampling_rate, spread=1, - deletion_threshold=None, invert_delete=False): class complexity: """ docstring TODO @@ -288,3 +280,159 @@ def get_epoch(self): return complexity_epoch +class synchrotool(complexity): + complexity.__doc_ + + def __init__(self, spiketrains, + sampling_rate=None, + spread=1): + self.spiketrains = spiketrains + self.sampling_rate = sampling_rate + self.spread = spread + + # self.super(...).__init__() + + # find times of synchrony of size >=n + # complexity_epoch = + + # ... + return self + + def annotate_synchrofacts(self): + return None + + def delete_synchrofacts(self): + return None + + def extract_synchrofacts(self): + return None + + # def delete_synchrofacts(self, in_place=False): + # + # if not in_place: + # # return clean_spiketrains + # + # @property + # def synchrofacts(self): + # self.synchrofacts = self.detect_synchrofacts(deletion_threshold=1, + # invert_delete=True) + + def detect_synchrofacts(self, + deletion_threshold=None, + invert_delete=False): + """ + Given a list of neo.Spiketrain objects, calculate the number of synchronous + spikes found and optionally delete or extract them from the given list + *in-place*. + + The spike trains are binned at sampling precission + (i.e. bin_size = 1 / `sampling_rate`) + + Two spikes are considered synchronous if they occur separated by strictly + fewer than `spread - 1` empty bins from one another. See + `elephant.statistics.precise_complexity_intervals` for a detailed + description of how synchronous events are counted. + + Synchronous events are considered within the same spike train and across + different spike trains in the `spiketrains` list. Such that, synchronous + events can be found both in multi-unit and single-unit spike trains. + + The spike trains in the `spiketrains` list are annotated with the + complexity value of each spike in their :attr:`array_annotations`. + + + Parameters + ---------- + spiketrains : list of neo.SpikeTrains + a list of neo.SpikeTrains objects. These spike trains should have been + recorded simultaneously. + sampling_rate : pq.Quantity + Sampling rate of the spike trains. The spike trains are binned with + bin_size = 1 / `sampling_rate`. + spread : int + Number of bins in which to check for synchronous spikes. + Spikes that occur separated by `spread - 1` or less empty bins are + considered synchronous. + Default: 1 + deletion_threshold : int, optional + Threshold value for the deletion of spikes engaged in synchronous + activity. + * `deletion_threshold = None` leads to no spikes being deleted, spike + trains are array-annotated and the spike times are kept unchanged. + * `deletion_threshold >= 2` leads to all spikes with a larger or + equal complexity value to be deleted *in-place*. + * `deletion_threshold` cannot be set to 1 (this would delete all + spikes and there are definitely more efficient ways of doing this) + * `deletion_threshold <= 0` leads to a ValueError. + Default: None + invert_delete : bool + Inversion of the mask for deletion of synchronous events. + * `invert_delete = False` leads to the deletion of all spikes with + complexity >= `deletion_threshold`, + i.e. deletes synchronous spikes. + * `invert_delete = True` leads to the deletion of all spikes with + complexity < `deletion_threshold`, i.e. returns synchronous spikes. + Default: False + + Returns + ------- + complexity_epoch : neo.Epoch + An epoch object containing complexity values, left edges and durations + of all intervals with at least one spike. + Calculated with + `elephant.spike_train_processing.precise_complexity_intervals`. + * ``complexity_epoch.array_annotations['complexity']`` contains the + complexity values per spike. + * ``complexity_epoch.times`` contains the left edges. + * ``complexity_epoch.durations`` contains the durations. + + See also + -------- + elephant.spike_train_processing.precise_complexity_intervals + + """ + if deletion_threshold is not None and deletion_threshold <= 1: + raise ValueError('A deletion_threshold <= 1 would result ' + 'in deletion of all spikes.') + + complexity = complexity_epoch.array_annotations['complexity'] + right_edges = complexity_epoch.times + complexity_epoch.durations + + # j = index of pre-selected sts in spiketrains + # idx = index of pre-selected sts in original + # block.segments[seg].spiketrains + for idx, st in enumerate(spiketrains): + + # all indices of spikes that are within the half-open intervals + # defined by the boundaries + # note that every second entry in boundaries is an upper boundary + spike_to_epoch_idx = np.searchsorted(right_edges, + st.times.rescale( + right_edges.units)) + complexity_per_spike = complexity[spike_to_epoch_idx] + + st.array_annotate(complexity=complexity_per_spike) + + if deletion_threshold is not None: + mask = complexity_per_spike < deletion_threshold + if invert_delete: + mask = np.invert(mask) + old_st = st + new_st = old_st[mask] + spiketrains[idx] = new_st + unit = old_st.unit + segment = old_st.segment + if unit is not None: + unit.spiketrains[self._get_index(unit.spiketrains, + old_st)] = new_st + if segment is not None: + segment.spiketrains[self._get_index(segment.spiketrains, + old_st)] = new_st + del old_st + + return complexity_epoch + + def _get_index(lst, obj): + for index, item in enumerate(lst): + if item is obj: + return index From b002cd5046aae79bfae25a09101223f0cb008b2a Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Fri, 19 Jun 2020 18:08:38 +0200 Subject: [PATCH 22/58] Move complexity class to statistics move documentation to more appropriate locations add deprecation warning to complexity_pdf start implementing time_histogram for spread > 0 and epoch for spread = 0 --- elephant/spike_train_processing.py | 269 -------------------- elephant/statistics.py | 378 +++++++++++++++++++++++++++++ elephant/test/test_statistics.py | 2 +- 3 files changed, 379 insertions(+), 270 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 368f4d228..9fc80b41c 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -8,276 +8,7 @@ from __future__ import division -import neo -import elephant.conversion as conv import numpy as np -import quantities as pq -from .utils import _check_consistency_of_spiketrainlist -from elephant.statistics import time_histogram -import warnings - - -class complexity: - """ - docstring TODO - - COPIED FROM PREVIOUS GET EPOCHS AS IS: - Calculate the complexity (i.e. number of synchronous spikes found) - at `sampling_rate` precision in a list of spiketrains. - - Complexity is calculated by counting the number of spikes (i.e. non-empty - bins) that occur separated by `spread - 1` or less empty bins, within and - across spike trains in the `spiketrains` list. - - Parameters - ---------- - spiketrains : list of neo.SpikeTrains - a list of neo.SpikeTrains objects. These spike trains should have been - recorded simultaneously. - sampling_rate : pq.Quantity - Sampling rate of the spike trains. - spread : int, optional - Number of bins in which to check for synchronous spikes. - Spikes that occur separated by `spread - 1` or less empty bins are - considered synchronous. - * `spread = 0` corresponds to a bincount accross spike trains. - * `spread = 1` corresponds to counting consecutive spikes. - * `spread = 2` corresponds to counting consecutive spikes and - spikes separated by exactly 1 empty bin. - * `spread = n` corresponds to counting spikes separated by exactly - or less than `n - 1` empty bins. - Default: 0 - - Returns - ------- - complexity_intervals : neo.Epoch - An epoch object containing complexity values, left edges and durations - of all intervals with at least one spike. - - Raises - ------ - ValueError - When `t_stop` is smaller than `t_start`. - - Examples - -------- - Here the behavior of - `elephant.spike_train_processing.precise_complexity_intervals` is shown, by - applying the function to some sample spiketrains. - - >>> import neo - >>> import quantities as pq - - >>> sampling_rate = 1/pq.ms - >>> st1 = neo.SpikeTrain([1, 4, 6] * pq.ms, t_stop=10.0 * pq.ms) - >>> st2 = neo.SpikeTrain([1, 5, 8] * pq.ms, t_stop=10.0 * pq.ms) - - >>> # spread = 0, a simple bincount - >>> ep1 = precise_complexity_intervals([st1, st2], sampling_rate) - >>> print(ep1.array_annotations['complexity'].flatten()) - [2 1 1 1 1] - >>> print(ep1.times) - [0. 3.5 4.5 5.5 7.5] ms - >>> print(ep1.durations) - [1.5 1. 1. 1. 1. ] ms - - >>> # spread = 1, consecutive spikes - >>> ep2 = precise_complexity_intervals([st1, st2], sampling_rate, spread=1) - >>> print(ep2.array_annotations['complexity'].flatten()) - [2 3 1] - >>> print(ep2.times) - [0. 3.5 7.5] ms - >>> print(ep2.durations) - [1.5 3. 1. ] ms - - >>> # spread = 2, consecutive spikes and separated by 1 empty bin - >>> ep3 = precise_complexity_intervals([st1, st2], sampling_rate, spread=2) - >>> print(ep3.array_annotations['complexity'].flatten()) - [2 4] - >>> print(ep3.times) - [0. 3.5] ms - >>> print(ep3.durations) - [1.5 5. ] ms - """ - - def __init__(self, spiketrains, - sampling_rate=None, - bin_size=None, - binary=True, - spread=0): - - if isinstance(spiketrains, list): - _check_consistency_of_spiketrainlist(spiketrains) - else: - raise TypeError('spiketrains should be a list of neo.SpikeTrain') - self.input_spiketrains = spiketrains - self.sampling_rate = sampling_rate - self.bin_size = bin_size - self.binary = binary - self.spread = spread - - if bin_size is None and sampling_rate is None: - raise ValueError('No bin_size or sampling_rate was speficied!') - elif bin_size is None and sampling_rate is not None: - self.bin_size = 1 / self.sampling_rate - - if spread < 0: - raise ValueError('Spread must be >=0') - elif spread == 0: - self.time_histogram, self.histogram = self._histogram_no_spread() - else: - print('Complexity calculated at sampling rate precision') - # self.epoch = self.precise_complexity_intervals() - self.epoch = self.get_epoch() - self.histogram = self._histogram_with_spread() - - return self - - @property - def pdf(self): - """ - Normalization of the Complexity Histogram (probabilty distribution) - """ - norm_hist = self.histogram / self.histogram.sum() - # Convert the Complexity pdf to an neo.AnalogSignal - pdf = neo.AnalogSignal( - np.array(norm_hist).reshape(len(norm_hist), 1) * - pq.dimensionless, t_start=0 * pq.dimensionless, - sampling_period=1 * pq.dimensionless) - return pdf - - # @property - # def epoch(self): - # if self.spread == 0: - # warnings.warn('No epoch for cases with spread = 0') - # return None - # else: - # return self._epoch - - def _histogram_no_spread(self): - """ - Complexity Distribution of a list of `neo.SpikeTrain` objects. - - Probability density computed from the complexity histogram which is the - histogram of the entries of the population histogram of clipped - (binary) spike trains computed with a bin width of `binsize`. - It provides for each complexity (== number of active neurons per bin) - the number of occurrences. The normalization of that histogram to 1 is - the probability density. - - Implementation is based on [1]_. - - Returns - ------- - complexity_distribution : neo.AnalogSignal - A `neo.AnalogSignal` object containing the histogram values. - `neo.AnalogSignal[j]` is the histogram computed between - `t_start + j * binsize` and `t_start + (j + 1) * binsize`. - - See also - -------- - elephant.conversion.BinnedSpikeTrain - - References - ---------- - .. [1] S. Gruen, M. Abeles, & M. Diesmann, "Impact of higher-order - correlations on coincidence distributions of massively parallel - data," In "Dynamic Brain - from Neural Spikes to Behaviors", - pp. 96-114, Springer Berlin Heidelberg, 2008. - - """ - # Computing the population histogram with parameter binary=True to - # clip the spike trains before summing - pophist = time_histogram(self.input_spiketrains, - self.bin_size, - binary=self.binary) - - # Computing the histogram of the entries of pophist - complexity_hist = np.histogram( - pophist.magnitude, - bins=range(0, len(self.input_spiketrains) + 2))[0] - - return pophist, complexity_hist - - def _histogram_with_spread(self): - """ - Calculate the complexity histogram; - the number of occurrences of events of different complexities. - - Returns - ------- - complexity_histogram : np.ndarray - A histogram of complexities. `complexity_histogram[i]` corresponds - to the number of events of complexity `i` for `i > 0`. - """ - complexity_histogram = np.bincount( - self.epoch.array_annotations['complexity']) - return complexity_histogram - - def get_epoch(self): - bst = conv.BinnedSpikeTrain(self.input_spiketrains, - binsize=self.bin_size) - - if self.binary: - binarized = bst.to_sparse_bool_array() - bincount = np.array(binarized.sum(axis=0)).squeeze() - else: - bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() - - i = 0 - complexities = [] - left_edges = [] - right_edges = [] - while i < len(bincount): - current_bincount = bincount[i] - if current_bincount == 0: - i += 1 - else: - last_window_sum = current_bincount - last_nonzero_index = 0 - current_window = bincount[i:i + self.spread + 1] - window_sum = current_window.sum() - while window_sum > last_window_sum: - last_nonzero_index = np.nonzero(current_window)[0][-1] - current_window = bincount[i: - i + last_nonzero_index - + self.spread + 1] - last_window_sum = window_sum - window_sum = current_window.sum() - complexities.append(window_sum) - left_edges.append( - bst.bin_edges[i].magnitude.item()) - right_edges.append( - bst.bin_edges[ - i + last_nonzero_index + 1 - ].magnitude.item()) - i += last_nonzero_index + 1 - - # we dropped units above, neither concatenate nor append works - # with arrays of quantities - left_edges *= bst.bin_edges.units - right_edges *= bst.bin_edges.units - - if self.sampling_rate: - # ensure that spikes are not on the bin edges - bin_shift = .5 / self.sampling_rate - left_edges -= bin_shift - right_edges -= bin_shift - else: - warnings.warn('No sampling rate specified. ' - 'Note that using the complexity epoch to get ' - 'precise spike times can lead to rounding errors.') - - # ensure that an epoch does not start before the minimum t_start - min_t_start = min([st.t_start for st in self.input_spiketrains]) - left_edges[0] = min(min_t_start, left_edges[0]) - - complexity_epoch = neo.Epoch(times=left_edges, - durations=right_edges - left_edges, - array_annotations={'complexity': - complexities}) - - return complexity_epoch class synchrotool(complexity): diff --git a/elephant/statistics.py b/elephant/statistics.py index 70dc03197..9832a893e 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -763,6 +763,384 @@ def time_histogram(spiketrains, binsize, t_start=None, t_stop=None, t_start=t_start) +def complexity_pdf(spiketrains, binsize): + """ + Deprecated in favor of the complexity class which has a pdf attribute. + Will be removed in the next release! + + Complexity Distribution of a list of `neo.SpikeTrain` objects. + + Probability density computed from the complexity histogram which is the + histogram of the entries of the population histogram of clipped (binary) + spike trains computed with a bin width of `binsize`. + It provides for each complexity (== number of active neurons per bin) the + number of occurrences. The normalization of that histogram to 1 is the + probability density. + + Implementation is based on [1]_. + + Parameters + ---------- + spiketrains : list of neo.SpikeTrain + Spike trains with a common time axis (same `t_start` and `t_stop`) + binsize : pq.Quantity + Width of the histogram's time bins. + + Returns + ------- + complexity_distribution : neo.AnalogSignal + A `neo.AnalogSignal` object containing the histogram values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + + See also + -------- + elephant.conversion.BinnedSpikeTrain + + References + ---------- + .. [1] S. Gruen, M. Abeles, & M. Diesmann, "Impact of higher-order + correlations on coincidence distributions of massively parallel + data," In "Dynamic Brain - from Neural Spikes to Behaviors", + pp. 96-114, Springer Berlin Heidelberg, 2008. + + """ + warnings.warn("complexity_pdf is deprecated in favor of the complexity " + "class which has a pdf attribute. complexity_pdf will be " + "removed in the next Elephant release.", DeprecationWarning) + + complexity_obj = complexity(spiketrains, bin_size=binsize) + + return complexity_obj.pdf + + +class complexity: + """ + Complexity Distribution of a list of `neo.SpikeTrain` objects. + docstring TODO + + COPIED FROM PREVIOUS GET EPOCHS AS IS: + Calculate the complexity (i.e. number of synchronous spikes found) + at `sampling_rate` precision in a list of spiketrains. + + Complexity is calculated by counting the number of spikes (i.e. non-empty + bins) that occur separated by `spread - 1` or less empty bins, within and + across spike trains in the `spiketrains` list. + + Implementation is based on [1]_. + + Parameters + ---------- + spiketrains : list of neo.SpikeTrain + Spike trains with a common time axis (same `t_start` and `t_stop`) + bin_size : pq.Quantity + Width of the histogram's time bins. + sampling_rate : pq.Quantity + Sampling rate of the spike trains. + spread : int, optional + Number of bins in which to check for synchronous spikes. + Spikes that occur separated by `spread - 1` or less empty bins are + considered synchronous. + * `spread = 0` corresponds to a bincount accross spike trains. + * `spread = 1` corresponds to counting consecutive spikes. + * `spread = 2` corresponds to counting consecutive spikes and + spikes separated by exactly 1 empty bin. + * `spread = n` corresponds to counting spikes separated by exactly + or less than `n - 1` empty bins. + Default: 0 + + Raises + ------ + ValueError + When `t_stop` is smaller than `t_start`. + + See also + -------- + elephant.conversion.BinnedSpikeTrain + + References + ---------- + .. [1] S. Gruen, M. Abeles, & M. Diesmann, "Impact of higher-order + correlations on coincidence distributions of massively parallel + data," In "Dynamic Brain - from Neural Spikes to Behaviors", + pp. 96-114, Springer Berlin Heidelberg, 2008. + + Examples + -------- + Here the behavior of + `elephant.spike_train_processing.precise_complexity_intervals` is shown, by + applying the function to some sample spiketrains. + + >>> import neo + >>> import quantities as pq + + >>> sampling_rate = 1/pq.ms + >>> st1 = neo.SpikeTrain([1, 4, 6] * pq.ms, t_stop=10.0 * pq.ms) + >>> st2 = neo.SpikeTrain([1, 5, 8] * pq.ms, t_stop=10.0 * pq.ms) + + >>> # spread = 0, a simple bincount + >>> ep1 = precise_complexity_intervals([st1, st2], sampling_rate) + >>> print(ep1.array_annotations['complexity'].flatten()) + [2 1 1 1 1] + >>> print(ep1.times) + [0. 3.5 4.5 5.5 7.5] ms + >>> print(ep1.durations) + [1.5 1. 1. 1. 1. ] ms + + >>> # spread = 1, consecutive spikes + >>> ep2 = precise_complexity_intervals([st1, st2], sampling_rate, spread=1) + >>> print(ep2.array_annotations['complexity'].flatten()) + [2 3 1] + >>> print(ep2.times) + [0. 3.5 7.5] ms + >>> print(ep2.durations) + [1.5 3. 1. ] ms + + >>> # spread = 2, consecutive spikes and separated by 1 empty bin + >>> ep3 = precise_complexity_intervals([st1, st2], sampling_rate, spread=2) + >>> print(ep3.array_annotations['complexity'].flatten()) + [2 4] + >>> print(ep3.times) + [0. 3.5] ms + >>> print(ep3.durations) + [1.5 5. ] ms + """ + + def __init__(self, spiketrains, + sampling_rate=None, + bin_size=None, + binary=True, + spread=0, + tolerance=1e-8): + + if isinstance(spiketrains, list): + _check_consistency_of_spiketrainlist(spiketrains) + else: + raise TypeError('spiketrains should be a list of neo.SpikeTrain') + self.input_spiketrains = spiketrains + self.t_start = spiketrains[0].t_start + self.t_stop = spiketrains[0].t_stop + self.sampling_rate = sampling_rate + self.bin_size = bin_size + self.binary = binary + self.spread = spread + + if bin_size is None and sampling_rate is None: + raise ValueError('No bin_size or sampling_rate was speficied!') + elif bin_size is None and sampling_rate is not None: + self.bin_size = 1 / self.sampling_rate + + if spread < 0: + raise ValueError('Spread must be >=0') + elif spread == 0: + self.time_histogram, self.histogram = self._histogram_no_spread() + self.epoch = self._epoch_no_spread() + else: + print('Complexity calculated at sampling rate precision') + self.epoch = self.get_epoch() + self.time_histogram, self.histogram = self._histogram_with_spread() + + @property + def pdf(self): + """ + Probability density computed from the complexity histogram which is the + histogram of the entries of the population histogram of + spike trains computed with a bin width of `complexity.bin_size`. + It provides for each complexity (== number of active neurons per bin) + the number of occurrences. The normalization of that histogram to 1 is + the probability density. + + Returns + ------- + complexity_distribution : neo.AnalogSignal + A `neo.AnalogSignal` object containing the histogram values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + """ + norm_hist = self.histogram / self.histogram.sum() + # Convert the Complexity pdf to an neo.AnalogSignal + pdf = neo.AnalogSignal( + np.array(norm_hist).reshape(len(norm_hist), 1) * + pq.dimensionless, t_start=0 * pq.dimensionless, + sampling_period=1 * pq.dimensionless) + return pdf + + def _histogram_no_spread(self): + """ + Complexity Distribution of a list of `neo.SpikeTrain` objects. + + Probability density computed from the complexity histogram which is the + histogram of the entries of the population histogram of clipped + (binary) spike trains computed with a bin width of `binsize`. + It provides for each complexity (== number of active neurons per bin) + the number of occurrences. The normalization of that histogram to 1 is + the probability density. + + Implementation is based on [1]_. + + Returns + ------- + complexity_distribution : neo.AnalogSignal + A `neo.AnalogSignal` object containing the histogram values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + + See also + -------- + elephant.conversion.BinnedSpikeTrain + + References + ---------- + .. [1] S. Gruen, M. Abeles, & M. Diesmann, "Impact of higher-order + correlations on coincidence distributions of massively parallel + data," In "Dynamic Brain - from Neural Spikes to Behaviors", + pp. 96-114, Springer Berlin Heidelberg, 2008. + + """ + # Computing the population histogram with parameter binary=True to + # clip the spike trains before summing + pophist = time_histogram(self.input_spiketrains, + self.bin_size, + binary=self.binary) + + # Computing the histogram of the entries of pophist + complexity_hist = np.histogram( + pophist.magnitude, + bins=range(0, len(self.input_spiketrains) + 2))[0] + + return pophist, complexity_hist + + def _histogram_with_spread(self): + """ + Calculate the complexity histogram; + the number of occurrences of events of different complexities. + + Returns + ------- + complexity_histogram : np.ndarray + A histogram of complexities. `complexity_histogram[i]` corresponds + to the number of events of complexity `i` for `i > 0`. + """ + complexity_histogram = np.bincount( + self.epoch.array_annotations['complexity']) + num_bins = ((self.t_stop - self.t_start).rescale( + self.bin_size.units) / self.bin_size.magnitude).item() + if conv._detect_rounding_errors(num_bins, tolerance=self.tolerance): + warnings.warn('Correcting a rounding error in the histogram ' + 'calculation by increasing num_bins by 1. ' + 'You can set tolerance=None to disable this ' + 'behaviour.') + num_bins += 1 + time_histogram = np.zeros((num_bins, ), dtype=int) + + start_bins = ((self.epoch.times - self.t_start).rescale( + self.binsize.units) / self.binsize).magnitude.flatten() + stop_bins = ((self.epoch.times + self.epoch.durations + - self.t_start).rescale( + self.binsize.units) / self.binsize).magnitude.flatten() + + rounding_error_indices = conv._detect_rounding_errors(start_bins, + self.tolerance) + + num_rounding_corrections = rounding_error_indices.sum() + if num_rounding_corrections > 0: + warnings.warn('Correcting {} rounding errors by shifting ' + 'the affected spikes into the following bin. ' + 'You can set tolerance=None to disable this ' + 'behaviour.'.format(num_rounding_corrections)) + start_bins[rounding_error_indices] += .5 + + start_bins = start_bins.astype(int) + + rounding_error_indices = conv._detect_rounding_errors(stop_bins, + self.tolerance) + + num_rounding_corrections = rounding_error_indices.sum() + if num_rounding_corrections > 0: + warnings.warn('Correcting {} rounding errors by shifting ' + 'the affected spikes into the following bin. ' + 'You can set tolerance=None to disable this ' + 'behaviour.'.format(num_rounding_corrections)) + stop_bins[rounding_error_indices] += .5 + + stop_bins = stop_bins.astype(int) + + for idx, (start, stop) in enumerate(zip(start_bins, stop_bins)): + time_histogram[start:stop] = \ + self.epoch.array_annotations[idx]['compexity'] + + return time_histogram, complexity_histogram + + def _epoch_no_spread(self): + # TODO + return epoch + + def get_epoch(self): + bst = conv.BinnedSpikeTrain(self.input_spiketrains, + binsize=self.bin_size) + + if self.binary: + binarized = bst.to_sparse_bool_array() + bincount = np.array(binarized.sum(axis=0)).squeeze() + else: + bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() + + i = 0 + complexities = [] + left_edges = [] + right_edges = [] + while i < len(bincount): + current_bincount = bincount[i] + if current_bincount == 0: + i += 1 + else: + last_window_sum = current_bincount + last_nonzero_index = 0 + current_window = bincount[i:i + self.spread + 1] + window_sum = current_window.sum() + while window_sum > last_window_sum: + last_nonzero_index = np.nonzero(current_window)[0][-1] + current_window = bincount[i: + i + last_nonzero_index + + self.spread + 1] + last_window_sum = window_sum + window_sum = current_window.sum() + complexities.append(window_sum) + left_edges.append( + bst.bin_edges[i].magnitude.item()) + right_edges.append( + bst.bin_edges[ + i + last_nonzero_index + 1 + ].magnitude.item()) + i += last_nonzero_index + 1 + + # we dropped units above, neither concatenate nor append works + # with arrays of quantities + left_edges *= bst.bin_edges.units + right_edges *= bst.bin_edges.units + + if self.sampling_rate: + # ensure that spikes are not on the bin edges + bin_shift = .5 / self.sampling_rate + left_edges -= bin_shift + right_edges -= bin_shift + else: + warnings.warn('No sampling rate specified. ' + 'Note that using the complexity epoch to get ' + 'precise spike times can lead to rounding errors.') + + # ensure that an epoch does not start before the minimum t_start + min_t_start = min([st.t_start for st in self.input_spiketrains]) + left_edges[0] = min(min_t_start, left_edges[0]) + + complexity_epoch = neo.Epoch(times=left_edges, + durations=right_edges - left_edges, + array_annotations={'complexity': + complexities}) + + return complexity_epoch + + """ Kernel Bandwidth Optimization. diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index f0749f866..b5c7416d0 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -353,7 +353,7 @@ def test_2short_spike_train(self): an input with more than 1 entry. """ self.assertTrue(math.isnan(es.lv(seq, with_nan=True))) - + class CV2TestCase(unittest.TestCase): def setUp(self): From 6e360bba89e35d1ab5ff382c8da17af5435ac426 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 24 Jun 2020 13:04:07 +0200 Subject: [PATCH 23/58] Test whether spiketrains is a list --- elephant/statistics.py | 6 ++---- elephant/test/test_utils.py | 15 +++++++++------ elephant/utils.py | 8 +++++--- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 9832a893e..383c32bb0 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -913,10 +913,8 @@ def __init__(self, spiketrains, spread=0, tolerance=1e-8): - if isinstance(spiketrains, list): - _check_consistency_of_spiketrainlist(spiketrains) - else: - raise TypeError('spiketrains should be a list of neo.SpikeTrain') + _check_consistency_of_spiketrainlist(spiketrains) + self.input_spiketrains = spiketrains self.t_start = spiketrains[0].t_start self.t_stop = spiketrains[0].t_stop diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index c6063963e..4474752cb 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -16,15 +16,18 @@ class checkSpiketrainTestCase(unittest.TestCase): def test_wrong_input_errors(self): self.assertRaises(ValueError, - utils._check_consistency_of_spiketrains, + utils._check_consistency_of_spiketrainlist, [], 1 / pq.s) self.assertRaises(TypeError, - utils._check_consistency_of_spiketrains, + utils._check_consistency_of_spiketrainlist, + neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)) + self.assertRaises(TypeError, + utils._check_consistency_of_spiketrainlist, [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), np.arange(2)], 1 / pq.s) self.assertRaises(ValueError, - utils._check_consistency_of_spiketrains, + utils._check_consistency_of_spiketrainlist, [neo.SpikeTrain([1]*pq.s, t_start=1*pq.s, t_stop=2*pq.s), @@ -33,13 +36,13 @@ def test_wrong_input_errors(self): t_stop=2*pq.s)], same_t_start=True) self.assertRaises(ValueError, - utils._check_consistency_of_spiketrains, + utils._check_consistency_of_spiketrainlist, [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), neo.SpikeTrain([1]*pq.s, t_stop=3*pq.s)], same_t_stop=True) self.assertRaises(ValueError, - utils._check_consistency_of_spiketrains, - [neo.SpikeTrain([1]*pq.ms, t_stop=2*pq.s), + utils._check_consistency_of_spiketrainlist, + [neo.SpikeTrain([1]*pq.ms, t_stop=2000*pq.ms), neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], same_units=True) diff --git a/elephant/utils.py b/elephant/utils.py index 9f60d27e1..d4f708466 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -22,12 +22,14 @@ def is_binary(array): def _check_consistency_of_spiketrainlist(spiketrains, - same_t_start=None, - same_t_stop=None, + same_t_start=False, + same_t_stop=False, same_units=False): """ Private function to check lists of spiketrains. """ + if not isinstance(spiketrains, list): + raise TypeError('spiketrains should be a list of neo.SpikeTrain') if len(spiketrains) == 0: raise ValueError('The spiketrains list is empty!') for st in spiketrains: @@ -42,5 +44,5 @@ def _check_consistency_of_spiketrainlist(spiketrains, if same_t_stop and not st.t_stop == spiketrains[0].t_stop: raise ValueError( "the spike trains must have the same t_stop!") - if same_units and not st.units == st[0].units: + if same_units and not st.units == spiketrains[0].units: raise ValueError('The spike trains must have the same units!') From 13867993dc27be213f33891611811f350b9a3752 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 24 Jun 2020 13:04:35 +0200 Subject: [PATCH 24/58] Finish complexity class (minus docs) --- elephant/statistics.py | 45 +++++++- elephant/test/test_statistics.py | 176 ++++++++++++++++++++++++++++--- 2 files changed, 200 insertions(+), 21 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 383c32bb0..05f85a4ed 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -922,6 +922,7 @@ def __init__(self, spiketrains, self.bin_size = bin_size self.binary = binary self.spread = spread + self.tolerance = tolerance if bin_size is None and sampling_rate is None: raise ValueError('No bin_size or sampling_rate was speficied!') @@ -1029,13 +1030,25 @@ def _histogram_with_spread(self): 'You can set tolerance=None to disable this ' 'behaviour.') num_bins += 1 + num_bins = int(num_bins) time_histogram = np.zeros((num_bins, ), dtype=int) start_bins = ((self.epoch.times - self.t_start).rescale( - self.binsize.units) / self.binsize).magnitude.flatten() + self.bin_size.units) / self.bin_size).magnitude.flatten() stop_bins = ((self.epoch.times + self.epoch.durations - self.t_start).rescale( - self.binsize.units) / self.binsize).magnitude.flatten() + self.bin_size.units) / self.bin_size).magnitude.flatten() + + if self.sampling_rate is not None: + shift = (.5 / self.sampling_rate / self.bin_size + ).simplified.magnitude.item() + # account for the first bin not being shifted in the epoch creation + # if the shift would move it past t_start + if self.epoch.times[0] == self.t_start: + start_bins[1:] += shift + else: + start_bins += shift + stop_bins += shift rounding_error_indices = conv._detect_rounding_errors(start_bins, self.tolerance) @@ -1065,12 +1078,34 @@ def _histogram_with_spread(self): for idx, (start, stop) in enumerate(zip(start_bins, stop_bins)): time_histogram[start:stop] = \ - self.epoch.array_annotations[idx]['compexity'] + self.epoch.array_annotations['complexity'][idx] + + time_histogram = neo.AnalogSignal( + signal=time_histogram.reshape(time_histogram.size, 1), + sampling_period=self.bin_size, units=pq.dimensionless, + t_start=self.t_start) + + empty_bins = (self.t_stop - self.t_start - self.epoch.durations.sum()) + empty_bins = empty_bins.rescale(self.bin_size.units) / self.bin_size + if conv._detect_rounding_errors(empty_bins, tolerance=self.tolerance): + warnings.warn('Correcting a rounding error in the histogram ' + 'calculation by increasing num_bins by 1. ' + 'You can set tolerance=None to disable this ' + 'behaviour.') + empty_bins += 1 + empty_bins = int(empty_bins) + + complexity_histogram[0] = empty_bins return time_histogram, complexity_histogram def _epoch_no_spread(self): - # TODO + epoch = neo.Epoch(self.time_histogram.times, + durations=self.bin_size + * np.ones(self.time_histogram.shape), + array_annotations={ + 'complexity': + self.time_histogram.magnitude.flatten()}) return epoch def get_epoch(self): @@ -1129,7 +1164,7 @@ def get_epoch(self): # ensure that an epoch does not start before the minimum t_start min_t_start = min([st.t_start for st in self.input_spiketrains]) - left_edges[0] = min(min_t_start, left_edges[0]) + left_edges[0] = max(min_t_start, left_edges[0]) complexity_epoch = neo.Epoch(times=left_edges, durations=right_edges - left_edges, diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index b5c7416d0..6950642c7 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -641,31 +641,175 @@ def test_time_histogram_output(self): class ComplexityPdfTestCase(unittest.TestCase): - def setUp(self): - self.spiketrain_a = neo.SpikeTrain( + def test_complexity_pdf_old(self): + spiketrain_a = neo.SpikeTrain( [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) - self.spiketrain_b = neo.SpikeTrain( + spiketrain_b = neo.SpikeTrain( [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) - self.spiketrain_c = neo.SpikeTrain( + spiketrain_c = neo.SpikeTrain( [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) - self.spiketrains = [ - self.spiketrain_a, self.spiketrain_b, self.spiketrain_c] - - def tearDown(self): - del self.spiketrain_a - self.spiketrain_a = None - del self.spiketrain_b - self.spiketrain_b = None - - def test_complexity_pdf(self): + spiketrains = [ + spiketrain_a, spiketrain_b, spiketrain_c] + # runs the previous function which will be deprecated targ = np.array([0.92, 0.01, 0.01, 0.06]) - complexity = es.complexity_pdf(self.spiketrains, binsize=0.1*pq.s) + complexity = es.complexity_pdf(spiketrains, binsize=0.1*pq.s) assert_array_equal(targ, complexity.magnitude[:, 0]) self.assertEqual(1, complexity.magnitude[:, 0].sum()) - self.assertEqual(len(self.spiketrains)+1, len(complexity)) + self.assertEqual(len(spiketrains)+1, len(complexity)) self.assertIsInstance(complexity, neo.AnalogSignal) self.assertEqual(complexity.units, 1*pq.dimensionless) + def test_complexity_pdf(self): + spiketrain_a = neo.SpikeTrain( + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) + spiketrain_b = neo.SpikeTrain( + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) + spiketrain_c = neo.SpikeTrain( + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) + spiketrains = [ + spiketrain_a, spiketrain_b, spiketrain_c] + # runs the previous function which will be deprecated + targ = np.array([0.92, 0.01, 0.01, 0.06]) + complexity_obj = es.complexity(spiketrains, bin_size=0.1*pq.s) + pdf = complexity_obj.pdf + assert_array_equal(targ, complexity_obj.pdf.magnitude[:, 0]) + self.assertEqual(1, pdf.magnitude[:, 0].sum()) + self.assertEqual(len(spiketrains)+1, len(pdf)) + self.assertIsInstance(pdf, neo.AnalogSignal) + self.assertEqual(pdf.units, 1*pq.dimensionless) + + def test_complexity_histogram_spread_0(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, + t_stop=20*pq.s)] + + correct_histogram = np.array([10, 8, 2]) + + correct_time_histogram = np.array([0, 2, 0, 0, 1, 1, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 2, 0, 1, 1]) + + complexity_obj = es.complexity(spiketrains, + sampling_rate=sampling_rate, + spread=0) + + assert_array_equal(complexity_obj.histogram, + correct_histogram) + + assert_array_equal( + complexity_obj.time_histogram.magnitude.flatten().astype(int), + correct_time_histogram) + + def test_complexity_epoch_spread_0(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, + t_stop=20*pq.s)] + + complexity_obj = es.complexity(spiketrains, + sampling_rate=sampling_rate, + spread=0) + + self.assertIsInstance(complexity_obj.epoch, + neo.Epoch) + + def test_complexity_histogram_spread_1(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([0, 1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_histogram = np.array([9, 5, 1, 2]) + + correct_time_histogram = np.array([3, 3, 0, 0, 2, 2, 0, 1, 0, 1, 0, + 3, 3, 3, 0, 0, 1, 0, 1, 0, 1]) + + complexity_obj = es.complexity(spiketrains, + sampling_rate=sampling_rate, + spread=1) + + assert_array_equal(complexity_obj.histogram, + correct_histogram) + + assert_array_equal( + complexity_obj.time_histogram.magnitude.flatten().astype(int), + correct_time_histogram) + + def test_complexity_histogram_spread_2(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_histogram = np.array([5, 0, 1, 1, 0, 0, 0, 1]) + + correct_time_histogram = np.array([0, 2, 0, 0, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 0, 0, 3, 3, 3, 3, 3]) + + complexity_obj = es.complexity(spiketrains, + sampling_rate=sampling_rate, + spread=2) + + assert_array_equal(complexity_obj.histogram, + correct_histogram) + + assert_array_equal( + complexity_obj.time_histogram.magnitude.flatten().astype(int), + correct_time_histogram) + + def test_wrong_input_errors(self): + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + self.assertRaises(ValueError, + es.complexity, + spiketrains) + + self.assertRaises(ValueError, + es.complexity, + spiketrains, + sampling_rate=1*pq.s, + spread=-7) + + def test_binning_for_input_with_rounding_errors(self): + + # a test with inputs divided by 30000 which leads to rounding errors + # these errors have to be accounted for by proper binning; + # check if we still get the correct result + + sampling_rate = 30000 / pq.s + + spiketrains = [neo.SpikeTrain(np.arange(1000, step=2) * pq.s / 30000, + t_stop=.1 * pq.s), + neo.SpikeTrain(np.arange(2000, step=4) * pq.s / 30000, + t_stop=.1 * pq.s)] + + correct_time_histogram = np.zeros(3000) + correct_time_histogram[:1000:2] = 1 + correct_time_histogram[:2000:4] += 1 + + complexity_obj = es.complexity(spiketrains, + sampling_rate=sampling_rate, + spread=1) + + assert_array_equal( + complexity_obj.time_histogram.magnitude.flatten().astype(int), + correct_time_histogram) + if __name__ == '__main__': unittest.main() From 35c081fdca9d9c84102870e34c8e0c81991255b6 Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 1 Jul 2020 12:02:20 +0200 Subject: [PATCH 25/58] create list of raised errors in util --- elephant/utils.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/elephant/utils.py b/elephant/utils.py index d4f708466..ae6979f39 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -26,7 +26,25 @@ def _check_consistency_of_spiketrainlist(spiketrains, same_t_stop=False, same_units=False): """ - Private function to check lists of spiketrains. + Private function to check the consistency of a list of neo.SpikeTrain + + Raises + ------ + TypeError + When `spiketrains` is not a list. + ValueError + When `spiketrains` is an empty list. + TypeError + When the elements in `spiketrains` are not instances of neo.SpikeTrain + ValueError + When `t_start` is not the same for all spiketrains, + if same_t_start=True + ValueError + When `t_stop` is not the same for all spiketrains, + if same_t_stop=True + ValueError + When `units` are not the same for all spiketrains, + if same_units=True """ if not isinstance(spiketrains, list): raise TypeError('spiketrains should be a list of neo.SpikeTrain') From 95189e79f5fbfd9bb9ca447c61a452592815d8aa Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 1 Jul 2020 12:42:58 +0200 Subject: [PATCH 26/58] Update docstrings --- elephant/statistics.py | 200 +++++++++++++++++++++++------------------ 1 file changed, 111 insertions(+), 89 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 05f85a4ed..7a7adf77c 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -816,47 +816,108 @@ def complexity_pdf(spiketrains, binsize): class complexity: """ - Complexity Distribution of a list of `neo.SpikeTrain` objects. - docstring TODO - - COPIED FROM PREVIOUS GET EPOCHS AS IS: - Calculate the complexity (i.e. number of synchronous spikes found) - at `sampling_rate` precision in a list of spiketrains. + Class for complexity distribution (i.e. number of synchronous spikes found) + of a list of `neo.SpikeTrain` objects. Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur separated by `spread - 1` or less empty bins, within and across spike trains in the `spiketrains` list. - Implementation is based on [1]_. + Implementation (without spread) is based on [1]_. Parameters ---------- spiketrains : list of neo.SpikeTrain Spike trains with a common time axis (same `t_start` and `t_stop`) - bin_size : pq.Quantity - Width of the histogram's time bins. - sampling_rate : pq.Quantity - Sampling rate of the spike trains. + sampling_rate : pq.Quantity, optional + Sampling rate of the spike trains with units of 1/time. + Default: None + bin_size : pq.Quantity, optional + Width of the histogram's time bins with units of time. + The user must specify the `bin_size` or the `sampling_rate`. + * If no `bin_size` is specified and the `sampling_rate` is available + 1/`sampling_rate` is used. + * If both are given then `bin_size` is used. + Default: None + binary : bool, optional + * If `True` then the time histograms will be binary. + * If `False` the total number of synchronous spikes is counted in the + time histogram. + Default: True spread : int, optional Number of bins in which to check for synchronous spikes. Spikes that occur separated by `spread - 1` or less empty bins are considered synchronous. - * `spread = 0` corresponds to a bincount accross spike trains. - * `spread = 1` corresponds to counting consecutive spikes. - * `spread = 2` corresponds to counting consecutive spikes and + * ``spread = 0`` corresponds to a bincount accross spike trains. + * ``spread = 1`` corresponds to counting consecutive spikes. + * ``spread = 2`` corresponds to counting consecutive spikes and spikes separated by exactly 1 empty bin. - * `spread = n` corresponds to counting spikes separated by exactly + * ``spread = n`` corresponds to counting spikes separated by exactly or less than `n - 1` empty bins. Default: 0 + tolerance : float, optional + Tolerance for rounding errors in the binning process and in the input + data + Default: 1e-8 + + Attributes + ---------- + epoch : neo.Epoch + An epoch object containing complexity values, left edges and durations + of all intervals with at least one spike. + * ``epoch.array_annotations['complexity']`` contains the + complexity values per spike. + * ``epoch.times`` contains the left edges. + * ``epoch.durations`` contains the durations. + time_histogram : neo.Analogsignal + A `neo.AnalogSignal` object containing the histogram values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + * If ``binary = True`` : Number of neurons that spiked in each bin, + regardless of the number of spikes. + * If ``binary = False`` : Number of neurons and spikes per neurons + in each bin. + complexity_histogram : np.ndarray + The number of occurrences of events of different complexities. + `complexity_hist[i]` corresponds to the number of events of + complexity `i` for `i > 0`. + pdf : neo.AnalogSignal + The normalization of `self.complexityhistogram` to 1. + A `neo.AnalogSignal` object containing the pdf values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. Raises ------ ValueError When `t_stop` is smaller than `t_start`. + ValueError + When both `sampling_rate` and `bin_size` are not specified. + ValueError + When `spread` is not a positive integer. + TypeError + When `spiketrains` is not a list. + ValueError + When `spiketrains` is an empty list. + TypeError + When the elements in `spiketrains` are not instances of neo.SpikeTrain + ValueError + When `t_start` is not the same for all spiketrains + ValueError + When `t_stop` is not the same for all spiketrains + + Notes + ----- + * Note that with most common parameter combinations spike times can end up + on bin edges. This makes the binning susceptible to rounding errors which + is accounted for by moving spikes which are within tolerance of the next + bin edge into the following bin. This can be adjusted using the tolerance + parameter and turned off by setting `tolerance=None`. See also -------- elephant.conversion.BinnedSpikeTrain + elephant.spike_train_processing.synchotool References ---------- @@ -873,37 +934,39 @@ class complexity: >>> import neo >>> import quantities as pq + >>> from elephant.statistics import complexity + + >>> sr = 1/pq.ms - >>> sampling_rate = 1/pq.ms >>> st1 = neo.SpikeTrain([1, 4, 6] * pq.ms, t_stop=10.0 * pq.ms) >>> st2 = neo.SpikeTrain([1, 5, 8] * pq.ms, t_stop=10.0 * pq.ms) + >>> sts = [st1, st2] >>> # spread = 0, a simple bincount - >>> ep1 = precise_complexity_intervals([st1, st2], sampling_rate) - >>> print(ep1.array_annotations['complexity'].flatten()) - [2 1 1 1 1] - >>> print(ep1.times) - [0. 3.5 4.5 5.5 7.5] ms - >>> print(ep1.durations) - [1.5 1. 1. 1. 1. ] ms + >>> cpx = complexity(sts, sampling_rate=sr) + Complexity calculated at sampling rate precision + >>> print(cpx.histogram) + [5 4 1] + >>> print(cpx.time_histogram.flatten()) + [0 2 0 0 1 1 1 0 1 0] dimensionless + >>> print(cpx.time_histogram.times) + [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] ms >>> # spread = 1, consecutive spikes - >>> ep2 = precise_complexity_intervals([st1, st2], sampling_rate, spread=1) - >>> print(ep2.array_annotations['complexity'].flatten()) - [2 3 1] - >>> print(ep2.times) - [0. 3.5 7.5] ms - >>> print(ep2.durations) - [1.5 3. 1. ] ms + >>> cpx = complexity(sts, sampling_rate=sr, spread=1) + Complexity calculated at sampling rate precision + >>> print(cpx.histogram) + [5 4 1] + >>> print(cpx.time_histogram.flatten()) + [0 2 0 0 3 3 3 0 1 0] dimensionless >>> # spread = 2, consecutive spikes and separated by 1 empty bin - >>> ep3 = precise_complexity_intervals([st1, st2], sampling_rate, spread=2) - >>> print(ep3.array_annotations['complexity'].flatten()) - [2 4] - >>> print(ep3.times) - [0. 3.5] ms - >>> print(ep3.durations) - [1.5 5. ] ms + >>> cpx = complexity(sts, sampling_rate=sr, spread=2) + Complexity calculated at sampling rate precision + >>> print(cpx.histogram) + [4 0 1 0 1] + >>> print(cpx.time_histogram.flatten()) + [0 2 0 0 4 4 4 4 4 0] dimensionless """ def __init__(self, spiketrains, @@ -936,25 +999,13 @@ def __init__(self, spiketrains, self.epoch = self._epoch_no_spread() else: print('Complexity calculated at sampling rate precision') - self.epoch = self.get_epoch() + self.epoch = self._epoch_with_spread() self.time_histogram, self.histogram = self._histogram_with_spread() @property def pdf(self): """ - Probability density computed from the complexity histogram which is the - histogram of the entries of the population histogram of - spike trains computed with a bin width of `complexity.bin_size`. - It provides for each complexity (== number of active neurons per bin) - the number of occurrences. The normalization of that histogram to 1 is - the probability density. - - Returns - ------- - complexity_distribution : neo.AnalogSignal - A `neo.AnalogSignal` object containing the histogram values. - `neo.AnalogSignal[j]` is the histogram computed between - `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + Probability density computed from the complexity histogram. """ norm_hist = self.histogram / self.histogram.sum() # Convert the Complexity pdf to an neo.AnalogSignal @@ -966,35 +1017,7 @@ def pdf(self): def _histogram_no_spread(self): """ - Complexity Distribution of a list of `neo.SpikeTrain` objects. - - Probability density computed from the complexity histogram which is the - histogram of the entries of the population histogram of clipped - (binary) spike trains computed with a bin width of `binsize`. - It provides for each complexity (== number of active neurons per bin) - the number of occurrences. The normalization of that histogram to 1 is - the probability density. - - Implementation is based on [1]_. - - Returns - ------- - complexity_distribution : neo.AnalogSignal - A `neo.AnalogSignal` object containing the histogram values. - `neo.AnalogSignal[j]` is the histogram computed between - `t_start + j * binsize` and `t_start + (j + 1) * binsize`. - - See also - -------- - elephant.conversion.BinnedSpikeTrain - - References - ---------- - .. [1] S. Gruen, M. Abeles, & M. Diesmann, "Impact of higher-order - correlations on coincidence distributions of massively parallel - data," In "Dynamic Brain - from Neural Spikes to Behaviors", - pp. 96-114, Springer Berlin Heidelberg, 2008. - + Calculate the complexity histogram and time histogram for `spread` = 0 """ # Computing the population histogram with parameter binary=True to # clip the spike trains before summing @@ -1011,14 +1034,7 @@ def _histogram_no_spread(self): def _histogram_with_spread(self): """ - Calculate the complexity histogram; - the number of occurrences of events of different complexities. - - Returns - ------- - complexity_histogram : np.ndarray - A histogram of complexities. `complexity_histogram[i]` corresponds - to the number of events of complexity `i` for `i > 0`. + Calculate the complexity histogram and time histogram for `spread` > 0 """ complexity_histogram = np.bincount( self.epoch.array_annotations['complexity']) @@ -1100,6 +1116,9 @@ def _histogram_with_spread(self): return time_histogram, complexity_histogram def _epoch_no_spread(self): + """ + Get an epoch object of the complexity distribution with `spread` = 0 + """ epoch = neo.Epoch(self.time_histogram.times, durations=self.bin_size * np.ones(self.time_histogram.shape), @@ -1108,7 +1127,10 @@ def _epoch_no_spread(self): self.time_histogram.magnitude.flatten()}) return epoch - def get_epoch(self): + def _epoch_with_spread(self): + """ + Get an epoch object of the complexity distribution with `spread` > 0 + """ bst = conv.BinnedSpikeTrain(self.input_spiketrains, binsize=self.bin_size) From 385ed2aa78daaf98c989a9c32b6c0520c586d383 Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 1 Jul 2020 12:43:44 +0200 Subject: [PATCH 27/58] Fix time_histogram not being returned as list of int (by not multiplying directly with pq.dimensionless) --- elephant/statistics.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 7a7adf77c..7605fcc8c 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -748,18 +748,21 @@ def time_histogram(spiketrains, binsize, t_start=None, t_stop=None, # Renormalise the histogram if output == 'counts': # Raw - bin_hist = bin_hist * pq.dimensionless + bin_hist = bin_hist + units = pq.dimensionless elif output == 'mean': # Divide by number of input spike trains - bin_hist = bin_hist * 1. / len(spiketrains) * pq.dimensionless + bin_hist = bin_hist * 1. / len(spiketrains) + units = pq.dimensionless elif output == 'rate': # Divide by number of input spike trains and bin width bin_hist = bin_hist * 1. / len(spiketrains) / binsize + units = bin_hist.units else: raise ValueError('Parameter output is not valid.') return neo.AnalogSignal(signal=bin_hist.reshape(bin_hist.size, 1), - sampling_period=binsize, units=bin_hist.units, + sampling_period=binsize, units=units, t_start=t_start) From db9a299649509a04c12b95922caba3b708ca2efe Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 1 Jul 2020 12:44:17 +0200 Subject: [PATCH 28/58] enforce t_start and t_stop checks --- elephant/statistics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 7605fcc8c..00f86eb93 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -979,7 +979,9 @@ def __init__(self, spiketrains, spread=0, tolerance=1e-8): - _check_consistency_of_spiketrainlist(spiketrains) + _check_consistency_of_spiketrainlist(spiketrains, + same_t_start=True, + same_t_stop=True) self.input_spiketrains = spiketrains self.t_start = spiketrains[0].t_start From e82399d7be9656314cbc6c14c2cafa73b9191d82 Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 1 Jul 2020 12:44:45 +0200 Subject: [PATCH 29/58] rename variables to avoid overwriting time_histogram function call --- elephant/statistics.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 00f86eb93..c690959fd 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1032,16 +1032,16 @@ def _histogram_no_spread(self): # Computing the histogram of the entries of pophist complexity_hist = np.histogram( - pophist.magnitude, + time_hist.magnitude, bins=range(0, len(self.input_spiketrains) + 2))[0] - return pophist, complexity_hist + return time_hist, complexity_hist def _histogram_with_spread(self): """ Calculate the complexity histogram and time histogram for `spread` > 0 """ - complexity_histogram = np.bincount( + complexity_hist = np.bincount( self.epoch.array_annotations['complexity']) num_bins = ((self.t_stop - self.t_start).rescale( self.bin_size.units) / self.bin_size.magnitude).item() @@ -1052,7 +1052,7 @@ def _histogram_with_spread(self): 'behaviour.') num_bins += 1 num_bins = int(num_bins) - time_histogram = np.zeros((num_bins, ), dtype=int) + time_hist = np.zeros((num_bins, ), dtype=int) start_bins = ((self.epoch.times - self.t_start).rescale( self.bin_size.units) / self.bin_size).magnitude.flatten() @@ -1098,11 +1098,11 @@ def _histogram_with_spread(self): stop_bins = stop_bins.astype(int) for idx, (start, stop) in enumerate(zip(start_bins, stop_bins)): - time_histogram[start:stop] = \ + time_hist[start:stop] = \ self.epoch.array_annotations['complexity'][idx] - time_histogram = neo.AnalogSignal( - signal=time_histogram.reshape(time_histogram.size, 1), + time_hist = neo.AnalogSignal( + signal=time_hist.reshape(time_hist.size, 1), sampling_period=self.bin_size, units=pq.dimensionless, t_start=self.t_start) @@ -1116,9 +1116,9 @@ def _histogram_with_spread(self): empty_bins += 1 empty_bins = int(empty_bins) - complexity_histogram[0] = empty_bins + complexity_hist[0] = empty_bins - return time_histogram, complexity_histogram + return time_hist, complexity_hist def _epoch_no_spread(self): """ From 2c6b02e8b78322ae4acd6e395a552c09d7318d8d Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 1 Jul 2020 12:45:07 +0200 Subject: [PATCH 30/58] include tolerance to binning in spread > 0 case --- elephant/statistics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index c690959fd..773f51b3a 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1137,7 +1137,8 @@ def _epoch_with_spread(self): Get an epoch object of the complexity distribution with `spread` > 0 """ bst = conv.BinnedSpikeTrain(self.input_spiketrains, - binsize=self.bin_size) + binsize=self.bin_size, + tolerance=self.tolerance) if self.binary: binarized = bst.to_sparse_bool_array() From 7b0c5414cfe939cd2497cfadc75aec06c48042dc Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 1 Jul 2020 18:16:39 +0200 Subject: [PATCH 31/58] Apply bin shifting to epoch without spread --- elephant/statistics.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 05f85a4ed..9ac71d235 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1100,9 +1100,25 @@ def _histogram_with_spread(self): return time_histogram, complexity_histogram def _epoch_no_spread(self): - epoch = neo.Epoch(self.time_histogram.times, - durations=self.bin_size - * np.ones(self.time_histogram.shape), + left_edges = self.time_histogram.times + durations = self.bin_size * np.ones(self.time_histogram.shape) + if self.sampling_rate: + # ensure that spikes are not on the bin edges + bin_shift = .5 / self.sampling_rate + left_edges -= bin_shift + else: + warnings.warn('No sampling rate specified. ' + 'Note that using the complexity epoch to get ' + 'precise spike times can lead to rounding errors.') + + # ensure that an epoch does not start before the minimum t_start + min_t_start = min([st.t_start for st in self.input_spiketrains]) + if left_edges[0] < min_t_start: + left_edges[0] = min_t_start + durations[0] -= bin_shift + + epoch = neo.Epoch(left_edges, + durations=durations, array_annotations={ 'complexity': self.time_histogram.magnitude.flatten()}) From e36b74a57380e016b03fb481d3feef3dd5121706 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 1 Jul 2020 18:17:13 +0200 Subject: [PATCH 32/58] Implement synchrofact detection as child class Adapt the tests to this class structure --- elephant/spike_train_processing.py | 137 +++++++++---------- elephant/test/test_spike_train_processing.py | 130 +++++++----------- 2 files changed, 116 insertions(+), 151 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 9fc80b41c..398fbc2a8 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -7,54 +7,73 @@ """ from __future__ import division - +from elephant.statistics import complexity +from copy import deepcopy import numpy as np class synchrotool(complexity): - complexity.__doc_ + # complexity.__doc_ def __init__(self, spiketrains, - sampling_rate=None, - spread=1): - self.spiketrains = spiketrains - self.sampling_rate = sampling_rate - self.spread = spread - - # self.super(...).__init__() + sampling_rate, + binary=True, + spread=0, + tolerance=1e-8): + + self.annotated = False + + super().__init__(spiketrains=spiketrains, + sampling_rate=sampling_rate, + binary=binary, + spread=spread, + tolerance=tolerance) + + def delete_synchrofacts(self, threshold, + in_place=False, invert=False): + + if not self.annotated: + self.annotate_synchrofacts() + + if threshold <= 1: + raise ValueError('A deletion threshold <= 1 would result ' + 'in the deletion of all spikes.') + + if in_place: + spiketrain_list = self.input_spiketrains + else: + spiketrain_list = deepcopy(self.input_spiketrains) + + for idx, st in enumerate(spiketrain_list): + mask = st.array_annotations['complexity'] < threshold + if invert: + mask = np.invert(mask) + new_st = st[mask] + spiketrain_list[idx] = new_st + if in_place: + unit = st.unit + segment = st.segment + if unit is not None: + unit.spiketrains[ + self._get_index(unit.spiketrains, st) + ] = new_st + if segment is not None: + segment.spiketrains[ + self._get_index(segment.spiketrains, st) + ] = new_st - # find times of synchrony of size >=n - # complexity_epoch = + return spiketrain_list - # ... - return self + def extract_synchrofacts(self, threshold, in_place=False): + return self.delete_synchrofacts(threshold=threshold, + in_place=in_place, + invert=True) def annotate_synchrofacts(self): - return None - - def delete_synchrofacts(self): - return None - - def extract_synchrofacts(self): - return None - - # def delete_synchrofacts(self, in_place=False): - # - # if not in_place: - # # return clean_spiketrains - # - # @property - # def synchrofacts(self): - # self.synchrofacts = self.detect_synchrofacts(deletion_threshold=1, - # invert_delete=True) - - def detect_synchrofacts(self, - deletion_threshold=None, - invert_delete=False): """ - Given a list of neo.Spiketrain objects, calculate the number of synchronous - spikes found and optionally delete or extract them from the given list - *in-place*. + Given a list of neo.Spiketrain objects, calculate the number of + synchronous spikes found and optionally delete or extract them from + the given list *in-place*. The spike trains are binned at sampling precission (i.e. bin_size = 1 / `sampling_rate`) @@ -122,48 +141,26 @@ def detect_synchrofacts(self, elephant.spike_train_processing.precise_complexity_intervals """ - if deletion_threshold is not None and deletion_threshold <= 1: - raise ValueError('A deletion_threshold <= 1 would result ' - 'in deletion of all spikes.') + epoch_complexities = self.epoch.array_annotations['complexity'] + right_edges = self.epoch.times.magnitude.flatten() + self.epoch.durations.rescale(self.epoch.times.units).magnitude.flatten() + print(self.epoch.times) + print(self.epoch.array_annotations['complexity']) - complexity = complexity_epoch.array_annotations['complexity'] - right_edges = complexity_epoch.times + complexity_epoch.durations - - # j = index of pre-selected sts in spiketrains - # idx = index of pre-selected sts in original - # block.segments[seg].spiketrains - for idx, st in enumerate(spiketrains): + for idx, st in enumerate(self.input_spiketrains): # all indices of spikes that are within the half-open intervals # defined by the boundaries # note that every second entry in boundaries is an upper boundary - spike_to_epoch_idx = np.searchsorted(right_edges, - st.times.rescale( - right_edges.units)) - complexity_per_spike = complexity[spike_to_epoch_idx] + spike_to_epoch_idx = np.searchsorted( + right_edges, + st.times.rescale(self.epoch.times.units).magnitude.flatten()) + complexity_per_spike = epoch_complexities[spike_to_epoch_idx] st.array_annotate(complexity=complexity_per_spike) - if deletion_threshold is not None: - mask = complexity_per_spike < deletion_threshold - if invert_delete: - mask = np.invert(mask) - old_st = st - new_st = old_st[mask] - spiketrains[idx] = new_st - unit = old_st.unit - segment = old_st.segment - if unit is not None: - unit.spiketrains[self._get_index(unit.spiketrains, - old_st)] = new_st - if segment is not None: - segment.spiketrains[self._get_index(segment.spiketrains, - old_st)] = new_st - del old_st - - return complexity_epoch + self.annotated = True - def _get_index(lst, obj): + def _get_index(self, lst, obj): for index, item in enumerate(lst): if item is obj: return index diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index dbfdc372b..3205d7c7c 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -17,14 +17,17 @@ class SynchrofactDetectionTestCase(unittest.TestCase): def _test_template(self, spiketrains, correct_complexities, sampling_rate, - spread, deletion_threshold=2, invert_delete=False): - # test annotation - spike_train_processing.detect_synchrofacts( + spread, deletion_threshold=2, invert_delete=False, + in_place=False, binary=True): + + synchrofact_obj = spike_train_processing.synchrotool( spiketrains, - spread=spread, sampling_rate=sampling_rate, - invert_delete=invert_delete, - deletion_threshold=None) + binary=binary, + spread=spread) + + # test annotation + synchrofact_obj.annotate_synchrofacts() annotations = [st.array_annotations['complexity'] for st in spiketrains] @@ -44,12 +47,9 @@ def _test_template(self, spiketrains, correct_complexities, sampling_rate, ]) # test deletion - spike_train_processing.detect_synchrofacts( - spiketrains, - spread=spread, - sampling_rate=sampling_rate, - invert_delete=invert_delete, - deletion_threshold=deletion_threshold) + synchrofact_obj.delete_synchrofacts(threshold=deletion_threshold, + in_place=in_place, + invert=invert_delete) cleaned_spike_times = np.array( [st.times for st in spiketrains]) @@ -94,7 +94,7 @@ def test_spread_0(self): [2, 1, 1, 1, 2, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, invert_delete=False, + spread=0, invert_delete=False, in_place=True, deletion_threshold=2) def test_spread_1(self): @@ -113,7 +113,7 @@ def test_spread_1(self): [2, 2, 1, 3, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert_delete=False, + spread=1, invert_delete=False, in_place=True, deletion_threshold=2) def test_n_equals_3(self): @@ -132,8 +132,8 @@ def test_n_equals_3(self): [3, 2, 1, 2, 3, 3, 3, 2]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert_delete=False, - deletion_threshold=3) + spread=1, invert_delete=False, binary=False, + in_place=True, deletion_threshold=3) def test_invert_delete(self): @@ -151,7 +151,7 @@ def test_invert_delete(self): [2, 2, 1, 3, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert_delete=True, + spread=1, invert_delete=True, in_place=True, deletion_threshold=2) def test_binning_for_input_with_rounding_errors(self): @@ -177,7 +177,7 @@ def test_binning_for_input_with_rounding_errors(self): second_annotations]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, invert_delete=False, + spread=0, invert_delete=False, in_place=True, deletion_threshold=2) def test_correct_transfer_of_spiketrain_attributes(self): @@ -221,11 +221,15 @@ def test_correct_transfer_of_spiketrain_attributes(self): spiketrain.array_annotations.items()} # perform a synchrofact search with delete=True - spike_train_processing.detect_synchrofacts([spiketrain], - spread=0, - sampling_rate=sampling_rate, - invert_delete=False, - deletion_threshold=2) + synchrofact_obj = spike_train_processing.synchrotool( + [spiketrain], + spread=0, + sampling_rate=sampling_rate, + binary=False) + synchrofact_obj.delete_synchrofacts( + invert=False, + in_place=True, + threshold=2) # Ensure that the spiketrain was not duplicated self.assertEqual(len(block.filter(objects=neo.SpikeTrain)), 1) @@ -245,64 +249,28 @@ def test_correct_transfer_of_spiketrain_attributes(self): self.assertTrue(key in cleaned_array_annotations.keys()) assert_array_almost_equal(value, cleaned_array_annotations[key]) - def test_complexity_histogram_spread_0(self): - - sampling_rate = 1 / pq.s - - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, - t_stop=20*pq.s), - neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, - t_stop=20*pq.s)] - - correct_histogram = np.array([0, 8, 2]) - - histogram = spike_train_processing.precise_complexity_histogram( - spiketrains, - sampling_rate=sampling_rate, - spread=0) - - assert_array_equal(histogram, correct_histogram) - - def test_complexity_histogram_spread_1(self): - - sampling_rate = 1 / pq.s - - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21*pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21*pq.s)] - - correct_histogram = np.array([0, 5, 2, 1]) - - histogram = spike_train_processing.precise_complexity_histogram( - spiketrains, - sampling_rate=sampling_rate, - spread=1) - - assert_array_equal(histogram, correct_histogram) - - def test_wrong_input_errors(self): - self.assertRaises(ValueError, - spike_train_processing.detect_synchrofacts, - [], 1 / pq.s) - self.assertRaises(TypeError, - spike_train_processing.detect_synchrofacts, - [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), - np.arange(2)], - 1 / pq.s) - self.assertRaises(ValueError, - spike_train_processing.detect_synchrofacts, - [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], - 1 / pq.s, - deletion_threshold=-1) - self.assertRaises(ValueError, - spike_train_processing.precise_complexity_intervals, - [], 1 / pq.s) - self.assertRaises(TypeError, - spike_train_processing.precise_complexity_intervals, - [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), - np.arange(2)], - 1 / pq.s) + # def test_wrong_input_errors(self): + # self.assertRaises(ValueError, + # spike_train_processing.detect_synchrofacts, + # [], 1 / pq.s) + # self.assertRaises(TypeError, + # spike_train_processing.detect_synchrofacts, + # [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), + # np.arange(2)], + # 1 / pq.s) + # self.assertRaises(ValueError, + # spike_train_processing.detect_synchrofacts, + # [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], + # 1 / pq.s, + # deletion_threshold=-1) + # self.assertRaises(ValueError, + # spike_train_processing.precise_complexity_intervals, + # [], 1 / pq.s) + # self.assertRaises(TypeError, + # spike_train_processing.precise_complexity_intervals, + # [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), + # np.arange(2)], + # 1 / pq.s) if __name__ == '__main__': From 39bb6747a627c3e15873ae7d6242b021e7e6ed41 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 1 Jul 2020 18:21:56 +0200 Subject: [PATCH 33/58] Fix overlooked occurrence of renamed variable --- elephant/statistics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 84656d9ab..bc9ad21a3 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1026,9 +1026,9 @@ def _histogram_no_spread(self): """ # Computing the population histogram with parameter binary=True to # clip the spike trains before summing - pophist = time_histogram(self.input_spiketrains, - self.bin_size, - binary=self.binary) + time_hist = time_histogram(self.input_spiketrains, + self.bin_size, + binary=self.binary) # Computing the histogram of the entries of pophist complexity_hist = np.histogram( From 9758ab93f6775980a8310bbbaca9900b166c04da Mon Sep 17 00:00:00 2001 From: Aitor Date: Sat, 4 Jul 2020 18:19:20 +0200 Subject: [PATCH 34/58] style changes to docstring --- elephant/statistics.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 773f51b3a..080045022 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -52,6 +52,7 @@ fanofactor complexity_pdf + complexity :copyright: Copyright 2014-2020 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. @@ -860,7 +861,7 @@ class complexity: Default: 0 tolerance : float, optional Tolerance for rounding errors in the binning process and in the input - data + data. Default: 1e-8 Attributes @@ -894,21 +895,22 @@ class complexity: ------ ValueError When `t_stop` is smaller than `t_start`. - ValueError + When both `sampling_rate` and `bin_size` are not specified. - ValueError + When `spread` is not a positive integer. - TypeError - When `spiketrains` is not a list. - ValueError + When `spiketrains` is an empty list. - TypeError - When the elements in `spiketrains` are not instances of neo.SpikeTrain - ValueError + When `t_start` is not the same for all spiketrains - ValueError + When `t_stop` is not the same for all spiketrains + TypeError + When `spiketrains` is not a list. + + When the elements in `spiketrains` are not instances of neo.SpikeTrain + Notes ----- * Note that with most common parameter combinations spike times can end up @@ -920,7 +922,7 @@ class complexity: See also -------- elephant.conversion.BinnedSpikeTrain - elephant.spike_train_processing.synchotool + elephant.spike_train_processing.synchrotool References ---------- @@ -1026,9 +1028,9 @@ def _histogram_no_spread(self): """ # Computing the population histogram with parameter binary=True to # clip the spike trains before summing - pophist = time_histogram(self.input_spiketrains, - self.bin_size, - binary=self.binary) + time_hist = time_histogram(self.input_spiketrains, + self.bin_size, + binary=self.binary) # Computing the histogram of the entries of pophist complexity_hist = np.histogram( From 411ac502a9acbd91b5b1ea04fa2b194b8424eb19 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 15 Jul 2020 11:57:32 +0200 Subject: [PATCH 35/58] Account for rounding errors in input check --- elephant/conversion.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/elephant/conversion.py b/elephant/conversion.py index 553ec406a..33524ab07 100644 --- a/elephant/conversion.py +++ b/elephant/conversion.py @@ -633,12 +633,17 @@ def _check_consistency(self, spiketrains, binsize, num_bins, t_start, raise ValueError( 'too many / too large time bins. Some spike trains are ' 'not defined in the ending time') - if num_bins != int(( + # account for rounding errors in the reference num_bins + num_bins_test = (( (t_stop - t_start).rescale( - binsize.units) / binsize).magnitude): + binsize.units) / binsize).magnitude) + if _detect_rounding_errors(num_bins_test, tolerance=self.tolerance): + num_bins_test += 1 + num_bins_test = int(num_bins_test) + if num_bins != num_bins_test: raise ValueError( "Inconsistent arguments t_start (%s), " % t_start + - "t_stop (%s), binsize (%d) " % (t_stop, binsize) + + "t_stop (%s), binsize (%s) " % (t_stop, binsize) + "and num_bins (%d)" % num_bins) if num_bins - int(num_bins) != 0 or num_bins < 0: raise TypeError( From 2dc37fe19567b84c2739ab8b102de2f5148eed69 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 15 Jul 2020 11:58:39 +0200 Subject: [PATCH 36/58] Test for spread<=1 error raising --- elephant/test/test_spike_train_processing.py | 31 ++++++-------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index 3205d7c7c..511d19cd2 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -249,28 +249,15 @@ def test_correct_transfer_of_spiketrain_attributes(self): self.assertTrue(key in cleaned_array_annotations.keys()) assert_array_almost_equal(value, cleaned_array_annotations[key]) - # def test_wrong_input_errors(self): - # self.assertRaises(ValueError, - # spike_train_processing.detect_synchrofacts, - # [], 1 / pq.s) - # self.assertRaises(TypeError, - # spike_train_processing.detect_synchrofacts, - # [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), - # np.arange(2)], - # 1 / pq.s) - # self.assertRaises(ValueError, - # spike_train_processing.detect_synchrofacts, - # [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], - # 1 / pq.s, - # deletion_threshold=-1) - # self.assertRaises(ValueError, - # spike_train_processing.precise_complexity_intervals, - # [], 1 / pq.s) - # self.assertRaises(TypeError, - # spike_train_processing.precise_complexity_intervals, - # [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), - # np.arange(2)], - # 1 / pq.s) + def test_wrong_input_errors(self): + synchrofact_obj = spike_train_processing.synchrotool( + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], + sampling_rate=1/pq.s, + binary=True, + spread=1) + self.assertRaises(ValueError, + synchrofact_obj.delete_synchrofacts, + -1) if __name__ == '__main__': From 1264a805c1e25b803f58cdc97336e09767fe2240 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 15 Jul 2020 11:59:39 +0200 Subject: [PATCH 37/58] Test for no sampling rate warning --- elephant/test/test_statistics.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 6950642c7..4d043bb06 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -785,6 +785,18 @@ def test_wrong_input_errors(self): sampling_rate=1*pq.s, spread=-7) + @unittest.skipUnless(python_version_major == 3, "assertWarns requires 3.2") + def test_sampling_rate_warning(self): + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + with self.assertWarns(UserWarning): + es.complexity(spiketrains, + bin_size=1*pq.s, + spread=1) + def test_binning_for_input_with_rounding_errors(self): # a test with inputs divided by 30000 which leads to rounding errors From c48df7f978e11dabfe87a8c3f6fcba1e42cc4296 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 15 Jul 2020 12:01:10 +0200 Subject: [PATCH 38/58] Test for num_bins rounding error --- elephant/test/test_statistics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 4d043bb06..0a130387e 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -803,14 +803,14 @@ def test_binning_for_input_with_rounding_errors(self): # these errors have to be accounted for by proper binning; # check if we still get the correct result - sampling_rate = 30000 / pq.s + sampling_rate = 333 / pq.s - spiketrains = [neo.SpikeTrain(np.arange(1000, step=2) * pq.s / 30000, - t_stop=.1 * pq.s), - neo.SpikeTrain(np.arange(2000, step=4) * pq.s / 30000, - t_stop=.1 * pq.s)] + spiketrains = [neo.SpikeTrain(np.arange(1000, step=2) * pq.s / 333, + t_stop=30.33333333333 * pq.s), + neo.SpikeTrain(np.arange(2000, step=4) * pq.s / 333, + t_stop=30.33333333333 * pq.s)] - correct_time_histogram = np.zeros(3000) + correct_time_histogram = np.zeros(10101) correct_time_histogram[:1000:2] = 1 correct_time_histogram[:2000:4] += 1 From 764544b4dd32cfe9c127a515b7ae1bdfe046243f Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 15 Jul 2020 12:02:15 +0200 Subject: [PATCH 39/58] Cleanup --- elephant/spike_train_processing.py | 8 +------- elephant/test/test_statistics.py | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 398fbc2a8..216871f9f 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -29,8 +29,7 @@ def __init__(self, spiketrains, spread=spread, tolerance=tolerance) - def delete_synchrofacts(self, threshold, - in_place=False, invert=False): + def delete_synchrofacts(self, threshold, in_place=False, invert=False): if not self.annotated: self.annotate_synchrofacts() @@ -64,11 +63,6 @@ def delete_synchrofacts(self, threshold, return spiketrain_list - def extract_synchrofacts(self, threshold, in_place=False): - return self.delete_synchrofacts(threshold=threshold, - in_place=in_place, - invert=True) - def annotate_synchrofacts(self): """ Given a list of neo.Spiketrain objects, calculate the number of diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 0a130387e..ca9029985 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -641,7 +641,7 @@ def test_time_histogram_output(self): class ComplexityPdfTestCase(unittest.TestCase): - def test_complexity_pdf_old(self): + def test_complexity_pdf_deprecated(self): spiketrain_a = neo.SpikeTrain( [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) spiketrain_b = neo.SpikeTrain( From 6b4cd4f38104310afc909923bd7c8fc8f1f48e30 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 15 Jul 2020 12:18:36 +0200 Subject: [PATCH 40/58] Cleanup --- elephant/statistics.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 19f08f5c3..0dde6625b 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -985,6 +985,12 @@ def __init__(self, spiketrains, same_t_start=True, same_t_stop=True) + if bin_size is None and sampling_rate is None: + raise ValueError('No bin_size or sampling_rate was specified!') + + if spread < 0: + raise ValueError('Spread must be >=0') + self.input_spiketrains = spiketrains self.t_start = spiketrains[0].t_start self.t_stop = spiketrains[0].t_stop @@ -994,18 +1000,13 @@ def __init__(self, spiketrains, self.spread = spread self.tolerance = tolerance - if bin_size is None and sampling_rate is None: - raise ValueError('No bin_size or sampling_rate was speficied!') - elif bin_size is None and sampling_rate is not None: + if bin_size is None and sampling_rate is not None: self.bin_size = 1 / self.sampling_rate - if spread < 0: - raise ValueError('Spread must be >=0') - elif spread == 0: + if spread == 0: self.time_histogram, self.histogram = self._histogram_no_spread() self.epoch = self._epoch_no_spread() else: - print('Complexity calculated at sampling rate precision') self.epoch = self._epoch_with_spread() self.time_histogram, self.histogram = self._histogram_with_spread() From dc7b9f45ddfb5fda4443297fcf20765e17773824 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 15 Jul 2020 12:41:59 +0200 Subject: [PATCH 41/58] Remove prints --- elephant/spike_train_processing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 216871f9f..c021bbdc3 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -136,9 +136,11 @@ def annotate_synchrofacts(self): """ epoch_complexities = self.epoch.array_annotations['complexity'] - right_edges = self.epoch.times.magnitude.flatten() + self.epoch.durations.rescale(self.epoch.times.units).magnitude.flatten() - print(self.epoch.times) - print(self.epoch.array_annotations['complexity']) + right_edges = ( + self.epoch.times.magnitude.flatten() + + self.epoch.durations.rescale( + self.epoch.times.units).magnitude.flatten() + ) for idx, st in enumerate(self.input_spiketrains): From 934ebc3f3941079fb100696183c83ec4baf96142 Mon Sep 17 00:00:00 2001 From: Aitor Date: Wed, 15 Jul 2020 14:50:06 +0200 Subject: [PATCH 42/58] update docs, rename `invert` kwarg to `mode` --- elephant/spike_train_processing.py | 228 ++++++++++++++++++++--------- elephant/statistics.py | 20 +-- 2 files changed, 166 insertions(+), 82 deletions(-) diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index c021bbdc3..1885fd910 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -13,7 +13,117 @@ class synchrotool(complexity): - # complexity.__doc_ + """ + Tool class to find, remove and/or annotate the presence of synchronous + spiking events across multiple spike trains. + + The complexity is used to characterize synchronous events within the same + spike train and across different spike trains in the `spiketrains` list. + Such that, synchronous events can be found both in multi-unit and + single-unit spike trains. + + This class inherits from ``elephant.statistics.complexity``, see its + documentation for more details. + + *The rest of this documentation is copied from + ``elephant.statistics.complexity`` !!! + TODO: Figure out a better way to merge the docstrings.* + + Parameters + ---------- + spiketrains : list of neo.SpikeTrain + Spike trains with a common time axis (same `t_start` and `t_stop`) + sampling_rate : pq.Quantity, optional + Sampling rate of the spike trains with units of 1/time. + Default: None + bin_size : pq.Quantity, optional + Width of the histogram's time bins with units of time. + The user must specify the `bin_size` or the `sampling_rate`. + * If no `bin_size` is specified and the `sampling_rate` is available + 1/`sampling_rate` is used. + * If both are given then `bin_size` is used. + Default: None + binary : bool, optional + * If `True` then the time histograms will be binary. + * If `False` the total number of synchronous spikes is counted in the + time histogram. + Default: True + spread : int, optional + Number of bins in which to check for synchronous spikes. + Spikes that occur separated by `spread - 1` or less empty bins are + considered synchronous. + * ``spread = 0`` corresponds to a bincount accross spike trains. + * ``spread = 1`` corresponds to counting consecutive spikes. + * ``spread = 2`` corresponds to counting consecutive spikes and + spikes separated by exactly 1 empty bin. + * ``spread = n`` corresponds to counting spikes separated by exactly + or less than `n - 1` empty bins. + Default: 0 + tolerance : float, optional + Tolerance for rounding errors in the binning process and in the input + data. + Default: 1e-8 + + Attributes + ---------- + epoch : neo.Epoch + An epoch object containing complexity values, left edges and durations + of all intervals with at least one spike. + * ``epoch.array_annotations['complexity']`` contains the + complexity values per spike. + * ``epoch.times`` contains the left edges. + * ``epoch.durations`` contains the durations. + time_histogram : neo.Analogsignal + A `neo.AnalogSignal` object containing the histogram values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + * If ``binary = True`` : Number of neurons that spiked in each bin, + regardless of the number of spikes. + * If ``binary = False`` : Number of neurons and spikes per neurons + in each bin. + complexity_histogram : np.ndarray + The number of occurrences of events of different complexities. + `complexity_hist[i]` corresponds to the number of events of + complexity `i` for `i > 0`. + pdf : neo.AnalogSignal + The normalization of `self.complexityhistogram` to 1. + A `neo.AnalogSignal` object containing the pdf values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + + Raises + ------ + ValueError + When `t_stop` is smaller than `t_start`. + + When both `sampling_rate` and `bin_size` are not specified. + + When `spread` is not a positive integer. + + When `spiketrains` is an empty list. + + When `t_start` is not the same for all spiketrains + + When `t_stop` is not the same for all spiketrains + + TypeError + When `spiketrains` is not a list. + + When the elements in `spiketrains` are not instances of neo.SpikeTrain + + Notes + ----- + * Note that with most common parameter combinations spike times can end up + on bin edges. This makes the binning susceptible to rounding errors which + is accounted for by moving spikes which are within tolerance of the next + bin edge into the following bin. This can be adjusted using the tolerance + parameter and turned off by setting `tolerance=None`. + + See also + -------- + elephant.statistics.complexity + + """ def __init__(self, spiketrains, sampling_rate, @@ -29,11 +139,52 @@ def __init__(self, spiketrains, spread=spread, tolerance=tolerance) - def delete_synchrofacts(self, threshold, in_place=False, invert=False): + def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): + """ + Delete or extract synchronous spiking events. + + Parameters + ---------- + threshold : int + Threshold value for the deletion of spikes engaged in synchronous + activity. + * `deletion_threshold >= 2` leads to all spikes with a larger or + equal complexity value to be deleted/extracted. + * `deletion_threshold <= 1` leads to a ValueError, since this + would delete/extract all spikes and there are definitely more + efficient ways of doing so. + in_place : bool + Determines whether the modification are made in place + on ``self.input_spiketrains``. + Default: False + mode : bool + Inversion of the mask for deletion of synchronous events. + * ``'delete'`` leads to the deletion of all spikes with + complexity >= `threshold`, + i.e. deletes synchronous spikes. + * ``'extract'`` leads to the deletion of all spikes with + complexity < `threshold`, i.e. extracts synchronous spikes. + Default: 'delete' + + Returns + ------- + list of neo.SpikeTrain + List of spiketrains where the spikes with + ``complexity >= threshold`` have been deleted/extracted. + * If ``in_place`` is True, the returned list is the same as + ``self.input_spiketrains``. + * If ``in_place`` is False, the returned list is a deepcopy of + ``self.input_spiketrains``. + + """ if not self.annotated: self.annotate_synchrofacts() + if mode not in ['delete', 'extract']: + raise ValueError(str(mode) + ' is not a valid mode. ' + "valid modes are ['delete', 'extract']") + if threshold <= 1: raise ValueError('A deletion threshold <= 1 would result ' 'in the deletion of all spikes.') @@ -45,7 +196,7 @@ def delete_synchrofacts(self, threshold, in_place=False, invert=False): for idx, st in enumerate(spiketrain_list): mask = st.array_annotations['complexity'] < threshold - if invert: + if mode == 'extract': mask = np.invert(mask) new_st = st[mask] spiketrain_list[idx] = new_st @@ -65,75 +216,8 @@ def delete_synchrofacts(self, threshold, in_place=False, invert=False): def annotate_synchrofacts(self): """ - Given a list of neo.Spiketrain objects, calculate the number of - synchronous spikes found and optionally delete or extract them from - the given list *in-place*. - - The spike trains are binned at sampling precission - (i.e. bin_size = 1 / `sampling_rate`) - - Two spikes are considered synchronous if they occur separated by strictly - fewer than `spread - 1` empty bins from one another. See - `elephant.statistics.precise_complexity_intervals` for a detailed - description of how synchronous events are counted. - - Synchronous events are considered within the same spike train and across - different spike trains in the `spiketrains` list. Such that, synchronous - events can be found both in multi-unit and single-unit spike trains. - - The spike trains in the `spiketrains` list are annotated with the - complexity value of each spike in their :attr:`array_annotations`. - - - Parameters - ---------- - spiketrains : list of neo.SpikeTrains - a list of neo.SpikeTrains objects. These spike trains should have been - recorded simultaneously. - sampling_rate : pq.Quantity - Sampling rate of the spike trains. The spike trains are binned with - bin_size = 1 / `sampling_rate`. - spread : int - Number of bins in which to check for synchronous spikes. - Spikes that occur separated by `spread - 1` or less empty bins are - considered synchronous. - Default: 1 - deletion_threshold : int, optional - Threshold value for the deletion of spikes engaged in synchronous - activity. - * `deletion_threshold = None` leads to no spikes being deleted, spike - trains are array-annotated and the spike times are kept unchanged. - * `deletion_threshold >= 2` leads to all spikes with a larger or - equal complexity value to be deleted *in-place*. - * `deletion_threshold` cannot be set to 1 (this would delete all - spikes and there are definitely more efficient ways of doing this) - * `deletion_threshold <= 0` leads to a ValueError. - Default: None - invert_delete : bool - Inversion of the mask for deletion of synchronous events. - * `invert_delete = False` leads to the deletion of all spikes with - complexity >= `deletion_threshold`, - i.e. deletes synchronous spikes. - * `invert_delete = True` leads to the deletion of all spikes with - complexity < `deletion_threshold`, i.e. returns synchronous spikes. - Default: False - - Returns - ------- - complexity_epoch : neo.Epoch - An epoch object containing complexity values, left edges and durations - of all intervals with at least one spike. - Calculated with - `elephant.spike_train_processing.precise_complexity_intervals`. - * ``complexity_epoch.array_annotations['complexity']`` contains the - complexity values per spike. - * ``complexity_epoch.times`` contains the left edges. - * ``complexity_epoch.durations`` contains the durations. - - See also - -------- - elephant.spike_train_processing.precise_complexity_intervals - + Annotate the complexity of each spike in the array_annotations + *in-place*. """ epoch_complexities = self.epoch.array_annotations['complexity'] right_edges = ( diff --git a/elephant/statistics.py b/elephant/statistics.py index 0dde6625b..c52b8422d 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -839,14 +839,14 @@ class complexity: bin_size : pq.Quantity, optional Width of the histogram's time bins with units of time. The user must specify the `bin_size` or the `sampling_rate`. - * If no `bin_size` is specified and the `sampling_rate` is available - 1/`sampling_rate` is used. - * If both are given then `bin_size` is used. + * If no `bin_size` is specified and the `sampling_rate` is available + 1/`sampling_rate` is used. + * If both are given then `bin_size` is used. Default: None binary : bool, optional - * If `True` then the time histograms will be binary. - * If `False` the total number of synchronous spikes is counted in the - time histogram. + * If `True` then the time histograms will be binary. + * If `False` the total number of synchronous spikes is counted in the + time histogram. Default: True spread : int, optional Number of bins in which to check for synchronous spikes. @@ -877,10 +877,10 @@ class complexity: A `neo.AnalogSignal` object containing the histogram values. `neo.AnalogSignal[j]` is the histogram computed between `t_start + j * binsize` and `t_start + (j + 1) * binsize`. - * If ``binary = True`` : Number of neurons that spiked in each bin, - regardless of the number of spikes. - * If ``binary = False`` : Number of neurons and spikes per neurons - in each bin. + * If ``binary = True`` : Number of neurons that spiked in each bin, + regardless of the number of spikes. + * If ``binary = False`` : Number of neurons and spikes per neurons + in each bin. complexity_histogram : np.ndarray The number of occurrences of events of different complexities. `complexity_hist[i]` corresponds to the number of events of From e8c52fedec53c6a30e9a4be798ad81f0c2cf6f91 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Mon, 20 Jul 2020 15:54:13 +0200 Subject: [PATCH 43/58] Account for renaming of invert_delete in tests --- elephant/test/test_spike_train_processing.py | 22 ++++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index 511d19cd2..00aa590fc 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -17,7 +17,7 @@ class SynchrofactDetectionTestCase(unittest.TestCase): def _test_template(self, spiketrains, correct_complexities, sampling_rate, - spread, deletion_threshold=2, invert_delete=False, + spread, deletion_threshold=2, mode='delete', in_place=False, binary=True): synchrofact_obj = spike_train_processing.synchrotool( @@ -34,7 +34,7 @@ def _test_template(self, spiketrains, correct_complexities, sampling_rate, assert_array_equal(annotations, correct_complexities) - if invert_delete: + if mode == 'extract': correct_spike_times = np.array( [spikes[mask] for spikes, mask in zip(spiketrains, @@ -49,7 +49,7 @@ def _test_template(self, spiketrains, correct_complexities, sampling_rate, # test deletion synchrofact_obj.delete_synchrofacts(threshold=deletion_threshold, in_place=in_place, - invert=invert_delete) + mode=mode) cleaned_spike_times = np.array( [st.times for st in spiketrains]) @@ -74,7 +74,7 @@ def test_no_synchrofacts(self): [1, 1, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert_delete=False, + spread=1, mode='delete', deletion_threshold=2) def test_spread_0(self): @@ -94,7 +94,7 @@ def test_spread_0(self): [2, 1, 1, 1, 2, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, invert_delete=False, in_place=True, + spread=0, mode='delete', in_place=True, deletion_threshold=2) def test_spread_1(self): @@ -113,7 +113,7 @@ def test_spread_1(self): [2, 2, 1, 3, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert_delete=False, in_place=True, + spread=1, mode='delete', in_place=True, deletion_threshold=2) def test_n_equals_3(self): @@ -132,10 +132,10 @@ def test_n_equals_3(self): [3, 2, 1, 2, 3, 3, 3, 2]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert_delete=False, binary=False, + spread=1, mode='delete', binary=False, in_place=True, deletion_threshold=3) - def test_invert_delete(self): + def test_extract(self): # test synchrofact search taking into account adjacent bins # this requires an additional loop with shifted binning @@ -151,7 +151,7 @@ def test_invert_delete(self): [2, 2, 1, 3, 1, 1]]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, invert_delete=True, in_place=True, + spread=1, mode='extract', in_place=True, deletion_threshold=2) def test_binning_for_input_with_rounding_errors(self): @@ -177,7 +177,7 @@ def test_binning_for_input_with_rounding_errors(self): second_annotations]) self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, invert_delete=False, in_place=True, + spread=0, mode='delete', in_place=True, deletion_threshold=2) def test_correct_transfer_of_spiketrain_attributes(self): @@ -227,7 +227,7 @@ def test_correct_transfer_of_spiketrain_attributes(self): sampling_rate=sampling_rate, binary=False) synchrofact_obj.delete_synchrofacts( - invert=False, + mode='delete', in_place=True, threshold=2) From f6a9d202f9de556eb26b6b309aeeaf84c2779de4 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Mon, 20 Jul 2020 15:54:56 +0200 Subject: [PATCH 44/58] Only account for shift when it actually happens --- elephant/statistics.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index c52b8422d..c974b94f9 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1129,21 +1129,23 @@ def _epoch_no_spread(self): """ left_edges = self.time_histogram.times durations = self.bin_size * np.ones(self.time_histogram.shape) + if self.sampling_rate: # ensure that spikes are not on the bin edges bin_shift = .5 / self.sampling_rate left_edges -= bin_shift + + # ensure that an epoch does not start before the minimum t_start + min_t_start = min([st.t_start for st in self.input_spiketrains]) + if left_edges[0] < min_t_start: + left_edges[0] = min_t_start + durations[0] -= bin_shift + else: warnings.warn('No sampling rate specified. ' 'Note that using the complexity epoch to get ' 'precise spike times can lead to rounding errors.') - # ensure that an epoch does not start before the minimum t_start - min_t_start = min([st.t_start for st in self.input_spiketrains]) - if left_edges[0] < min_t_start: - left_edges[0] = min_t_start - durations[0] -= bin_shift - epoch = neo.Epoch(left_edges, durations=durations, array_annotations={ From ba0eb64be23eb6d3aaaea8f9dc0032efaa26ad24 Mon Sep 17 00:00:00 2001 From: dizcza Date: Tue, 11 Aug 2020 13:35:12 +0200 Subject: [PATCH 45/58] decorative refactoring --- doc/reference/spike_train_processing.rst | 1 - elephant/spike_train_processing.py | 61 +++++++++++++----------- elephant/statistics.py | 35 +++++++------- elephant/test/test_statistics.py | 10 ++-- requirements/requirements-docs.txt | 1 + 5 files changed, 59 insertions(+), 49 deletions(-) diff --git a/doc/reference/spike_train_processing.rst b/doc/reference/spike_train_processing.rst index aca083c97..b701938bd 100644 --- a/doc/reference/spike_train_processing.rst +++ b/doc/reference/spike_train_processing.rst @@ -9,4 +9,3 @@ Spike train processing .. automodule:: elephant.spike_train_processing - :members: diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 1885fd910..152851f08 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -1,16 +1,26 @@ # -*- coding: utf-8 -*- """ -Module for spike train processing +Module for spike train processing. + + +.. autosummary:: + :toctree: toctree/spike_train_processing/ + + synchrotool + :copyright: Copyright 2014-2020 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ -from __future__ import division -from elephant.statistics import complexity +from __future__ import division, print_function, unicode_literals + from copy import deepcopy + import numpy as np +from elephant.statistics import complexity + class synchrotool(complexity): """ @@ -37,6 +47,7 @@ class synchrotool(complexity): Sampling rate of the spike trains with units of 1/time. Default: None bin_size : pq.Quantity, optional + FIXME: no bin_size detected Width of the histogram's time bins with units of time. The user must specify the `bin_size` or the `sampling_rate`. * If no `bin_size` is specified and the `sampling_rate` is available @@ -85,11 +96,6 @@ class synchrotool(complexity): The number of occurrences of events of different complexities. `complexity_hist[i]` corresponds to the number of events of complexity `i` for `i > 0`. - pdf : neo.AnalogSignal - The normalization of `self.complexityhistogram` to 1. - A `neo.AnalogSignal` object containing the pdf values. - `neo.AnalogSignal[j]` is the histogram computed between - `t_start + j * binsize` and `t_start + (j + 1) * binsize`. Raises ------ @@ -113,11 +119,11 @@ class synchrotool(complexity): Notes ----- - * Note that with most common parameter combinations spike times can end up - on bin edges. This makes the binning susceptible to rounding errors which - is accounted for by moving spikes which are within tolerance of the next - bin edge into the following bin. This can be adjusted using the tolerance - parameter and turned off by setting `tolerance=None`. + Note that with most common parameter combinations spike times can end up + on bin edges. This makes the binning susceptible to rounding errors which + is accounted for by moving spikes which are within tolerance of the next + bin edge into the following bin. This can be adjusted using the tolerance + parameter and turned off by setting `tolerance=None`. See also -------- @@ -133,11 +139,11 @@ def __init__(self, spiketrains, self.annotated = False - super().__init__(spiketrains=spiketrains, - sampling_rate=sampling_rate, - binary=binary, - spread=spread, - tolerance=tolerance) + super(synchrotool, self).__init__(spiketrains=spiketrains, + sampling_rate=sampling_rate, + binary=binary, + spread=spread, + tolerance=tolerance) def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): """ @@ -204,13 +210,13 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): unit = st.unit segment = st.segment if unit is not None: - unit.spiketrains[ - self._get_index(unit.spiketrains, st) - ] = new_st + new_index = self._get_spiketrain_index( + unit.spiketrains, st) + unit.spiketrains[new_index] = new_st if segment is not None: - segment.spiketrains[ - self._get_index(segment.spiketrains, st) - ] = new_st + new_index = self._get_spiketrain_index( + segment.spiketrains, st) + segment.spiketrains[new_index] = new_st return spiketrain_list @@ -240,7 +246,8 @@ def annotate_synchrofacts(self): self.annotated = True - def _get_index(self, lst, obj): - for index, item in enumerate(lst): - if item is obj: + def _get_spiketrain_index(self, spiketrain_list, spiketrain): + for index, item in enumerate(spiketrain_list): + if item is spiketrain: return index + raise ValueError("Spiketrain is not found in the list") diff --git a/elephant/statistics.py b/elephant/statistics.py index d2f0f60b0..f86fdab45 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -847,7 +847,7 @@ def complexity_pdf(spiketrains, bin_size): complexity_obj = complexity(spiketrains, bin_size=bin_size) - return complexity_obj.pdf + return complexity_obj.pdf() class complexity: @@ -917,11 +917,6 @@ class complexity: The number of occurrences of events of different complexities. `complexity_hist[i]` corresponds to the number of events of complexity `i` for `i > 0`. - pdf : neo.AnalogSignal - The normalization of `self.complexityhistogram` to 1. - A `neo.AnalogSignal` object containing the pdf values. - `neo.AnalogSignal[j]` is the histogram computed between - `t_start + j * binsize` and `t_start + (j + 1) * binsize`. Raises ------ @@ -982,7 +977,7 @@ class complexity: >>> # spread = 0, a simple bincount >>> cpx = complexity(sts, sampling_rate=sr) Complexity calculated at sampling rate precision - >>> print(cpx.histogram) + >>> print(cpx.complexity_histogram) [5 4 1] >>> print(cpx.time_histogram.flatten()) [0 2 0 0 1 1 1 0 1 0] dimensionless @@ -992,7 +987,7 @@ class complexity: >>> # spread = 1, consecutive spikes >>> cpx = complexity(sts, sampling_rate=sr, spread=1) Complexity calculated at sampling rate precision - >>> print(cpx.histogram) + >>> print(cpx.complexity_histogram) [5 4 1] >>> print(cpx.time_histogram.flatten()) [0 2 0 0 3 3 3 0 1 0] dimensionless @@ -1000,7 +995,7 @@ class complexity: >>> # spread = 2, consecutive spikes and separated by 1 empty bin >>> cpx = complexity(sts, sampling_rate=sr, spread=2) Complexity calculated at sampling rate precision - >>> print(cpx.histogram) + >>> print(cpx.complexity_histogram) [4 0 1 0 1] >>> print(cpx.time_histogram.flatten()) [0 2 0 0 4 4 4 4 4 0] dimensionless @@ -1036,18 +1031,26 @@ def __init__(self, spiketrains, self.bin_size = 1 / self.sampling_rate if spread == 0: - self.time_histogram, self.histogram = self._histogram_no_spread() + self.time_histogram, self.complexity_histogram = \ + self._histogram_no_spread() self.epoch = self._epoch_no_spread() else: self.epoch = self._epoch_with_spread() - self.time_histogram, self.histogram = self._histogram_with_spread() + self.time_histogram, self.complexity_histogram = \ + self._histogram_with_spread() - @property def pdf(self): """ Probability density computed from the complexity histogram. + + Returns + ------- + pdf : neo.AnalogSignal + A `neo.AnalogSignal` object containing the pdf values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. """ - norm_hist = self.histogram / self.histogram.sum() + norm_hist = self.complexity_histogram / self.complexity_histogram.sum() # Convert the Complexity pdf to an neo.AnalogSignal pdf = neo.AnalogSignal( np.array(norm_hist).reshape(len(norm_hist), 1) * @@ -1137,7 +1140,7 @@ def _histogram_with_spread(self): self.epoch.array_annotations['complexity'][idx] time_hist = neo.AnalogSignal( - signal=time_hist.reshape(time_hist.size, 1), + signal=np.expand_dims(time_hist, axis=1), sampling_period=self.bin_size, units=pq.dimensionless, t_start=self.t_start) @@ -1168,7 +1171,7 @@ def _epoch_no_spread(self): left_edges -= bin_shift # ensure that an epoch does not start before the minimum t_start - min_t_start = min([st.t_start for st in self.input_spiketrains]) + min_t_start = min(st.t_start for st in self.input_spiketrains) if left_edges[0] < min_t_start: left_edges[0] = min_t_start durations[0] -= bin_shift @@ -1244,7 +1247,7 @@ def _epoch_with_spread(self): 'precise spike times can lead to rounding errors.') # ensure that an epoch does not start before the minimum t_start - min_t_start = min([st.t_start for st in self.input_spiketrains]) + min_t_start = min(st.t_start for st in self.input_spiketrains) left_edges[0] = max(min_t_start, left_edges[0]) complexity_epoch = neo.Epoch(times=left_edges, diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index dba405986..cc2c7338b 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -777,8 +777,8 @@ def test_complexity_pdf(self): # runs the previous function which will be deprecated targ = np.array([0.92, 0.01, 0.01, 0.06]) complexity_obj = statistics.complexity(spiketrains, bin_size=0.1*pq.s) - pdf = complexity_obj.pdf - assert_array_equal(targ, complexity_obj.pdf.magnitude[:, 0]) + pdf = complexity_obj.pdf() + assert_array_equal(targ, complexity_obj.pdf().magnitude[:, 0]) self.assertEqual(1, pdf.magnitude[:, 0].sum()) self.assertEqual(len(spiketrains)+1, len(pdf)) self.assertIsInstance(pdf, neo.AnalogSignal) @@ -802,7 +802,7 @@ def test_complexity_histogram_spread_0(self): sampling_rate=sampling_rate, spread=0) - assert_array_equal(complexity_obj.histogram, + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( @@ -843,7 +843,7 @@ def test_complexity_histogram_spread_1(self): sampling_rate=sampling_rate, spread=1) - assert_array_equal(complexity_obj.histogram, + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( @@ -868,7 +868,7 @@ def test_complexity_histogram_spread_2(self): sampling_rate=sampling_rate, spread=2) - assert_array_equal(complexity_obj.histogram, + assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) assert_array_equal( diff --git a/requirements/requirements-docs.txt b/requirements/requirements-docs.txt index 460baeab5..ff9f7303b 100644 --- a/requirements/requirements-docs.txt +++ b/requirements/requirements-docs.txt @@ -6,3 +6,4 @@ nbsphinx>=0.5.0 sphinxcontrib-bibtex>=1.0.0 sphinx-tabs>=1.1.13 matplotlib>=3.1.0 +# conda install -c conda-forge pandoc From 92292ff91084e664f95f1bd1f6dfbff3a9896eb4 Mon Sep 17 00:00:00 2001 From: dizcza Date: Wed, 30 Sep 2020 17:12:50 +0200 Subject: [PATCH 46/58] python 2 issue; Capitalize class names; fixed TODOs --- .gitignore | 8 +- elephant/spike_train_processing.py | 114 +++---------------- elephant/statistics.py | 23 ++-- elephant/test/test_spike_train_processing.py | 28 ++--- elephant/test/test_statistics.py | 42 +++---- 5 files changed, 65 insertions(+), 150 deletions(-) diff --git a/.gitignore b/.gitignore index 4233fc9a1..6c4e72a86 100644 --- a/.gitignore +++ b/.gitignore @@ -44,11 +44,9 @@ lib lib64 # sphinx build directory doc/_build -doc/reference/toctree/asset/elephant.asset.synchronous* -doc/reference/toctree/parallel -doc/reference/toctree/statistics -doc/reference/toctree/unitary_event_analysis -doc/reference/toctree/causality +doc/reference/toctree/* +!doc/reference/toctree/asset/elephant.asset.ASSET.rst +!doc/reference/toctree/kernels *.h5 # setup.py dist directory dist diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py index 152851f08..7956bc08a 100644 --- a/elephant/spike_train_processing.py +++ b/elephant/spike_train_processing.py @@ -6,7 +6,7 @@ .. autosummary:: :toctree: toctree/spike_train_processing/ - synchrotool + Synchrotool :copyright: Copyright 2014-2020 by the Elephant team, see `doc/authors.rst`. @@ -19,10 +19,14 @@ import numpy as np -from elephant.statistics import complexity +from elephant.statistics import Complexity +__all__ = [ + "Synchrotool" +] -class synchrotool(complexity): + +class Synchrotool(Complexity): """ Tool class to find, remove and/or annotate the presence of synchronous spiking events across multiple spike trains. @@ -32,114 +36,26 @@ class synchrotool(complexity): Such that, synchronous events can be found both in multi-unit and single-unit spike trains. - This class inherits from ``elephant.statistics.complexity``, see its - documentation for more details. - - *The rest of this documentation is copied from - ``elephant.statistics.complexity`` !!! - TODO: Figure out a better way to merge the docstrings.* - - Parameters - ---------- - spiketrains : list of neo.SpikeTrain - Spike trains with a common time axis (same `t_start` and `t_stop`) - sampling_rate : pq.Quantity, optional - Sampling rate of the spike trains with units of 1/time. - Default: None - bin_size : pq.Quantity, optional - FIXME: no bin_size detected - Width of the histogram's time bins with units of time. - The user must specify the `bin_size` or the `sampling_rate`. - * If no `bin_size` is specified and the `sampling_rate` is available - 1/`sampling_rate` is used. - * If both are given then `bin_size` is used. - Default: None - binary : bool, optional - * If `True` then the time histograms will be binary. - * If `False` the total number of synchronous spikes is counted in the - time histogram. - Default: True - spread : int, optional - Number of bins in which to check for synchronous spikes. - Spikes that occur separated by `spread - 1` or less empty bins are - considered synchronous. - * ``spread = 0`` corresponds to a bincount accross spike trains. - * ``spread = 1`` corresponds to counting consecutive spikes. - * ``spread = 2`` corresponds to counting consecutive spikes and - spikes separated by exactly 1 empty bin. - * ``spread = n`` corresponds to counting spikes separated by exactly - or less than `n - 1` empty bins. - Default: 0 - tolerance : float, optional - Tolerance for rounding errors in the binning process and in the input - data. - Default: 1e-8 - - Attributes - ---------- - epoch : neo.Epoch - An epoch object containing complexity values, left edges and durations - of all intervals with at least one spike. - * ``epoch.array_annotations['complexity']`` contains the - complexity values per spike. - * ``epoch.times`` contains the left edges. - * ``epoch.durations`` contains the durations. - time_histogram : neo.Analogsignal - A `neo.AnalogSignal` object containing the histogram values. - `neo.AnalogSignal[j]` is the histogram computed between - `t_start + j * binsize` and `t_start + (j + 1) * binsize`. - * If ``binary = True`` : Number of neurons that spiked in each bin, - regardless of the number of spikes. - * If ``binary = False`` : Number of neurons and spikes per neurons - in each bin. - complexity_histogram : np.ndarray - The number of occurrences of events of different complexities. - `complexity_hist[i]` corresponds to the number of events of - complexity `i` for `i > 0`. - - Raises - ------ - ValueError - When `t_stop` is smaller than `t_start`. - - When both `sampling_rate` and `bin_size` are not specified. - - When `spread` is not a positive integer. - - When `spiketrains` is an empty list. - - When `t_start` is not the same for all spiketrains - - When `t_stop` is not the same for all spiketrains - - TypeError - When `spiketrains` is not a list. - - When the elements in `spiketrains` are not instances of neo.SpikeTrain - - Notes - ----- - Note that with most common parameter combinations spike times can end up - on bin edges. This makes the binning susceptible to rounding errors which - is accounted for by moving spikes which are within tolerance of the next - bin edge into the following bin. This can be adjusted using the tolerance - parameter and turned off by setting `tolerance=None`. + This class inherits from :func:`elephant.statistics.Complexity`, see its + documentation for more details and input parameters description. See also -------- - elephant.statistics.complexity + elephant.statistics.Complexity """ def __init__(self, spiketrains, sampling_rate, + bin_size=None, binary=True, spread=0, tolerance=1e-8): self.annotated = False - super(synchrotool, self).__init__(spiketrains=spiketrains, + super(Synchrotool, self).__init__(spiketrains=spiketrains, + bin_size=bin_size, sampling_rate=sampling_rate, binary=binary, spread=spread, @@ -222,8 +138,8 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): def annotate_synchrofacts(self): """ - Annotate the complexity of each spike in the array_annotations - *in-place*. + Annotate the complexity of each spike in the + ``self.epoch.array_annotations`` *in-place*. """ epoch_complexities = self.epoch.array_annotations['complexity'] right_edges = ( diff --git a/elephant/statistics.py b/elephant/statistics.py index 07d6c6b78..056635bc9 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -53,7 +53,7 @@ fanofactor complexity_pdf - complexity + Complexity :copyright: Copyright 2014-2020 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. @@ -89,6 +89,7 @@ "instantaneous_rate", "time_histogram", "complexity_pdf", + "Complexity", "fftkernel", "optimal_kernel_bandwidth" ] @@ -969,12 +970,12 @@ def complexity_pdf(spiketrains, bin_size): "class which has a pdf attribute. complexity_pdf will be " "removed in the next Elephant release.", DeprecationWarning) - complexity_obj = complexity(spiketrains, bin_size=bin_size) + complexity = Complexity(spiketrains, bin_size=bin_size) - return complexity_obj.pdf() + return complexity.pdf() -class complexity: +class Complexity(object): """ Class for complexity distribution (i.e. number of synchronous spikes found) of a list of `neo.SpikeTrain` objects. @@ -989,10 +990,10 @@ class complexity: ---------- spiketrains : list of neo.SpikeTrain Spike trains with a common time axis (same `t_start` and `t_stop`) - sampling_rate : pq.Quantity, optional + sampling_rate : pq.Quantity or None, optional Sampling rate of the spike trains with units of 1/time. Default: None - bin_size : pq.Quantity, optional + bin_size : pq.Quantity or None, optional Width of the histogram's time bins with units of time. The user must specify the `bin_size` or the `sampling_rate`. * If no `bin_size` is specified and the `sampling_rate` is available @@ -1073,7 +1074,7 @@ class complexity: See also -------- elephant.conversion.BinnedSpikeTrain - elephant.spike_train_processing.synchrotool + elephant.spike_train_processing.Synchrotool References ---------- @@ -1090,7 +1091,7 @@ class complexity: >>> import neo >>> import quantities as pq - >>> from elephant.statistics import complexity + >>> from elephant.statistics import Complexity >>> sr = 1/pq.ms @@ -1099,7 +1100,7 @@ class complexity: >>> sts = [st1, st2] >>> # spread = 0, a simple bincount - >>> cpx = complexity(sts, sampling_rate=sr) + >>> cpx = Complexity(sts, sampling_rate=sr) Complexity calculated at sampling rate precision >>> print(cpx.complexity_histogram) [5 4 1] @@ -1109,7 +1110,7 @@ class complexity: [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] ms >>> # spread = 1, consecutive spikes - >>> cpx = complexity(sts, sampling_rate=sr, spread=1) + >>> cpx = Complexity(sts, sampling_rate=sr, spread=1) Complexity calculated at sampling rate precision >>> print(cpx.complexity_histogram) [5 4 1] @@ -1117,7 +1118,7 @@ class complexity: [0 2 0 0 3 3 3 0 1 0] dimensionless >>> # spread = 2, consecutive spikes and separated by 1 empty bin - >>> cpx = complexity(sts, sampling_rate=sr, spread=2) + >>> cpx = Complexity(sts, sampling_rate=sr, spread=2) Complexity calculated at sampling rate precision >>> print(cpx.complexity_histogram) [4 0 1 0 1] diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py index 00aa590fc..b69b54383 100644 --- a/elephant/test/test_spike_train_processing.py +++ b/elephant/test/test_spike_train_processing.py @@ -20,7 +20,7 @@ def _test_template(self, spiketrains, correct_complexities, sampling_rate, spread, deletion_threshold=2, mode='delete', in_place=False, binary=True): - synchrofact_obj = spike_train_processing.synchrotool( + synchrofact_obj = spike_train_processing.Synchrotool( spiketrains, sampling_rate=sampling_rate, binary=binary, @@ -35,24 +35,24 @@ def _test_template(self, spiketrains, correct_complexities, sampling_rate, assert_array_equal(annotations, correct_complexities) if mode == 'extract': - correct_spike_times = np.array( - [spikes[mask] for spikes, mask - in zip(spiketrains, - correct_complexities >= deletion_threshold) - ]) + correct_spike_times = [ + spikes[mask] for spikes, mask + in zip(spiketrains, + correct_complexities >= deletion_threshold) + ] else: - correct_spike_times = np.array( - [spikes[mask] for spikes, mask - in zip(spiketrains, correct_complexities < deletion_threshold) - ]) + correct_spike_times = [ + spikes[mask] for spikes, mask + in zip(spiketrains, + correct_complexities < deletion_threshold) + ] # test deletion synchrofact_obj.delete_synchrofacts(threshold=deletion_threshold, in_place=in_place, mode=mode) - cleaned_spike_times = np.array( - [st.times for st in spiketrains]) + cleaned_spike_times = [st.times for st in spiketrains] for correct_st, cleaned_st in zip(correct_spike_times, cleaned_spike_times): @@ -221,7 +221,7 @@ def test_correct_transfer_of_spiketrain_attributes(self): spiketrain.array_annotations.items()} # perform a synchrofact search with delete=True - synchrofact_obj = spike_train_processing.synchrotool( + synchrofact_obj = spike_train_processing.Synchrotool( [spiketrain], spread=0, sampling_rate=sampling_rate, @@ -250,7 +250,7 @@ def test_correct_transfer_of_spiketrain_attributes(self): assert_array_almost_equal(value, cleaned_array_annotations[key]) def test_wrong_input_errors(self): - synchrofact_obj = spike_train_processing.synchrotool( + synchrofact_obj = spike_train_processing.Synchrotool( [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], sampling_rate=1/pq.s, binary=True, diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index c47fa0845..6807f4cd8 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -841,7 +841,7 @@ def test_complexity_pdf(self): spiketrain_a, spiketrain_b, spiketrain_c] # runs the previous function which will be deprecated targ = np.array([0.92, 0.01, 0.01, 0.06]) - complexity_obj = statistics.complexity(spiketrains, bin_size=0.1*pq.s) + complexity_obj = statistics.Complexity(spiketrains, bin_size=0.1 * pq.s) pdf = complexity_obj.pdf() assert_array_equal(targ, complexity_obj.pdf().magnitude[:, 0]) self.assertEqual(1, pdf.magnitude[:, 0].sum()) @@ -863,9 +863,9 @@ def test_complexity_histogram_spread_0(self): correct_time_histogram = np.array([0, 2, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 2, 0, 1, 1]) - complexity_obj = statistics.complexity(spiketrains, - sampling_rate=sampling_rate, - spread=0) + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=0) assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) @@ -883,9 +883,9 @@ def test_complexity_epoch_spread_0(self): neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, t_stop=20*pq.s)] - complexity_obj = statistics.complexity(spiketrains, - sampling_rate=sampling_rate, - spread=0) + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=0) self.assertIsInstance(complexity_obj.epoch, neo.Epoch) @@ -904,9 +904,9 @@ def test_complexity_histogram_spread_1(self): correct_time_histogram = np.array([3, 3, 0, 0, 2, 2, 0, 1, 0, 1, 0, 3, 3, 3, 0, 0, 1, 0, 1, 0, 1]) - complexity_obj = statistics.complexity(spiketrains, - sampling_rate=sampling_rate, - spread=1) + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=1) assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) @@ -929,9 +929,9 @@ def test_complexity_histogram_spread_2(self): correct_time_histogram = np.array([0, 2, 0, 0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 0, 0, 3, 3, 3, 3, 3]) - complexity_obj = statistics.complexity(spiketrains, - sampling_rate=sampling_rate, - spread=2) + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=2) assert_array_equal(complexity_obj.complexity_histogram, correct_histogram) @@ -947,11 +947,11 @@ def test_wrong_input_errors(self): t_stop=21*pq.s)] self.assertRaises(ValueError, - statistics.complexity, + statistics.Complexity, spiketrains) self.assertRaises(ValueError, - statistics.complexity, + statistics.Complexity, spiketrains, sampling_rate=1*pq.s, spread=-7) @@ -964,9 +964,9 @@ def test_sampling_rate_warning(self): t_stop=21*pq.s)] with self.assertWarns(UserWarning): - statistics.complexity(spiketrains, - bin_size=1*pq.s, - spread=1) + statistics.Complexity(spiketrains, + bin_size=1*pq.s, + spread=1) def test_binning_for_input_with_rounding_errors(self): @@ -985,9 +985,9 @@ def test_binning_for_input_with_rounding_errors(self): correct_time_histogram[:1000:2] = 1 correct_time_histogram[:2000:4] += 1 - complexity_obj = statistics.complexity(spiketrains, - sampling_rate=sampling_rate, - spread=1) + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=1) assert_array_equal( complexity_obj.time_histogram.magnitude.flatten().astype(int), From b47fb3dd2d8c4a5893ed238193cf6c7313bedec9 Mon Sep 17 00:00:00 2001 From: dizcza Date: Mon, 5 Oct 2020 14:00:32 +0200 Subject: [PATCH 47/58] merged spike_train_processing into spike_train_synchrony --- doc/modules.rst | 1 - doc/reference/spike_train_processing.rst | 11 - elephant/spike_train_processing.py | 169 ----------- elephant/spike_train_synchrony.py | 154 ++++++++++ elephant/statistics.py | 6 +- elephant/test/test_spike_train_processing.py | 264 ----------------- elephant/test/test_spike_train_synchrony.py | 284 +++++++++++++++++-- 7 files changed, 421 insertions(+), 468 deletions(-) delete mode 100644 doc/reference/spike_train_processing.rst delete mode 100644 elephant/spike_train_processing.py delete mode 100644 elephant/test/test_spike_train_processing.py diff --git a/doc/modules.rst b/doc/modules.rst index 85bce9751..fd2e790b5 100644 --- a/doc/modules.rst +++ b/doc/modules.rst @@ -22,7 +22,6 @@ Function Reference by Module reference/spade reference/spectral reference/spike_train_generation - reference/spike_train_processing reference/spike_train_surrogates reference/sta reference/statistics diff --git a/doc/reference/spike_train_processing.rst b/doc/reference/spike_train_processing.rst deleted file mode 100644 index b701938bd..000000000 --- a/doc/reference/spike_train_processing.rst +++ /dev/null @@ -1,11 +0,0 @@ -====================== -Spike train processing -====================== - - -.. testsetup:: - - from elephant.spike_train_processing import synchrofacts, complexity_intervals - - -.. automodule:: elephant.spike_train_processing diff --git a/elephant/spike_train_processing.py b/elephant/spike_train_processing.py deleted file mode 100644 index 7956bc08a..000000000 --- a/elephant/spike_train_processing.py +++ /dev/null @@ -1,169 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Module for spike train processing. - - -.. autosummary:: - :toctree: toctree/spike_train_processing/ - - Synchrotool - - -:copyright: Copyright 2014-2020 by the Elephant team, see `doc/authors.rst`. -:license: Modified BSD, see LICENSE.txt for details. -""" - -from __future__ import division, print_function, unicode_literals - -from copy import deepcopy - -import numpy as np - -from elephant.statistics import Complexity - -__all__ = [ - "Synchrotool" -] - - -class Synchrotool(Complexity): - """ - Tool class to find, remove and/or annotate the presence of synchronous - spiking events across multiple spike trains. - - The complexity is used to characterize synchronous events within the same - spike train and across different spike trains in the `spiketrains` list. - Such that, synchronous events can be found both in multi-unit and - single-unit spike trains. - - This class inherits from :func:`elephant.statistics.Complexity`, see its - documentation for more details and input parameters description. - - See also - -------- - elephant.statistics.Complexity - - """ - - def __init__(self, spiketrains, - sampling_rate, - bin_size=None, - binary=True, - spread=0, - tolerance=1e-8): - - self.annotated = False - - super(Synchrotool, self).__init__(spiketrains=spiketrains, - bin_size=bin_size, - sampling_rate=sampling_rate, - binary=binary, - spread=spread, - tolerance=tolerance) - - def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): - """ - Delete or extract synchronous spiking events. - - Parameters - ---------- - threshold : int - Threshold value for the deletion of spikes engaged in synchronous - activity. - * `deletion_threshold >= 2` leads to all spikes with a larger or - equal complexity value to be deleted/extracted. - * `deletion_threshold <= 1` leads to a ValueError, since this - would delete/extract all spikes and there are definitely more - efficient ways of doing so. - in_place : bool - Determines whether the modification are made in place - on ``self.input_spiketrains``. - Default: False - mode : bool - Inversion of the mask for deletion of synchronous events. - * ``'delete'`` leads to the deletion of all spikes with - complexity >= `threshold`, - i.e. deletes synchronous spikes. - * ``'extract'`` leads to the deletion of all spikes with - complexity < `threshold`, i.e. extracts synchronous spikes. - Default: 'delete' - - Returns - ------- - list of neo.SpikeTrain - List of spiketrains where the spikes with - ``complexity >= threshold`` have been deleted/extracted. - * If ``in_place`` is True, the returned list is the same as - ``self.input_spiketrains``. - * If ``in_place`` is False, the returned list is a deepcopy of - ``self.input_spiketrains``. - - """ - - if not self.annotated: - self.annotate_synchrofacts() - - if mode not in ['delete', 'extract']: - raise ValueError(str(mode) + ' is not a valid mode. ' - "valid modes are ['delete', 'extract']") - - if threshold <= 1: - raise ValueError('A deletion threshold <= 1 would result ' - 'in the deletion of all spikes.') - - if in_place: - spiketrain_list = self.input_spiketrains - else: - spiketrain_list = deepcopy(self.input_spiketrains) - - for idx, st in enumerate(spiketrain_list): - mask = st.array_annotations['complexity'] < threshold - if mode == 'extract': - mask = np.invert(mask) - new_st = st[mask] - spiketrain_list[idx] = new_st - if in_place: - unit = st.unit - segment = st.segment - if unit is not None: - new_index = self._get_spiketrain_index( - unit.spiketrains, st) - unit.spiketrains[new_index] = new_st - if segment is not None: - new_index = self._get_spiketrain_index( - segment.spiketrains, st) - segment.spiketrains[new_index] = new_st - - return spiketrain_list - - def annotate_synchrofacts(self): - """ - Annotate the complexity of each spike in the - ``self.epoch.array_annotations`` *in-place*. - """ - epoch_complexities = self.epoch.array_annotations['complexity'] - right_edges = ( - self.epoch.times.magnitude.flatten() - + self.epoch.durations.rescale( - self.epoch.times.units).magnitude.flatten() - ) - - for idx, st in enumerate(self.input_spiketrains): - - # all indices of spikes that are within the half-open intervals - # defined by the boundaries - # note that every second entry in boundaries is an upper boundary - spike_to_epoch_idx = np.searchsorted( - right_edges, - st.times.rescale(self.epoch.times.units).magnitude.flatten()) - complexity_per_spike = epoch_complexities[spike_to_epoch_idx] - - st.array_annotate(complexity=complexity_per_spike) - - self.annotated = True - - def _get_spiketrain_index(self, spiketrain_list, spiketrain): - for index, item in enumerate(spiketrain_list): - if item is spiketrain: - return index - raise ValueError("Spiketrain is not found in the list") diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index 81a002e22..78c522a0b 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -10,6 +10,7 @@ :toctree: toctree/spike_train_synchrony/ spike_contrast + Synchrotool :copyright: Copyright 2015-2020 by the Elephant team, see `doc/authors.rst`. @@ -18,17 +19,26 @@ from __future__ import division, print_function, unicode_literals from collections import namedtuple +from copy import deepcopy import neo import numpy as np import quantities as pq +from elephant.statistics import Complexity from elephant.utils import is_time_quantity SpikeContrastTrace = namedtuple("SpikeContrastTrace", ( "contrast", "active_spiketrains", "synchrony")) +__all__ = [ + "SpikeContrastTrace", + "spike_contrast", + "Synchrotool" +] + + def _get_theta_and_n_per_bin(spiketrains, t_start, t_stop, bin_size): """ Calculates theta (amount of spikes per bin) and the amount of active spike @@ -216,3 +226,147 @@ def spike_contrast(spiketrains, t_start=None, t_stop=None, return synchrony, spike_contrast_trace return synchrony + + + +class Synchrotool(Complexity): + """ + Tool class to find, remove and/or annotate the presence of synchronous + spiking events across multiple spike trains. + + The complexity is used to characterize synchronous events within the same + spike train and across different spike trains in the `spiketrains` list. + Such that, synchronous events can be found both in multi-unit and + single-unit spike trains. + + This class inherits from :func:`elephant.statistics.Complexity`, see its + documentation for more details and input parameters description. + + See also + -------- + elephant.statistics.Complexity + + """ + + def __init__(self, spiketrains, + sampling_rate, + bin_size=None, + binary=True, + spread=0, + tolerance=1e-8): + + self.annotated = False + + super(Synchrotool, self).__init__(spiketrains=spiketrains, + bin_size=bin_size, + sampling_rate=sampling_rate, + binary=binary, + spread=spread, + tolerance=tolerance) + + def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): + """ + Delete or extract synchronous spiking events. + + Parameters + ---------- + threshold : int + Threshold value for the deletion of spikes engaged in synchronous + activity. + * `deletion_threshold >= 2` leads to all spikes with a larger or + equal complexity value to be deleted/extracted. + * `deletion_threshold <= 1` leads to a ValueError, since this + would delete/extract all spikes and there are definitely more + efficient ways of doing so. + in_place : bool + Determines whether the modification are made in place + on ``self.input_spiketrains``. + Default: False + mode : bool + Inversion of the mask for deletion of synchronous events. + * ``'delete'`` leads to the deletion of all spikes with + complexity >= `threshold`, + i.e. deletes synchronous spikes. + * ``'extract'`` leads to the deletion of all spikes with + complexity < `threshold`, i.e. extracts synchronous spikes. + Default: 'delete' + + Returns + ------- + list of neo.SpikeTrain + List of spiketrains where the spikes with + ``complexity >= threshold`` have been deleted/extracted. + * If ``in_place`` is True, the returned list is the same as + ``self.input_spiketrains``. + * If ``in_place`` is False, the returned list is a deepcopy of + ``self.input_spiketrains``. + + """ + + if not self.annotated: + self.annotate_synchrofacts() + + if mode not in ['delete', 'extract']: + raise ValueError(str(mode) + ' is not a valid mode. ' + "valid modes are ['delete', 'extract']") + + if threshold <= 1: + raise ValueError('A deletion threshold <= 1 would result ' + 'in the deletion of all spikes.') + + if in_place: + spiketrain_list = self.input_spiketrains + else: + spiketrain_list = deepcopy(self.input_spiketrains) + + for idx, st in enumerate(spiketrain_list): + mask = st.array_annotations['complexity'] < threshold + if mode == 'extract': + mask = np.invert(mask) + new_st = st[mask] + spiketrain_list[idx] = new_st + if in_place: + unit = st.unit + segment = st.segment + if unit is not None: + new_index = self._get_spiketrain_index( + unit.spiketrains, st) + unit.spiketrains[new_index] = new_st + if segment is not None: + new_index = self._get_spiketrain_index( + segment.spiketrains, st) + segment.spiketrains[new_index] = new_st + + return spiketrain_list + + def annotate_synchrofacts(self): + """ + Annotate the complexity of each spike in the + ``self.epoch.array_annotations`` *in-place*. + """ + epoch_complexities = self.epoch.array_annotations['complexity'] + right_edges = ( + self.epoch.times.magnitude.flatten() + + self.epoch.durations.rescale( + self.epoch.times.units).magnitude.flatten() + ) + + for idx, st in enumerate(self.input_spiketrains): + + # all indices of spikes that are within the half-open intervals + # defined by the boundaries + # note that every second entry in boundaries is an upper boundary + spike_to_epoch_idx = np.searchsorted( + right_edges, + st.times.rescale(self.epoch.times.units).magnitude.flatten()) + complexity_per_spike = epoch_complexities[spike_to_epoch_idx] + + st.array_annotate(complexity=complexity_per_spike) + + self.annotated = True + + def _get_spiketrain_index(self, spiketrain_list, spiketrain): + for index, item in enumerate(spiketrain_list): + if item is spiketrain: + return index + raise ValueError("Spiketrain is not found in the list") diff --git a/elephant/statistics.py b/elephant/statistics.py index 056635bc9..5e15455d5 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1074,7 +1074,7 @@ class Complexity(object): See also -------- elephant.conversion.BinnedSpikeTrain - elephant.spike_train_processing.Synchrotool + elephant.spike_train_synchrony.Synchrotool References ---------- @@ -1085,10 +1085,6 @@ class Complexity(object): Examples -------- - Here the behavior of - `elephant.spike_train_processing.precise_complexity_intervals` is shown, by - applying the function to some sample spiketrains. - >>> import neo >>> import quantities as pq >>> from elephant.statistics import Complexity diff --git a/elephant/test/test_spike_train_processing.py b/elephant/test/test_spike_train_processing.py deleted file mode 100644 index b69b54383..000000000 --- a/elephant/test/test_spike_train_processing.py +++ /dev/null @@ -1,264 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Unit tests for the synchrofact detection app -""" - -import unittest - -import neo -import numpy as np -from numpy.testing import assert_array_almost_equal -from numpy.testing import assert_array_equal -import quantities as pq - -from elephant import spike_train_processing - - -class SynchrofactDetectionTestCase(unittest.TestCase): - - def _test_template(self, spiketrains, correct_complexities, sampling_rate, - spread, deletion_threshold=2, mode='delete', - in_place=False, binary=True): - - synchrofact_obj = spike_train_processing.Synchrotool( - spiketrains, - sampling_rate=sampling_rate, - binary=binary, - spread=spread) - - # test annotation - synchrofact_obj.annotate_synchrofacts() - - annotations = [st.array_annotations['complexity'] - for st in spiketrains] - - assert_array_equal(annotations, correct_complexities) - - if mode == 'extract': - correct_spike_times = [ - spikes[mask] for spikes, mask - in zip(spiketrains, - correct_complexities >= deletion_threshold) - ] - else: - correct_spike_times = [ - spikes[mask] for spikes, mask - in zip(spiketrains, - correct_complexities < deletion_threshold) - ] - - # test deletion - synchrofact_obj.delete_synchrofacts(threshold=deletion_threshold, - in_place=in_place, - mode=mode) - - cleaned_spike_times = [st.times for st in spiketrains] - - for correct_st, cleaned_st in zip(correct_spike_times, - cleaned_spike_times): - assert_array_almost_equal(cleaned_st, correct_st) - - def test_no_synchrofacts(self): - - # nothing to find here - # there used to be an error for spread > 0 when nothing was found - - sampling_rate = 1 / pq.s - - spiketrains = [neo.SpikeTrain([1, 9, 12, 19] * pq.s, - t_stop=20*pq.s), - neo.SpikeTrain([3, 7, 15, 17] * pq.s, - t_stop=20*pq.s)] - - correct_annotations = np.array([[1, 1, 1, 1], - [1, 1, 1, 1]]) - - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, mode='delete', - deletion_threshold=2) - - def test_spread_0(self): - - # basic test with a minimum number of two spikes per synchrofact - # only taking into account multiple spikes - # within one bin of size 1 / sampling_rate - - sampling_rate = 1 / pq.s - - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, - t_stop=20*pq.s), - neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, - t_stop=20*pq.s)] - - correct_annotations = np.array([[2, 1, 1, 1, 2, 1], - [2, 1, 1, 1, 2, 1]]) - - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, mode='delete', in_place=True, - deletion_threshold=2) - - def test_spread_1(self): - - # test synchrofact search taking into account adjacent bins - # this requires an additional loop with shifted binning - - sampling_rate = 1 / pq.s - - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21*pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21*pq.s)] - - correct_annotations = np.array([[2, 2, 1, 3, 3, 1], - [2, 2, 1, 3, 1, 1]]) - - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, mode='delete', in_place=True, - deletion_threshold=2) - - def test_n_equals_3(self): - - # test synchrofact detection with a minimum number of - # three spikes per synchrofact - - sampling_rate = 1 / pq.s - - spiketrains = [neo.SpikeTrain([1, 1, 5, 10, 13, 16, 17, 19] * pq.s, - t_stop=21*pq.s), - neo.SpikeTrain([1, 4, 7, 9, 12, 14, 16, 20] * pq.s, - t_stop=21*pq.s)] - - correct_annotations = np.array([[3, 3, 2, 2, 3, 3, 3, 2], - [3, 2, 1, 2, 3, 3, 3, 2]]) - - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, mode='delete', binary=False, - in_place=True, deletion_threshold=3) - - def test_extract(self): - - # test synchrofact search taking into account adjacent bins - # this requires an additional loop with shifted binning - - sampling_rate = 1 / pq.s - - spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, - t_stop=21*pq.s), - neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, - t_stop=21*pq.s)] - - correct_annotations = np.array([[2, 2, 1, 3, 3, 1], - [2, 2, 1, 3, 1, 1]]) - - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=1, mode='extract', in_place=True, - deletion_threshold=2) - - def test_binning_for_input_with_rounding_errors(self): - - # a test with inputs divided by 30000 which leads to rounding errors - # these errors have to be accounted for by proper binning; - # check if we still get the correct result - - sampling_rate = 30000 / pq.s - - spiketrains = [neo.SpikeTrain(np.arange(1000) * pq.s / 30000, - t_stop=.1 * pq.s), - neo.SpikeTrain(np.arange(2000, step=2) * pq.s / 30000, - t_stop=.1 * pq.s)] - - first_annotations = np.ones(1000) - first_annotations[::2] = 2 - - second_annotations = np.ones(1000) - second_annotations[:500] = 2 - - correct_annotations = np.array([first_annotations, - second_annotations]) - - self._test_template(spiketrains, correct_annotations, sampling_rate, - spread=0, mode='delete', in_place=True, - deletion_threshold=2) - - def test_correct_transfer_of_spiketrain_attributes(self): - - # for delete=True the spiketrains in the block are changed, - # test if their attributes remain correct - - sampling_rate = 1 / pq.s - - spiketrain = neo.SpikeTrain([1, 1, 5, 0] * pq.s, - t_stop=10 * pq.s) - - block = neo.Block() - - channel_index = neo.ChannelIndex(name='Channel 1', index=1) - block.channel_indexes.append(channel_index) - - unit = neo.Unit('Unit 1') - channel_index.units.append(unit) - unit.spiketrains.append(spiketrain) - spiketrain.unit = unit - - segment = neo.Segment() - block.segments.append(segment) - segment.spiketrains.append(spiketrain) - spiketrain.segment = segment - - spiketrain.annotate(cool_spike_train=True) - spiketrain.array_annotate( - spike_number=np.arange(len(spiketrain.times.magnitude))) - spiketrain.waveforms = np.sin( - np.arange(len(spiketrain.times.magnitude))[:, np.newaxis] - + np.arange(len(spiketrain.times.magnitude))[np.newaxis, :]) - - correct_mask = np.array([False, False, True, True]) - - # store the correct attributes - correct_annotations = spiketrain.annotations.copy() - correct_waveforms = spiketrain.waveforms[correct_mask].copy() - correct_array_annotations = {key: value[correct_mask] for key, value in - spiketrain.array_annotations.items()} - - # perform a synchrofact search with delete=True - synchrofact_obj = spike_train_processing.Synchrotool( - [spiketrain], - spread=0, - sampling_rate=sampling_rate, - binary=False) - synchrofact_obj.delete_synchrofacts( - mode='delete', - in_place=True, - threshold=2) - - # Ensure that the spiketrain was not duplicated - self.assertEqual(len(block.filter(objects=neo.SpikeTrain)), 1) - - cleaned_spiketrain = segment.spiketrains[0] - - cleaned_annotations = cleaned_spiketrain.annotations - cleaned_waveforms = cleaned_spiketrain.waveforms - cleaned_array_annotations = cleaned_spiketrain.array_annotations - cleaned_array_annotations.pop('complexity') - - self.assertDictEqual(correct_annotations, cleaned_annotations) - assert_array_almost_equal(cleaned_waveforms, correct_waveforms) - self.assertTrue(len(cleaned_array_annotations) - == len(correct_array_annotations)) - for key, value in correct_array_annotations.items(): - self.assertTrue(key in cleaned_array_annotations.keys()) - assert_array_almost_equal(value, cleaned_array_annotations[key]) - - def test_wrong_input_errors(self): - synchrofact_obj = spike_train_processing.Synchrotool( - [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], - sampling_rate=1/pq.s, - binary=True, - spread=1) - self.assertRaises(ValueError, - synchrofact_obj.delete_synchrofacts, - -1) - - -if __name__ == '__main__': - unittest.main() diff --git a/elephant/test/test_spike_train_synchrony.py b/elephant/test/test_spike_train_synchrony.py index c546fe267..b566426cd 100644 --- a/elephant/test/test_spike_train_synchrony.py +++ b/elephant/test/test_spike_train_synchrony.py @@ -5,15 +5,17 @@ import neo import numpy as np -from numpy.testing import assert_array_equal +import quantities as pq +from numpy.testing import assert_array_almost_equal, assert_array_equal from quantities import Hz, ms, second -import elephant.spike_train_synchrony as spc import elephant.spike_train_generation as stgen +from elephant.spike_train_synchrony import Synchrotool, spike_contrast, \ + _get_theta_and_n_per_bin, _binning_half_overlap from elephant.test.download import download, unzip -class TestUM(unittest.TestCase): +class TestSpikeContrast(unittest.TestCase): def test_spike_contrast_random(self): # randomly generated spiketrains that share the same t_start and @@ -39,7 +41,7 @@ def test_spike_contrast_random(self): t_stop=10000. * ms) spike_trains = [spike_train_1, spike_train_2, spike_train_3, spike_train_4, spike_train_5, spike_train_6] - synchrony = spc.spike_contrast(spike_trains) + synchrony = spike_contrast(spike_trains) self.assertAlmostEqual(synchrony, 0.2098687, places=6) def test_spike_contrast_same_signal(self): @@ -48,7 +50,7 @@ def test_spike_contrast_same_signal(self): t_start=0. * ms, t_stop=10000. * ms) spike_trains = [spike_train, spike_train] - synchrony = spc.spike_contrast(spike_trains, min_bin=1 * ms) + synchrony = spike_contrast(spike_trains, min_bin=1 * ms) self.assertEqual(synchrony, 1.0) def test_spike_contrast_double_duration(self): @@ -64,7 +66,7 @@ def test_spike_contrast_double_duration(self): t_stop=10000. * ms) spike_trains = [spike_train_1, spike_train_2, spike_train_3] - synchrony = spc.spike_contrast(spike_trains, t_stop=20000 * ms) + synchrony = spike_contrast(spike_trains, t_stop=20000 * ms) self.assertEqual(synchrony, 0.5) def test_spike_contrast_non_overlapping_spiketrains(self): @@ -76,7 +78,7 @@ def test_spike_contrast_non_overlapping_spiketrains(self): t_start=5000. * ms, t_stop=10000. * ms) spiketrains = [spike_train_1, spike_train_2] - synchrony = spc.spike_contrast(spiketrains, t_stop=5000 * ms) + synchrony = spike_contrast(spiketrains, t_stop=5000 * ms) # the synchrony of non-overlapping spiketrains must be zero self.assertEqual(synchrony, 0.) @@ -86,7 +88,7 @@ def test_spike_contrast_trace(self): t_stop=1000. * ms) spike_train_2 = stgen.homogeneous_poisson_process(rate=20 * Hz, t_stop=1000. * ms) - synchrony, trace = spc.spike_contrast([spike_train_1, spike_train_2], + synchrony, trace = spike_contrast([spike_train_1, spike_train_2], return_trace=True) self.assertEqual(synchrony, max(trace.synchrony)) self.assertEqual(len(trace.contrast), len(trace.active_spiketrains)) @@ -94,28 +96,28 @@ def test_spike_contrast_trace(self): def test_invalid_data(self): # invalid spiketrains - self.assertRaises(TypeError, spc.spike_contrast, [[0, 1], [1.5, 2.3]]) - self.assertRaises(ValueError, spc.spike_contrast, + self.assertRaises(TypeError, spike_contrast, [[0, 1], [1.5, 2.3]]) + self.assertRaises(ValueError, spike_contrast, [neo.SpikeTrain([10] * ms, t_stop=1000 * ms), neo.SpikeTrain([20] * ms, t_stop=1000 * ms)]) # a single spiketrain spiketrain_valid = neo.SpikeTrain([0, 1000] * ms, t_stop=1000 * ms) - self.assertRaises(ValueError, spc.spike_contrast, [spiketrain_valid]) + self.assertRaises(ValueError, spike_contrast, [spiketrain_valid]) spiketrain_valid2 = neo.SpikeTrain([500, 800] * ms, t_stop=1000 * ms) spiketrains = [spiketrain_valid, spiketrain_valid2] # invalid shrink factor - self.assertRaises(ValueError, spc.spike_contrast, spiketrains, + self.assertRaises(ValueError, spike_contrast, spiketrains, bin_shrink_factor=0.) # invalid t_start, t_stop, and min_bin - self.assertRaises(TypeError, spc.spike_contrast, spiketrains, + self.assertRaises(TypeError, spike_contrast, spiketrains, t_start=0) - self.assertRaises(TypeError, spc.spike_contrast, spiketrains, + self.assertRaises(TypeError, spike_contrast, spiketrains, t_stop=1000) - self.assertRaises(TypeError, spc.spike_contrast, spiketrains, + self.assertRaises(TypeError, spike_contrast, spiketrains, min_bin=0.01) def test_get_theta_and_n_per_bin(self): @@ -124,7 +126,7 @@ def test_get_theta_and_n_per_bin(self): [1, 2, 3, 9], [1, 2, 2.5] ] - theta, n = spc._get_theta_and_n_per_bin(spike_trains, + theta, n = _get_theta_and_n_per_bin(spike_trains, t_start=0, t_stop=10, bin_size=5) @@ -137,7 +139,7 @@ def test_binning_half_overlap(self): t_start = 0 t_stop = 10 edges = np.arange(t_start, t_stop + bin_step, bin_step) - histogram = spc._binning_half_overlap(spiketrain, edges=edges) + histogram = _binning_half_overlap(spiketrain, edges=edges) assert_array_equal(histogram, [3, 1, 1]) def test_spike_contrast_with_Izhikevich_network_auto(self): @@ -169,9 +171,255 @@ def test_spike_contrast_with_Izhikevich_network_auto(self): neo.SpikeTrain(st, t_start=0 * second, t_stop=2 * second, units=second) for st in simulation['spiketrains']] - synchrony = spc.spike_contrast(spiketrains) + synchrony = spike_contrast(spiketrains) self.assertAlmostEqual(synchrony, synchrony_true, places=2) +class SynchrofactDetectionTestCase(unittest.TestCase): + + def _test_template(self, spiketrains, correct_complexities, sampling_rate, + spread, deletion_threshold=2, mode='delete', + in_place=False, binary=True): + + synchrofact_obj = Synchrotool( + spiketrains, + sampling_rate=sampling_rate, + binary=binary, + spread=spread) + + # test annotation + synchrofact_obj.annotate_synchrofacts() + + annotations = [st.array_annotations['complexity'] + for st in spiketrains] + + assert_array_equal(annotations, correct_complexities) + + if mode == 'extract': + correct_spike_times = [ + spikes[mask] for spikes, mask + in zip(spiketrains, + correct_complexities >= deletion_threshold) + ] + else: + correct_spike_times = [ + spikes[mask] for spikes, mask + in zip(spiketrains, + correct_complexities < deletion_threshold) + ] + + # test deletion + synchrofact_obj.delete_synchrofacts(threshold=deletion_threshold, + in_place=in_place, + mode=mode) + + cleaned_spike_times = [st.times for st in spiketrains] + + for correct_st, cleaned_st in zip(correct_spike_times, + cleaned_spike_times): + assert_array_almost_equal(cleaned_st, correct_st) + + def test_no_synchrofacts(self): + + # nothing to find here + # there used to be an error for spread > 0 when nothing was found + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 9, 12, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([3, 7, 15, 17] * pq.s, + t_stop=20*pq.s)] + + correct_annotations = np.array([[1, 1, 1, 1], + [1, 1, 1, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, mode='delete', + deletion_threshold=2) + + def test_spread_0(self): + + # basic test with a minimum number of two spikes per synchrofact + # only taking into account multiple spikes + # within one bin of size 1 / sampling_rate + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, + t_stop=20*pq.s)] + + correct_annotations = np.array([[2, 1, 1, 1, 2, 1], + [2, 1, 1, 1, 2, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=0, mode='delete', in_place=True, + deletion_threshold=2) + + def test_spread_1(self): + + # test synchrofact search taking into account adjacent bins + # this requires an additional loop with shifted binning + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_annotations = np.array([[2, 2, 1, 3, 3, 1], + [2, 2, 1, 3, 1, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, mode='delete', in_place=True, + deletion_threshold=2) + + def test_n_equals_3(self): + + # test synchrofact detection with a minimum number of + # three spikes per synchrofact + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 1, 5, 10, 13, 16, 17, 19] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 9, 12, 14, 16, 20] * pq.s, + t_stop=21*pq.s)] + + correct_annotations = np.array([[3, 3, 2, 2, 3, 3, 3, 2], + [3, 2, 1, 2, 3, 3, 3, 2]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, mode='delete', binary=False, + in_place=True, deletion_threshold=3) + + def test_extract(self): + + # test synchrofact search taking into account adjacent bins + # this requires an additional loop with shifted binning + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_annotations = np.array([[2, 2, 1, 3, 3, 1], + [2, 2, 1, 3, 1, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, mode='extract', in_place=True, + deletion_threshold=2) + + def test_binning_for_input_with_rounding_errors(self): + + # a test with inputs divided by 30000 which leads to rounding errors + # these errors have to be accounted for by proper binning; + # check if we still get the correct result + + sampling_rate = 30000 / pq.s + + spiketrains = [neo.SpikeTrain(np.arange(1000) * pq.s / 30000, + t_stop=.1 * pq.s), + neo.SpikeTrain(np.arange(2000, step=2) * pq.s / 30000, + t_stop=.1 * pq.s)] + + first_annotations = np.ones(1000) + first_annotations[::2] = 2 + + second_annotations = np.ones(1000) + second_annotations[:500] = 2 + + correct_annotations = np.array([first_annotations, + second_annotations]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=0, mode='delete', in_place=True, + deletion_threshold=2) + + def test_correct_transfer_of_spiketrain_attributes(self): + + # for delete=True the spiketrains in the block are changed, + # test if their attributes remain correct + + sampling_rate = 1 / pq.s + + spiketrain = neo.SpikeTrain([1, 1, 5, 0] * pq.s, + t_stop=10 * pq.s) + + block = neo.Block() + + channel_index = neo.ChannelIndex(name='Channel 1', index=1) + block.channel_indexes.append(channel_index) + + unit = neo.Unit('Unit 1') + channel_index.units.append(unit) + unit.spiketrains.append(spiketrain) + spiketrain.unit = unit + + segment = neo.Segment() + block.segments.append(segment) + segment.spiketrains.append(spiketrain) + spiketrain.segment = segment + + spiketrain.annotate(cool_spike_train=True) + spiketrain.array_annotate( + spike_number=np.arange(len(spiketrain.times.magnitude))) + spiketrain.waveforms = np.sin( + np.arange(len(spiketrain.times.magnitude))[:, np.newaxis] + + np.arange(len(spiketrain.times.magnitude))[np.newaxis, :]) + + correct_mask = np.array([False, False, True, True]) + + # store the correct attributes + correct_annotations = spiketrain.annotations.copy() + correct_waveforms = spiketrain.waveforms[correct_mask].copy() + correct_array_annotations = {key: value[correct_mask] for key, value in + spiketrain.array_annotations.items()} + + # perform a synchrofact search with delete=True + synchrofact_obj = Synchrotool( + [spiketrain], + spread=0, + sampling_rate=sampling_rate, + binary=False) + synchrofact_obj.delete_synchrofacts( + mode='delete', + in_place=True, + threshold=2) + + # Ensure that the spiketrain was not duplicated + self.assertEqual(len(block.filter(objects=neo.SpikeTrain)), 1) + + cleaned_spiketrain = segment.spiketrains[0] + + cleaned_annotations = cleaned_spiketrain.annotations + cleaned_waveforms = cleaned_spiketrain.waveforms + cleaned_array_annotations = cleaned_spiketrain.array_annotations + cleaned_array_annotations.pop('complexity') + + self.assertDictEqual(correct_annotations, cleaned_annotations) + assert_array_almost_equal(cleaned_waveforms, correct_waveforms) + self.assertTrue(len(cleaned_array_annotations) + == len(correct_array_annotations)) + for key, value in correct_array_annotations.items(): + self.assertTrue(key in cleaned_array_annotations.keys()) + assert_array_almost_equal(value, cleaned_array_annotations[key]) + + def test_wrong_input_errors(self): + synchrofact_obj = Synchrotool( + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], + sampling_rate=1/pq.s, + binary=True, + spread=1) + self.assertRaises(ValueError, + synchrofact_obj.delete_synchrofacts, + -1) + + if __name__ == '__main__': unittest.main() From 194aedb15f3e7b45ba1cc0427ee116a3be570f97 Mon Sep 17 00:00:00 2001 From: dizcza Date: Mon, 5 Oct 2020 14:07:07 +0200 Subject: [PATCH 48/58] converted a string to a comment --- elephant/test/test_statistics.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 6807f4cd8..000310516 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -436,10 +436,8 @@ def test_lvr_refractoriness_kwarg(self): def test_2short_spike_train(self): seq = [1] with self.assertWarns(UserWarning): - """ - Catches UserWarning: Input size is too small. Please provide - an input with more than 1 entry. - """ + # Catches UserWarning: Input size is too small. Please provide + # an input with more than 1 entry. self.assertTrue(math.isnan(statistics.lvr(seq, with_nan=True))) From 727bbf061c0ad3956a0c9f66b4a2c02f14ef2293 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 11 Nov 2020 18:34:04 +0100 Subject: [PATCH 49/58] Faster algorithm for calculating complexity epochs --- elephant/statistics.py | 67 +++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 5e15455d5..bd2e01e12 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1323,39 +1323,40 @@ def _epoch_with_spread(self): else: bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() - i = 0 - complexities = [] - left_edges = [] - right_edges = [] - while i < len(bincount): - current_bincount = bincount[i] - if current_bincount == 0: - i += 1 - else: - last_window_sum = current_bincount - last_nonzero_index = 0 - current_window = bincount[i:i + self.spread + 1] - window_sum = current_window.sum() - while window_sum > last_window_sum: - last_nonzero_index = np.nonzero(current_window)[0][-1] - current_window = bincount[i: - i + last_nonzero_index - + self.spread + 1] - last_window_sum = window_sum - window_sum = current_window.sum() - complexities.append(window_sum) - left_edges.append( - bst.bin_edges[i].magnitude.item()) - right_edges.append( - bst.bin_edges[ - i + last_nonzero_index + 1 - ].magnitude.item()) - i += last_nonzero_index + 1 - - # we dropped units above, neither concatenate nor append works - # with arrays of quantities - left_edges *= bst.bin_edges.units - right_edges *= bst.bin_edges.units + nonzero_indices = np.nonzero(bincount)[0] + left_diff = np.diff(nonzero_indices, + prepend=-self.spread - 1) + right_diff = np.diff(nonzero_indices, + append=len(bincount) + self.spread + 1) + + # standalone bins (no merging required) + single_bin_indices = np.logical_and(left_diff > self.spread, + right_diff > self.spread) + single_bins = nonzero_indices[single_bin_indices] + + # bins separated by fewer than spread bins form clusters + # that have to be merged + cluster_start_indices = np.logical_and(left_diff > self.spread, + right_diff <= self.spread) + cluster_starts = nonzero_indices[cluster_start_indices] + cluster_stop_indices = np.logical_and(left_diff <= self.spread, + right_diff > self.spread) + cluster_stops = nonzero_indices[cluster_stop_indices] + 1 + + single_bin_complexities = bincount[single_bins] + cluster_complexities = [bincount[start:stop].sum() + for start, stop in zip(cluster_starts, + cluster_stops)] + + # merge standalone bins and clusters and sort them + combined_starts = np.concatenate((single_bins, cluster_starts)) + combined_stops = np.concatenate((single_bins + 1, cluster_stops)) + combined_complexities = np.concatenate((single_bin_complexities, + cluster_complexities)) + sorting = np.argsort(combined_starts, kind='mergesort') + left_edges = bst.bin_edges[combined_starts[sorting]] + right_edges = bst.bin_edges[combined_stops[sorting]] + complexities = combined_complexities[sorting].astype(int) if self.sampling_rate: # ensure that spikes are not on the bin edges From 0b2dbe603fbf1ee75f212dda5663d06ca198702f Mon Sep 17 00:00:00 2001 From: dizcza Date: Fri, 13 Nov 2020 09:22:26 +0100 Subject: [PATCH 50/58] check_neo_consistency test --- elephant/test/test_utils.py | 24 ++++++++---------- elephant/utils.py | 50 ++----------------------------------- 2 files changed, 12 insertions(+), 62 deletions(-) diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index 4474752cb..5376ef2a4 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -15,36 +15,32 @@ class checkSpiketrainTestCase(unittest.TestCase): def test_wrong_input_errors(self): - self.assertRaises(ValueError, - utils._check_consistency_of_spiketrainlist, - [], 1 / pq.s) self.assertRaises(TypeError, - utils._check_consistency_of_spiketrainlist, - neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)) + utils.check_neo_consistency, + [], object_type=neo.SpikeTrain) self.assertRaises(TypeError, - utils._check_consistency_of_spiketrainlist, + utils.check_neo_consistency, [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), - np.arange(2)], - 1 / pq.s) + np.arange(2)], object_type=neo.SpikeTrain) self.assertRaises(ValueError, - utils._check_consistency_of_spiketrainlist, + utils.check_neo_consistency, [neo.SpikeTrain([1]*pq.s, t_start=1*pq.s, t_stop=2*pq.s), neo.SpikeTrain([1]*pq.s, t_start=0*pq.s, t_stop=2*pq.s)], - same_t_start=True) + object_type=neo.SpikeTrain) self.assertRaises(ValueError, - utils._check_consistency_of_spiketrainlist, + utils.check_neo_consistency, [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), neo.SpikeTrain([1]*pq.s, t_stop=3*pq.s)], - same_t_stop=True) + object_type=neo.SpikeTrain) self.assertRaises(ValueError, - utils._check_consistency_of_spiketrainlist, + utils.check_neo_consistency, [neo.SpikeTrain([1]*pq.ms, t_stop=2000*pq.ms), neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], - same_units=True) + object_type=neo.SpikeTrain) if __name__ == '__main__': diff --git a/elephant/utils.py b/elephant/utils.py index 3a6b60eec..f351ccc21 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -26,51 +26,6 @@ def is_binary(array): return ((array == 0) | (array == 1)).all() -def _check_consistency_of_spiketrainlist(spiketrains, - same_t_start=False, - same_t_stop=False, - same_units=False): - """ - Private function to check the consistency of a list of neo.SpikeTrain - - Raises - ------ - TypeError - When `spiketrains` is not a list. - ValueError - When `spiketrains` is an empty list. - TypeError - When the elements in `spiketrains` are not instances of neo.SpikeTrain - ValueError - When `t_start` is not the same for all spiketrains, - if same_t_start=True - ValueError - When `t_stop` is not the same for all spiketrains, - if same_t_stop=True - ValueError - When `units` are not the same for all spiketrains, - if same_units=True - """ - if not isinstance(spiketrains, list): - raise TypeError('spiketrains should be a list of neo.SpikeTrain') - if len(spiketrains) == 0: - raise ValueError('The spiketrains list is empty!') - for st in spiketrains: - if not isinstance(st, SpikeTrain): - raise TypeError( - 'elements in spiketrains list must be instances of ' - ':class:`SpikeTrain` of Neo!' - 'Found: %s, value %s' % (type(st), str(st))) - if same_t_start and not st.t_start == spiketrains[0].t_start: - raise ValueError( - "the spike trains must have the same t_start!") - if same_t_stop and not st.t_stop == spiketrains[0].t_stop: - raise ValueError( - "the spike trains must have the same t_stop!") - if same_units and not st.units == spiketrains[0].units: - raise ValueError('The spike trains must have the same units!') - - def deprecated_alias(**aliases): """ A deprecation decorator constructor. @@ -214,9 +169,8 @@ def check_neo_consistency(neo_objects, object_type, t_start=None, units = neo_objects[0].units start = neo_objects[0].t_start.item() stop = neo_objects[0].t_stop.item() - except AttributeError: - raise TypeError("The input must be a list of {}. Got {}".format( - object_type.__name__, type(neo_objects[0]).__name__)) + except (IndexError, AttributeError): + raise TypeError(f"The input must be a list of {object_type.__name__}") if tolerance is None: tolerance = 0 for neo_obj in neo_objects: From a094f7a77267c730cb078a4fdbf26f228ae3ecc2 Mon Sep 17 00:00:00 2001 From: dizcza Date: Fri, 13 Nov 2020 09:43:22 +0100 Subject: [PATCH 51/58] updated _epoch_with_spread according to the recent changes in master --- elephant/statistics.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index ce8d853b7..d83a37720 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1317,10 +1317,8 @@ def _epoch_with_spread(self): tolerance=self.tolerance) if self.binary: - binarized = bst.to_sparse_bool_array() - bincount = np.array(binarized.sum(axis=0)).squeeze() - else: - bincount = np.array(bst.to_sparse_array().sum(axis=0)).squeeze() + bst = bst.binarize() + bincount = bst.get_num_of_spikes(axis=0) nonzero_indices = np.nonzero(bincount)[0] left_diff = np.diff(nonzero_indices, @@ -1367,9 +1365,9 @@ def _epoch_with_spread(self): 'Note that using the complexity epoch to get ' 'precise spike times can lead to rounding errors.') - # ensure that an epoch does not start before the minimum t_start - min_t_start = min(st.t_start for st in self.input_spiketrains) - left_edges[0] = max(min_t_start, left_edges[0]) + # Ensure that an epoch does not start before the minimum t_start. + # Note: all spike trains share the same t_start and t_stop. + left_edges[0] = max(self.t_start, left_edges[0]) complexity_epoch = neo.Epoch(times=left_edges, durations=right_edges - left_edges, From 819e159dd4dd70968cb8190771e6585c85bdaf1a Mon Sep 17 00:00:00 2001 From: dizcza Date: Mon, 30 Nov 2020 12:41:42 +0100 Subject: [PATCH 52/58] optimized quantity calls --- elephant/spike_train_synchrony.py | 6 +++--- elephant/statistics.py | 36 +++++++++++++++---------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index fa8235d69..578aeb321 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -240,7 +240,7 @@ class Synchrotool(Complexity): Such that, synchronous events can be found both in multi-unit and single-unit spike trains. - This class inherits from :func:`elephant.statistics.Complexity`, see its + This class inherits from :class:`elephant.statistics.Complexity`, see its documentation for more details and input parameters description. See also @@ -308,8 +308,8 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): self.annotate_synchrofacts() if mode not in ['delete', 'extract']: - raise ValueError(str(mode) + ' is not a valid mode. ' - "valid modes are ['delete', 'extract']") + raise ValueError(f"Invalid mode '{mode}'. Valid modes are: " + f"'delete', 'extract'") if threshold <= 1: raise ValueError('A deletion threshold <= 1 would result ' diff --git a/elephant/statistics.py b/elephant/statistics.py index 3a7347a85..188306bf6 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1183,8 +1183,9 @@ def pdf(self): norm_hist = self.complexity_histogram / self.complexity_histogram.sum() # Convert the Complexity pdf to an neo.AnalogSignal pdf = neo.AnalogSignal( - np.array(norm_hist).reshape(len(norm_hist), 1) * - pq.dimensionless, t_start=0 * pq.dimensionless, + np.expand_dims(norm_hist, axis=1), + units=pq.dimensionless, + t_start=0 * pq.dimensionless, sampling_period=1 * pq.dimensionless) return pdf @@ -1211,8 +1212,8 @@ def _histogram_with_spread(self): """ complexity_hist = np.bincount( self.epoch.array_annotations['complexity']) - num_bins = ((self.t_stop - self.t_start).rescale( - self.bin_size.units) / self.bin_size.magnitude).item() + num_bins = (self.t_stop - self.t_start).rescale( + self.bin_size.units).item() / self.bin_size.item() if conv._detect_rounding_errors(num_bins, tolerance=self.tolerance): warnings.warn('Correcting a rounding error in the histogram ' 'calculation by increasing num_bins by 1. ' @@ -1220,17 +1221,16 @@ def _histogram_with_spread(self): 'behaviour.') num_bins += 1 num_bins = int(num_bins) - time_hist = np.zeros((num_bins, ), dtype=int) + time_hist = np.zeros(num_bins, dtype=int) - start_bins = ((self.epoch.times - self.t_start).rescale( - self.bin_size.units) / self.bin_size).magnitude.flatten() - stop_bins = ((self.epoch.times + self.epoch.durations - - self.t_start).rescale( - self.bin_size.units) / self.bin_size).magnitude.flatten() + start_bins = (self.epoch.times - self.t_start).rescale( + self.bin_size.units).magnitude / self.bin_size.item() + stop_bins = (self.epoch.times + self.epoch.durations - self.t_start + ).rescale(self.bin_size.units + ).magnitude / self.bin_size.item() if self.sampling_rate is not None: - shift = (.5 / self.sampling_rate / self.bin_size - ).simplified.magnitude.item() + shift = (.5 / self.sampling_rate / self.bin_size).simplified.item() # account for the first bin not being shifted in the epoch creation # if the shift would move it past t_start if self.epoch.times[0] == self.t_start: @@ -1275,7 +1275,8 @@ def _histogram_with_spread(self): t_start=self.t_start) empty_bins = (self.t_stop - self.t_start - self.epoch.durations.sum()) - empty_bins = empty_bins.rescale(self.bin_size.units) / self.bin_size + empty_bins = empty_bins.rescale(self.bin_size.units + ).magnitude / self.bin_size.item() if conv._detect_rounding_errors(empty_bins, tolerance=self.tolerance): warnings.warn('Correcting a rounding error in the histogram ' 'calculation by increasing num_bins by 1. ' @@ -1300,12 +1301,11 @@ def _epoch_no_spread(self): bin_shift = .5 / self.sampling_rate left_edges -= bin_shift - # ensure that an epoch does not start before the minimum t_start - min_t_start = min(st.t_start for st in self.input_spiketrains) - if left_edges[0] < min_t_start: - left_edges[0] = min_t_start + # Ensure that an epoch does not start before the minimum t_start. + # Note: all spike trains share the same t_start and t_stop. + if left_edges[0] < self.t_start: + left_edges[0] = self.t_start durations[0] -= bin_shift - else: warnings.warn('No sampling rate specified. ' 'Note that using the complexity epoch to get ' From e20e8ce6ca50cb2d3cee6cdb3911acaaca89ebf9 Mon Sep 17 00:00:00 2001 From: dizcza Date: Mon, 30 Nov 2020 12:42:04 +0100 Subject: [PATCH 53/58] overwrite autosummary docs --- doc/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index bf57e658c..3cb1e5b96 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -123,8 +123,8 @@ # the autosummary fields of each module. autosummary_generate = True -# don't overwrite our custom toctree/*.rst -autosummary_generate_overwrite = False +# Set to False to not overwrite our custom toctree/*.rst +autosummary_generate_overwrite = True # -- Options for HTML output --------------------------------------------- From 72335208b184f22481c255e18db11dec4d87d8c7 Mon Sep 17 00:00:00 2001 From: dizcza Date: Mon, 30 Nov 2020 13:43:57 +0100 Subject: [PATCH 54/58] round_binning_errors utility function --- doc/conf.py | 3 +- doc/modules.rst | 1 + doc/reference/utils.rst | 5 +++ elephant/conversion.py | 34 +++--------------- elephant/statistics.py | 46 +++---------------------- elephant/test/test_utils.py | 19 ++++++++-- elephant/utils.py | 69 ++++++++++++++++++++++++++++++++++--- 7 files changed, 99 insertions(+), 78 deletions(-) create mode 100644 doc/reference/utils.rst diff --git a/doc/conf.py b/doc/conf.py index 3cb1e5b96..a580dd7d1 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -344,7 +344,8 @@ # configuration for intersphinx: refer to Viziphant intersphinx_mapping = { - 'viziphant': ('https://viziphant.readthedocs.io/en/latest/', None) + 'viziphant': ('https://viziphant.readthedocs.io/en/latest/', None), + 'numpy': ('https://numpy.org/doc/stable', None) } # Use more reliable mathjax source diff --git a/doc/modules.rst b/doc/modules.rst index fd2e790b5..c99f87d66 100644 --- a/doc/modules.rst +++ b/doc/modules.rst @@ -26,6 +26,7 @@ Function Reference by Module reference/sta reference/statistics reference/unitary_event_analysis + reference/utils reference/waveform_features diff --git a/doc/reference/utils.rst b/doc/reference/utils.rst new file mode 100644 index 000000000..cf9c2d1f9 --- /dev/null +++ b/doc/reference/utils.rst @@ -0,0 +1,5 @@ +================= +Utility functions +================= + +.. automodule:: elephant.utils diff --git a/elephant/conversion.py b/elephant/conversion.py index 5e6f679b8..f5afd13fd 100644 --- a/elephant/conversion.py +++ b/elephant/conversion.py @@ -27,7 +27,7 @@ import scipy.sparse as sps from elephant.utils import is_binary, deprecated_alias, \ - check_neo_consistency, get_common_start_stop_times + check_neo_consistency, get_common_start_stop_times, round_binning_errors __all__ = [ "binarize", @@ -185,18 +185,6 @@ def binarize(spiketrain, sampling_rate=None, t_start=None, t_stop=None, ########################################################################### -def _detect_rounding_errors(values, tolerance): - """ - Finds rounding errors in values that will be cast to int afterwards. - Returns True for values that are within tolerance of the next integer. - Works for both scalars and numpy arrays. - """ - if tolerance is None or tolerance == 0: - return np.zeros_like(values, dtype=bool) - # same as '1 - (values % 1) <= tolerance' but faster - return 1 - tolerance <= values % 1 - - class BinnedSpikeTrain(object): """ Class which calculates a binned spike train and provides methods to @@ -417,12 +405,8 @@ def get_n_bins(): n_bins = (self._t_stop - self._t_start) / self._bin_size if isinstance(n_bins, pq.Quantity): n_bins = n_bins.simplified.item() - if _detect_rounding_errors(n_bins, tolerance=tolerance): - warnings.warn('Correcting a rounding error in the calculation ' - 'of n_bins by increasing n_bins by 1. ' - 'You can set tolerance=None to disable this ' - 'behaviour.') - return int(n_bins) + n_bins = round_binning_errors(n_bins, tolerance=tolerance) + return n_bins def check_n_bins_consistency(): if self.n_bins != get_n_bins(): @@ -825,17 +809,7 @@ def _create_sparse_matrix(self, spiketrains, tolerance): # shift spikes that are very close # to the right edge into the next bin - rounding_error_indices = _detect_rounding_errors( - bins, tolerance=tolerance) - num_rounding_corrections = rounding_error_indices.sum() - if num_rounding_corrections > 0: - warnings.warn('Correcting {} rounding errors by shifting ' - 'the affected spikes into the following bin. ' - 'You can set tolerance=None to disable this ' - 'behaviour.'.format(num_rounding_corrections)) - bins[rounding_error_indices] += .5 - - bins = bins.astype(np.int32) + bins = round_binning_errors(bins, tolerance=tolerance) valid_bins = bins[bins < self.n_bins] n_discarded += len(bins) - len(valid_bins) f, c = np.unique(valid_bins, return_counts=True) diff --git a/elephant/statistics.py b/elephant/statistics.py index 188306bf6..75ec70aaa 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -73,7 +73,7 @@ import elephant.kernels as kernels from elephant.conversion import BinnedSpikeTrain from elephant.utils import deprecated_alias, check_neo_consistency, \ - is_time_quantity + is_time_quantity, round_binning_errors # do not import unicode_literals # (quantities rescale does not work with unicodes) @@ -1214,13 +1214,7 @@ def _histogram_with_spread(self): self.epoch.array_annotations['complexity']) num_bins = (self.t_stop - self.t_start).rescale( self.bin_size.units).item() / self.bin_size.item() - if conv._detect_rounding_errors(num_bins, tolerance=self.tolerance): - warnings.warn('Correcting a rounding error in the histogram ' - 'calculation by increasing num_bins by 1. ' - 'You can set tolerance=None to disable this ' - 'behaviour.') - num_bins += 1 - num_bins = int(num_bins) + num_bins = round_binning_errors(num_bins, tolerance=self.tolerance) time_hist = np.zeros(num_bins, dtype=int) start_bins = (self.epoch.times - self.t_start).rescale( @@ -1239,31 +1233,8 @@ def _histogram_with_spread(self): start_bins += shift stop_bins += shift - rounding_error_indices = conv._detect_rounding_errors(start_bins, - self.tolerance) - - num_rounding_corrections = rounding_error_indices.sum() - if num_rounding_corrections > 0: - warnings.warn('Correcting {} rounding errors by shifting ' - 'the affected spikes into the following bin. ' - 'You can set tolerance=None to disable this ' - 'behaviour.'.format(num_rounding_corrections)) - start_bins[rounding_error_indices] += .5 - - start_bins = start_bins.astype(int) - - rounding_error_indices = conv._detect_rounding_errors(stop_bins, - self.tolerance) - - num_rounding_corrections = rounding_error_indices.sum() - if num_rounding_corrections > 0: - warnings.warn('Correcting {} rounding errors by shifting ' - 'the affected spikes into the following bin. ' - 'You can set tolerance=None to disable this ' - 'behaviour.'.format(num_rounding_corrections)) - stop_bins[rounding_error_indices] += .5 - - stop_bins = stop_bins.astype(int) + start_bins = round_binning_errors(start_bins, tolerance=self.tolerance) + stop_bins = round_binning_errors(stop_bins, tolerance=self.tolerance) for idx, (start, stop) in enumerate(zip(start_bins, stop_bins)): time_hist[start:stop] = \ @@ -1277,14 +1248,7 @@ def _histogram_with_spread(self): empty_bins = (self.t_stop - self.t_start - self.epoch.durations.sum()) empty_bins = empty_bins.rescale(self.bin_size.units ).magnitude / self.bin_size.item() - if conv._detect_rounding_errors(empty_bins, tolerance=self.tolerance): - warnings.warn('Correcting a rounding error in the histogram ' - 'calculation by increasing num_bins by 1. ' - 'You can set tolerance=None to disable this ' - 'behaviour.') - empty_bins += 1 - empty_bins = int(empty_bins) - + empty_bins = round_binning_errors(empty_bins, tolerance=self.tolerance) complexity_hist[0] = empty_bins return time_hist, complexity_hist diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index 5376ef2a4..5db79bd23 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -10,11 +10,12 @@ import quantities as pq from elephant import utils +from numpy.testing import assert_array_equal -class checkSpiketrainTestCase(unittest.TestCase): +class TestUtils(unittest.TestCase): - def test_wrong_input_errors(self): + def test_check_neo_consistency(self): self.assertRaises(TypeError, utils.check_neo_consistency, [], object_type=neo.SpikeTrain) @@ -42,6 +43,20 @@ def test_wrong_input_errors(self): neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], object_type=neo.SpikeTrain) + def test_round_binning_errors(self): + with self.assertWarns(UserWarning): + n_bins = utils.round_binning_errors(0.999999, tolerance=1e-6) + self.assertEqual(n_bins, 1) + self.assertEqual(utils.round_binning_errors(0.999999, tolerance=None), + 0) + array = np.array([0, 0.7, 1 - 1e-8, 1 - 1e-9]) + with self.assertWarns(UserWarning): + corrected = utils.round_binning_errors(array.copy()) + assert_array_equal(corrected, [0, 0, 1, 1]) + assert_array_equal( + utils.round_binning_errors(array.copy(), tolerance=None), + [0, 0, 0, 0]) + if __name__ == '__main__': unittest.main() diff --git a/elephant/utils.py b/elephant/utils.py index f351ccc21..eb226518c 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -1,3 +1,13 @@ +""" +.. autosummary:: + :toctree: toctree/utils + + is_time_quantity + get_common_start_stop_times + check_neo_consistency + round_binning_errors +""" + from __future__ import division, print_function, unicode_literals import warnings @@ -7,8 +17,6 @@ import numpy as np import quantities as pq -from neo import SpikeTrain - def is_binary(array): """ @@ -98,7 +106,7 @@ def is_time_quantity(x, allow_none=False): def get_common_start_stop_times(neo_objects): """ - Extracts the `t_start`and the `t_stop` from the input neo objects. + Extracts the common `t_start` and the `t_stop` from the input neo objects. If a single neo object is given, its `t_start` and `t_stop` is returned. Otherwise, the aligned times are returned: the maximal `t_start` and @@ -138,7 +146,7 @@ def get_common_start_stop_times(neo_objects): def check_neo_consistency(neo_objects, object_type, t_start=None, - t_stop=None, tolerance=1e-6): + t_stop=None, tolerance=1e-8): """ Checks that all input neo objects share the same units, t_start, and t_stop. @@ -215,3 +223,56 @@ def check_same_units(quantities, object_type=pq.Quantity): raise ValueError("The input quantities must have the same units, " "which is achieved with object.rescale('ms') " "operation.") + + +def round_binning_errors(values, tolerance=1e-8): + """ + Round the input `values` in-place due to the machine floating point + precision errors. + + Parameters + ---------- + values : np.ndarray or float + An input array or a scalar. + tolerance : float or None, optional + The precision error absolute tolerance; acts as ``atol`` in + :func:`numpy.isclose` function. If None, no rounding is performed. + Default: 1e-8 + + Returns + ------- + values : np.ndarray or int + Corrected integer values. + + Examples + -------- + >>> from elephant.utils import round_binning_errors + >>> round_binning_errors(0.999999, tolerance=None) + 0 + >>> round_binning_errors(0.999999, tolerance=1e-6) + 1 + """ + if tolerance is None or tolerance == 0: + if isinstance(values, np.ndarray): + return values.astype(np.int32) + return int(values) # a scalar + + # same as '1 - (values % 1) <= tolerance' but faster + correction_mask = 1 - tolerance <= values % 1 + if isinstance(values, np.ndarray): + num_corrections = correction_mask.sum() + if num_corrections > 0: + warnings.warn(f'Correcting {num_corrections} rounding errors by ' + f'shifting the affected spikes into the following ' + f'bin. You can set tolerance=None to disable this ' + 'behaviour.') + values[correction_mask] += 0.5 + return values.astype(np.int32) + + if correction_mask: + warnings.warn('Correcting a rounding error in the calculation ' + 'of the number of bins by incrementing the value by 1. ' + 'You can set tolerance=None to disable this ' + 'behaviour.') + values += 0.5 + return int(values) From e62f2e10ecf6a6810934a57a19fe0f48291d3d8e Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 3 Dec 2020 14:44:23 +0100 Subject: [PATCH 55/58] Support groups only, remove unit and channelindex --- elephant/spike_train_synchrony.py | 24 +++++++++++++++------ elephant/test/test_spike_train_synchrony.py | 15 +++++++------ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index 578aeb321..ea2dac8b4 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -237,7 +237,7 @@ class Synchrotool(Complexity): The complexity is used to characterize synchronous events within the same spike train and across different spike trains in the `spiketrains` list. - Such that, synchronous events can be found both in multi-unit and + This way synchronous events can be found both in multi-unit and single-unit spike trains. This class inherits from :class:`elephant.statistics.Complexity`, see its @@ -327,17 +327,29 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): new_st = st[mask] spiketrain_list[idx] = new_st if in_place: - unit = st.unit segment = st.segment - if unit is not None: - new_index = self._get_spiketrain_index( - unit.spiketrains, st) - unit.spiketrains[new_index] = new_st + if segment is not None: + # replace link to spiketrain in segment new_index = self._get_spiketrain_index( segment.spiketrains, st) segment.spiketrains[new_index] = new_st + block = segment.block + if block is not None: + # replace link to spiketrain in groups + for group in block.groups: + try: + idx = self._get_spiketrain_index( + group.spiketrains, + st) + except ValueError: + # st is not in this group, move to next group + continue + + # st found in group, replace with new_st + group.spiketrains[idx] = new_st + return spiketrain_list def annotate_synchrofacts(self): diff --git a/elephant/test/test_spike_train_synchrony.py b/elephant/test/test_spike_train_synchrony.py index 861f07fdb..ee7e1e374 100644 --- a/elephant/test/test_spike_train_synchrony.py +++ b/elephant/test/test_spike_train_synchrony.py @@ -374,16 +374,13 @@ def test_correct_transfer_of_spiketrain_attributes(self): block = neo.Block() - channel_index = neo.ChannelIndex(name='Channel 1', index=1) - block.channel_indexes.append(channel_index) - - unit = neo.Unit('Unit 1') - channel_index.units.append(unit) - unit.spiketrains.append(spiketrain) - spiketrain.unit = unit + group = neo.Group(name='Test Group') + block.groups.append(group) + group.spiketrains.append(spiketrain) segment = neo.Segment() block.segments.append(segment) + segment.block = block segment.spiketrains.append(spiketrain) spiketrain.segment = segment @@ -418,6 +415,10 @@ def test_correct_transfer_of_spiketrain_attributes(self): cleaned_spiketrain = segment.spiketrains[0] + # Ensure that the spiketrain is also in the group + self.assertEqual(len(block.groups[0].spiketrains), 1) + self.assertIs(block.groups[0].spiketrains[0], cleaned_spiketrain) + cleaned_annotations = cleaned_spiketrain.annotations cleaned_waveforms = cleaned_spiketrain.waveforms cleaned_array_annotations = cleaned_spiketrain.array_annotations From 06e5739b4027b3ffab4641ca488dc09febf3cba0 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 3 Dec 2020 15:03:02 +0100 Subject: [PATCH 56/58] Fix issues in docstrings spotted by @kohlerca --- elephant/spike_train_synchrony.py | 11 +++++++++-- elephant/statistics.py | 20 +++++++++++++++----- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index ea2dac8b4..8959dd896 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -279,11 +279,11 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): * `deletion_threshold <= 1` leads to a ValueError, since this would delete/extract all spikes and there are definitely more efficient ways of doing so. - in_place : bool + in_place : bool, optional Determines whether the modification are made in place on ``self.input_spiketrains``. Default: False - mode : bool + mode : {'delete', 'extract'}, optional Inversion of the mask for deletion of synchronous events. * ``'delete'`` leads to the deletion of all spikes with complexity >= `threshold`, @@ -292,6 +292,13 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): complexity < `threshold`, i.e. extracts synchronous spikes. Default: 'delete' + Raises + ------ + ValueError + If `mode` is not one in {'delete', 'extract'}. + + If `threshold <= 1`. + Returns ------- list of neo.SpikeTrain diff --git a/elephant/statistics.py b/elephant/statistics.py index 75ec70aaa..dad28e5d4 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -1003,17 +1003,20 @@ class Complexity(object): Spike trains with a common time axis (same `t_start` and `t_stop`) sampling_rate : pq.Quantity or None, optional Sampling rate of the spike trains with units of 1/time. + Used to shift the epoch edges in order to avoid rounding errors. + If None using the epoch to slice spike trains may introduce + rounding errors. Default: None bin_size : pq.Quantity or None, optional Width of the histogram's time bins with units of time. The user must specify the `bin_size` or the `sampling_rate`. - * If no `bin_size` is specified and the `sampling_rate` is available - 1/`sampling_rate` is used. + * If None and the `sampling_rate` is available + 1/`sampling_rate` is used. * If both are given then `bin_size` is used. Default: None binary : bool, optional - * If `True` then the time histograms will be binary. - * If `False` the total number of synchronous spikes is counted in the + * If True then the time histograms will be binary. + * If False the total number of synchronous spikes is counted in the time histogram. Default: True spread : int, optional @@ -1027,9 +1030,10 @@ class Complexity(object): * ``spread = n`` corresponds to counting spikes separated by exactly or less than `n - 1` empty bins. Default: 0 - tolerance : float, optional + tolerance : float or None, optional Tolerance for rounding errors in the binning process and in the input data. + If None possible binning errors are not accounted for. Default: 1e-8 Attributes @@ -1074,6 +1078,12 @@ class Complexity(object): When the elements in `spiketrains` are not instances of neo.SpikeTrain + Warns + ----- + UserWarning + If no sampling rate is supplied which may lead to rounding errors + when using the epoch to slice spike trains. + Notes ----- * Note that with most common parameter combinations spike times can end up From ffd54fb12d5bf9b94e363e20d47ea24202890441 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 3 Dec 2020 15:09:36 +0100 Subject: [PATCH 57/58] pep8 --- elephant/test/test_statistics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index f2413da42..d6b727262 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -949,7 +949,8 @@ def test_complexity_pdf(self): spiketrain_a, spiketrain_b, spiketrain_c] # runs the previous function which will be deprecated targ = np.array([0.92, 0.01, 0.01, 0.06]) - complexity_obj = statistics.Complexity(spiketrains, bin_size=0.1 * pq.s) + complexity_obj = statistics.Complexity(spiketrains, + bin_size=0.1 * pq.s) pdf = complexity_obj.pdf() assert_array_equal(targ, complexity_obj.pdf().magnitude[:, 0]) self.assertEqual(1, pdf.magnitude[:, 0].sum()) From 410304519043e7f460a59fea298f481c6ef29ce0 Mon Sep 17 00:00:00 2001 From: dizcza Date: Thu, 3 Dec 2020 16:57:10 +0100 Subject: [PATCH 58/58] simplified --- elephant/spike_train_synchrony.py | 45 ++++++++++++++++--------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index 8959dd896..9f68a173c 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -335,27 +335,30 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): spiketrain_list[idx] = new_st if in_place: segment = st.segment - - if segment is not None: - # replace link to spiketrain in segment - new_index = self._get_spiketrain_index( - segment.spiketrains, st) - segment.spiketrains[new_index] = new_st - - block = segment.block - if block is not None: - # replace link to spiketrain in groups - for group in block.groups: - try: - idx = self._get_spiketrain_index( - group.spiketrains, - st) - except ValueError: - # st is not in this group, move to next group - continue - - # st found in group, replace with new_st - group.spiketrains[idx] = new_st + if segment is None: + continue + + # replace link to spiketrain in segment + new_index = self._get_spiketrain_index( + segment.spiketrains, st) + segment.spiketrains[new_index] = new_st + + block = segment.block + if block is None: + continue + + # replace link to spiketrain in groups + for group in block.groups: + try: + idx = self._get_spiketrain_index( + group.spiketrains, + st) + except ValueError: + # st is not in this group, move to next group + continue + + # st found in group, replace with new_st + group.spiketrains[idx] = new_st return spiketrain_list