Skip to content

Commit

Permalink
Ensure data is sorted before sync/dejitter
Browse files Browse the repository at this point in the history
* Revert to existing dejitter segmentation logic

* Handles cases where data is recorded out-of-order
  • Loading branch information
jamieforth committed Dec 9, 2024
1 parent deb4c54 commit 8e5a76d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 29 deletions.
29 changes: 25 additions & 4 deletions src/pyxdf/pyxdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,18 @@ def load_xdf(
f.read(chunklen - 2)

# Concatenate the signal across chunks
for stream in temp.values():
for stream_id, stream in temp.items():
if stream.time_stamps:
# stream with non-empty list of chunks
stream.time_stamps = np.concatenate(stream.time_stamps)
if stream.fmt == "string":
stream.time_series = list(itertools.chain(*stream.time_series))
else:
stream.time_series = np.concatenate(stream.time_series)
# Handle samples that may have arrived out-of-order, sorting
# data by ground truth timestamps if necessary. Identical
# timestamps will remain, but can be handled by dejittering.
stream = _ensure_sorted(stream_id, stream)
else:
# stream without any chunks
stream.time_stamps = np.zeros((0,))
Expand Down Expand Up @@ -534,6 +538,25 @@ def _scan_forward(f):
return False


def _ensure_sorted(stream_id, stream):
diffs = np.diff(stream.time_stamps)
non_strict_inc_count = np.sum(diffs <= 0)
if non_strict_inc_count > 0:
msg = " stream %d not monotonic %d sample(s) out-of-order. Sorting..."
logger.info(msg, stream_id, non_strict_inc_count)
ind = np.argsort(stream.time_stamps, kind="stable")
stream.time_stamps = stream.time_stamps[ind]
if stream.fmt == "string":
stream.time_series = np.array(stream.time_series)[ind].tolist()
else:
stream.time_series = stream.time_series[ind]
identical_timestamp_count = len(diffs) - np.count_nonzero(diffs)
if identical_timestamp_count > 0:
msg = " stream %d contains %d identical timestamp(s)."
logger.info(msg, stream_id, identical_timestamp_count)
return stream


def _clock_sync(
streams,
handle_clock_resets=True,
Expand Down Expand Up @@ -629,9 +652,7 @@ def _detect_breaks(stream, threshold_seconds=1.0, threshold_samples=500):
"""Detect breaks in the time_stamps of a stream."""
# Identify breaks in the time_stamps
diffs = np.diff(stream.time_stamps)
b_breaks = (diffs <= 0) | (
diffs > np.max((threshold_seconds, threshold_samples * stream.tdiff))
)
b_breaks = diffs > np.max((threshold_seconds, threshold_samples * stream.tdiff))
# find indices (+ 1 to compensate for lost sample in np.diff)
break_inds = np.where(b_breaks)[0] + 1
return break_inds
Expand Down
75 changes: 50 additions & 25 deletions test/test_jitter_removal.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from pyxdf.pyxdf import _detect_breaks
import numpy as np
from pyxdf.pyxdf import _detect_breaks, _ensure_sorted


class MockStreamData:
def __init__(self, time_stamps, tdiff):
self.time_stamps = time_stamps
def __init__(self, time_stamps, tdiff, fmt="float32"):
self.time_stamps = np.array(time_stamps)
self.tdiff = tdiff
self.fmt = fmt
if fmt == "string":
self.time_series = [str(x) for x in time_stamps]
else:
self.time_series = np.array(time_stamps, dtype=fmt)


# Monotonic timeseries data.


def test_detect_no_breaks_seconds():
Expand Down Expand Up @@ -33,7 +42,7 @@ def test_detect_breaks_seconds():
timestamps = list(range(-5, 5, 2))
stream = MockStreamData(timestamps, 1)
# if diff > 1 and larger 0 * tdiff -> 4
breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0)
breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=0)
assert breaks.size == len(timestamps) - 1


Expand Down Expand Up @@ -75,37 +84,53 @@ def test_detect_breaks_gap_in_positive():
assert breaks.size == len(timestamps) - 1


# Non-monotonic timeseries data.


def test_detect_breaks_reverse():
timestamps = list(reversed(range(-5, 5)))
stream = MockStreamData(timestamps, 1)
# if diff <= 0 -> 9
breaks = _detect_breaks(stream, threshold_seconds=0, threshold_samples=0)
assert breaks.size == len(timestamps) - 1
stream = _ensure_sorted(1, stream)
# Timeseries should now also be sorted.
assert np.all(stream.time_series == sorted(timestamps))
# if diff > 1 and larger 0 * tdiff -> 0
breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=0)
assert breaks.size == 0


def test_detect_breaks_gaps_non_monotonic():
timestamps = [-4, 1, -3, -2, -1, 1, 5, 1, 2]
def test_detect_breaks_non_monotonic_num():
timestamps = [-4, -5, -3, -2, 0, 0, 1, 2]
stream = MockStreamData(timestamps, 1)
# if diff <= 0 or diff > 1 and larger 1 * tdiff -> 5
stream = _ensure_sorted(1, stream)
# Timeseries data should now also be sorted.
assert np.all(stream.time_series == sorted(timestamps))
# if diff > 1 and larger 1 * tdiff -> 1
breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1)
assert list(breaks) == [1, 2, 5, 6, 7]
# if diff <= 0 or diff > 2 and larger 1 * tdiff -> 4
breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=1)
assert list(breaks) == [1, 2, 6, 7]
# if diff <= 0 or diff > 0.1 and larger 0 * tdiff -> 8
assert breaks.size == 1
assert breaks[0] == 4
# if diff > 2 and larger 2 * tdiff -> 0
breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=2)
assert breaks.size == 0
# if diff > 0.1 and larger 0 * tdiff -> 6
breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0)
assert breaks.size == len(timestamps) - 1
assert breaks.size == len(timestamps) - 2
assert list(breaks) == [1, 2, 3, 4, 6, 7]


def test_detect_breaks_strict_non_monotonic():
timestamps = [-4, -5, -3, -2, -1, 0, 0, 1, 2]
stream = MockStreamData(timestamps, 1)
# if diff <= 0 or diff > 1 and larger 1 * tdiff -> 3
def test_detect_breaks_non_monotonic_str():
timestamps = [-4, -5, -3, -2, 0, 0, 1, 2]
stream = MockStreamData(timestamps, 1, "string")
stream = _ensure_sorted(1, stream)
# Timeseries data should now also be sorted.
assert np.all(stream.time_series == [str(x) for x in sorted(timestamps)])
# if diff > 1 and larger 1 * tdiff -> 1
breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1)
assert list(breaks) == [1, 2, 6]
# if diff <= 0 or diff > 2 and larger 2 * tdiff -> 2
assert breaks.size == 1
assert breaks[0] == 4
# if diff > 2 and larger 2 * tdiff -> 0
breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=2)
assert list(breaks) == [1, 6]
# if diff <= 0 or diff > 0.1 and larger 0 * tdiff -> 8
assert breaks.size == 0
# if diff > 0.1 and larger 0 * tdiff -> 6
breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0)
assert breaks.size == len(timestamps) - 1
assert breaks.size == len(timestamps) - 2
assert list(breaks) == [1, 2, 3, 4, 6, 7]

0 comments on commit 8e5a76d

Please sign in to comment.