Skip to content

Commit

Permalink
Spike sorting rerun (#755)
Browse files Browse the repository at this point in the history
* add wf extraction to SpikeSorting task

* ibl-neuropixel version

* add the sync dataset in the spike sorting loader probe info

* add keys

* add to metrics dict

* bitwise qc label

* reverse the bitwise qc labels so 0 is always passin

* adjust wf extraction call

* rename bitwise qc key

* fix test metrics

* check for pyks import

* add the cell qc computation at the spike sorting stage

* PopeyeDataHandler

* add wf extraction to SpikeSorting task

* ibl-neuropixel version

* add the sync dataset in the spike sorting loader probe info

* add keys

* add to metrics dict

* bitwise qc label

* reverse the bitwise qc labels so 0 is always passin

* adjust wf extraction call

* rename bitwise qc key

* fix test metrics

* check for pyks import

* add the cell qc computation at the spike sorting stage

* PopeyeDataHandler

* update spike sorting plots

* pass when symlink exists

* revert patcher path changes

* patcher patch

* sdsc spikesorting registration

* bugfix: SpikeSortingLoader.raw_electrophysiology regex to match cbin files on SDSC

* random tempdir for pykilosort

* fix ss reg task

* final fix RegisterSpikeSorting

* add low memory option to SS task

* Popeye patcher: allows overriding  SDSC_PATCH_PATH with env variable

* popeye data handler Path bugfix

* revision in sdsc registration

* update revision in constructor

* SDSC DataHandler get patch path from env

* Revert "SDSC DataHandler get patch path from env"

This reverts commit eeea95c.

* pin slidingRP requirement - we can unpin when merging with newer ibllib

* slidingRP

* fix slidingRP bug

* spike sorting if pykilosort is available, do not run subprocess

* flake

* revert regex for pyks find version

* task data handler allows empty input dataset list

* changes to spike sorting loader

* spike sorting loader gets a good_units parameters

---------

Co-authored-by: owinter <[email protected]>
Co-authored-by: Gaelle <[email protected]>
Co-authored-by: chris-langfield <[email protected]>
  • Loading branch information
4 people authored Jun 10, 2024
1 parent ac3f36e commit 0cfb83a
Show file tree
Hide file tree
Showing 10 changed files with 397 additions and 205 deletions.
120 changes: 100 additions & 20 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from dataclasses import dataclass, field
import gc
import logging
import re
import os
from pathlib import Path


import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
Expand All @@ -19,13 +19,14 @@
from neuropixel import TIP_SIZE_UM, trace_header
import spikeglx

import ibldsp.voltage
from iblutil.util import Bunch
from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
from iblatlas.atlas import AllenAtlas, BrainRegions
from iblatlas import atlas
from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
from ibllib.pipes import histology
from ibllib.pipes.ephys_alignment import EphysAlignment
from ibllib.plots import vertical_lines
from ibllib.plots import vertical_lines, Density

import brainbox.plot
from brainbox.io.spikeglx import Streamer
Expand Down Expand Up @@ -916,16 +917,18 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_
if missing == 'raise':
raise e

def download_spike_sorting(self, **kwargs):
def download_spike_sorting(self, objects=None, **kwargs):
"""
Downloads spikes, clusters and channels
:param spike_sorter: (defaults to 'pykilosort')
:param dataset_types: list of extra dataset types
:param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels']
:return:
"""
for obj in ['spikes', 'clusters', 'channels']:
objects = ['spikes', 'clusters', 'channels'] if objects is None else objects
for obj in objects:
self.download_spike_sorting_object(obj=obj, **kwargs)
self.spike_sorting_path = self.files['spikes'][0].parent
self.spike_sorting_path = self.files['clusters'][0].parent

def download_raw_electrophysiology(self, band='ap'):
"""
Expand Down Expand Up @@ -963,7 +966,7 @@ def raw_electrophysiology(self, stream=True, band='ap', **kwargs):
return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs)
else:
raw_data_files = self.download_raw_electrophysiology(band=band)
cbin_file = next(filter(lambda f: f.name.endswith(f'.{band}.cbin'), raw_data_files), None)
cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None)
if cbin_file is not None:
return spikeglx.Reader(cbin_file)

Expand Down Expand Up @@ -999,7 +1002,7 @@ def load_channels(self, **kwargs):
self.histology = 'alf'
return channels

def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs):
def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
"""
Loads spikes, clusters and channels
Expand All @@ -1013,20 +1016,44 @@ def load_spike_sorting(self, spike_sorter='pykilosort', **kwargs):
- traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data
:param spike_sorter: (defaults to 'pykilosort')
:param dataset_types: list of extra dataset types
:param revision: for example "2024-05-06", (defaults to None):
:param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one
:param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates']
:param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table
:param kwargs: additional arguments to be passed to one.api.One.load_object
:return:
"""
if len(self.collections) == 0:
return {}, {}, {}
self.files = {}
self.spike_sorter = spike_sorter
self.download_spike_sorting(spike_sorter=spike_sorter, **kwargs)
channels = self.load_channels(spike_sorter=spike_sorter, **kwargs)
self.revision = revision
objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None
self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs)
channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs)
clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards)
spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards)

if good_units:
spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards)
else:
spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards)
if enforce_version:
self._assert_version_consistency()
return spikes, clusters, channels

def _assert_version_consistency(self):
"""
Makes sure the state of the spike sorting object matches the files downloaded
:return: None
"""
for k in ['spikes', 'clusters', 'channels', 'passingSpikes']:
for fn in self.files.get(k, []):
if self.spike_sorter:
assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \
f"You required strict version {self.spike_sorter}, {fn} does not match"
if self.revision:
assert fn.relative_to(self.session_path).parts[3] == f"#{self.revision}#", \
f"You required strict revision {self.revision}, {fn} does not match"

@staticmethod
def compute_metrics(spikes, clusters=None):
nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size
Expand Down Expand Up @@ -1079,6 +1106,8 @@ def _get_probe_info(self):
if self._sync is None:
timestamps = self.one.load_dataset(
self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}')
_ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks
self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}')
try:
ap_meta = spikeglx.read_meta_data(self.one.load_dataset(
self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}'))
Expand Down Expand Up @@ -1116,7 +1145,13 @@ def samples2times(self, values, direction='forward'):
def pid2ref(self):
return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}"

def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, **kwargs):
def _default_plot_title(self, spikes):
title = f"{self.pid2ref}, {self.pid} \n" \
f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
return title

def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None,
drift=None, title=None, **kwargs):
"""
:param spikes: spikes dictionary or Bunch
:param channels: channels dictionary or Bunch.
Expand All @@ -1138,9 +1173,9 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_
# set default raster plot parameters
kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5}
brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs)
title_str = f"{self.pid2ref}, {self.pid} \n" \
f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters"
axs[0, 0].title.set_text(title_str)
if title is None:
title = self._default_plot_title(spikes)
axs[0, 0].title.set_text(title)
for k, ts in time_series.items():
vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0])
if 'atlas_id' in channels:
Expand All @@ -1150,10 +1185,55 @@ def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_
axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1])
fig.tight_layout()

