Skip to content

Commit

Permalink
Use TimeSeriesReference
Browse files Browse the repository at this point in the history
Use TimeSeriesReference in tests

Add test for backwards compatibility of timeintervals

Update ObjectMapper (not yet working)

Fix test

Demonstrate two approaches to resolving mapping issue

Fix
  • Loading branch information
rly committed Aug 9, 2021
1 parent 8b314cf commit bef1be8
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 21 deletions.
7 changes: 4 additions & 3 deletions src/pynwb/epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from hdmf.data_utils import DataIO

from . import register_class, CORE_NAMESPACE
from .base import TimeSeries
from .base import TimeSeries, TimeSeriesReferenceVectorData, TimeSeriesReference
from hdmf.common import DynamicTable


Expand All @@ -20,7 +20,8 @@ class TimeIntervals(DynamicTable):
{'name': 'start_time', 'description': 'Start time of epoch, in seconds', 'required': True},
{'name': 'stop_time', 'description': 'Stop time of epoch, in seconds', 'required': True},
{'name': 'tags', 'description': 'user-defined tags', 'index': True},
{'name': 'timeseries', 'description': 'index into a TimeSeries object', 'index': True}
{'name': 'timeseries', 'description': 'index into a TimeSeries object', 'index': True,
'class': TimeSeriesReferenceVectorData}
)

