Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpikeSortingLoader.raw_waveforms() #823

Merged
merged 4 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 0 additions & 117 deletions brainbox/behavior/wheel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
"""
Set of functions to handle wheel data.
"""
import logging
import warnings
import traceback

import numpy as np
from numpy import pi
from iblutil.numerical import between_sorted
Expand Down Expand Up @@ -68,42 +64,6 @@ def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None
return yinterp, t


def velocity(re_ts, re_pos):
"""
(DEPRECATED) Compute wheel velocity from non-uniformly sampled wheel data. Returns the velocity
at the same samples locations as the position through interpolation.
Parameters
----------
re_ts : array_like
Array of timestamps
re_pos: array_like
Array of unwrapped wheel positions
Returns
-------
np.ndarray
numpy array of velocities
"""
for line in traceback.format_stack():
print(line.strip())

msg = 'brainbox.behavior.wheel.velocity will soon be removed. Use velocity_filtered instead.'
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

dp = np.diff(re_pos)
dt = np.diff(re_ts)
# Compute raw velocity
vel = dp / dt
# Compute velocity time scale
tv = re_ts[:-1] + dt / 2
# interpolate over original time scale
if tv.size > 1:
ifcn = interpolate.interp1d(tv, vel, fill_value="extrapolate")
return ifcn(re_ts)


def velocity_filtered(pos, fs, corner_frequency=20, order=8):
"""
Compute wheel velocity from uniformly sampled wheel data.
Expand All @@ -130,83 +90,6 @@ def velocity_filtered(pos, fs, corner_frequency=20, order=8):
return vel, acc


def velocity_smoothed(pos, freq, smooth_size=0.03):
"""
(DEPRECATED) Compute wheel velocity from uniformly sampled wheel data.
Parameters
----------
pos : array_like
Array of wheel positions
smooth_size : float
Size of Gaussian smoothing window in seconds
freq : float
Sampling frequency of the data
Returns
-------
vel : np.ndarray
Array of velocity values
acc : np.ndarray
Array of acceleration values
"""
for line in traceback.format_stack():
print(line.strip())

msg = 'brainbox.behavior.wheel.velocity_smoothed will be removed. Use velocity_filtered instead.'
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

# Define our smoothing window with an area of 1 so the units won't be changed
std_samps = np.round(smooth_size * freq) # Standard deviation relative to sampling frequency
N = std_samps * 6 # Number of points in the Gaussian covering +/-3 standard deviations
gauss_std = (N - 1) / 6
win = scipy.signal.windows.gaussian(N, gauss_std)
win = win / win.sum() # Normalize amplitude

# Convolve and multiply by sampling frequency to restore original units
vel = np.insert(scipy.signal.convolve(np.diff(pos), win, mode='same'), 0, 0) * freq
acc = np.insert(scipy.signal.convolve(np.diff(vel), win, mode='same'), 0, 0) * freq

return vel, acc


def last_movement_onset(t, vel, event_time):
"""
(DEPRECATED) Find the time at which movement started, given an event timestamp that occurred during the
movement.
Movement start is defined as the first sample after the velocity has been zero for at least 50ms.
Wheel inputs should be evenly sampled.
:param t: numpy array of wheel timestamps in seconds
:param vel: numpy array of wheel velocities
:param event_time: timestamp anywhere during movement of interest, e.g. peak velocity
:return: timestamp of movement onset
"""
for line in traceback.format_stack():
print(line.strip())

msg = 'brainbox.behavior.wheel.last_movement_onset has been deprecated. Use get_movement_onset instead.'
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

# Look back from timestamp
threshold = 50e-3
mask = t < event_time
times = t[mask]
vel = vel[mask]
t = None # Initialize
for i, t in enumerate(times[::-1]):
i = times.size - i
idx = np.min(np.where((t - times) < threshold))
if np.max(np.abs(vel[idx:i])) < 0.5:
break

# Return timestamp
return t


def get_movement_onset(intervals, event_times):
"""
Find the time at which movement started, given an event timestamp that occurred during the
Expand Down
16 changes: 16 additions & 0 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import spikeglx

import ibldsp.voltage
from ibldsp.waveform_extraction import WaveformsLoader
from iblutil.util import Bunch
from iblatlas.atlas import AllenAtlas, BrainRegions
from iblatlas import atlas
Expand Down Expand Up @@ -975,6 +976,21 @@ def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
if cbin_file is not None:
return spikeglx.Reader(cbin_file)

def download_raw_waveforms(self, **kwargs):
"""
Downloads raw waveforms extracted from sorting to local disk.
"""
_logger.debug(f"loading waveforms from {self.collection}")
return self.one.load_object(
self.eid, "waveforms",
attribute=["traces", "templates", "table", "channels"],
collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs
)

def raw_waveforms(self, **kwargs):
wf_paths = self.download_raw_waveforms(**kwargs)
return WaveformsLoader(wf_paths[0].parent, wfs_dtype=np.float16)

def load_channels(self, **kwargs):
"""
Loads channels
Expand Down
6 changes: 0 additions & 6 deletions brainbox/tests/test_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,6 @@ def test_get_movement_onset(self):
with self.assertRaises(ValueError):
wheel.get_movement_onset(intervals, np.random.permutation(self.trials['feedback_times']))

def test_velocity_deprecation(self):
"""Ensure brainbox.behavior.wheel.velocity is removed."""
from datetime import datetime
self.assertTrue(datetime.today() < datetime(2024, 8, 1),
'remove brainbox.behavior.wheel.velocity, velocity_smoothed and last_movement_onset')


class TestTraining(unittest.TestCase):
def setUp(self):
Expand Down
36 changes: 0 additions & 36 deletions ibllib/oneibl/stream.py

This file was deleted.

9 changes: 0 additions & 9 deletions ibllib/tests/test_oneibl.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,14 +590,5 @@ def test_server_upload_data(self, register_dataset_mock):
self.assertDictEqual(expected, handler.processed)


class TestStream(unittest.TestCase):
"""Test for oneibl.stream module."""

def test_deprecation(self):
"""Ensure oneibl.stream module removed."""
from datetime import datetime
self.assertTrue(datetime.today() < datetime(2024, 8, 1), 'remove oneibl.stream module')


if __name__ == '__main__':
unittest.main()
Loading