self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore')
if 'drift' in self.files:
drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
if drift is None:
self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore')
if 'drift' in self.files:
drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards)
if isinstance(drift, dict):
axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5)
axs[0, 0].set(ylim=[-15, 15])

if save_dir is not None:
png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
fig.savefig(png_file)
plt.close(fig)
gc.collect()
else:
return fig, axs

def plot_rawdata_snippet(self, sr, spikes, clusters, t0,
channels=None,
br: BrainRegions = None,
save_dir=None,
label='raster',
gain=-93,
title=None):

# compute the raw data offset and destripe, we take 400ms around t0
first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs))
raw = sr[first_sample:last_sample, :-sr.nsync].T
channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True
destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels)
# filter out the spikes according to good/bad clusters and to the time slice
spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample]))
ss = spikes['samples'][spike_sel]
sc = clusters['channels'][spikes['clusters'][spike_sel]]
sok = clusters['label'][spikes['clusters'][spike_sel]] == 1
if title is None:
title = self._default_plot_title(spikes)
# display the raw data snippet with spikes overlaid
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col')
Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s')
axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5)
axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5)
axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035])
# adds the channel locations if available
if (channels is not None) and ('atlas_id' in channels):
br = br or BrainRegions()
plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
brain_regions=br, display=True, ax=axs[1], title=self.histology)
axs[1].get_yaxis().set_visible(False)
fig.tight_layout()

