Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Aug 13, 2024
1 parent 0016522 commit 4967b4d
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 169 deletions.
117 changes: 0 additions & 117 deletions brainbox/behavior/wheel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
"""
Set of functions to handle wheel data.
"""
import logging
import warnings
import traceback

import numpy as np
from numpy import pi
from iblutil.numerical import between_sorted
Expand Down Expand Up @@ -68,42 +64,6 @@ def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None
return yinterp, t


def velocity(re_ts, re_pos):
"""
(DEPRECATED) Compute wheel velocity from non-uniformly sampled wheel data. Returns the velocity
at the same samples locations as the position through interpolation.
Parameters
----------
re_ts : array_like
Array of timestamps
re_pos: array_like
Array of unwrapped wheel positions
Returns
-------
np.ndarray
numpy array of velocities
"""
for line in traceback.format_stack():
print(line.strip())

msg = 'brainbox.behavior.wheel.velocity will soon be removed. Use velocity_filtered instead.'
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

dp = np.diff(re_pos)
dt = np.diff(re_ts)
# Compute raw velocity
vel = dp / dt
# Compute velocity time scale
tv = re_ts[:-1] + dt / 2
# interpolate over original time scale
if tv.size > 1:
ifcn = interpolate.interp1d(tv, vel, fill_value="extrapolate")
return ifcn(re_ts)


def velocity_filtered(pos, fs, corner_frequency=20, order=8):
"""
Compute wheel velocity from uniformly sampled wheel data.
Expand All @@ -130,83 +90,6 @@ def velocity_filtered(pos, fs, corner_frequency=20, order=8):
return vel, acc


def velocity_smoothed(pos, freq, smooth_size=0.03):
"""
(DEPRECATED) Compute wheel velocity from uniformly sampled wheel data.
Parameters
----------
pos : array_like
Array of wheel positions
smooth_size : float
Size of Gaussian smoothing window in seconds
freq : float
Sampling frequency of the data
Returns
-------
vel : np.ndarray
Array of velocity values
acc : np.ndarray
Array of acceleration values
"""
for line in traceback.format_stack():
print(line.strip())

msg = 'brainbox.behavior.wheel.velocity_smoothed will be removed. Use velocity_filtered instead.'
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

# Define our smoothing window with an area of 1 so the units won't be changed
std_samps = np.round(smooth_size * freq) # Standard deviation relative to sampling frequency
N = std_samps * 6 # Number of points in the Gaussian covering +/-3 standard deviations
gauss_std = (N - 1) / 6
win = scipy.signal.windows.gaussian(N, gauss_std)
win = win / win.sum() # Normalize amplitude

# Convolve and multiply by sampling frequency to restore original units
vel = np.insert(scipy.signal.convolve(np.diff(pos), win, mode='same'), 0, 0) * freq
acc = np.insert(scipy.signal.convolve(np.diff(vel), win, mode='same'), 0, 0) * freq

return vel, acc


def last_movement_onset(t, vel, event_time):
"""
(DEPRECATED) Find the time at which movement started, given an event timestamp that occurred during the
movement.
Movement start is defined as the first sample after the velocity has been zero for at least 50ms.
Wheel inputs should be evenly sampled.
:param t: numpy array of wheel timestamps in seconds
:param vel: numpy array of wheel velocities
:param event_time: timestamp anywhere during movement of interest, e.g. peak velocity
:return: timestamp of movement onset
"""
for line in traceback.format_stack():
print(line.strip())

msg = 'brainbox.behavior.wheel.last_movement_onset has been deprecated. Use get_movement_onset instead.'
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

# Look back from timestamp
threshold = 50e-3
mask = t < event_time
times = t[mask]
vel = vel[mask]
t = None # Initialize
for i, t in enumerate(times[::-1]):
i = times.size - i
idx = np.min(np.where((t - times) < threshold))
if np.max(np.abs(vel[idx:i])) < 0.5:
break

# Return timestamp
return t


def get_movement_onset(intervals, event_times):
"""
Find the time at which movement started, given an event timestamp that occurred during the
Expand Down
6 changes: 0 additions & 6 deletions brainbox/tests/test_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,6 @@ def test_get_movement_onset(self):
with self.assertRaises(ValueError):
wheel.get_movement_onset(intervals, np.random.permutation(self.trials['feedback_times']))

def test_velocity_deprecation(self):
"""Ensure brainbox.behavior.wheel.velocity is removed."""
from datetime import datetime
self.assertTrue(datetime.today() < datetime(2024, 8, 1),
'remove brainbox.behavior.wheel.velocity, velocity_smoothed and last_movement_onset')


class TestTraining(unittest.TestCase):
def setUp(self):
Expand Down
36 changes: 0 additions & 36 deletions ibllib/oneibl/stream.py

This file was deleted.

5 changes: 4 additions & 1 deletion ibllib/pipes/dynamic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def _get_trials_tasks(session_path, acquisition_description=None, sync_tasks=Non
_logger.debug('%s (protocol #%i, task #%i) = %s.%s',
protocol, i, j, task.__module__, task.__name__)
# Rename the class to something more informative
task_name = f'Trials_{task.__name__}_{i:02}'
task_name = f'{task.__name__}_{i:02}'
if not (task.__name__.startswith('TrainingStatus') or task.__name__.endswith('RegisterRaw')):
task_name = f'Trials_{task_name}'
# For now we assume that the second task in the list is always the trials extractor, which is dependent
# on the sync task and sync arguments
if j == 1:
Expand Down Expand Up @@ -413,6 +415,7 @@ def make_pipeline(session_path, **pkwargs):

# Syncing tasks
(sync, sync_args), = acquisition_description['sync'].items()
sync_args = sync_args.copy() # ensure acquisition_description unchanged
sync_label = _sync_label(sync, **sync_args) # get the format of the DAQ data. This informs the extractor task
sync_args['sync_collection'] = sync_args.pop('collection') # rename the key so it matches task run arguments
sync_args['sync_ext'] = sync_args.pop('extension', None)
Expand Down
2 changes: 2 additions & 0 deletions ibllib/pipes/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def _input_files_to_register(self, assert_all_exist=False):
-------
list of pathlib.Path
A list of input files to register.
# TODO This method currently does not support wildcards
"""
try:
input_files = self.input_files
Expand Down
9 changes: 0 additions & 9 deletions ibllib/tests/test_oneibl.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,14 +590,5 @@ def test_server_upload_data(self, register_dataset_mock):
self.assertDictEqual(expected, handler.processed)


class TestStream(unittest.TestCase):
"""Test for oneibl.stream module."""

def test_deprecation(self):
"""Ensure oneibl.stream module removed."""
from datetime import datetime
self.assertTrue(datetime.today() < datetime(2024, 8, 1), 'remove oneibl.stream module')


if __name__ == '__main__':
unittest.main()

0 comments on commit 4967b4d

Please sign in to comment.