@docval({'name': 'name', 'type': str, 'doc': 'name of this TimeIntervals'}, # required
Expand Down Expand Up @@ -51,7 +52,7 @@ def add_interval(self, **kwargs):
tmp = list()
for ts in timeseries:
idx_start, count = self.__calculate_idx_count(start_time, stop_time, ts)
tmp.append((idx_start, count, ts))
tmp.append(TimeSeriesReference(idx_start=idx_start, count=count, timeseries=ts))
timeseries = tmp
rkwargs['timeseries'] = timeseries
return super(TimeIntervals, self).add_row(**rkwargs)
Expand Down
83 changes: 77 additions & 6 deletions src/pynwb/io/epoch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,80 @@
from .. import register_map
# from .. import register_map
#
# from pynwb.base import TimeSeriesReference, TimeSeriesReferenceVectorData
# from pynwb.epoch import TimeIntervals
# from hdmf.build import GroupBuilder, DatasetBuilder, BuildManager
# from hdmf.build.manager import Proxy
# from hdmf.common import VectorData, VectorIndex
# from hdmf.common.io.table import DynamicTableMap
# from hdmf.container import AbstractContainer
# from hdmf.utils import call_docval_func, docval
#

from pynwb.epoch import TimeIntervals
from hdmf.common.io.table import DynamicTableMap
# @register_map(TimeIntervals)
# class TimeIntervalsMap(DynamicTableMap):

# TODO both approaches (override construct, or define columns_carg) work independently (comment one out and use
# the other) but the first approach is hacky since it overrides construct() which is not meant to be public
# and the second approach does not set the object ID of the new TimeSeriesReferenceVectorData to be the same
# as the VectorData on disk. TODO update HDMF construct() to allow alteration of the builder, like a prebuild hook.

@register_map(TimeIntervals)
class TimeIntervalsMap(DynamicTableMap):
pass
# # override construct() to change the neurodata_type of the 'timeseries' column from VectorData to
# # TimeSeriesReferenceVectorData - TODO make this not hacky
# @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder),
# 'doc': 'the builder to construct the AbstractContainer from'},
# {'name': 'manager', 'type': BuildManager, 'doc': 'the BuildManager for this build'},
# {'name': 'parent', 'type': (Proxy, AbstractContainer),
# 'doc': 'the parent AbstractContainer/Proxy for the AbstractContainer being built', 'default': None})
# def construct(self, **kwargs):
# builder = kwargs['builder']
# timeseries_builder = builder.get('timeseries')
# if timeseries_builder.attributes['neurodata_type'] != 'TimeSeriesReferenceVectorData':
# # override builder attributes
# timeseries_builder.attributes['neurodata_type'] = 'TimeSeriesReferenceVectorData'
# timeseries_builder.attributes['namespace'] = 'core'
# obj = call_docval_func(super().construct, kwargs)
# return obj

# @DynamicTableMap.constructor_arg('columns')
# def columns_carg(self, builder, manager):
# # handle case when a TimeIntervals is read with a non-TimeSeriesReferenceVectorData "timeseries" column
# # these data are read completely here (not lazily)
# timeseries_builder = builder.get('timeseries')
# if timeseries_builder.attributes['neurodata_type'] != 'TimeSeriesReferenceVectorData':
# # override builder attributes
# timeseries_builder.attributes['neurodata_type'] = 'TimeSeriesReferenceVectorData'
# timeseries_builder.attributes['namespace'] = 'core'
# # construct new columns list
# columns = list()
# for dset_builder in builder.datasets.values():
# dset_obj = manager.construct(dset_builder)
# # go through only the column datasets and replace the 'timeseries_index' and 'timeseries' columns
# # without changing the order
# if isinstance(dset_obj, VectorData):
# if dset_obj.name == 'timeseries':
# pass
# elif dset_obj.name == 'timeseries_index':
# # TODO do we need to update children?
# new_ts_column = TimeSeriesReferenceVectorData(
# name='timeseries',
# description='index into a TimeSeries object',
# ) # TODO match object ID??
# new_ts_index_column = VectorIndex(
# name='timeseries_index',
# data=list(),
# target=new_ts_column
# ) # TODO match object ID??
# for row in dset_obj:
# new_row = list()
# for tup in row:
# new_row.append(TimeSeriesReference(*tup))
# new_ts_index_column.add_vector(new_row)
#
# columns.append(new_ts_index_column)
# columns.append(new_ts_column)
# else:
# columns.append(dset_obj)
#
# return columns
#
# return None # do not override
29 changes: 29 additions & 0 deletions src/pynwb/testing/make_test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,34 @@ def make_nwbfile_imageseries_no_unit():
_write(test_name, nwbfile)


def make_nwbfile_timeintervals_old_tuple():
"""Create a test file with a TimeIntervals table using the old non-neurodata-type for referencing a TimeSeries."""
nwbfile = NWBFile(session_description='ADDME',
identifier='ADDME',
session_start_time=datetime.now().astimezone())
ts = TimeSeries(
name='test_timeseries',
data=[0, 1, 2, 3, 4],
unit='unit',
rate=1.,
)
nwbfile.add_acquisition(ts)

ti = nwbfile.create_time_intervals(
name='test_intervals',
description='test table'
)
ti.add_interval(
start_time=0.,
stop_time=2.,
tags=[],
timeseries=ts
)

test_name = 'timeintervals_non_ndtype_tsref'
_write(test_name, nwbfile)


if __name__ == '__main__':

if __version__ == '1.1.2':
Expand All @@ -125,3 +153,4 @@ def make_nwbfile_imageseries_no_unit():
make_nwbfile_timeseries_no_unit()
make_nwbfile_imageseries_no_data()
make_nwbfile_imageseries_no_unit()
make_nwbfile_timeintervals_old_tuple()
Binary file not shown.
11 changes: 11 additions & 0 deletions tests/back_compat/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

from pynwb import NWBHDF5IO, validate, TimeSeries
from pynwb.base import TimeSeriesReference
from pynwb.image import ImageSeries
from pynwb.testing import TestCase

Expand Down Expand Up @@ -67,3 +68,13 @@ def test_read_imageseries_no_unit(self):
with NWBHDF5IO(str(f), 'r') as io:
read_nwbfile = io.read()
self.assertEqual(read_nwbfile.acquisition['test_imageseries'].unit, ImageSeries.DEFAULT_UNIT)

def test_read_timeintervals_non_ndtype_tsref(self):
"""Test that a TimeIntervals written without using TimeSeriesReference is read correctly."""
f = Path(__file__).parent / '1.5.1_timeintervals_non_ndtype_tsref.nwb'
with NWBHDF5IO(str(f), 'r') as io:
read_nwbfile = io.read()
ts = read_nwbfile.acquisition['test_timeseries']
ret = read_nwbfile.intervals['test_intervals'][0, 'timeseries']
exp = [TimeSeriesReference(0, 2, ts)]
self.assertEqual(ret, exp)
17 changes: 9 additions & 8 deletions tests/integration/hdf5/test_nwbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from hdmf.common import DynamicTable

from pynwb import NWBFile, TimeSeries, NWBHDF5IO, get_manager
from pynwb.base import TimeSeriesReference
from pynwb.file import Subject
from pynwb.epoch import TimeIntervals
from pynwb.ecephys import ElectricalSeries
Expand Down Expand Up @@ -275,10 +276,10 @@ def addContainer(self, nwbfile):
'bar': ['fish', 'fowl', 'dog', 'cat'],
'start_time': [0.2, 0.25, 0.30, 0.35],
'stop_time': [0.25, 0.30, 0.40, 0.45],
'timeseries': [[(2, 1, tsa)],
[(3, 1, tsa)],
[(3, 1, tsa)],
[(4, 1, tsa)]],
'timeseries': [[TimeSeriesReference(2, 1, tsa)],
[TimeSeriesReference(3, 1, tsa)],
[TimeSeriesReference(3, 1, tsa)],
[TimeSeriesReference(4, 1, tsa)]],
'tags': [[''], [''], ['fizz', 'buzz'], ['qaz']]
}),
'epochs',
Expand All @@ -305,10 +306,10 @@ def test_df_comparison(self):
'bar': ['fish', 'fowl', 'dog', 'cat'],
'start_time': [0.2, 0.25, 0.30, 0.35],
'stop_time': [0.25, 0.30, 0.40, 0.45],
'timeseries': [[(2, 1, tsa)],
[(3, 1, tsa)],
[(3, 1, tsa)],
[(4, 1, tsa)]],
'timeseries': [[TimeSeriesReference(2, 1, tsa)],
[TimeSeriesReference(3, 1, tsa)],
[TimeSeriesReference(3, 1, tsa)],
[TimeSeriesReference(4, 1, tsa)]],
'tags': [[''], [''], ['fizz', 'buzz'], ['qaz']]
},
index=pd.Index(np.arange(4), name='id')
Expand Down
12 changes: 8 additions & 4 deletions tests/unit/test_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from datetime import datetime
from dateutil import tz