if save_dir is not None:
png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir)
Expand Down
50 changes: 33 additions & 17 deletions brainbox/metrics/single_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@
'missed_spikes_est': dict(spks_per_bin=10, sigma=4, min_num_bins=50),
'acceptable_contamination': 0.1,
'bin_size': 0.25,
'med_amp_thresh_uv': 50,
'med_amp_thresh_uv': 50, # units below this threshold are considered noise
'min_isi': 0.0001,
'presence_window': 10,
'refractory_period': 0.0015,
'RPslide_thresh': 0.1,
'RPmax_confidence': 90, # a unit needs to pass with at least this confidence percentage (0 - 100)
}


Expand Down Expand Up @@ -942,7 +943,11 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
'presence_ratio',
'presence_ratio_std',
'slidingRP_viol',
'spike_count'
'spike_count',
'slidingRP_viol_forced',
'max_confidence',
'min_contamination',
'n_spikes_below2'
]
if tbounds:
ispi = between_sorted(spike_times, tbounds)
Expand Down Expand Up @@ -982,6 +987,10 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
srp = metrics.slidingRP_all(spikeTimes=spike_times, spikeClusters=spike_clusters,
sampleRate=30000, binSizeCorr=1 / 30000)
r.slidingRP_viol[ir] = srp['value']
r.slidingRP_viol_forced[ir] = srp['value_forced']
r.max_confidence[ir] = srp['max_confidence']
r.min_contamination[ir] = srp['min_contamination']
r.n_spikes_below2 = srp['n_spikes_below2']

# loop over each cluster to compute the rest of the metrics
for ic in np.arange(nclust):
Expand All @@ -1000,29 +1009,36 @@ def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths,
r.missed_spikes_est[ic], _, _ = missed_spikes_est(amps, **params['missed_spikes_est'])
# wonder if there is a need to low-cut this
r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600

r.label = compute_labels(r)
r.label, r.bitwise_fail = compute_labels(r, return_bitwise=True)
return r


