Skip to content

Commit

Permalink
Refactor _jitter_removal and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
agricolab committed Dec 8, 2024
1 parent 4b025c3 commit 0e41c5e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/pyxdf/pyxdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,19 +625,26 @@ def _clock_sync(
return streams


def _jitter_removal(streams, threshold_seconds=1, threshold_samples=500):
def _detect_breaks(stream, threshold_seconds=1.0, threshold_samples=500):
"""Detect breaks in the time_stamps of a stream."""
# Identify breaks in the time_stamps
diffs = np.diff(stream.time_stamps)
b_breaks = np.abs(diffs) > np.max(
(threshold_seconds, threshold_samples * stream.tdiff)
)
# find indices (+ 1 to compensate for lost sample in np.diff)
break_inds = np.where(b_breaks)[0] + 1
return break_inds


def _jitter_removal(streams, threshold_seconds=1.0, threshold_samples=500):
for stream_id, stream in streams.items():
stream.effective_srate = 0 # will be recalculated if possible
nsamples = len(stream.time_stamps)
stream.segments = []
if nsamples > 0 and stream.srate > 0:
# Identify breaks in the time_stamps
diffs = np.diff(stream.time_stamps)
b_breaks = np.abs(diffs) > np.max(
(threshold_seconds, threshold_samples * stream.tdiff)
)
# find indices (+ 1 to compensate for lost sample in np.diff)
break_inds = np.where(b_breaks)[0] + 1
# find break indices
break_inds = _detect_breaks(stream, threshold_seconds, threshold_samples)

# Get indices delimiting segments without breaks
# 0th sample is a segment start and last sample is a segment stop
Expand Down
51 changes: 51 additions & 0 deletions test/test_jitter_removal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from pyxdf.pyxdf import _detect_breaks


class MockStreamData:
def __init__(self, time_stamps, tdiff):
self.time_stamps = time_stamps
self.tdiff = tdiff


def test_detect_no_breaks():
timestamps = list(range(-5, 5))
stream = MockStreamData(timestamps, 1)
# if diff > 2 and larger 500 * tdiff -> 0
breaks = _detect_breaks(stream, threshold_seconds=2, threshold_samples=500)
assert breaks.size == 0
# if diff > 0.1 and larger 1 * tdiff -> 0
breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=1)
assert breaks.size == 0


def test_detect_breaks_reverse():
timestamps = list(reversed(range(-5, 5)))
stream = MockStreamData(timestamps, 1)
breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0)
assert breaks.size == len(timestamps) - 1


def test_detect_breaks_gap_in_negative():
timestamps = [-4, 1, 2, 3, 4]
stream = MockStreamData(timestamps, 1)
breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1)
assert breaks.size == 1
assert breaks[0] == 1
timestamps = [-4, -2, -1, 0, 1, 2, 3, 4]
stream = MockStreamData(timestamps, 1)
breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1)
assert breaks.size == 1
assert breaks[0] == 1


def test_detect_breaks_gap_in_positive():
timestamps = [1, 3, 4, 5, 6]
stream = MockStreamData(timestamps, 1)
# if diff > 1 and larger 1 * tdiff -> 1 -> 1
breaks = _detect_breaks(stream, threshold_seconds=1, threshold_samples=1)
assert breaks.size == 1
assert breaks[0] == 1
# if diff > 0.1 and larger 0 * tdiff ->
breaks = _detect_breaks(stream, threshold_seconds=0.1, threshold_samples=0)
assert breaks.size == len(timestamps) - 1

0 comments on commit 0e41c5e

Please sign in to comment.