diff --git a/src/pyxdf/pyxdf.py b/src/pyxdf/pyxdf.py index e40dc57..79c8b07 100644 --- a/src/pyxdf/pyxdf.py +++ b/src/pyxdf/pyxdf.py @@ -334,7 +334,7 @@ def load_xdf( f.read(chunklen - 2) # Concatenate the signal across chunks - for stream in temp.values(): + for stream_id, stream in temp.items(): if stream.time_stamps: # stream with non-empty list of chunks stream.time_stamps = np.concatenate(stream.time_stamps) @@ -342,6 +342,10 @@ def load_xdf( stream.time_series = list(itertools.chain(*stream.time_series)) else: stream.time_series = np.concatenate(stream.time_series) + # Handle samples that may have arrived out-of-order, sorting + # data by ground truth timestamps if necessary. Identical + # timestamps will remain, but can be handled by dejittering. + stream = _ensure_sorted(stream_id, stream) else: # stream without any chunks stream.time_stamps = np.zeros((0,)) @@ -534,6 +538,25 @@ def _scan_forward(f): return False +def _ensure_sorted(stream_id, stream): + diffs = np.diff(stream.time_stamps) + non_strict_inc_count = np.sum(diffs <= 0) + if non_strict_inc_count > 0: + msg = " stream %d not monotonic %d sample(s) out-of-order. Sorting..." + logger.info(msg, stream_id, non_strict_inc_count) + ind = np.argsort(stream.time_stamps, kind="stable") + stream.time_stamps = stream.time_stamps[ind] + if stream.fmt == "string": + stream.time_series = np.array(stream.time_series)[ind].tolist() + else: + stream.time_series = stream.time_series[ind] + identical_timestamp_count = len(diffs) - np.count_nonzero(diffs) + if identical_timestamp_count > 0: + msg = " stream %d contains %d identical timestamp(s)." + logger.info(msg, stream_id, identical_timestamp_count) + return stream + + def _clock_sync( streams, handle_clock_resets=True, @@ -629,9 +652,7 @@ def _detect_breaks(stream, threshold_seconds=1.0, threshold_samples=500): """Detect breaks in the time_stamps of a stream.""" # Identify breaks in the time_stamps diffs = np.diff(stream.time_stamps) - b_breaks = (diffs <= 0) | ( - diffs > np.max((threshold_seconds, threshold_samples * stream.tdiff)) - ) + b_breaks = diffs > np.max((threshold_seconds, threshold_samples * stream.tdiff)) # find indices (+ 1 to compensate for lost sample in np.diff) break_inds = np.where(b_breaks)[0] + 1 return break_inds diff --git a/test/test_jitter_removal.py b/test/test_jitter_removal.py index a3a192d..23be9e8 100644 --- a/test/test_jitter_removal.py +++ b/test/test_jitter_removal.py @@ -1,10 +1,19 @@ -from pyxdf.pyxdf import _detect_breaks +import numpy as np +from pyxdf.pyxdf import _detect_breaks, _ensure_sorted class MockStreamData: - def __init__(self, time_stamps, tdiff): - self.time_stamps = time_stamps + def __init__(self, time_stamps, tdiff, fmt="float32"): + self.time_stamps = np.array(time_stamps) self.tdiff = tdiff + self.fmt = fmt + if fmt == "string": + self.time_series = [str(x) for x in time_stamps] + else: + self.time_series = np.array(time_stamps, dtype=fmt) + + +# Monotonic timeseries data. def test_detect_no_breaks_seconds(): @@ -33,7 +42,7 @@ def test_detect_breaks_seconds(): timestamps = list(range(-5, 5, 2)) stream = MockStreamData(timestamps, 1) # if diff > 1 and larger 0 * tdiff -> 4 - breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0) + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=0) assert breaks.size == len(timestamps) - 1 @@ -75,37 +84,53 @@ def test_detect_breaks_gap_in_positive(): assert breaks.size == len(timestamps) - 1 +# Non-monotonic timeseries data. + + def test_detect_breaks_reverse(): timestamps = list(reversed(range(-5, 5))) stream = MockStreamData(timestamps, 1) - # if diff <= 0 -> 9 - breaks = _detect_breaks(stream, threshold_seconds=0, threshold_samples=0) - assert breaks.size == len(timestamps) - 1 + stream = _ensure_sorted(1, stream) + # Timeseries should now also be sorted. + assert np.all(stream.time_series == sorted(timestamps)) + # if diff > 1 and larger 0 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=0) + assert breaks.size == 0 -def test_detect_breaks_gaps_non_monotonic(): - timestamps = [-4, 1, -3, -2, -1, 1, 5, 1, 2] +def test_detect_breaks_non_monotonic_num(): + timestamps = [-4, -5, -3, -2, 0, 0, 1, 2] stream = MockStreamData(timestamps, 1) - # if diff <= 0 or diff > 1 and larger 1 * tdiff -> 5 + stream = _ensure_sorted(1, stream) + # Timeseries data should now also be sorted. + assert np.all(stream.time_series == sorted(timestamps)) + # if diff > 1 and larger 1 * tdiff -> 1 breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1) - assert list(breaks) == [1, 2, 5, 6, 7] - # if diff <= 0 or diff > 2 and larger 1 * tdiff -> 4 - breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=1) - assert list(breaks) == [1, 2, 6, 7] - # if diff <= 0 or diff > 0.1 and larger 0 * tdiff -> 8 + assert breaks.size == 1 + assert breaks[0] == 4 + # if diff > 2 and larger 2 * tdiff -> 0 + breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=2) + assert breaks.size == 0 + # if diff > 0.1 and larger 0 * tdiff -> 6 breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0) - assert breaks.size == len(timestamps) - 1 + assert breaks.size == len(timestamps) - 2 + assert list(breaks) == [1, 2, 3, 4, 6, 7] -def test_detect_breaks_strict_non_monotonic(): - timestamps = [-4, -5, -3, -2, -1, 0, 0, 1, 2] - stream = MockStreamData(timestamps, 1) - # if diff <= 0 or diff > 1 and larger 1 * tdiff -> 3 +def test_detect_breaks_non_monotonic_str(): + timestamps = [-4, -5, -3, -2, 0, 0, 1, 2] + stream = MockStreamData(timestamps, 1, "string") + stream = _ensure_sorted(1, stream) + # Timeseries data should now also be sorted. + assert np.all(stream.time_series == [str(x) for x in sorted(timestamps)]) + # if diff > 1 and larger 1 * tdiff -> 1 breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1) - assert list(breaks) == [1, 2, 6] - # if diff <= 0 or diff > 2 and larger 2 * tdiff -> 2 + assert breaks.size == 1 + assert breaks[0] == 4 + # if diff > 2 and larger 2 * tdiff -> 0 breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=2) - assert list(breaks) == [1, 6] - # if diff <= 0 or diff > 0.1 and larger 0 * tdiff -> 8 + assert breaks.size == 0 + # if diff > 0.1 and larger 0 * tdiff -> 6 breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0) - assert breaks.size == len(timestamps) - 1 + assert breaks.size == len(timestamps) - 2 + assert list(breaks) == [1, 2, 3, 4, 6, 7]