from pynwb.epoch import TimeIntervals
from pynwb import TimeSeries, NWBFile
from pynwb.base import TimeSeriesReference
from pynwb.epoch import TimeIntervals
from pynwb.testing import TestCase


Expand Down Expand Up @@ -35,8 +36,11 @@ def get_dataframe(self):
'foo': [1, 2, 3, 4],
'bar': ['fish', 'fowl', 'dog', 'cat'],
'start_time': [0.2, 0.25, 0.30, 0.35],
'stop_time': [0.25, 0.30, 0.40, 0.45],
'timeseries': [[tsa], [tsb], [], [tsb, tsa]],
'stop_time': [0.25, 0.30, 0.40, 0.1],
'timeseries': [[TimeSeriesReference(2, 1, tsa)],
[TimeSeriesReference(1, 0, tsb)],
[],
[TimeSeriesReference(1, 2, tsb), TimeSeriesReference(4, 6, tsa)]],
'keys': ['q', 'w', 'e', 'r'],
'tags': [[], [], ['fizz', 'buzz'], ['qaz']]
})
Expand All @@ -46,7 +50,7 @@ def test_dataframe_roundtrip(self):
epochs = TimeIntervals.from_dataframe(df, name='test epochs')
obtained = epochs.to_dataframe()

self.assertIs(obtained.loc[3, 'timeseries'][1], df.loc[3, 'timeseries'][1])
self.assertEqual(obtained.loc[3, 'timeseries'][1], df.loc[3, 'timeseries'][1])
self.assertEqual(obtained.loc[2, 'foo'], df.loc[2, 'foo'])

def test_dataframe_roundtrip_drop_ts(self):
Expand Down

0 comments on commit bef1be8

Please sign in to comment.