def compute_labels(r, params=METRICS_PARAMS, return_details=False):
def compute_labels(r, params=METRICS_PARAMS, return_bitwise=False):
"""
From a dataframe or a dictionary of unit metrics, compute a lablel
From a dataframe or a dictionary of unit metrics, compute a label
:param r: dictionary or pandas dataframe containing unit qcs
:param return_details: False (returns a full dictionary of metrics)
:param return_bitwise: True (returns a full dictionary of metrics)
:return: vector of proportion of qcs passed between 0 and 1, where 1 denotes an all pass
"""
# right now the score is a value between 0 and 1 denoting the proportion of passing qcs
# we could eventually do a bitwise qc
# right now the score is a value between 0 and 1 denoting the proportion of passing qcs,
# where 1 means passing and 0 means failing
labels = np.c_[
r.slidingRP_viol,
r['max_confidence'] >= params['RPmax_confidence'], # this is the least significant bit
r.noise_cutoff < params['noise_cutoff']['nc_threshold'],
r.amp_median > params['med_amp_thresh_uv'] / 1e6,
# add a new metric here on higher significant bits
]
if not return_details:
return np.mean(labels, axis=1)
column_names = ['slidingRP_viol', 'noise_cutoff', 'amp_median']
qcdict = {}
for c in np.arange(labels.shape[1]):
qcdict[column_names[c]] = labels[:, c]
return np.mean(labels, axis=1), qcdict
# The first column takes binary values 001 or 000 to represent fail or pass,
# the second, 010 or 000, the third, 100 or 000 etc.
# The bitwise or "sum" produces 111 if all metrics fail, or 000 if all metrics pass
# All other permutations are also captured, i.e. 110 == 000 || 010 || 100 means
# the second and third metrics failed and the first metric was a pass
score = np.mean(labels, axis=1)
if return_bitwise:
# note the cast to uint8 casts nan to 0
# a nan implies no metrics was computed which we mark as a failure here
n_criteria = labels.shape[1]
bitwise = np.bitwise_or.reduce(2 ** np.arange(n_criteria) * (~ labels.astype(bool)).astype(np.uint8), axis=1)
return score, bitwise.astype(np.uint8)
else:
return score
2 changes: 1 addition & 1 deletion brainbox/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def driftmap(ts, feat, ax=None, plot_style='bincount',
else:
# compute raster map as a function of site depth
R, times, depths = bincount2D(
ts[iok], feat[iok], t_bin, d_bin, weights=weights)
ts[iok], feat[iok], t_bin, d_bin, weights=weights[iok] if weights is not None else None)
# plot raster map
ax.imshow(R, aspect='auto', cmap='binary', vmin=0, vmax=vmax or np.std(R) * 4,
extent=np.r_[times[[0, -1]], depths[[0, -1]]], origin='lower', **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion brainbox/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _assertions(dfm, idf, target_cid):
assert np.allclose(dfm['drift'][idf], np.array(cid) * 100 * 4 * 3.6, rtol=1.1)
assert np.allclose(dfm['firing_rate'][idf], frs, rtol=1.1)
assert np.allclose(dfm['cluster_id'], target_cid)

# test expected bitwise qc values:
expected_labels = 1 - np.sum(np.unpackbits(dfm['bitwise_fail']).reshape(-1, 8), axis=1) / 3
assert np.allclose(dfm['label'], expected_labels)
# check with missing clusters
dfm = quick_unit_metrics(c, t, a, d, cluster_ids=np.arange(5), tbounds=[100, 900])
idf, _ = ismember(np.arange(5), cid)
Expand Down
31 changes: 28 additions & 3 deletions ibllib/oneibl/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def getData(self, one=None):
for file in self.signature['input_files']:
dfs.append(filter_datasets(session_datasets, filename=file[0], collection=file[1],
wildcards=True, assert_unique=False))
if len(dfs) == 0:
return pd.DataFrame()
df = pd.concat(dfs)

# Some cases the eid is stored in the index. If so we drop this level
Expand Down Expand Up @@ -413,23 +415,29 @@ class SDSCDataHandler(DataHandler):
:param signature: input and output file signatures
:param one: ONE instance
"""

def __init__(self, task, session_path, signatures, one=None):
super().__init__(session_path, signatures, one=one)
self.task = task
self.SDSC_PATCH_PATH = SDSC_PATCH_PATH
self.SDSC_ROOT_PATH = SDSC_ROOT_PATH

def setUp(self):
"""Function to create symlinks to necessary data to run tasks."""
df = super().getData()

SDSC_TMP = Path(SDSC_PATCH_PATH.joinpath(self.task.__class__.__name__))
SDSC_TMP = Path(self.SDSC_PATCH_PATH.joinpath(self.task.__class__.__name__))
for i, d in df.iterrows():
file_path = Path(d['session_path']).joinpath(d['rel_path'])
uuid = i
file_uuid = add_uuid_string(file_path, uuid)
file_link = SDSC_TMP.joinpath(file_path)
file_link.parent.mkdir(exist_ok=True, parents=True)
file_link.symlink_to(
Path(SDSC_ROOT_PATH.joinpath(file_uuid)))
try:
file_link.symlink_to(
Path(self.SDSC_ROOT_PATH.joinpath(file_uuid)))
except FileExistsError:
pass

self.task.session_path = SDSC_TMP.joinpath(d['session_path'])

Expand All @@ -448,3 +456,20 @@ def cleanUp(self):
"""Function to clean up symlinks created to run task."""
assert SDSC_PATCH_PATH.parts[0:4] == self.task.session_path.parts[0:4]
shutil.rmtree(self.task.session_path)


class PopeyeDataHandler(SDSCDataHandler):

def __init__(self, task, session_path, signatures, one=None):
super().__init__(task, session_path, signatures, one=one)
self.SDSC_PATCH_PATH = Path(os.getenv('SDSC_PATCH_PATH', "/mnt/sdceph/users/ibl/data/quarantine/tasks/"))
self.SDSC_ROOT_PATH = Path("/mnt/sdceph/users/ibl/data")

def uploadData(self, outputs, version, **kwargs):
raise NotImplementedError(
"Cannot register data from Popeye. Login as Datauser and use the RegisterSpikeSortingSDSC task."
)

def cleanUp(self):
"""Symlinks are preserved until registration."""
pass
Loading

0 comments on commit 0cfb83a

Please sign in to comment.