diff --git a/doc/conf.py b/doc/conf.py index bf57e658c..a580dd7d1 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -123,8 +123,8 @@ # the autosummary fields of each module. autosummary_generate = True -# don't overwrite our custom toctree/*.rst -autosummary_generate_overwrite = False +# Set to False to not overwrite our custom toctree/*.rst +autosummary_generate_overwrite = True # -- Options for HTML output --------------------------------------------- @@ -344,7 +344,8 @@ # configuration for intersphinx: refer to Viziphant intersphinx_mapping = { - 'viziphant': ('https://viziphant.readthedocs.io/en/latest/', None) + 'viziphant': ('https://viziphant.readthedocs.io/en/latest/', None), + 'numpy': ('https://numpy.org/doc/stable', None) } # Use more reliable mathjax source diff --git a/doc/modules.rst b/doc/modules.rst index fd2e790b5..c99f87d66 100644 --- a/doc/modules.rst +++ b/doc/modules.rst @@ -26,6 +26,7 @@ Function Reference by Module reference/sta reference/statistics reference/unitary_event_analysis + reference/utils reference/waveform_features diff --git a/doc/reference/utils.rst b/doc/reference/utils.rst new file mode 100644 index 000000000..cf9c2d1f9 --- /dev/null +++ b/doc/reference/utils.rst @@ -0,0 +1,5 @@ +================= +Utility functions +================= + +.. automodule:: elephant.utils diff --git a/elephant/conversion.py b/elephant/conversion.py index 5e6f679b8..f5afd13fd 100644 --- a/elephant/conversion.py +++ b/elephant/conversion.py @@ -27,7 +27,7 @@ import scipy.sparse as sps from elephant.utils import is_binary, deprecated_alias, \ - check_neo_consistency, get_common_start_stop_times + check_neo_consistency, get_common_start_stop_times, round_binning_errors __all__ = [ "binarize", @@ -185,18 +185,6 @@ def binarize(spiketrain, sampling_rate=None, t_start=None, t_stop=None, ########################################################################### -def _detect_rounding_errors(values, tolerance): - """ - Finds rounding errors in values that will be cast to int afterwards. - Returns True for values that are within tolerance of the next integer. - Works for both scalars and numpy arrays. - """ - if tolerance is None or tolerance == 0: - return np.zeros_like(values, dtype=bool) - # same as '1 - (values % 1) <= tolerance' but faster - return 1 - tolerance <= values % 1 - - class BinnedSpikeTrain(object): """ Class which calculates a binned spike train and provides methods to @@ -417,12 +405,8 @@ def get_n_bins(): n_bins = (self._t_stop - self._t_start) / self._bin_size if isinstance(n_bins, pq.Quantity): n_bins = n_bins.simplified.item() - if _detect_rounding_errors(n_bins, tolerance=tolerance): - warnings.warn('Correcting a rounding error in the calculation ' - 'of n_bins by increasing n_bins by 1. ' - 'You can set tolerance=None to disable this ' - 'behaviour.') - return int(n_bins) + n_bins = round_binning_errors(n_bins, tolerance=tolerance) + return n_bins def check_n_bins_consistency(): if self.n_bins != get_n_bins(): @@ -825,17 +809,7 @@ def _create_sparse_matrix(self, spiketrains, tolerance): # shift spikes that are very close # to the right edge into the next bin - rounding_error_indices = _detect_rounding_errors( - bins, tolerance=tolerance) - num_rounding_corrections = rounding_error_indices.sum() - if num_rounding_corrections > 0: - warnings.warn('Correcting {} rounding errors by shifting ' - 'the affected spikes into the following bin. ' - 'You can set tolerance=None to disable this ' - 'behaviour.'.format(num_rounding_corrections)) - bins[rounding_error_indices] += .5 - - bins = bins.astype(np.int32) + bins = round_binning_errors(bins, tolerance=tolerance) valid_bins = bins[bins < self.n_bins] n_discarded += len(bins) - len(valid_bins) f, c = np.unique(valid_bins, return_counts=True) diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index 31bd404cf..9f68a173c 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -10,6 +10,7 @@ :toctree: toctree/spike_train_synchrony/ spike_contrast + Synchrotool :copyright: Copyright 2015-2020 by the Elephant team, see `doc/authors.rst`. @@ -18,17 +19,26 @@ from __future__ import division, print_function, unicode_literals from collections import namedtuple +from copy import deepcopy import neo import numpy as np import quantities as pq +from elephant.statistics import Complexity from elephant.utils import is_time_quantity SpikeContrastTrace = namedtuple("SpikeContrastTrace", ( "contrast", "active_spiketrains", "synchrony")) +__all__ = [ + "SpikeContrastTrace", + "spike_contrast", + "Synchrotool" +] + + def _get_theta_and_n_per_bin(spiketrains, t_start, t_stop, bin_size): """ Calculates theta (amount of spikes per bin) and the amount of active spike @@ -218,3 +228,168 @@ def spike_contrast(spiketrains, t_start=None, t_stop=None, return synchrony, spike_contrast_trace return synchrony + + +class Synchrotool(Complexity): + """ + Tool class to find, remove and/or annotate the presence of synchronous + spiking events across multiple spike trains. + + The complexity is used to characterize synchronous events within the same + spike train and across different spike trains in the `spiketrains` list. + This way synchronous events can be found both in multi-unit and + single-unit spike trains. + + This class inherits from :class:`elephant.statistics.Complexity`, see its + documentation for more details and input parameters description. + + See also + -------- + elephant.statistics.Complexity + + """ + + def __init__(self, spiketrains, + sampling_rate, + bin_size=None, + binary=True, + spread=0, + tolerance=1e-8): + + self.annotated = False + + super(Synchrotool, self).__init__(spiketrains=spiketrains, + bin_size=bin_size, + sampling_rate=sampling_rate, + binary=binary, + spread=spread, + tolerance=tolerance) + + def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): + """ + Delete or extract synchronous spiking events. + + Parameters + ---------- + threshold : int + Threshold value for the deletion of spikes engaged in synchronous + activity. + * `deletion_threshold >= 2` leads to all spikes with a larger or + equal complexity value to be deleted/extracted. + * `deletion_threshold <= 1` leads to a ValueError, since this + would delete/extract all spikes and there are definitely more + efficient ways of doing so. + in_place : bool, optional + Determines whether the modification are made in place + on ``self.input_spiketrains``. + Default: False + mode : {'delete', 'extract'}, optional + Inversion of the mask for deletion of synchronous events. + * ``'delete'`` leads to the deletion of all spikes with + complexity >= `threshold`, + i.e. deletes synchronous spikes. + * ``'extract'`` leads to the deletion of all spikes with + complexity < `threshold`, i.e. extracts synchronous spikes. + Default: 'delete' + + Raises + ------ + ValueError + If `mode` is not one in {'delete', 'extract'}. + + If `threshold <= 1`. + + Returns + ------- + list of neo.SpikeTrain + List of spiketrains where the spikes with + ``complexity >= threshold`` have been deleted/extracted. + * If ``in_place`` is True, the returned list is the same as + ``self.input_spiketrains``. + * If ``in_place`` is False, the returned list is a deepcopy of + ``self.input_spiketrains``. + + """ + + if not self.annotated: + self.annotate_synchrofacts() + + if mode not in ['delete', 'extract']: + raise ValueError(f"Invalid mode '{mode}'. Valid modes are: " + f"'delete', 'extract'") + + if threshold <= 1: + raise ValueError('A deletion threshold <= 1 would result ' + 'in the deletion of all spikes.') + + if in_place: + spiketrain_list = self.input_spiketrains + else: + spiketrain_list = deepcopy(self.input_spiketrains) + + for idx, st in enumerate(spiketrain_list): + mask = st.array_annotations['complexity'] < threshold + if mode == 'extract': + mask = np.invert(mask) + new_st = st[mask] + spiketrain_list[idx] = new_st + if in_place: + segment = st.segment + if segment is None: + continue + + # replace link to spiketrain in segment + new_index = self._get_spiketrain_index( + segment.spiketrains, st) + segment.spiketrains[new_index] = new_st + + block = segment.block + if block is None: + continue + + # replace link to spiketrain in groups + for group in block.groups: + try: + idx = self._get_spiketrain_index( + group.spiketrains, + st) + except ValueError: + # st is not in this group, move to next group + continue + + # st found in group, replace with new_st + group.spiketrains[idx] = new_st + + return spiketrain_list + + def annotate_synchrofacts(self): + """ + Annotate the complexity of each spike in the + ``self.epoch.array_annotations`` *in-place*. + """ + epoch_complexities = self.epoch.array_annotations['complexity'] + right_edges = ( + self.epoch.times.magnitude.flatten() + + self.epoch.durations.rescale( + self.epoch.times.units).magnitude.flatten() + ) + + for idx, st in enumerate(self.input_spiketrains): + + # all indices of spikes that are within the half-open intervals + # defined by the boundaries + # note that every second entry in boundaries is an upper boundary + spike_to_epoch_idx = np.searchsorted( + right_edges, + st.times.rescale(self.epoch.times.units).magnitude.flatten()) + complexity_per_spike = epoch_complexities[spike_to_epoch_idx] + + st.array_annotate(complexity=complexity_per_spike) + + self.annotated = True + + def _get_spiketrain_index(self, spiketrain_list, spiketrain): + for index, item in enumerate(spiketrain_list): + if item is spiketrain: + return index + raise ValueError("Spiketrain is not found in the list") diff --git a/elephant/statistics.py b/elephant/statistics.py index bcb41451d..dad28e5d4 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -53,31 +53,30 @@ fanofactor complexity_pdf + Complexity :copyright: Copyright 2014-2020 by the Elephant team, see `doc/authors.rst`. :license: Modified BSD, see LICENSE.txt for details. """ from __future__ import division, print_function -# do not import unicode_literals -# (quantities rescale does not work with unicodes) +import math import warnings import neo import numpy as np -import math import quantities as pq -import scipy.signal import scipy.stats -from neo.core import SpikeTrain +import elephant.conversion as conv import elephant.kernels as kernels from elephant.conversion import BinnedSpikeTrain -from elephant.utils import deprecated_alias, get_common_start_stop_times, \ - check_neo_consistency +from elephant.utils import deprecated_alias, check_neo_consistency, \ + is_time_quantity, round_binning_errors -from elephant.utils import is_time_quantity +# do not import unicode_literals +# (quantities rescale does not work with unicodes) __all__ = [ "isi", @@ -90,6 +89,7 @@ "instantaneous_rate", "time_histogram", "complexity_pdf", + "Complexity", "fftkernel", "optimal_kernel_bandwidth" ] @@ -926,7 +926,7 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, # Divide by number of input spike trains and bin width bin_hist = bin_hist / (len(spiketrains) * bin_size) else: - raise ValueError('Parameter output is not valid.') + raise ValueError(f'Parameter output ({output}) is not valid.') return neo.AnalogSignal(signal=np.expand_dims(bin_hist, axis=1), sampling_period=bin_size, units=bin_hist.units, @@ -937,6 +937,9 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, @deprecated_alias(binsize='bin_size') def complexity_pdf(spiketrains, bin_size): """ + Deprecated in favor of the complexity class which has a pdf attribute. + Will be removed in the next release! + Complexity Distribution of a list of `neo.SpikeTrain` objects. Probability density computed from the complexity histogram which is the @@ -974,23 +977,388 @@ def complexity_pdf(spiketrains, bin_size): pp. 96-114, Springer Berlin Heidelberg, 2008. """ - # Computing the population histogram with parameter binary=True to clip the - # spike trains before summing - pophist = time_histogram(spiketrains, bin_size, binary=True) - - # Computing the histogram of the entries of pophist (=Complexity histogram) - complexity_hist = np.histogram( - pophist.magnitude, bins=range(0, len(spiketrains) + 2))[0] - - # Normalization of the Complexity Histogram to 1 (probabilty distribution) - complexity_hist = complexity_hist / complexity_hist.sum() - # Convert the Complexity pdf to an neo.AnalogSignal - complexity_distribution = neo.AnalogSignal( - np.expand_dims(complexity_hist, axis=1) * - pq.dimensionless, t_start=0 * pq.dimensionless, - sampling_period=1 * pq.dimensionless) - - return complexity_distribution + warnings.warn("complexity_pdf is deprecated in favor of the complexity " + "class which has a pdf attribute. complexity_pdf will be " + "removed in the next Elephant release.", DeprecationWarning) + + complexity = Complexity(spiketrains, bin_size=bin_size) + + return complexity.pdf() + + +class Complexity(object): + """ + Class for complexity distribution (i.e. number of synchronous spikes found) + of a list of `neo.SpikeTrain` objects. + + Complexity is calculated by counting the number of spikes (i.e. non-empty + bins) that occur separated by `spread - 1` or less empty bins, within and + across spike trains in the `spiketrains` list. + + Implementation (without spread) is based on [1]_. + + Parameters + ---------- + spiketrains : list of neo.SpikeTrain + Spike trains with a common time axis (same `t_start` and `t_stop`) + sampling_rate : pq.Quantity or None, optional + Sampling rate of the spike trains with units of 1/time. + Used to shift the epoch edges in order to avoid rounding errors. + If None using the epoch to slice spike trains may introduce + rounding errors. + Default: None + bin_size : pq.Quantity or None, optional + Width of the histogram's time bins with units of time. + The user must specify the `bin_size` or the `sampling_rate`. + * If None and the `sampling_rate` is available + 1/`sampling_rate` is used. + * If both are given then `bin_size` is used. + Default: None + binary : bool, optional + * If True then the time histograms will be binary. + * If False the total number of synchronous spikes is counted in the + time histogram. + Default: True + spread : int, optional + Number of bins in which to check for synchronous spikes. + Spikes that occur separated by `spread - 1` or less empty bins are + considered synchronous. + * ``spread = 0`` corresponds to a bincount accross spike trains. + * ``spread = 1`` corresponds to counting consecutive spikes. + * ``spread = 2`` corresponds to counting consecutive spikes and + spikes separated by exactly 1 empty bin. + * ``spread = n`` corresponds to counting spikes separated by exactly + or less than `n - 1` empty bins. + Default: 0 + tolerance : float or None, optional + Tolerance for rounding errors in the binning process and in the input + data. + If None possible binning errors are not accounted for. + Default: 1e-8 + + Attributes + ---------- + epoch : neo.Epoch + An epoch object containing complexity values, left edges and durations + of all intervals with at least one spike. + * ``epoch.array_annotations['complexity']`` contains the + complexity values per spike. + * ``epoch.times`` contains the left edges. + * ``epoch.durations`` contains the durations. + time_histogram : neo.Analogsignal + A `neo.AnalogSignal` object containing the histogram values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + * If ``binary = True`` : Number of neurons that spiked in each bin, + regardless of the number of spikes. + * If ``binary = False`` : Number of neurons and spikes per neurons + in each bin. + complexity_histogram : np.ndarray + The number of occurrences of events of different complexities. + `complexity_hist[i]` corresponds to the number of events of + complexity `i` for `i > 0`. + + Raises + ------ + ValueError + When `t_stop` is smaller than `t_start`. + + When both `sampling_rate` and `bin_size` are not specified. + + When `spread` is not a positive integer. + + When `spiketrains` is an empty list. + + When `t_start` is not the same for all spiketrains + + When `t_stop` is not the same for all spiketrains + + TypeError + When `spiketrains` is not a list. + + When the elements in `spiketrains` are not instances of neo.SpikeTrain + + Warns + ----- + UserWarning + If no sampling rate is supplied which may lead to rounding errors + when using the epoch to slice spike trains. + + Notes + ----- + * Note that with most common parameter combinations spike times can end up + on bin edges. This makes the binning susceptible to rounding errors which + is accounted for by moving spikes which are within tolerance of the next + bin edge into the following bin. This can be adjusted using the tolerance + parameter and turned off by setting `tolerance=None`. + + See also + -------- + elephant.conversion.BinnedSpikeTrain + elephant.spike_train_synchrony.Synchrotool + + References + ---------- + .. [1] S. Gruen, M. Abeles, & M. Diesmann, "Impact of higher-order + correlations on coincidence distributions of massively parallel + data," In "Dynamic Brain - from Neural Spikes to Behaviors", + pp. 96-114, Springer Berlin Heidelberg, 2008. + + Examples + -------- + >>> import neo + >>> import quantities as pq + >>> from elephant.statistics import Complexity + + >>> sr = 1/pq.ms + + >>> st1 = neo.SpikeTrain([1, 4, 6] * pq.ms, t_stop=10.0 * pq.ms) + >>> st2 = neo.SpikeTrain([1, 5, 8] * pq.ms, t_stop=10.0 * pq.ms) + >>> sts = [st1, st2] + + >>> # spread = 0, a simple bincount + >>> cpx = Complexity(sts, sampling_rate=sr) + Complexity calculated at sampling rate precision + >>> print(cpx.complexity_histogram) + [5 4 1] + >>> print(cpx.time_histogram.flatten()) + [0 2 0 0 1 1 1 0 1 0] dimensionless + >>> print(cpx.time_histogram.times) + [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] ms + + >>> # spread = 1, consecutive spikes + >>> cpx = Complexity(sts, sampling_rate=sr, spread=1) + Complexity calculated at sampling rate precision + >>> print(cpx.complexity_histogram) + [5 4 1] + >>> print(cpx.time_histogram.flatten()) + [0 2 0 0 3 3 3 0 1 0] dimensionless + + >>> # spread = 2, consecutive spikes and separated by 1 empty bin + >>> cpx = Complexity(sts, sampling_rate=sr, spread=2) + Complexity calculated at sampling rate precision + >>> print(cpx.complexity_histogram) + [4 0 1 0 1] + >>> print(cpx.time_histogram.flatten()) + [0 2 0 0 4 4 4 4 4 0] dimensionless + """ + + def __init__(self, spiketrains, + sampling_rate=None, + bin_size=None, + binary=True, + spread=0, + tolerance=1e-8): + + check_neo_consistency(spiketrains, object_type=neo.SpikeTrain) + + if bin_size is None and sampling_rate is None: + raise ValueError('No bin_size or sampling_rate was specified!') + + if spread < 0: + raise ValueError('Spread must be >=0') + + self.input_spiketrains = spiketrains + self.t_start = spiketrains[0].t_start + self.t_stop = spiketrains[0].t_stop + self.sampling_rate = sampling_rate + self.bin_size = bin_size + self.binary = binary + self.spread = spread + self.tolerance = tolerance + + if bin_size is None and sampling_rate is not None: + self.bin_size = 1 / self.sampling_rate + + if spread == 0: + self.time_histogram, self.complexity_histogram = \ + self._histogram_no_spread() + self.epoch = self._epoch_no_spread() + else: + self.epoch = self._epoch_with_spread() + self.time_histogram, self.complexity_histogram = \ + self._histogram_with_spread() + + def pdf(self): + """ + Probability density computed from the complexity histogram. + + Returns + ------- + pdf : neo.AnalogSignal + A `neo.AnalogSignal` object containing the pdf values. + `neo.AnalogSignal[j]` is the histogram computed between + `t_start + j * binsize` and `t_start + (j + 1) * binsize`. + """ + norm_hist = self.complexity_histogram / self.complexity_histogram.sum() + # Convert the Complexity pdf to an neo.AnalogSignal + pdf = neo.AnalogSignal( + np.expand_dims(norm_hist, axis=1), + units=pq.dimensionless, + t_start=0 * pq.dimensionless, + sampling_period=1 * pq.dimensionless) + return pdf + + def _histogram_no_spread(self): + """ + Calculate the complexity histogram and time histogram for `spread` = 0 + """ + # Computing the population histogram with parameter binary=True to + # clip the spike trains before summing + time_hist = time_histogram(self.input_spiketrains, + self.bin_size, + binary=self.binary) + + # Computing the histogram of the entries of pophist + complexity_hist = np.histogram( + time_hist.magnitude, + bins=range(0, len(self.input_spiketrains) + 2))[0] + + return time_hist, complexity_hist + + def _histogram_with_spread(self): + """ + Calculate the complexity histogram and time histogram for `spread` > 0 + """ + complexity_hist = np.bincount( + self.epoch.array_annotations['complexity']) + num_bins = (self.t_stop - self.t_start).rescale( + self.bin_size.units).item() / self.bin_size.item() + num_bins = round_binning_errors(num_bins, tolerance=self.tolerance) + time_hist = np.zeros(num_bins, dtype=int) + + start_bins = (self.epoch.times - self.t_start).rescale( + self.bin_size.units).magnitude / self.bin_size.item() + stop_bins = (self.epoch.times + self.epoch.durations - self.t_start + ).rescale(self.bin_size.units + ).magnitude / self.bin_size.item() + + if self.sampling_rate is not None: + shift = (.5 / self.sampling_rate / self.bin_size).simplified.item() + # account for the first bin not being shifted in the epoch creation + # if the shift would move it past t_start + if self.epoch.times[0] == self.t_start: + start_bins[1:] += shift + else: + start_bins += shift + stop_bins += shift + + start_bins = round_binning_errors(start_bins, tolerance=self.tolerance) + stop_bins = round_binning_errors(stop_bins, tolerance=self.tolerance) + + for idx, (start, stop) in enumerate(zip(start_bins, stop_bins)): + time_hist[start:stop] = \ + self.epoch.array_annotations['complexity'][idx] + + time_hist = neo.AnalogSignal( + signal=np.expand_dims(time_hist, axis=1), + sampling_period=self.bin_size, units=pq.dimensionless, + t_start=self.t_start) + + empty_bins = (self.t_stop - self.t_start - self.epoch.durations.sum()) + empty_bins = empty_bins.rescale(self.bin_size.units + ).magnitude / self.bin_size.item() + empty_bins = round_binning_errors(empty_bins, tolerance=self.tolerance) + complexity_hist[0] = empty_bins + + return time_hist, complexity_hist + + def _epoch_no_spread(self): + """ + Get an epoch object of the complexity distribution with `spread` = 0 + """ + left_edges = self.time_histogram.times + durations = self.bin_size * np.ones(self.time_histogram.shape) + + if self.sampling_rate: + # ensure that spikes are not on the bin edges + bin_shift = .5 / self.sampling_rate + left_edges -= bin_shift + + # Ensure that an epoch does not start before the minimum t_start. + # Note: all spike trains share the same t_start and t_stop. + if left_edges[0] < self.t_start: + left_edges[0] = self.t_start + durations[0] -= bin_shift + else: + warnings.warn('No sampling rate specified. ' + 'Note that using the complexity epoch to get ' + 'precise spike times can lead to rounding errors.') + + epoch = neo.Epoch(left_edges, + durations=durations, + array_annotations={ + 'complexity': + self.time_histogram.magnitude.flatten()}) + return epoch + + def _epoch_with_spread(self): + """ + Get an epoch object of the complexity distribution with `spread` > 0 + """ + bst = conv.BinnedSpikeTrain(self.input_spiketrains, + binsize=self.bin_size, + tolerance=self.tolerance) + + if self.binary: + bst = bst.binarize() + bincount = bst.get_num_of_spikes(axis=0) + + nonzero_indices = np.nonzero(bincount)[0] + left_diff = np.diff(nonzero_indices, + prepend=-self.spread - 1) + right_diff = np.diff(nonzero_indices, + append=len(bincount) + self.spread + 1) + + # standalone bins (no merging required) + single_bin_indices = np.logical_and(left_diff > self.spread, + right_diff > self.spread) + single_bins = nonzero_indices[single_bin_indices] + + # bins separated by fewer than spread bins form clusters + # that have to be merged + cluster_start_indices = np.logical_and(left_diff > self.spread, + right_diff <= self.spread) + cluster_starts = nonzero_indices[cluster_start_indices] + cluster_stop_indices = np.logical_and(left_diff <= self.spread, + right_diff > self.spread) + cluster_stops = nonzero_indices[cluster_stop_indices] + 1 + + single_bin_complexities = bincount[single_bins] + cluster_complexities = [bincount[start:stop].sum() + for start, stop in zip(cluster_starts, + cluster_stops)] + + # merge standalone bins and clusters and sort them + combined_starts = np.concatenate((single_bins, cluster_starts)) + combined_stops = np.concatenate((single_bins + 1, cluster_stops)) + combined_complexities = np.concatenate((single_bin_complexities, + cluster_complexities)) + sorting = np.argsort(combined_starts, kind='mergesort') + left_edges = bst.bin_edges[combined_starts[sorting]] + right_edges = bst.bin_edges[combined_stops[sorting]] + complexities = combined_complexities[sorting].astype(int) + + if self.sampling_rate: + # ensure that spikes are not on the bin edges + bin_shift = .5 / self.sampling_rate + left_edges -= bin_shift + right_edges -= bin_shift + else: + warnings.warn('No sampling rate specified. ' + 'Note that using the complexity epoch to get ' + 'precise spike times can lead to rounding errors.') + + # Ensure that an epoch does not start before the minimum t_start. + # Note: all spike trains share the same t_start and t_stop. + left_edges[0] = max(self.t_start, left_edges[0]) + + complexity_epoch = neo.Epoch(times=left_edges, + durations=right_edges - left_edges, + array_annotations={'complexity': + complexities}) + + return complexity_epoch """ diff --git a/elephant/test/test_spike_train_synchrony.py b/elephant/test/test_spike_train_synchrony.py index ee840bb07..ee7e1e374 100644 --- a/elephant/test/test_spike_train_synchrony.py +++ b/elephant/test/test_spike_train_synchrony.py @@ -5,15 +5,17 @@ import neo import numpy as np -from numpy.testing import assert_array_equal +import quantities as pq +from numpy.testing import assert_array_almost_equal, assert_array_equal from quantities import Hz, ms, second -import elephant.spike_train_synchrony as spc import elephant.spike_train_generation as stgen +from elephant.spike_train_synchrony import Synchrotool, spike_contrast, \ + _get_theta_and_n_per_bin, _binning_half_overlap from elephant.test.download import download, unzip -class TestUM(unittest.TestCase): +class TestSpikeContrast(unittest.TestCase): def test_spike_contrast_random(self): # randomly generated spiketrains that share the same t_start and @@ -39,7 +41,7 @@ def test_spike_contrast_random(self): t_stop=10000. * ms) spike_trains = [spike_train_1, spike_train_2, spike_train_3, spike_train_4, spike_train_5, spike_train_6] - synchrony = spc.spike_contrast(spike_trains) + synchrony = spike_contrast(spike_trains) self.assertAlmostEqual(synchrony, 0.2098687, places=6) def test_spike_contrast_same_signal(self): @@ -48,7 +50,7 @@ def test_spike_contrast_same_signal(self): t_start=0. * ms, t_stop=10000. * ms) spike_trains = [spike_train, spike_train] - synchrony = spc.spike_contrast(spike_trains, min_bin=1 * ms) + synchrony = spike_contrast(spike_trains, min_bin=1 * ms) self.assertEqual(synchrony, 1.0) def test_spike_contrast_double_duration(self): @@ -64,7 +66,7 @@ def test_spike_contrast_double_duration(self): t_stop=10000. * ms) spike_trains = [spike_train_1, spike_train_2, spike_train_3] - synchrony = spc.spike_contrast(spike_trains, t_stop=20000 * ms) + synchrony = spike_contrast(spike_trains, t_stop=20000 * ms) self.assertEqual(synchrony, 0.5) def test_spike_contrast_non_overlapping_spiketrains(self): @@ -76,7 +78,7 @@ def test_spike_contrast_non_overlapping_spiketrains(self): t_start=5000. * ms, t_stop=10000. * ms) spiketrains = [spike_train_1, spike_train_2] - synchrony = spc.spike_contrast(spiketrains, t_stop=5000 * ms) + synchrony = spike_contrast(spiketrains, t_stop=5000 * ms) # the synchrony of non-overlapping spiketrains must be zero self.assertEqual(synchrony, 0.) @@ -86,7 +88,7 @@ def test_spike_contrast_trace(self): t_stop=1000. * ms) spike_train_2 = stgen.homogeneous_poisson_process(rate=20 * Hz, t_stop=1000. * ms) - synchrony, trace = spc.spike_contrast([spike_train_1, spike_train_2], + synchrony, trace = spike_contrast([spike_train_1, spike_train_2], return_trace=True) self.assertEqual(synchrony, max(trace.synchrony)) self.assertEqual(len(trace.contrast), len(trace.active_spiketrains)) @@ -94,28 +96,28 @@ def test_spike_contrast_trace(self): def test_invalid_data(self): # invalid spiketrains - self.assertRaises(TypeError, spc.spike_contrast, [[0, 1], [1.5, 2.3]]) - self.assertRaises(ValueError, spc.spike_contrast, + self.assertRaises(TypeError, spike_contrast, [[0, 1], [1.5, 2.3]]) + self.assertRaises(ValueError, spike_contrast, [neo.SpikeTrain([10] * ms, t_stop=1000 * ms), neo.SpikeTrain([20] * ms, t_stop=1000 * ms)]) # a single spiketrain spiketrain_valid = neo.SpikeTrain([0, 1000] * ms, t_stop=1000 * ms) - self.assertRaises(ValueError, spc.spike_contrast, [spiketrain_valid]) + self.assertRaises(ValueError, spike_contrast, [spiketrain_valid]) spiketrain_valid2 = neo.SpikeTrain([500, 800] * ms, t_stop=1000 * ms) spiketrains = [spiketrain_valid, spiketrain_valid2] # invalid shrink factor - self.assertRaises(ValueError, spc.spike_contrast, spiketrains, + self.assertRaises(ValueError, spike_contrast, spiketrains, bin_shrink_factor=0.) # invalid t_start, t_stop, and min_bin - self.assertRaises(TypeError, spc.spike_contrast, spiketrains, + self.assertRaises(TypeError, spike_contrast, spiketrains, t_start=0) - self.assertRaises(TypeError, spc.spike_contrast, spiketrains, + self.assertRaises(TypeError, spike_contrast, spiketrains, t_stop=1000) - self.assertRaises(TypeError, spc.spike_contrast, spiketrains, + self.assertRaises(TypeError, spike_contrast, spiketrains, min_bin=0.01) def test_t_start_agnostic(self): @@ -126,7 +128,7 @@ def test_t_start_agnostic(self): spike_train_2 = stgen.homogeneous_poisson_process(rate=10 * Hz, t_stop=t_stop) spiketrains = [spike_train_1, spike_train_2] - synchrony_target = spc.spike_contrast(spiketrains) + synchrony_target = spike_contrast(spiketrains) # a check for developer: test meaningful result assert synchrony_target > 0 t_shift = 20 * second @@ -136,7 +138,7 @@ def test_t_start_agnostic(self): t_stop=t_stop + t_shift) for st in spiketrains ] - synchrony = spc.spike_contrast(spiketrains_shifted) + synchrony = spike_contrast(spiketrains_shifted) self.assertAlmostEqual(synchrony, synchrony_target) def test_get_theta_and_n_per_bin(self): @@ -145,7 +147,7 @@ def test_get_theta_and_n_per_bin(self): [1, 2, 3, 9], [1, 2, 2.5] ] - theta, n = spc._get_theta_and_n_per_bin(spike_trains, + theta, n = _get_theta_and_n_per_bin(spike_trains, t_start=0, t_stop=10, bin_size=5) @@ -158,7 +160,7 @@ def test_binning_half_overlap(self): t_start = 0 t_stop = 10 edges = np.arange(t_start, t_stop + bin_step, bin_step) - histogram = spc._binning_half_overlap(spiketrain, edges=edges) + histogram = _binning_half_overlap(spiketrain, edges=edges) assert_array_equal(histogram, [3, 1, 1]) def test_spike_contrast_with_Izhikevich_network_auto(self): @@ -190,9 +192,256 @@ def test_spike_contrast_with_Izhikevich_network_auto(self): neo.SpikeTrain(st, t_start=0 * second, t_stop=2 * second, units=second) for st in simulation['spiketrains']] - synchrony = spc.spike_contrast(spiketrains) + synchrony = spike_contrast(spiketrains) self.assertAlmostEqual(synchrony, synchrony_true, places=2) +class SynchrofactDetectionTestCase(unittest.TestCase): + + def _test_template(self, spiketrains, correct_complexities, sampling_rate, + spread, deletion_threshold=2, mode='delete', + in_place=False, binary=True): + + synchrofact_obj = Synchrotool( + spiketrains, + sampling_rate=sampling_rate, + binary=binary, + spread=spread) + + # test annotation + synchrofact_obj.annotate_synchrofacts() + + annotations = [st.array_annotations['complexity'] + for st in spiketrains] + + assert_array_equal(annotations, correct_complexities) + + if mode == 'extract': + correct_spike_times = [ + spikes[mask] for spikes, mask + in zip(spiketrains, + correct_complexities >= deletion_threshold) + ] + else: + correct_spike_times = [ + spikes[mask] for spikes, mask + in zip(spiketrains, + correct_complexities < deletion_threshold) + ] + + # test deletion + synchrofact_obj.delete_synchrofacts(threshold=deletion_threshold, + in_place=in_place, + mode=mode) + + cleaned_spike_times = [st.times for st in spiketrains] + + for correct_st, cleaned_st in zip(correct_spike_times, + cleaned_spike_times): + assert_array_almost_equal(cleaned_st, correct_st) + + def test_no_synchrofacts(self): + + # nothing to find here + # there used to be an error for spread > 0 when nothing was found + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 9, 12, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([3, 7, 15, 17] * pq.s, + t_stop=20*pq.s)] + + correct_annotations = np.array([[1, 1, 1, 1], + [1, 1, 1, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, mode='delete', + deletion_threshold=2) + + def test_spread_0(self): + + # basic test with a minimum number of two spikes per synchrofact + # only taking into account multiple spikes + # within one bin of size 1 / sampling_rate + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, + t_stop=20*pq.s)] + + correct_annotations = np.array([[2, 1, 1, 1, 2, 1], + [2, 1, 1, 1, 2, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=0, mode='delete', in_place=True, + deletion_threshold=2) + + def test_spread_1(self): + + # test synchrofact search taking into account adjacent bins + # this requires an additional loop with shifted binning + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_annotations = np.array([[2, 2, 1, 3, 3, 1], + [2, 2, 1, 3, 1, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, mode='delete', in_place=True, + deletion_threshold=2) + + def test_n_equals_3(self): + + # test synchrofact detection with a minimum number of + # three spikes per synchrofact + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 1, 5, 10, 13, 16, 17, 19] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 9, 12, 14, 16, 20] * pq.s, + t_stop=21*pq.s)] + + correct_annotations = np.array([[3, 3, 2, 2, 3, 3, 3, 2], + [3, 2, 1, 2, 3, 3, 3, 2]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, mode='delete', binary=False, + in_place=True, deletion_threshold=3) + + def test_extract(self): + + # test synchrofact search taking into account adjacent bins + # this requires an additional loop with shifted binning + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_annotations = np.array([[2, 2, 1, 3, 3, 1], + [2, 2, 1, 3, 1, 1]]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=1, mode='extract', in_place=True, + deletion_threshold=2) + + def test_binning_for_input_with_rounding_errors(self): + + # a test with inputs divided by 30000 which leads to rounding errors + # these errors have to be accounted for by proper binning; + # check if we still get the correct result + + sampling_rate = 30000 / pq.s + + spiketrains = [neo.SpikeTrain(np.arange(1000) * pq.s / 30000, + t_stop=.1 * pq.s), + neo.SpikeTrain(np.arange(2000, step=2) * pq.s / 30000, + t_stop=.1 * pq.s)] + + first_annotations = np.ones(1000) + first_annotations[::2] = 2 + + second_annotations = np.ones(1000) + second_annotations[:500] = 2 + + correct_annotations = np.array([first_annotations, + second_annotations]) + + self._test_template(spiketrains, correct_annotations, sampling_rate, + spread=0, mode='delete', in_place=True, + deletion_threshold=2) + + def test_correct_transfer_of_spiketrain_attributes(self): + + # for delete=True the spiketrains in the block are changed, + # test if their attributes remain correct + + sampling_rate = 1 / pq.s + + spiketrain = neo.SpikeTrain([1, 1, 5, 0] * pq.s, + t_stop=10 * pq.s) + + block = neo.Block() + + group = neo.Group(name='Test Group') + block.groups.append(group) + group.spiketrains.append(spiketrain) + + segment = neo.Segment() + block.segments.append(segment) + segment.block = block + segment.spiketrains.append(spiketrain) + spiketrain.segment = segment + + spiketrain.annotate(cool_spike_train=True) + spiketrain.array_annotate( + spike_number=np.arange(len(spiketrain.times.magnitude))) + spiketrain.waveforms = np.sin( + np.arange(len(spiketrain.times.magnitude))[:, np.newaxis] + + np.arange(len(spiketrain.times.magnitude))[np.newaxis, :]) + + correct_mask = np.array([False, False, True, True]) + + # store the correct attributes + correct_annotations = spiketrain.annotations.copy() + correct_waveforms = spiketrain.waveforms[correct_mask].copy() + correct_array_annotations = {key: value[correct_mask] for key, value in + spiketrain.array_annotations.items()} + + # perform a synchrofact search with delete=True + synchrofact_obj = Synchrotool( + [spiketrain], + spread=0, + sampling_rate=sampling_rate, + binary=False) + synchrofact_obj.delete_synchrofacts( + mode='delete', + in_place=True, + threshold=2) + + # Ensure that the spiketrain was not duplicated + self.assertEqual(len(block.filter(objects=neo.SpikeTrain)), 1) + + cleaned_spiketrain = segment.spiketrains[0] + + # Ensure that the spiketrain is also in the group + self.assertEqual(len(block.groups[0].spiketrains), 1) + self.assertIs(block.groups[0].spiketrains[0], cleaned_spiketrain) + + cleaned_annotations = cleaned_spiketrain.annotations + cleaned_waveforms = cleaned_spiketrain.waveforms + cleaned_array_annotations = cleaned_spiketrain.array_annotations + cleaned_array_annotations.pop('complexity') + + self.assertDictEqual(correct_annotations, cleaned_annotations) + assert_array_almost_equal(cleaned_waveforms, correct_waveforms) + self.assertTrue(len(cleaned_array_annotations) + == len(correct_array_annotations)) + for key, value in correct_array_annotations.items(): + self.assertTrue(key in cleaned_array_annotations.keys()) + assert_array_almost_equal(value, cleaned_array_annotations[key]) + + def test_wrong_input_errors(self): + synchrofact_obj = Synchrotool( + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], + sampling_rate=1/pq.s, + binary=True, + spread=1) + self.assertRaises(ValueError, + synchrofact_obj.delete_synchrofacts, + -1) + + if __name__ == '__main__': unittest.main() diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index ede8c0d32..d6b727262 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -441,10 +441,8 @@ def test_lvr_refractoriness_kwarg(self): def test_2short_spike_train(self): seq = [1] with self.assertWarns(UserWarning): - """ - Catches UserWarning: Input size is too small. Please provide - an input with more than 1 entry. - """ + # Catches UserWarning: Input size is too small. Please provide + # an input with more than 1 entry. self.assertTrue(math.isnan(statistics.lvr(seq, with_nan=True))) @@ -593,8 +591,8 @@ def test_rate_estimation_consistency(self): center_kernel=center_kernel) num_spikes = len(self.spike_train) auc = spint.cumtrapz( - y=rate_estimate.magnitude.squeeze(), - x=rate_estimate.times.simplified.magnitude)[-1] + y=rate_estimate.magnitude[:, 0], + x=rate_estimate.times.rescale('s').magnitude)[-1] self.assertAlmostEqual(num_spikes, auc, delta=0.01 * num_spikes) @@ -922,32 +920,187 @@ def test_annotations(self): class ComplexityPdfTestCase(unittest.TestCase): - def setUp(self): - self.spiketrain_a = neo.SpikeTrain( + def test_complexity_pdf_deprecated(self): + spiketrain_a = neo.SpikeTrain( [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) - self.spiketrain_b = neo.SpikeTrain( + spiketrain_b = neo.SpikeTrain( [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) - self.spiketrain_c = neo.SpikeTrain( + spiketrain_c = neo.SpikeTrain( [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) - self.spiketrains = [ - self.spiketrain_a, self.spiketrain_b, self.spiketrain_c] - - def tearDown(self): - del self.spiketrain_a - self.spiketrain_a = None - del self.spiketrain_b - self.spiketrain_b = None - - def test_complexity_pdf(self): + spiketrains = [ + spiketrain_a, spiketrain_b, spiketrain_c] + # runs the previous function which will be deprecated targ = np.array([0.92, 0.01, 0.01, 0.06]) - complexity = statistics.complexity_pdf(self.spiketrains, - bin_size=0.1 * pq.s) + complexity = statistics.complexity_pdf(spiketrains, binsize=0.1*pq.s) assert_array_equal(targ, complexity.magnitude[:, 0]) self.assertEqual(1, complexity.magnitude[:, 0].sum()) - self.assertEqual(len(self.spiketrains) + 1, len(complexity)) + self.assertEqual(len(spiketrains)+1, len(complexity)) self.assertIsInstance(complexity, neo.AnalogSignal) self.assertEqual(complexity.units, 1 * pq.dimensionless) + def test_complexity_pdf(self): + spiketrain_a = neo.SpikeTrain( + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 6.7] * pq.s, t_stop=10.0 * pq.s) + spiketrain_b = neo.SpikeTrain( + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) + spiketrain_c = neo.SpikeTrain( + [0.5, 0.7, 1.2, 2.3, 4.3, 5.5, 8.0] * pq.s, t_stop=10.0 * pq.s) + spiketrains = [ + spiketrain_a, spiketrain_b, spiketrain_c] + # runs the previous function which will be deprecated + targ = np.array([0.92, 0.01, 0.01, 0.06]) + complexity_obj = statistics.Complexity(spiketrains, + bin_size=0.1 * pq.s) + pdf = complexity_obj.pdf() + assert_array_equal(targ, complexity_obj.pdf().magnitude[:, 0]) + self.assertEqual(1, pdf.magnitude[:, 0].sum()) + self.assertEqual(len(spiketrains)+1, len(pdf)) + self.assertIsInstance(pdf, neo.AnalogSignal) + self.assertEqual(pdf.units, 1*pq.dimensionless) + + def test_complexity_histogram_spread_0(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, + t_stop=20*pq.s)] + + correct_histogram = np.array([10, 8, 2]) + + correct_time_histogram = np.array([0, 2, 0, 0, 1, 1, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 2, 0, 1, 1]) + + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=0) + + assert_array_equal(complexity_obj.complexity_histogram, + correct_histogram) + + assert_array_equal( + complexity_obj.time_histogram.magnitude.flatten().astype(int), + correct_time_histogram) + + def test_complexity_epoch_spread_0(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 16, 19] * pq.s, + t_stop=20*pq.s), + neo.SpikeTrain([1, 4, 8, 12, 16, 18] * pq.s, + t_stop=20*pq.s)] + + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=0) + + self.assertIsInstance(complexity_obj.epoch, + neo.Epoch) + + def test_complexity_histogram_spread_1(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([0, 1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_histogram = np.array([9, 5, 1, 2]) + + correct_time_histogram = np.array([3, 3, 0, 0, 2, 2, 0, 1, 0, 1, 0, + 3, 3, 3, 0, 0, 1, 0, 1, 0, 1]) + + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=1) + + assert_array_equal(complexity_obj.complexity_histogram, + correct_histogram) + + assert_array_equal( + complexity_obj.time_histogram.magnitude.flatten().astype(int), + correct_time_histogram) + + def test_complexity_histogram_spread_2(self): + + sampling_rate = 1 / pq.s + + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + correct_histogram = np.array([5, 0, 1, 1, 0, 0, 0, 1]) + + correct_time_histogram = np.array([0, 2, 0, 0, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 0, 0, 3, 3, 3, 3, 3]) + + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=2) + + assert_array_equal(complexity_obj.complexity_histogram, + correct_histogram) + + assert_array_equal( + complexity_obj.time_histogram.magnitude.flatten().astype(int), + correct_time_histogram) + + def test_wrong_input_errors(self): + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + self.assertRaises(ValueError, + statistics.Complexity, + spiketrains) + + self.assertRaises(ValueError, + statistics.Complexity, + spiketrains, + sampling_rate=1*pq.s, + spread=-7) + + def test_sampling_rate_warning(self): + spiketrains = [neo.SpikeTrain([1, 5, 9, 11, 13, 20] * pq.s, + t_stop=21*pq.s), + neo.SpikeTrain([1, 4, 7, 12, 16, 18] * pq.s, + t_stop=21*pq.s)] + + with self.assertWarns(UserWarning): + statistics.Complexity(spiketrains, + bin_size=1*pq.s, + spread=1) + + def test_binning_for_input_with_rounding_errors(self): + + # a test with inputs divided by 30000 which leads to rounding errors + # these errors have to be accounted for by proper binning; + # check if we still get the correct result + + sampling_rate = 333 / pq.s + + spiketrains = [neo.SpikeTrain(np.arange(1000, step=2) * pq.s / 333, + t_stop=30.33333333333 * pq.s), + neo.SpikeTrain(np.arange(2000, step=4) * pq.s / 333, + t_stop=30.33333333333 * pq.s)] + + correct_time_histogram = np.zeros(10101) + correct_time_histogram[:1000:2] = 1 + correct_time_histogram[:2000:4] += 1 + + complexity_obj = statistics.Complexity(spiketrains, + sampling_rate=sampling_rate, + spread=1) + + assert_array_equal( + complexity_obj.time_histogram.magnitude.flatten().astype(int), + correct_time_histogram) + if __name__ == '__main__': unittest.main() diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py new file mode 100644 index 000000000..5db79bd23 --- /dev/null +++ b/elephant/test/test_utils.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +""" +Unit tests for the synchrofact detection app +""" + +import unittest + +import neo +import numpy as np +import quantities as pq + +from elephant import utils +from numpy.testing import assert_array_equal + + +class TestUtils(unittest.TestCase): + + def test_check_neo_consistency(self): + self.assertRaises(TypeError, + utils.check_neo_consistency, + [], object_type=neo.SpikeTrain) + self.assertRaises(TypeError, + utils.check_neo_consistency, + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), + np.arange(2)], object_type=neo.SpikeTrain) + self.assertRaises(ValueError, + utils.check_neo_consistency, + [neo.SpikeTrain([1]*pq.s, + t_start=1*pq.s, + t_stop=2*pq.s), + neo.SpikeTrain([1]*pq.s, + t_start=0*pq.s, + t_stop=2*pq.s)], + object_type=neo.SpikeTrain) + self.assertRaises(ValueError, + utils.check_neo_consistency, + [neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s), + neo.SpikeTrain([1]*pq.s, t_stop=3*pq.s)], + object_type=neo.SpikeTrain) + self.assertRaises(ValueError, + utils.check_neo_consistency, + [neo.SpikeTrain([1]*pq.ms, t_stop=2000*pq.ms), + neo.SpikeTrain([1]*pq.s, t_stop=2*pq.s)], + object_type=neo.SpikeTrain) + + def test_round_binning_errors(self): + with self.assertWarns(UserWarning): + n_bins = utils.round_binning_errors(0.999999, tolerance=1e-6) + self.assertEqual(n_bins, 1) + self.assertEqual(utils.round_binning_errors(0.999999, tolerance=None), + 0) + array = np.array([0, 0.7, 1 - 1e-8, 1 - 1e-9]) + with self.assertWarns(UserWarning): + corrected = utils.round_binning_errors(array.copy()) + assert_array_equal(corrected, [0, 0, 1, 1]) + assert_array_equal( + utils.round_binning_errors(array.copy(), tolerance=None), + [0, 0, 0, 0]) + + +if __name__ == '__main__': + unittest.main() diff --git a/elephant/utils.py b/elephant/utils.py index dcd1bb916..eb226518c 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -1,3 +1,13 @@ +""" +.. autosummary:: + :toctree: toctree/utils + + is_time_quantity + get_common_start_stop_times + check_neo_consistency + round_binning_errors +""" + from __future__ import division, print_function, unicode_literals import warnings @@ -96,7 +106,7 @@ def is_time_quantity(x, allow_none=False): def get_common_start_stop_times(neo_objects): """ - Extracts the `t_start`and the `t_stop` from the input neo objects. + Extracts the common `t_start` and the `t_stop` from the input neo objects. If a single neo object is given, its `t_start` and `t_stop` is returned. Otherwise, the aligned times are returned: the maximal `t_start` and @@ -136,7 +146,7 @@ def get_common_start_stop_times(neo_objects): def check_neo_consistency(neo_objects, object_type, t_start=None, - t_stop=None, tolerance=1e-6): + t_stop=None, tolerance=1e-8): """ Checks that all input neo objects share the same units, t_start, and t_stop. @@ -167,9 +177,8 @@ def check_neo_consistency(neo_objects, object_type, t_start=None, units = neo_objects[0].units start = neo_objects[0].t_start.item() stop = neo_objects[0].t_stop.item() - except AttributeError: - raise TypeError("The input must be a list of {}. Got {}".format( - object_type.__name__, type(neo_objects[0]).__name__)) + except (IndexError, AttributeError): + raise TypeError(f"The input must be a list of {object_type.__name__}") if tolerance is None: tolerance = 0 for neo_obj in neo_objects: @@ -214,3 +223,56 @@ def check_same_units(quantities, object_type=pq.Quantity): raise ValueError("The input quantities must have the same units, " "which is achieved with object.rescale('ms') " "operation.") + + +def round_binning_errors(values, tolerance=1e-8): + """ + Round the input `values` in-place due to the machine floating point + precision errors. + + Parameters + ---------- + values : np.ndarray or float + An input array or a scalar. + tolerance : float or None, optional + The precision error absolute tolerance; acts as ``atol`` in + :func:`numpy.isclose` function. If None, no rounding is performed. + Default: 1e-8 + + Returns + ------- + values : np.ndarray or int + Corrected integer values. + + Examples + -------- + >>> from elephant.utils import round_binning_errors + >>> round_binning_errors(0.999999, tolerance=None) + 0 + >>> round_binning_errors(0.999999, tolerance=1e-6) + 1 + """ + if tolerance is None or tolerance == 0: + if isinstance(values, np.ndarray): + return values.astype(np.int32) + return int(values) # a scalar + + # same as '1 - (values % 1) <= tolerance' but faster + correction_mask = 1 - tolerance <= values % 1 + if isinstance(values, np.ndarray): + num_corrections = correction_mask.sum() + if num_corrections > 0: + warnings.warn(f'Correcting {num_corrections} rounding errors by ' + f'shifting the affected spikes into the following ' + f'bin. You can set tolerance=None to disable this ' + 'behaviour.') + values[correction_mask] += 0.5 + return values.astype(np.int32) + + if correction_mask: + warnings.warn('Correcting a rounding error in the calculation ' + 'of the number of bins by incrementing the value by 1. ' + 'You can set tolerance=None to disable this ' + 'behaviour.') + values += 0.5 + return int(values) diff --git a/requirements/requirements-docs.txt b/requirements/requirements-docs.txt index 2b8e828ca..2e7d10072 100644 --- a/requirements/requirements-docs.txt +++ b/requirements/requirements-docs.txt @@ -6,3 +6,4 @@ nbsphinx>=0.8.0 sphinxcontrib-bibtex>=1.0.0 sphinx-tabs>=1.3.0 matplotlib>=3.3.2 +# conda install -c conda-forge pandoc