Skip to content

Commit

Permalink
add multi index and time bounds to get_unit_spikes (#1001)
Browse files Browse the repository at this point in the history
* add in_interval arg for get_unit_spike_times and add test
* add multi index for get_unit_spikes and tests
  • Loading branch information
bendichter authored Dec 12, 2019
1 parent ccbf099 commit 2f77d49
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/pynwb/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
except ImportError:
from collections import Iterable # Python 2.7
import warnings
from bisect import bisect_left, bisect_right

from hdmf.utils import docval, getargs, popargs, call_docval_func, get_docval

Expand Down Expand Up @@ -191,11 +192,26 @@ def add_unit(self, **kwargs):
else:
elec_col.table = self.__electrode_table

@docval({'name': 'index', 'type': int,
'doc': 'the index of the unit in unit_ids to retrieve spike times for'})
@docval({'name': 'index', 'type': (int, list, tuple, np.ndarray),
'doc': 'the index of the unit in unit_ids to retrieve spike times for'},
{'name': 'in_interval', 'type': (tuple, list), 'doc': 'only return values within this interval',
'default': None, 'shape': (2,)})
def get_unit_spike_times(self, **kwargs):
index = getargs('index', kwargs)
return np.asarray(self['spike_times'][index])
index, in_interval = getargs('index', 'in_interval', kwargs)
if type(index) in (list, tuple):
return [self.get_unit_spike_times(i, in_interval=in_interval) for i in index]
if in_interval is None:
return np.asarray(self['spike_times'][index])
else:
st = self['spike_times']
unit_start = 0 if index == 0 else st.data[index - 1]
unit_stop = st.data[index]
start_time, stop_time = in_interval

ind_start = bisect_left(st.target, start_time, unit_start, unit_stop)
ind_stop = bisect_right(st.target, stop_time, ind_start, unit_stop)

return np.asarray(st.target[ind_start:ind_stop])

@docval({'name': 'index', 'type': int,
'doc': 'the index of the unit in unit_ids to retrieve observation intervals for'})
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ def test_get_spike_times(self):
self.assertTrue(all(ut.get_unit_spike_times(0) == np.array([0, 1, 2])))
self.assertTrue(all(ut.get_unit_spike_times(1) == np.array([3, 4, 5])))

@staticmethod
def test_get_spike_times_interval():
ut = Units()
ut.add_unit(spike_times=[0, 1, 2])
ut.add_unit(spike_times=[3, 4, 5])
np.testing.assert_array_equal(ut.get_unit_spike_times(0, (.5, 3)), [1, 2])
np.testing.assert_array_equal(ut.get_unit_spike_times(0, (-.5, 1.1)), [0, 1])

def test_get_spike_times_multi(self):
ut = Units()
ut.add_unit(spike_times=[0, 1, 2])
ut.add_unit(spike_times=[3, 4, 5])
np.testing.assert_array_equal(ut.get_unit_spike_times((0, 1)), [[0, 1, 2], [3, 4, 5]])

def test_get_spike_times_multi_interval(self):
ut = Units()
ut.add_unit(spike_times=[0, 1, 2])
ut.add_unit(spike_times=[3, 4, 5])
np.testing.assert_array_equal(ut.get_unit_spike_times((0, 1), (1.5, 3.5)), [[2], [3]])

def test_times(self):
ut = Units()
ut.add_unit(spike_times=[0, 1, 2])
Expand Down

0 comments on commit 2f77d49

Please sign in to comment.