diff --git a/CHANGELOG.md b/CHANGELOG.md index 98a771273..bd562619f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ - Added support for NWB 2.5.0. - Added support for updated ``IndexSeries`` type, new ``order_of_images`` field in ``Images``, and new neurodata_type ``ImageReferences``. @rly (#1483) +- Added support for HDMF 3.3.1. This is now the minimum version of HDMF supported. Importantly, HDMF 3.3 introduces + warnings when the constructor of a class mapped to an HDMF-common data type or an autogenerated data type class + is passed positional arguments instead of all keyword arguments. @rly (#1484) ### Documentation and tutorial enhancements: - Added tutorial on annotating data via ``TimeIntervals``. @oruebel (#1390) diff --git a/docs/gallery/general/extensions.py b/docs/gallery/general/extensions.py index 206dbed85..c7314e466 100644 --- a/docs/gallery/general/extensions.py +++ b/docs/gallery/general/extensions.py @@ -104,7 +104,7 @@ from pynwb import register_class, load_namespaces from pynwb.ecephys import ElectricalSeries -from hdmf.utils import docval, call_docval_func, getargs, get_docval +from hdmf.utils import docval, get_docval, popargs ns_path = "mylab.namespace.yaml" load_namespaces(ns_path) @@ -118,16 +118,16 @@ class TetrodeSeries(ElectricalSeries): @docval(*get_docval(ElectricalSeries.__init__) + ( {'name': 'trode_id', 'type': int, 'doc': 'the tetrode id'},)) def __init__(self, **kwargs): - call_docval_func(super(TetrodeSeries, self).__init__, kwargs) - self.trode_id = getargs('trode_id', kwargs) + trode_id = popargs('trode_id', kwargs) + super().__init__(**kwargs) + self.trode_id = trode_id #################### # .. note:: # -# See the API docs for more information about :py:func:`~hdmf.utils.docval` -# :py:func:`~hdmf.utils.call_docval_func`, :py:func:`~hdmf.utils.getargs` -# and :py:func:`~hdmf.utils.get_docval` +# See the API docs for more information about :py:func:`~hdmf.utils.docval`, +# :py:func:`~hdmf.utils.popargs`, and :py:func:`~hdmf.utils.get_docval` # # When extending :py:class:`~pynwb.core.NWBContainer` or :py:class:`~pynwb.core.NWBContainer` # subclasses, you should define the class field ``__nwbfields__``. This will @@ -301,7 +301,7 @@ class Potato(NWBContainer): {'name': 'weight', 'type': float, 'doc': 'weight of potato in grams'}, {'name': 'age', 'type': float, 'doc': 'age of potato in days'}) def __init__(self, **kwargs): - super(Potato, self).__init__(name=kwargs['name']) + super().__init__(name=kwargs['name']) self.weight = kwargs['weight'] self.age = kwargs['age'] diff --git a/environment-ros3.yml b/environment-ros3.yml index 23518060c..d22a1a2f2 100644 --- a/environment-ros3.yml +++ b/environment-ros3.yml @@ -6,7 +6,7 @@ channels: dependencies: - python=3.9 - h5py==3.6.0 - - hdmf==3.1.1 + - hdmf==3.3.1 - matplotlib==3.5.1 - numpy==1.21.0 - pandas==1.3.0 diff --git a/requirements-min.txt b/requirements-min.txt index f7adc6641..74671d51f 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -1,6 +1,6 @@ # minimum versions of package dependencies for installing PyNWB h5py==2.10 # support for selection of datasets with list of indices added in 2.10 -hdmf==3.1.1 +hdmf==3.3.1 numpy==1.16 pandas==1.0.5 python-dateutil==2.7 diff --git a/requirements.txt b/requirements.txt index 31b8d102e..2f8f7fd55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # pinned dependencies to reproduce an entire development environment to use PyNWB h5py==3.3.0 -hdmf==3.1.1 +hdmf==3.3.1 numpy==1.21.0 pandas==1.3.0 python-dateutil==2.8.1 diff --git a/setup.cfg b/setup.cfg index d20424c74..51294ec81 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ per-file-ignores = setup.py:T001 test.py:T001 scripts/*:T001 +extend-ignore = E203 [metadata] description-file = README.rst diff --git a/src/pynwb/__init__.py b/src/pynwb/__init__.py index 6cc1499f6..57a508eec 100644 --- a/src/pynwb/__init__.py +++ b/src/pynwb/__init__.py @@ -8,7 +8,7 @@ import h5py from hdmf.spec import NamespaceCatalog -from hdmf.utils import docval, getargs, popargs, call_docval_func, get_docval +from hdmf.utils import docval, getargs, popargs, get_docval from hdmf.backends.io import HDMFIO from hdmf.backends.hdf5 import HDF5IO as _HDF5IO from hdmf.validate import ValidatorMap @@ -87,7 +87,7 @@ def get_manager(**kwargs): Get a BuildManager to use for I/O using the given extensions. If no extensions are provided, return a BuildManager that uses the core namespace ''' - type_map = call_docval_func(get_type_map, kwargs) + type_map = get_type_map(**kwargs) return BuildManager(type_map) @@ -227,13 +227,13 @@ def __init__(self, **kwargs): raise ValueError("cannot load namespaces from file when writing to it") tm = get_type_map() - super(NWBHDF5IO, self).load_namespaces(tm, path, file=file_obj, driver=driver) + super().load_namespaces(tm, path, file=file_obj, driver=driver) manager = BuildManager(tm) # XXX: Leaving this here in case we want to revert to this strategy for # loading cached namespaces # ns_catalog = NamespaceCatalog(NWBGroupSpec, NWBDatasetSpec, NWBNamespace) - # super(NWBHDF5IO, self).load_namespaces(ns_catalog, path) + # super().load_namespaces(ns_catalog, path) # tm = TypeMap(ns_catalog) # tm.copy_mappers(get_type_map()) else: @@ -243,7 +243,7 @@ def __init__(self, **kwargs): manager = get_manager(extensions=extensions) elif manager is None: manager = get_manager() - super(NWBHDF5IO, self).__init__(path, manager=manager, mode=mode, file=file_obj, comm=comm, driver=driver) + super().__init__(path, manager=manager, mode=mode, file=file_obj, comm=comm, driver=driver) @docval({'name': 'src_io', 'type': HDMFIO, 'doc': 'the HDMFIO object (such as NWBHDF5IO) that was used to read the data to export'}, @@ -287,7 +287,7 @@ def export(self, **kwargs): """ nwbfile = popargs('nwbfile', kwargs) kwargs['container'] = nwbfile - call_docval_func(super().export, kwargs) + super().export(**kwargs) from . import io as __io # noqa: F401,E402 diff --git a/src/pynwb/base.py b/src/pynwb/base.py index 625b4fe98..4d9bddc5c 100644 --- a/src/pynwb/base.py +++ b/src/pynwb/base.py @@ -4,7 +4,7 @@ import numpy as np -from hdmf.utils import docval, getargs, popargs, call_docval_func, get_docval +from hdmf.utils import docval, popargs_to_dict, get_docval, popargs from hdmf.common import DynamicTable, VectorData from hdmf.utils import get_data_shape @@ -33,50 +33,44 @@ class ProcessingModule(MultiContainerInterface): @docval({'name': 'name', 'type': str, 'doc': 'The name of this processing module'}, {'name': 'description', 'type': str, 'doc': 'Description of this processing module'}, {'name': 'data_interfaces', 'type': (list, tuple, dict), - 'doc': 'NWBDataInterfacess that belong to this ProcessingModule', 'default': None}) + 'doc': 'NWBDataInterfaces that belong to this ProcessingModule', 'default': None}) def __init__(self, **kwargs): - call_docval_func(super(ProcessingModule, self).__init__, kwargs) - self.description = popargs('description', kwargs) - self.data_interfaces = popargs('data_interfaces', kwargs) + description, data_interfaces = popargs("description", "data_interfaces", kwargs) + super().__init__(**kwargs) + self.description = description + self.data_interfaces = data_interfaces @property def containers(self): return self.data_interfaces - def __getitem__(self, arg): - return self.get(arg) - @docval({'name': 'container', 'type': (NWBDataInterface, DynamicTable), 'doc': 'the NWBDataInterface to add to this Module'}) def add_container(self, **kwargs): ''' Add an NWBContainer to this ProcessingModule ''' - container = getargs('container', kwargs) warn(PendingDeprecationWarning('add_container will be replaced by add')) - self.add(container) + self.add(kwargs['container']) @docval({'name': 'container_name', 'type': str, 'doc': 'the name of the NWBContainer to retrieve'}) def get_container(self, **kwargs): ''' Retrieve an NWBContainer from this ProcessingModule ''' - container_name = getargs('container_name', kwargs) warn(PendingDeprecationWarning('get_container will be replaced by get')) - return self.get(container_name) + return self.get(kwargs['container_name']) @docval({'name': 'NWBDataInterface', 'type': (NWBDataInterface, DynamicTable), 'doc': 'the NWBDataInterface to add to this Module'}) def add_data_interface(self, **kwargs): - NWBDataInterface = getargs('NWBDataInterface', kwargs) warn(PendingDeprecationWarning('add_data_interface will be replaced by add')) - self.add(NWBDataInterface) + self.add(kwargs['NWBDataInterface']) @docval({'name': 'data_interface_name', 'type': str, 'doc': 'the name of the NWBContainer to retrieve'}) def get_data_interface(self, **kwargs): - data_interface_name = getargs('data_interface_name', kwargs) warn(PendingDeprecationWarning('get_data_interface will be replaced by get')) - return self.get(data_interface_name) + return self.get(kwargs['data_interface_name']) @register_class('TimeSeries', CORE_NAMESPACE) @@ -153,20 +147,27 @@ class TimeSeries(NWBDataInterface): def __init__(self, **kwargs): """Create a TimeSeries object """ + keys_to_set = ("starting_time", + "rate", + "resolution", + "comments", + "description", + "conversion", + "offset", + "unit", + "control", + "control_description", + "continuity") + args_to_set = popargs_to_dict(keys_to_set, kwargs) + keys_to_process = ("data", "timestamps") # these are properties and cannot be set with setattr + args_to_process = popargs_to_dict(keys_to_process, kwargs) + super().__init__(**kwargs) - call_docval_func(super(TimeSeries, self).__init__, kwargs) - keys = ("resolution", - "comments", - "description", - "conversion", - "offset", - "unit", - "control", - "control_description", - "continuity") - - data_shape = get_data_shape(data=kwargs["data"], strict_no_data_load=True) - timestamps_shape = get_data_shape(data=kwargs["timestamps"], strict_no_data_load=True) + for key, val in args_to_set.items(): + setattr(self, key, val) + + data_shape = get_data_shape(data=args_to_process["data"], strict_no_data_load=True) + timestamps_shape = get_data_shape(data=args_to_process["timestamps"], strict_no_data_load=True) if ( # check that the shape is known data_shape is not None and timestamps_shape is not None @@ -182,32 +183,25 @@ def __init__(self, **kwargs): ): warn("Length of data does not match length of timestamps. Your data may be transposed. Time should be on " "the 0th dimension") - for key in keys: - val = kwargs.get(key) - if val is not None: - setattr(self, key, val) - data = getargs('data', kwargs) + data = args_to_process['data'] self.fields['data'] = data + if isinstance(data, TimeSeries): + data.__add_link('data_link', self) - timestamps = kwargs.get('timestamps') - starting_time = kwargs.get('starting_time') - rate = kwargs.get('rate') + timestamps = args_to_process['timestamps'] if timestamps is not None: - if rate is not None: + if self.rate is not None: raise ValueError('Specifying rate and timestamps is not supported.') - if starting_time is not None: + if self.starting_time is not None: raise ValueError('Specifying starting_time and timestamps is not supported.') self.fields['timestamps'] = timestamps self.timestamps_unit = self.__time_unit self.interval = 1 if isinstance(timestamps, TimeSeries): timestamps.__add_link('timestamp_link', self) - elif rate is not None: - self.rate = rate - if starting_time is not None: - self.starting_time = starting_time - else: + elif self.rate is not None: + if self.starting_time is None: # override default if rate is provided but not starting time self.starting_time = 0.0 self.starting_time_unit = self.__time_unit else: @@ -235,14 +229,16 @@ def no_len_warning(attr): else: warn(no_len_warning('data'), UserWarning) - if hasattr(self, 'timestamps'): - if hasattr(self.timestamps, '__len__'): - try: - return len(self.timestamps) - except TypeError: - warn(unreadable_warning('timestamps'), UserWarning) - elif not (hasattr(self, 'rate') and hasattr(self, 'starting_time')): - warn(no_len_warning('timestamps'), UserWarning) + # only get here if self.data has no __len__ or __len__ is unreadable + if hasattr(self.timestamps, '__len__'): + try: + return len(self.timestamps) + except TypeError: + warn(unreadable_warning('timestamps'), UserWarning) + elif self.rate is None and self.starting_time is None: + warn(no_len_warning('timestamps'), UserWarning) + + return None @property def data(self): @@ -270,8 +266,7 @@ def timestamp_link(self): def __get_links(self, links): ret = self.fields.get(links, list()) - if ret is not None: - ret = set(ret) + ret = set(ret) return ret def __add_link(self, links_key, link): @@ -296,9 +291,11 @@ class Image(NWBData): {'name': 'resolution', 'type': 'float', 'doc': 'pixels / cm', 'default': None}, {'name': 'description', 'type': str, 'doc': 'description of image', 'default': None}) def __init__(self, **kwargs): - call_docval_func(super(Image, self).__init__, kwargs) - self.resolution = kwargs['resolution'] - self.description = kwargs['description'] + args_to_set = popargs_to_dict(("resolution", "description"), kwargs) + super().__init__(**kwargs) + + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('ImageReferences', CORE_NAMESPACE) @@ -343,11 +340,11 @@ class Images(MultiContainerInterface): {'name': 'order_of_images', 'type': 'ImageReferences', 'doc': 'Ordered dataset of references to Image objects stored in the parent group.', 'default': None},) def __init__(self, **kwargs): - name, description, images, order_of_images = popargs('name', 'description', 'images', 'order_of_images', kwargs) - super(Images, self).__init__(name, **kwargs) - self.description = description - self.images = images - self.order_of_images = order_of_images + + args_to_set = popargs_to_dict(("description", "images", "order_of_images"), kwargs) + super().__init__(**kwargs) + for key, val in args_to_set.items(): + setattr(self, key, val) class TimeSeriesReference(NamedTuple): @@ -516,7 +513,7 @@ class TimeSeriesReferenceVectorData(VectorData): "to be selected as well as an object reference to the TimeSeries."}, *get_docval(VectorData.__init__, 'data')) def __init__(self, **kwargs): - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) # CAUTION: Define any logic specific for init in the self._init_internal function, not here! self._init_internal() @@ -535,8 +532,8 @@ def _init_internal(self): 'must be convertible to a TimeSeriesReference'}) def add_row(self, **kwargs): """Append a data value to this column.""" - val = getargs('val', kwargs) - if not (isinstance(val, self.TIME_SERIES_REFERENCE_TUPLE)): + val = kwargs['val'] + if not isinstance(val, self.TIME_SERIES_REFERENCE_TUPLE): val = self.TIME_SERIES_REFERENCE_TUPLE(*val) val.check_types() super().append(val) @@ -546,8 +543,8 @@ def add_row(self, **kwargs): 'must be convertible to a TimeSeriesReference'}) def append(self, **kwargs): """Append a data value to this column.""" - arg = getargs('arg', kwargs) - if not (isinstance(arg, self.TIME_SERIES_REFERENCE_TUPLE)): + arg = kwargs['arg'] + if not isinstance(arg, self.TIME_SERIES_REFERENCE_TUPLE): arg = self.TIME_SERIES_REFERENCE_TUPLE(*arg) arg.check_types() super().append(arg) diff --git a/src/pynwb/behavior.py b/src/pynwb/behavior.py index cd981c76f..b9388e8df 100644 --- a/src/pynwb/behavior.py +++ b/src/pynwb/behavior.py @@ -37,7 +37,7 @@ def __init__(self, **kwargs): Create a SpatialSeries TimeSeries dataset """ name, data, reference_frame, unit = popargs('name', 'data', 'reference_frame', 'unit', kwargs) - super(SpatialSeries, self).__init__(name, data, unit, **kwargs) + super().__init__(name, data, unit, **kwargs) # NWB 2.5 restricts length of second dimension to be <= 3 allowed_data_shapes = ((None, ), (None, 1), (None, 2), (None, 3)) diff --git a/src/pynwb/core.py b/src/pynwb/core.py index 317dd1281..41f99c2f1 100644 --- a/src/pynwb/core.py +++ b/src/pynwb/core.py @@ -5,7 +5,7 @@ from hdmf.container import AbstractContainer, MultiContainerInterface as hdmf_MultiContainerInterface, Table from hdmf.common import DynamicTable, DynamicTableRegion # noqa: F401 from hdmf.common import VectorData, VectorIndex, ElementIdentifiers # noqa: F401 -from hdmf.utils import docval, getargs, call_docval_func +from hdmf.utils import docval, popargs from hdmf.utils import LabelledDict # noqa: F401 from . import CORE_NAMESPACE, register_class @@ -28,7 +28,7 @@ def get_ancestor(self, **kwargs): """ Traverse parent hierarchy and return first instance of the specified data_type """ - neurodata_type = getargs('neurodata_type', kwargs) + neurodata_type = kwargs['neurodata_type'] return super().get_ancestor(data_type=neurodata_type) @@ -52,8 +52,8 @@ class NWBData(NWBMixin, Data): @docval({'name': 'name', 'type': str, 'doc': 'the name of this container'}, {'name': 'data', 'type': ('scalar_data', 'array_data', 'data', Data), 'doc': 'the source of the data'}) def __init__(self, **kwargs): - call_docval_func(super(NWBData, self).__init__, kwargs) - self.__data = getargs('data', kwargs) + super().__init__(**kwargs) + self.__data = kwargs['data'] @property def data(self): @@ -97,8 +97,8 @@ class ScratchData(NWBData): 'doc': 'notes about the data. This argument will be deprecated. Use description instead', 'default': ''}, {'name': 'description', 'type': str, 'doc': 'notes about the data', 'default': None}) def __init__(self, **kwargs): - call_docval_func(super().__init__, kwargs) - notes, description = getargs('notes', 'description', kwargs) + notes, description = popargs('notes', 'description', kwargs) + super().__init__(**kwargs) if notes != '': warn('The `notes` argument of ScratchData.__init__ will be deprecated. Use description instead.', PendingDeprecationWarning) diff --git a/src/pynwb/device.py b/src/pynwb/device.py index 9f21e2b57..836816651 100644 --- a/src/pynwb/device.py +++ b/src/pynwb/device.py @@ -1,4 +1,4 @@ -from hdmf.utils import docval, call_docval_func, popargs +from hdmf.utils import docval, popargs from . import register_class, CORE_NAMESPACE from .core import NWBContainer @@ -21,6 +21,6 @@ class Device(NWBContainer): 'default': None}) def __init__(self, **kwargs): description, manufacturer = popargs('description', 'manufacturer', kwargs) - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) self.description = description self.manufacturer = manufacturer diff --git a/src/pynwb/ecephys.py b/src/pynwb/ecephys.py index 46a184547..bdb787a55 100644 --- a/src/pynwb/ecephys.py +++ b/src/pynwb/ecephys.py @@ -1,7 +1,7 @@ from collections.abc import Iterable import warnings -from hdmf.utils import docval, getargs, popargs, call_docval_func, get_docval +from hdmf.utils import docval, popargs, get_docval, popargs_to_dict from hdmf.data_utils import DataChunkIterator, assertEqualShape from hdmf.utils import get_data_shape @@ -30,15 +30,13 @@ class ElectrodeGroup(NWBContainer): {'name': 'position', 'type': 'array_data', 'doc': 'stereotaxic position of this electrode group (x, y, z)', 'default': None}) def __init__(self, **kwargs): - call_docval_func(super(ElectrodeGroup, self).__init__, kwargs) - description, location, device, position = popargs('description', 'location', 'device', 'position', kwargs) - self.description = description - self.location = location - self.device = device - if position and len(position) != 3: + args_to_set = popargs_to_dict(('description', 'location', 'device', 'position'), kwargs) + super().__init__(**kwargs) + if args_to_set['position'] and len(args_to_set['position']) != 3: raise Exception('ElectrodeGroup position argument must have three elements: x, y, z, but received: %s' - % position) - self.position = position + % args_to_set['position']) + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('ElectricalSeries', CORE_NAMESPACE) @@ -78,25 +76,25 @@ class ElectricalSeries(TimeSeries): *get_docval(TimeSeries.__init__, 'resolution', 'conversion', 'timestamps', 'starting_time', 'rate', 'comments', 'description', 'control', 'control_description', 'offset')) def __init__(self, **kwargs): - name, electrodes, data, channel_conversion, filtering = popargs('name', 'electrodes', 'data', - 'channel_conversion', 'filtering', kwargs) - data_shape = get_data_shape(data, strict_no_data_load=True) + args_to_set = popargs_to_dict(('electrodes', 'channel_conversion', 'filtering'), kwargs) + + data_shape = get_data_shape(kwargs['data'], strict_no_data_load=True) if ( data_shape is not None and len(data_shape) == 2 - and data_shape[1] != len(electrodes.data) + and data_shape[1] != len(args_to_set['electrodes'].data) ): - if data_shape[0] == len(electrodes.data): + if data_shape[0] == len(args_to_set['electrodes'].data): warnings.warn("The second dimension of data does not match the length of electrodes, but instead the " "first does. Data is oriented incorrectly and should be transposed.") else: warnings.warn("The second dimension of data does not match the length of electrodes. Your data may be " "transposed.") - super(ElectricalSeries, self).__init__(name, data, 'volts', **kwargs) - self.electrodes = electrodes - self.channel_conversion = channel_conversion - self.filtering = filtering + kwargs['unit'] = 'volts' # fixed value + super().__init__(**kwargs) + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('SpikeEventSeries', CORE_NAMESPACE) @@ -120,8 +118,8 @@ class SpikeEventSeries(ElectricalSeries): *get_docval(ElectricalSeries.__init__, 'resolution', 'conversion', 'comments', 'description', 'control', 'control_description', 'offset')) def __init__(self, **kwargs): - name, data, electrodes = popargs('name', 'data', 'electrodes', kwargs) - timestamps = getargs('timestamps', kwargs) + data = kwargs['data'] + timestamps = kwargs['timestamps'] if not (isinstance(data, TimeSeries) or isinstance(timestamps, TimeSeries)): if not (isinstance(data, DataChunkIterator) or isinstance(timestamps, DataChunkIterator)): if len(data) != len(timestamps): @@ -129,7 +127,7 @@ def __init__(self, **kwargs): else: # TODO: add check when we have DataChunkIterators pass - super(SpikeEventSeries, self).__init__(name, data, electrodes, **kwargs) + super().__init__(**kwargs) @register_class('EventDetection', CORE_NAMESPACE) @@ -155,14 +153,11 @@ class EventDetection(NWBDataInterface): {'name': 'times', 'type': ('array_data', 'data'), 'doc': 'Timestamps of events, in Seconds'}, {'name': 'name', 'type': str, 'doc': 'the name of this container', 'default': 'EventDetection'}) def __init__(self, **kwargs): - detection_method, source_electricalseries, source_idx, times = popargs( - 'detection_method', 'source_electricalseries', 'source_idx', 'times', kwargs) - super(EventDetection, self).__init__(**kwargs) - self.detection_method = detection_method - self.source_electricalseries = source_electricalseries - self.source_idx = source_idx - self.times = times - self.unit = 'seconds' + args_to_set = popargs_to_dict(('detection_method', 'source_electricalseries', 'source_idx', 'times'), kwargs) + super().__init__(**kwargs) + for key, val in args_to_set.items(): + setattr(self, key, val) + self.unit = 'seconds' # fixed value @register_class('EventWaveform', CORE_NAMESPACE) @@ -207,15 +202,12 @@ class Clustering(NWBDataInterface): 'shape': (None,)}, {'name': 'name', 'type': str, 'doc': 'the name of this container', 'default': 'Clustering'}) def __init__(self, **kwargs): - import warnings warnings.warn("use pynwb.misc.Units or NWBFile.units instead", DeprecationWarning) - description, num, peak_over_rms, times = popargs( - 'description', 'num', 'peak_over_rms', 'times', kwargs) - super(Clustering, self).__init__(**kwargs) - self.description = description - self.num = num - self.peak_over_rms = list(peak_over_rms) - self.times = times + args_to_set = popargs_to_dict(('description', 'num', 'peak_over_rms', 'times'), kwargs) + super().__init__(**kwargs) + args_to_set['peak_over_rms'] = list(args_to_set['peak_over_rms']) + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('ClusterWaveforms', CORE_NAMESPACE) @@ -244,15 +236,12 @@ class ClusterWaveforms(NWBDataInterface): 'doc': 'the standard deviations of waveforms for each cluster'}, {'name': 'name', 'type': str, 'doc': 'the name of this container', 'default': 'ClusterWaveforms'}) def __init__(self, **kwargs): - import warnings warnings.warn("use pynwb.misc.Units or NWBFile.units instead", DeprecationWarning) - clustering_interface, waveform_filtering, waveform_mean, waveform_sd = popargs( - 'clustering_interface', 'waveform_filtering', 'waveform_mean', 'waveform_sd', kwargs) - super(ClusterWaveforms, self).__init__(**kwargs) - self.clustering_interface = clustering_interface - self.waveform_filtering = waveform_filtering - self.waveform_mean = waveform_mean - self.waveform_sd = waveform_sd + args_to_set = popargs_to_dict(('clustering_interface', 'waveform_filtering', + 'waveform_mean', 'waveform_sd'), kwargs) + super().__init__(**kwargs) + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('LFP', CORE_NAMESPACE) @@ -356,7 +345,7 @@ def __init__(self, **kwargs): raise ValueError(error_msg) # Initialize the object - super(FeatureExtraction, self).__init__(**kwargs) + super().__init__(**kwargs) self.electrodes = electrodes self.description = description self.times = list(times) diff --git a/src/pynwb/epoch.py b/src/pynwb/epoch.py index 82177b1bf..6f7f674a6 100644 --- a/src/pynwb/epoch.py +++ b/src/pynwb/epoch.py @@ -1,6 +1,6 @@ from bisect import bisect_left -from hdmf.utils import docval, getargs, popargs, call_docval_func, get_docval +from hdmf.utils import docval, getargs, popargs, get_docval from hdmf.data_utils import DataIO from . import register_class, CORE_NAMESPACE @@ -29,7 +29,7 @@ class TimeIntervals(DynamicTable): 'default': "experimental intervals"}, *get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames')) def __init__(self, **kwargs): - call_docval_func(super(TimeIntervals, self).__init__, kwargs) + super().__init__(**kwargs) @docval({'name': 'start_time', 'type': 'float', 'doc': 'Start time of epoch, in seconds'}, {'name': 'stop_time', 'type': 'float', 'doc': 'Stop time of epoch, in seconds'}, @@ -55,7 +55,7 @@ def add_interval(self, **kwargs): tmp.append(TimeSeriesReference(idx_start, count, ts)) timeseries = tmp rkwargs['timeseries'] = timeseries - return super(TimeIntervals, self).add_row(**rkwargs) + return super().add_row(**rkwargs) def __calculate_idx_count(self, start_time, stop_time, ts_data): if isinstance(ts_data.timestamps, DataIO): diff --git a/src/pynwb/file.py b/src/pynwb/file.py index 25bfdd86d..9c2324256 100644 --- a/src/pynwb/file.py +++ b/src/pynwb/file.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd -from hdmf.utils import docval, getargs, call_docval_func, get_docval, popargs +from hdmf.utils import docval, getargs, get_docval, popargs, popargs_to_dict from . import register_class, CORE_NAMESPACE from .base import TimeSeries, ProcessingModule @@ -30,9 +30,10 @@ def _not_parent(arg): @register_class('LabMetaData', CORE_NAMESPACE) class LabMetaData(NWBContainer): + @docval({'name': 'name', 'type': str, 'doc': 'name of metadata'}) def __init__(self, **kwargs): - super(LabMetaData, self).__init__(kwargs['name']) + super().__init__(**kwargs) @register_class('Subject', CORE_NAMESPACE) @@ -74,24 +75,29 @@ class Subject(NWBContainer): 'doc': 'The datetime of the date of birth. May be supplied instead of age.'}, {'name': 'strain', 'type': str, 'doc': 'The strain of the subject, e.g., "C57BL/6J"', 'default': None}) def __init__(self, **kwargs): + keys_to_set = ("age", + "description", + "genotype", + "sex", + "species", + "subject_id", + "weight", + "date_of_birth", + "strain") + args_to_set = popargs_to_dict(keys_to_set, kwargs) kwargs['name'] = 'subject' - call_docval_func(super(Subject, self).__init__, kwargs) - self.age = getargs('age', kwargs) - self.description = getargs('description', kwargs) - self.genotype = getargs('genotype', kwargs) - self.sex = getargs('sex', kwargs) - self.species = getargs('species', kwargs) - self.subject_id = getargs('subject_id', kwargs) - weight = getargs('weight', kwargs) + super().__init__(**kwargs) + + weight = args_to_set['weight'] if isinstance(weight, float): - weight = str(weight) + ' kg' - self.weight = weight - self.strain = getargs('strain', kwargs) - date_of_birth = getargs('date_of_birth', kwargs) + args_to_set['weight'] = str(weight) + ' kg' + + date_of_birth = args_to_set['date_of_birth'] if date_of_birth and date_of_birth.tzinfo is None: - self.date_of_birth = _add_missing_timezone(date_of_birth) - else: - self.date_of_birth = date_of_birth + args_to_set['date_of_birth'] = _add_missing_timezone(date_of_birth) + + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('NWBFile', CORE_NAMESPACE) @@ -354,29 +360,16 @@ class NWBFile(MultiContainerInterface): {'name': 'icephys_experimental_conditions', 'type': ExperimentalConditionsTable, 'default': None, 'doc': 'the ExperimentalConditionsTable table that belongs to this NWBFile'}) def __init__(self, **kwargs): - kwargs['name'] = 'root' - call_docval_func(super(NWBFile, self).__init__, kwargs) - self.fields['session_description'] = getargs('session_description', kwargs) - self.fields['identifier'] = getargs('identifier', kwargs) - - self.fields['session_start_time'] = getargs('session_start_time', kwargs) - if self.fields['session_start_time'].tzinfo is None: - self.fields['session_start_time'] = _add_missing_timezone(self.fields['session_start_time']) - - self.fields['timestamps_reference_time'] = getargs('timestamps_reference_time', kwargs) - if self.fields['timestamps_reference_time'] is None: - self.fields['timestamps_reference_time'] = self.fields['session_start_time'] - elif self.fields['timestamps_reference_time'].tzinfo is None: - raise ValueError("'timestamps_reference_time' must be a timezone-aware datetime object.") - - self.fields['file_create_date'] = getargs('file_create_date', kwargs) - if self.fields['file_create_date'] is None: - self.fields['file_create_date'] = datetime.now(tzlocal()) - if isinstance(self.fields['file_create_date'], datetime): - self.fields['file_create_date'] = [self.fields['file_create_date']] - self.fields['file_create_date'] = list(map(_add_missing_timezone, self.fields['file_create_date'])) - - fieldnames = [ + keys_to_set = [ + 'session_description', + 'identifier', + 'session_start_time', + 'experimenter', + 'file_create_date', + 'ic_electrodes', + 'icephys_electrodes', + 'related_publications', + 'timestamps_reference_time', 'acquisition', 'analysis', 'stimulus', @@ -419,30 +412,70 @@ def __init__(self, **kwargs): 'icephys_repetitions', 'icephys_experimental_conditions' ] - for attr in fieldnames: - setattr(self, attr, kwargs.get(attr, None)) + args_to_set = popargs_to_dict(keys_to_set, kwargs) + kwargs['name'] = 'root' + super().__init__(**kwargs) + + # add timezone to session_start_time if missing + session_start_time = args_to_set['session_start_time'] + if session_start_time.tzinfo is None: + args_to_set['session_start_time'] = _add_missing_timezone(session_start_time) + + # set timestamps_reference_time to session_start_time if not provided + # if provided, ensure that it has a timezone + timestamps_reference_time = args_to_set['timestamps_reference_time'] + if timestamps_reference_time is None: + args_to_set['timestamps_reference_time'] = args_to_set['session_start_time'] + elif timestamps_reference_time.tzinfo is None: + raise ValueError("'timestamps_reference_time' must be a timezone-aware datetime object.") + + # convert file_create_date to list and add timezone if missing + file_create_date = args_to_set['file_create_date'] + if file_create_date is None: + file_create_date = datetime.now(tzlocal()) + if isinstance(file_create_date, datetime): + file_create_date = [file_create_date] + args_to_set['file_create_date'] = list(map(_add_missing_timezone, file_create_date)) # backwards-compatibility code for ic_electrodes / icephys_electrodes - ic_elec_val = kwargs.get('icephys_electrodes', None) - if ic_elec_val is None and kwargs.get('ic_electrodes', None) is not None: - ic_elec_val = kwargs.get('ic_electrodes', None) + icephys_electrodes = args_to_set['icephys_electrodes'] + ic_electrodes = args_to_set['ic_electrodes'] + if icephys_electrodes is None and ic_electrodes is not None: warn("Use of the ic_electrodes parameter is deprecated. " "Use the icephys_electrodes parameter instead", DeprecationWarning) - setattr(self, 'icephys_electrodes', ic_elec_val) + args_to_set['icephys_electrodes'] = ic_electrodes + args_to_set.pop('ic_electrodes') # do not set this arg - experimenter = kwargs.get('experimenter', None) + # convert single experimenter to tuple + experimenter = args_to_set['experimenter'] if isinstance(experimenter, str): - experimenter = (experimenter,) - setattr(self, 'experimenter', experimenter) + args_to_set['experimenter'] = (experimenter,) - related_pubs = kwargs.get('related_publications', None) + # convert single related_publications to tuple + related_pubs = args_to_set['related_publications'] if isinstance(related_pubs, str): - related_pubs = (related_pubs,) - setattr(self, 'related_publications', related_pubs) + args_to_set['related_publications'] = (related_pubs,) - if getargs('source_script', kwargs) is None and getargs('source_script_file_name', kwargs) is not None: + # ensure source_script is provided if source_script_file_name is provided + if args_to_set['source_script'] is None and args_to_set['source_script_file_name'] is not None: raise ValueError("'source_script' cannot be None when 'source_script_file_name' is set") + # these attributes have no setters and can only be set using self.fields + keys_to_set_via_fields = ( + 'session_description', + 'identifier', + 'session_start_time', + 'timestamps_reference_time', + 'file_create_date' + ) + args_to_set_via_fields = popargs_to_dict(keys_to_set_via_fields, args_to_set) + + for key, val in args_to_set_via_fields.items(): + self.fields[key] = val + + for key, val in args_to_set.items(): + setattr(self, key, val) + self.__obj = None def all_children(self): @@ -524,7 +557,7 @@ def get_ic_electrode(self, *args, **kwargs): def __check_epochs(self): if self.epochs is None: - self.epochs = TimeIntervals('epochs', 'experimental epochs') + self.epochs = TimeIntervals(name='epochs', description='experimental epochs') @docval(*get_docval(TimeIntervals.add_column)) def add_epoch_column(self, **kwargs): @@ -534,7 +567,7 @@ def add_epoch_column(self, **kwargs): """ self.__check_epochs() self.epoch_tags.update(kwargs.pop('tags', list())) - call_docval_func(self.epochs.add_column, kwargs) + self.epochs.add_column(**kwargs) def add_epoch_metadata_column(self, *args, **kwargs): """ @@ -557,7 +590,7 @@ def add_epoch(self, **kwargs): self.__check_epochs() if kwargs['tags'] is not None: self.epoch_tags.update(kwargs['tags']) - call_docval_func(self.epochs.add_interval, kwargs) + self.epochs.add_interval(**kwargs) def __check_electrodes(self): if self.electrodes is None: @@ -570,7 +603,7 @@ def add_electrode_column(self, **kwargs): See :py:meth:`~hdmf.common.DynamicTable.add_column` for more details """ self.__check_electrodes() - call_docval_func(self.electrodes.add_column, kwargs) + self.electrodes.add_column(**kwargs) @docval({'name': 'x', 'type': 'float', 'doc': 'the x coordinate of the position (+x is posterior)'}, {'name': 'y', 'type': 'float', 'doc': 'the y coordinate of the position (+y is inferior)'}, @@ -616,7 +649,7 @@ def add_electrode(self, **kwargs): else: d.pop(col_name) # remove args from d if not set - call_docval_func(self.electrodes.add_row, d) + self.electrodes.add_row(**d) @docval({'name': 'region', 'type': (slice, list, tuple), 'doc': 'the indices of the table'}, {'name': 'description', 'type': str, 'doc': 'a brief description of what this electrode is'}, @@ -633,7 +666,7 @@ def create_electrode_table_region(self, **kwargs): + str(len(self.electrodes))) desc = getargs('description', kwargs) name = getargs('name', kwargs) - return DynamicTableRegion(name, region, desc, self.electrodes) + return DynamicTableRegion(name=name, data=region, description=desc, table=self.electrodes) def __check_units(self): if self.units is None: @@ -646,7 +679,7 @@ def add_unit_column(self, **kwargs): See :py:meth:`~hdmf.common.DynamicTable.add_column` for more details """ self.__check_units() - call_docval_func(self.units.add_column, kwargs) + self.units.add_column(**kwargs) @docval(*get_docval(Units.add_unit), allow_extra=True) def add_unit(self, **kwargs): @@ -656,11 +689,11 @@ def add_unit(self, **kwargs): """ self.__check_units() - call_docval_func(self.units.add_unit, kwargs) + self.units.add_unit(**kwargs) def __check_trials(self): if self.trials is None: - self.trials = TimeIntervals('trials', 'experimental trials') + self.trials = TimeIntervals(name='trials', description='experimental trials') @docval(*get_docval(DynamicTable.add_column)) def add_trial_column(self, **kwargs): @@ -669,7 +702,7 @@ def add_trial_column(self, **kwargs): See :py:meth:`~hdmf.common.DynamicTable.add_column` for more details """ self.__check_trials() - call_docval_func(self.trials.add_column, kwargs) + self.trials.add_column(**kwargs) @docval(*get_docval(TimeIntervals.add_interval), allow_extra=True) def add_trial(self, **kwargs): @@ -681,11 +714,14 @@ def add_trial(self, **kwargs): been added (through calls to `add_trial_columns`). """ self.__check_trials() - call_docval_func(self.trials.add_interval, kwargs) + self.trials.add_interval(**kwargs) def __check_invalid_times(self): if self.invalid_times is None: - self.invalid_times = TimeIntervals('invalid_times', 'time intervals to be removed from analysis') + self.invalid_times = TimeIntervals( + name='invalid_times', + description='time intervals to be removed from analysis' + ) @docval(*get_docval(DynamicTable.add_column)) def add_invalid_times_column(self, **kwargs): @@ -694,8 +730,9 @@ def add_invalid_times_column(self, **kwargs): See :py:meth:`~hdmf.common.DynamicTable.add_column` for more details """ self.__check_invalid_times() - call_docval_func(self.invalid_times.add_column, kwargs) + self.invalid_times.add_column(**kwargs) + @docval(*get_docval(TimeIntervals.add_interval), allow_extra=True) def add_invalid_time_interval(self, **kwargs): """ Add a trial to the trial table. @@ -705,7 +742,7 @@ def add_invalid_time_interval(self, **kwargs): been added (through calls to `add_invalid_times_columns`). """ self.__check_invalid_times() - call_docval_func(self.invalid_times.add_interval, kwargs) + self.invalid_times.add_interval(**kwargs) @docval({'name': 'electrode_table', 'type': DynamicTable, 'doc': 'the ElectrodeTable for this file'}) def set_electrode_table(self, **kwargs): @@ -803,7 +840,7 @@ def add_intracellular_recording(self, **kwargs): self.add_icephys_electrode(electrode) # make sure the intracellular recordings table exists and if not create it using get_intracellular_recordings # Add the recoding to the intracellular_recordings table - return call_docval_func(self.get_intracellular_recordings().add_recording, kwargs) + return self.get_intracellular_recordings().add_recording(**kwargs) @docval(returns='The NWBFile.icephys_simultaneous_recordings table', rtype=SimultaneousRecordingsTable) def get_icephys_simultaneous_recordings(self): @@ -826,7 +863,7 @@ def add_icephys_simultaneous_recording(self, **kwargs): """ Add a new simultaneous recording to the icephys_simultaneous_recordings table """ - return call_docval_func(self.get_icephys_simultaneous_recordings().add_simultaneous_recording, kwargs) + return self.get_icephys_simultaneous_recordings().add_simultaneous_recording(**kwargs) @docval(returns='The NWBFile.icephys_sequential_recordings table', rtype=SequentialRecordingsTable) def get_icephys_sequential_recordings(self): @@ -850,7 +887,7 @@ def add_icephys_sequential_recording(self, **kwargs): Add a new sequential recording to the icephys_sequential_recordings table """ self.get_icephys_sequential_recordings() - return call_docval_func(self.icephys_sequential_recordings.add_sequential_recording, kwargs) + return self.icephys_sequential_recordings.add_sequential_recording(**kwargs) @docval(returns='The NWBFile.icephys_repetitions table', rtype=RepetitionsTable) def get_icephys_repetitions(self): @@ -873,7 +910,7 @@ def add_icephys_repetition(self, **kwargs): """ Add a new repetition to the RepetitionsTable table """ - return call_docval_func(self.get_icephys_repetitions().add_repetition, kwargs) + return self.get_icephys_repetitions().add_repetition(**kwargs) @docval(returns='The NWBFile.icephys_experimental_conditions table', rtype=ExperimentalConditionsTable) def get_icephys_experimental_conditions(self): @@ -896,7 +933,7 @@ def add_icephys_experimental_condition(self, **kwargs): """ Add a new condition to the ExperimentalConditionsTable table """ - return call_docval_func(self.get_icephys_experimental_conditions().add_experimental_condition, kwargs) + return self.get_icephys_experimental_conditions().add_experimental_condition(**kwargs) def get_icephys_meta_parent_table(self): """ @@ -1040,7 +1077,7 @@ def _add_missing_timezone(date): def _tablefunc(table_name, description, columns): - t = DynamicTable(table_name, description) + t = DynamicTable(name=table_name, description=description) for c in columns: if isinstance(c, tuple): t.add_column(c[0], c[1]) diff --git a/src/pynwb/icephys.py b/src/pynwb/icephys.py index 32706a34e..6b3adeec4 100644 --- a/src/pynwb/icephys.py +++ b/src/pynwb/icephys.py @@ -1,7 +1,7 @@ import warnings from hdmf.common import DynamicTable, AlignedDynamicTable -from hdmf.utils import docval, popargs, call_docval_func, get_docval, getargs +from hdmf.utils import docval, popargs, popargs_to_dict, get_docval, getargs from . import register_class, CORE_NAMESPACE from .base import TimeSeries, TimeSeriesReferenceVectorData @@ -55,19 +55,22 @@ class IntracellularElectrode(NWBContainer): {'name': 'cell_id', 'type': str, 'doc': 'Unique ID of cell.', 'default': None} ) def __init__(self, **kwargs): - slice, seal, description, location, resistance, filtering, initial_access_resistance, device, cell_id = popargs( - 'slice', 'seal', 'description', 'location', 'resistance', - 'filtering', 'initial_access_resistance', 'device', 'cell_id', kwargs) - call_docval_func(super().__init__, kwargs) - self.slice = slice - self.seal = seal - self.description = description - self.location = location - self.resistance = resistance - self.filtering = filtering - self.initial_access_resistance = initial_access_resistance - self.device = device - self.cell_id = cell_id + keys_to_set = ( + 'slice', + 'seal', + 'description', + 'location', + 'resistance', + 'filtering', + 'initial_access_resistance', + 'device', + 'cell_id' + ) + args_to_set = popargs_to_dict(keys_to_set, kwargs) + super().__init__(**kwargs) + + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('PatchClampSeries', CORE_NAMESPACE) @@ -300,7 +303,7 @@ def __init__(self, **kwargs): warnings.warn("Use of SweepTable is deprecated. Use the IntracellularRecordingsTable " "instead. See also the NWBFile.add_intracellular_recordings function.", DeprecationWarning) - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) @docval({'name': 'pcs', 'type': PatchClampSeries, 'doc': 'PatchClampSeries to add to the table must have a valid sweep_number'}) @@ -360,7 +363,7 @@ def __init__(self, **kwargs): kwargs['name'] = 'electrodes' kwargs['description'] = ('Table for storing intracellular electrode related metadata') # Initialize the DynamicTable - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) @register_class('IntracellularStimuliTable', CORE_NAMESPACE) @@ -383,7 +386,7 @@ def __init__(self, **kwargs): kwargs['name'] = 'stimuli' kwargs['description'] = ('Table for storing intracellular stimulus related metadata') # Initialize the DynamicTable - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) @register_class('IntracellularResponsesTable', CORE_NAMESPACE) @@ -406,7 +409,7 @@ def __init__(self, **kwargs): kwargs['name'] = 'responses' kwargs['description'] = ('Table for storing intracellular response related metadata') # Initialize the DynamicTable - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) @register_class('IntracellularRecordingsTable', CORE_NAMESPACE) @@ -464,7 +467,7 @@ def __init__(self, **kwargs): kwargs['category_tables'] = dynamic_table_arg kwargs['categories'] = categories_arg - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) @docval({'name': 'electrode', 'type': IntracellularElectrode, 'doc': 'The intracellular electrode used'}, {'name': 'stimulus_start_index', 'type': 'int', 'doc': 'Start index of the stimulus', 'default': None}, @@ -669,7 +672,7 @@ def __init__(self, **kwargs): 'IntracellularRecordingsTable table together that were recorded simultaneously ' 'from different electrodes.') # Initialize the DynamicTable - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) if self['recordings'].target.table is None: if intracellular_recordings_table is not None: self['recordings'].target.table = intracellular_recordings_table @@ -729,7 +732,7 @@ def __init__(self, **kwargs): 'group together simultaneous_recordings where the a sequence of stimuli of the ' 'same type with varying parameters have been presented in a sequence.') # Initialize the DynamicTable - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) if self['simultaneous_recordings'].target.table is None: if simultaneous_recordings_table is not None: self['simultaneous_recordings'].target.table = simultaneous_recordings_table @@ -786,7 +789,7 @@ def __init__(self, **kwargs): 'of stimulus, the RepetitionsTable table is typically used to group sets ' 'of stimuli applied in sequence.') # Initialize the DynamicTable - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) if self['sequential_recordings'].target.table is None: if sequential_recordings_table is not None: self['sequential_recordings'].target.table = sequential_recordings_table @@ -836,7 +839,7 @@ def __init__(self, **kwargs): kwargs['description'] = ('A table for grouping different intracellular recording repetitions together that ' 'belong to the same experimental conditions.') # Initialize the DynamicTable - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) if self['repetitions'].target.table is None: if repetitions_table is not None: self['repetitions'].target.table = repetitions_table diff --git a/src/pynwb/image.py b/src/pynwb/image.py index 4e0377630..2fb7eb678 100644 --- a/src/pynwb/image.py +++ b/src/pynwb/image.py @@ -2,7 +2,7 @@ import numpy as np from collections.abc import Iterable -from hdmf.utils import docval, getargs, popargs, call_docval_func, get_docval +from hdmf.utils import docval, getargs, popargs, popargs_to_dict, get_docval from . import register_class, CORE_NAMESPACE from .base import TimeSeries, Image, Images @@ -56,13 +56,13 @@ class ImageSeries(TimeSeries): {'name': 'device', 'type': Device, 'doc': 'Device used to capture the images/video.', 'default': None},) def __init__(self, **kwargs): - bits_per_pixel, dimension, external_file, starting_frame, format, device = popargs( - 'bits_per_pixel', 'dimension', 'external_file', 'starting_frame', 'format', 'device', kwargs) + keys_to_set = ('bits_per_pixel', 'dimension', 'external_file', 'starting_frame', 'format', 'device') + args_to_set = popargs_to_dict(keys_to_set, kwargs) name, data, unit = getargs('name', 'data', 'unit', kwargs) if data is not None and unit is None: raise ValueError("Must supply 'unit' argument when supplying 'data' to %s '%s'." % (self.__class__.__name__, name)) - if external_file is None and data is None: + if args_to_set['external_file'] is None and data is None: raise ValueError("Must supply either external_file or data to %s '%s'." % (self.__class__.__name__, name)) @@ -72,17 +72,13 @@ def __init__(self, **kwargs): if unit is None: kwargs['unit'] = ImageSeries.DEFAULT_UNIT - call_docval_func(super(ImageSeries, self).__init__, kwargs) + # TODO catch warning when default data is used and timestamps are provided + super().__init__(**kwargs) - self.bits_per_pixel = bits_per_pixel - self.dimension = dimension - self.external_file = external_file - if external_file is not None: - self.starting_frame = starting_frame - else: - self.starting_frame = None - self.format = format - self.device = device + if args_to_set["external_file"] is None: + args_to_set["starting_frame"] = None # overwrite starting_frame + for key, val in args_to_set.items(): + setattr(self, key, val) @property def bits_per_pixel(self): @@ -136,7 +132,7 @@ def __init__(self, **kwargs): "a future version of NWB. Use the indexed_images field instead.") warnings.warn(msg, PendingDeprecationWarning) kwargs['unit'] = 'N/A' # fixed value starting in NWB 2.5 - super(IndexSeries, self).__init__(**kwargs) + super().__init__(**kwargs) self.indexed_timeseries = indexed_timeseries self.indexed_images = indexed_images if kwargs['conversion'] and kwargs['conversion'] != self.DEFAULT_CONVERSION: @@ -170,7 +166,7 @@ class ImageMaskSeries(ImageSeries): 'default': None},) def __init__(self, **kwargs): masked_imageseries = popargs('masked_imageseries', kwargs) - super(ImageMaskSeries, self).__init__(**kwargs) + super().__init__(**kwargs) self.masked_imageseries = masked_imageseries @@ -206,7 +202,7 @@ class OpticalSeries(ImageSeries): 'description', 'control', 'control_description', 'device', 'offset')) def __init__(self, **kwargs): distance, field_of_view, orientation = popargs('distance', 'field_of_view', 'orientation', kwargs) - super(OpticalSeries, self).__init__(**kwargs) + super().__init__(**kwargs) self.distance = distance self.field_of_view = field_of_view self.orientation = orientation @@ -221,7 +217,7 @@ class GrayscaleImage(Image): 'shape': (None, None)}, *get_docval(Image.__init__, 'resolution', 'description')) def __init__(self, **kwargs): - call_docval_func(super(GrayscaleImage, self).__init__, kwargs) + super().__init__(**kwargs) @register_class('RGBImage', CORE_NAMESPACE) @@ -234,7 +230,7 @@ class RGBImage(Image): 'shape': (None, None, 3)}, *get_docval(Image.__init__, 'resolution', 'description')) def __init__(self, **kwargs): - call_docval_func(super(RGBImage, self).__init__, kwargs) + super().__init__(**kwargs) @register_class('RGBAImage', CORE_NAMESPACE) @@ -247,4 +243,4 @@ class RGBAImage(Image): 'shape': (None, None, 4)}, *get_docval(Image.__init__, 'resolution', 'description')) def __init__(self, **kwargs): - call_docval_func(super(RGBAImage, self).__init__, kwargs) + super().__init__(**kwargs) diff --git a/src/pynwb/io/base.py b/src/pynwb/io/base.py index 7663ce87e..4b86e8713 100644 --- a/src/pynwb/io/base.py +++ b/src/pynwb/io/base.py @@ -9,7 +9,7 @@ class ModuleMap(NWBContainerMapper): def __init__(self, spec): - super(ModuleMap, self).__init__(spec) + super().__init__(spec) containers_spec = self.spec.get_neurodata_type('NWBDataInterface') table_spec = self.spec.get_neurodata_type('DynamicTable') self.map_spec('data_interfaces', containers_spec) @@ -24,7 +24,7 @@ def name(self, builder, manager): class TimeSeriesMap(NWBContainerMapper): def __init__(self, spec): - super(TimeSeriesMap, self).__init__(spec) + super().__init__(spec) data_spec = self.spec.get_dataset('data') self.map_spec('unit', data_spec.get_attribute('unit')) self.map_spec('resolution', data_spec.get_attribute('resolution')) diff --git a/src/pynwb/io/epoch.py b/src/pynwb/io/epoch.py index 022d9a148..3c2b42aa2 100644 --- a/src/pynwb/io/epoch.py +++ b/src/pynwb/io/epoch.py @@ -33,7 +33,7 @@ def columns_carg(self, builder, manager): # schema are compatible (i.e., only the neurodata_type was changed in 2.5) dset_obj.__class__ = TimeSeriesReferenceVectorData # Execute init logic specific for TimeSeriesReferenceVectorData - dset_obj. _init_internal() + dset_obj._init_internal() columns.append(dset_obj) # overwrite the columns constructor argument return columns diff --git a/src/pynwb/io/file.py b/src/pynwb/io/file.py index 19c6c384f..7484b7541 100644 --- a/src/pynwb/io/file.py +++ b/src/pynwb/io/file.py @@ -9,7 +9,7 @@ class NWBFileMap(ObjectMapper): def __init__(self, spec): - super(NWBFileMap, self).__init__(spec) + super().__init__(spec) acq_spec = self.spec.get_group('acquisition') self.unmap(acq_spec) diff --git a/src/pynwb/io/icephys.py b/src/pynwb/io/icephys.py index 0ed1a2c3b..24045d090 100644 --- a/src/pynwb/io/icephys.py +++ b/src/pynwb/io/icephys.py @@ -1,21 +1,16 @@ from .. import register_map -from pynwb.icephys import SweepTable, VoltageClampSeries, IntracellularRecordingsTable +from pynwb.icephys import VoltageClampSeries, IntracellularRecordingsTable from hdmf.common.io.table import DynamicTableMap from hdmf.common.io.alignedtable import AlignedDynamicTableMap from .base import TimeSeriesMap -@register_map(SweepTable) -class SweepTableMap(DynamicTableMap): - pass - - @register_map(VoltageClampSeries) class VoltageClampSeriesMap(TimeSeriesMap): def __init__(self, spec): - super(VoltageClampSeriesMap, self).__init__(spec) + super().__init__(spec) fields_with_unit = ('capacitance_fast', 'capacitance_slow', diff --git a/src/pynwb/io/ophys.py b/src/pynwb/io/ophys.py index a5b1b9d18..3e758c7ab 100644 --- a/src/pynwb/io/ophys.py +++ b/src/pynwb/io/ophys.py @@ -9,7 +9,7 @@ class PlaneSegmentationMap(DynamicTableMap): def __init__(self, spec): - super(PlaneSegmentationMap, self).__init__(spec) + super().__init__(spec) reference_images_spec = self.spec.get_group('reference_images').get_neurodata_type('ImageSeries') self.map_spec('reference_images', reference_images_spec) @@ -19,7 +19,7 @@ def __init__(self, spec): class ImagingPlaneMap(NWBContainerMapper): def __init__(self, spec): - super(ImagingPlaneMap, self).__init__(spec) + super().__init__(spec) manifold_spec = self.spec.get_dataset('manifold') origin_coords_spec = self.spec.get_dataset('origin_coords') grid_spacing_spec = self.spec.get_dataset('grid_spacing') diff --git a/src/pynwb/misc.py b/src/pynwb/misc.py index 64f199aa8..9c0c83ce1 100644 --- a/src/pynwb/misc.py +++ b/src/pynwb/misc.py @@ -3,7 +3,7 @@ import warnings from bisect import bisect_left, bisect_right -from hdmf.utils import docval, getargs, popargs, call_docval_func, get_docval +from hdmf.utils import docval, getargs, popargs, popargs_to_dict, get_docval from . import register_class, CORE_NAMESPACE from .base import TimeSeries @@ -26,7 +26,7 @@ class AnnotationSeries(TimeSeries): *get_docval(TimeSeries.__init__, 'timestamps', 'comments', 'description')) def __init__(self, **kwargs): name, data, timestamps = popargs('name', 'data', 'timestamps', kwargs) - super(AnnotationSeries, self).__init__(name, data, 'n/a', resolution=-1.0, timestamps=timestamps, **kwargs) + super().__init__(name=name, data=data, unit='n/a', resolution=-1.0, timestamps=timestamps, **kwargs) @docval({'name': 'time', 'type': 'float', 'doc': 'The time for the annotation'}, {'name': 'annotation', 'type': str, 'doc': 'the annotation'}) @@ -65,7 +65,7 @@ class AbstractFeatureSeries(TimeSeries): def __init__(self, **kwargs): name, data, features, feature_units = popargs('name', 'data', 'features', 'feature_units', kwargs) - super(AbstractFeatureSeries, self).__init__(name, data, "see 'feature_units'", **kwargs) + super().__init__(name=name, data=data, unit="see 'feature_units'", **kwargs) self.features = features self.feature_units = feature_units @@ -103,7 +103,7 @@ def __init__(self, **kwargs): name, data, timestamps = popargs('name', 'data', 'timestamps', kwargs) self.__interval_timestamps = timestamps self.__interval_data = data - super(IntervalSeries, self).__init__(name, data, 'n/a', resolution=-1.0, timestamps=timestamps, **kwargs) + super().__init__(name=name, data=data, unit='n/a', resolution=-1.0, timestamps=timestamps, **kwargs) @docval({'name': 'start', 'type': 'float', 'doc': 'The start time of the interval'}, {'name': 'stop', 'type': 'float', 'doc': 'The stop time of the interval'}) @@ -165,15 +165,18 @@ class Units(DynamicTable): 'doc': 'The smallest possible difference between two spike times', 'default': None} ) def __init__(self, **kwargs): - if kwargs.get('description', None) is None: + args_to_set = popargs_to_dict(("waveform_rate", "waveform_unit", "resolution"), kwargs) + electrode_table = popargs("electrode_table", kwargs) + if kwargs['description'] is None: kwargs['description'] = "data on spiking units" - call_docval_func(super(Units, self).__init__, kwargs) + super().__init__(**kwargs) + + for key, val in args_to_set.items(): + setattr(self, key, val) + if 'spike_times' not in self.colnames: self.__has_spike_times = False - self.__electrode_table = getargs('electrode_table', kwargs) - self.waveform_rate = getargs('waveform_rate', kwargs) - self.waveform_unit = getargs('waveform_unit', kwargs) - self.resolution = getargs('resolution', kwargs) + self.__electrode_table = electrode_table @docval({'name': 'spike_times', 'type': 'array_data', 'doc': 'the spike times for each unit', 'default': None, 'shape': (None,)}, @@ -198,7 +201,7 @@ def add_unit(self, **kwargs): """ Add a unit to this table """ - super(Units, self).add_row(**kwargs) + super().add_row(**kwargs) if 'electrodes' in self: elec_col = self['electrodes'].target if elec_col.table is None: @@ -276,7 +279,7 @@ class DecompositionSeries(TimeSeries): def __init__(self, **kwargs): metric, source_timeseries, bands, source_channels = popargs('metric', 'source_timeseries', 'bands', 'source_channels', kwargs) - super(DecompositionSeries, self).__init__(**kwargs) + super().__init__(**kwargs) self.source_timeseries = source_timeseries self.source_channels = source_channels if self.source_timeseries is None and self.source_channels is None: @@ -285,7 +288,10 @@ def __init__(self, **kwargs): "corresponding source_channels. (Optional)") self.metric = metric if bands is None: - bands = DynamicTable("bands", "data about the frequency bands that the signal was decomposed into") + bands = DynamicTable( + name="bands", + description="data about the frequency bands that the signal was decomposed into" + ) self.bands = bands def __check_column(self, name, desc): diff --git a/src/pynwb/ogen.py b/src/pynwb/ogen.py index 26a6c747f..5377fdc08 100644 --- a/src/pynwb/ogen.py +++ b/src/pynwb/ogen.py @@ -1,4 +1,4 @@ -from hdmf.utils import docval, popargs, get_docval, call_docval_func +from hdmf.utils import docval, popargs, get_docval, popargs_to_dict from . import register_class, CORE_NAMESPACE from .base import TimeSeries @@ -22,13 +22,11 @@ class OptogeneticStimulusSite(NWBContainer): {'name': 'excitation_lambda', 'type': 'float', 'doc': 'Excitation wavelength in nm.'}, {'name': 'location', 'type': str, 'doc': 'Location of stimulation site.'}) def __init__(self, **kwargs): - device, description, excitation_lambda, location = popargs( - 'device', 'description', 'excitation_lambda', 'location', kwargs) - call_docval_func(super(OptogeneticStimulusSite, self).__init__, kwargs) - self.device = device - self.description = description - self.excitation_lambda = excitation_lambda - self.location = location + args_to_set = popargs_to_dict(('device', 'description', 'excitation_lambda', 'location'), kwargs) + super().__init__(**kwargs) + + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('OptogeneticSeries', CORE_NAMESPACE) @@ -47,6 +45,7 @@ class OptogeneticSeries(TimeSeries): *get_docval(TimeSeries.__init__, 'resolution', 'conversion', 'timestamps', 'starting_time', 'rate', 'comments', 'description', 'control', 'control_description', 'offset')) def __init__(self, **kwargs): - name, data, site = popargs('name', 'data', 'site', kwargs) - super(OptogeneticSeries, self).__init__(name, data, 'watts', **kwargs) + site = popargs('site', kwargs) + kwargs['unit'] = 'watts' + super().__init__(**kwargs) self.site = site diff --git a/src/pynwb/ophys.py b/src/pynwb/ophys.py index 3789dad93..35905669e 100644 --- a/src/pynwb/ophys.py +++ b/src/pynwb/ophys.py @@ -2,7 +2,7 @@ import numpy as np import warnings -from hdmf.utils import docval, getargs, popargs, call_docval_func, get_docval, get_data_shape +from hdmf.utils import docval, popargs, get_docval, get_data_shape, popargs_to_dict from . import register_class, CORE_NAMESPACE from .base import TimeSeries @@ -24,7 +24,7 @@ class OpticalChannel(NWBContainer): {'name': 'emission_lambda', 'type': 'float', 'doc': 'Emission wavelength for channel, in nm.'}) # required def __init__(self, **kwargs): description, emission_lambda = popargs("description", "emission_lambda", kwargs) - call_docval_func(super(OpticalChannel, self).__init__, kwargs) + super().__init__(**kwargs) self.description = description self.emission_lambda = emission_lambda @@ -87,38 +87,37 @@ class ImagingPlane(NWBContainer): 'doc': "Measurement units for grid_spacing. The default value is 'meters'.", 'default': 'meters'}) def __init__(self, **kwargs): - optical_channel, description, device, excitation_lambda, imaging_rate, \ - indicator, location, manifold, conversion, unit, reference_frame, origin_coords, origin_coords_unit, \ - grid_spacing, grid_spacing_unit = popargs( - 'optical_channel', 'description', 'device', 'excitation_lambda', - 'imaging_rate', 'indicator', 'location', 'manifold', 'conversion', - 'unit', 'reference_frame', 'origin_coords', 'origin_coords_unit', 'grid_spacing', 'grid_spacing_unit', - kwargs) - call_docval_func(super(ImagingPlane, self).__init__, kwargs) - self.optical_channel = optical_channel if isinstance(optical_channel, list) else [optical_channel] - self.description = description - self.device = device - self.excitation_lambda = excitation_lambda - self.imaging_rate = imaging_rate - self.indicator = indicator - self.location = location - if manifold is not None: + keys_to_set = ('optical_channel', + 'description', + 'device', + 'excitation_lambda', + 'imaging_rate', + 'indicator', + 'location', + 'manifold', + 'conversion', + 'unit', + 'reference_frame', + 'origin_coords', + 'origin_coords_unit', + 'grid_spacing', + 'grid_spacing_unit') + args_to_set = popargs_to_dict(keys_to_set, kwargs) + super().__init__(**kwargs) + + if not isinstance(args_to_set['optical_channel'], list): + args_to_set['optical_channel'] = [args_to_set['optical_channel']] + if args_to_set['manifold'] is not None: warnings.warn("The 'manifold' argument is deprecated in favor of 'origin_coords' and 'grid_spacing'.", DeprecationWarning) - if conversion != 1.0: + if args_to_set['conversion'] != 1.0: warnings.warn("The 'conversion' argument is deprecated in favor of 'origin_coords' and 'grid_spacing'.", DeprecationWarning) - if unit != 'meters': + if args_to_set['unit'] != 'meters': warnings.warn("The 'unit' argument is deprecated in favor of 'origin_coords_unit' and 'grid_spacing_unit'.", DeprecationWarning) - self.manifold = manifold - self.conversion = conversion - self.unit = unit - self.reference_frame = reference_frame - self.origin_coords = origin_coords - self.origin_coords_unit = origin_coords_unit - self.grid_spacing = grid_spacing - self.grid_spacing_unit = grid_spacing_unit + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('TwoPhotonSeries', CORE_NAMESPACE) @@ -144,13 +143,15 @@ class TwoPhotonSeries(ImageSeries): 'dimension', 'resolution', 'conversion', 'timestamps', 'starting_time', 'rate', 'comments', 'description', 'control', 'control_description', 'device', 'offset')) def __init__(self, **kwargs): - field_of_view, imaging_plane, pmt_gain, scan_line_rate = popargs( - 'field_of_view', 'imaging_plane', 'pmt_gain', 'scan_line_rate', kwargs) - call_docval_func(super(TwoPhotonSeries, self).__init__, kwargs) - self.field_of_view = field_of_view - self.imaging_plane = imaging_plane - self.pmt_gain = pmt_gain - self.scan_line_rate = scan_line_rate + keys_to_set = ("field_of_view", + "imaging_plane", + "pmt_gain", + "scan_line_rate") + args_to_set = popargs_to_dict(keys_to_set, kwargs) + super().__init__(**kwargs) + + for key, val in args_to_set.items(): + setattr(self, key, val) @register_class('CorrectedImageStack', CORE_NAMESPACE) @@ -176,7 +177,7 @@ class CorrectedImageStack(NWBDataInterface): 'for example, to align each frame to a reference image. This must have the name "xy_translation".'}) def __init__(self, **kwargs): corrected, original, xy_translation = popargs('corrected', 'original', 'xy_translation', kwargs) - call_docval_func(super(CorrectedImageStack, self).__init__, kwargs) + super().__init__(**kwargs) self.corrected = corrected self.original = original self.xy_translation = xy_translation @@ -228,10 +229,9 @@ class PlaneSegmentation(DynamicTable): *get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames')) def __init__(self, **kwargs): imaging_plane, reference_images = popargs('imaging_plane', 'reference_images', kwargs) - if kwargs.get('name') is None: + if kwargs['name'] is None: kwargs['name'] = imaging_plane.name - columns, colnames = getargs('columns', 'colnames', kwargs) - call_docval_func(super(PlaneSegmentation, self).__init__, kwargs) + super().__init__(**kwargs) self.imaging_plane = imaging_plane if isinstance(reference_images, ImageSeries): reference_images = (reference_images,) @@ -260,7 +260,7 @@ def add_roi(self, **kwargs): rkwargs['pixel_mask'] = pixel_mask if voxel_mask is not None: rkwargs['voxel_mask'] = voxel_mask - return super(PlaneSegmentation, self).add_row(**rkwargs) + return super().add_row(**rkwargs) @staticmethod def pixel_to_image(pixel_mask): @@ -291,7 +291,7 @@ def image_to_pixel(image_mask): {'name': 'region', 'type': (slice, list, tuple), 'doc': 'the indices of the table', 'default': slice(None)}, {'name': 'name', 'type': str, 'doc': 'the name of the ROITableRegion', 'default': 'rois'}) def create_roi_table_region(self, **kwargs): - return call_docval_func(self.create_region, kwargs) + return self.create_region(**kwargs) @register_class('ImageSegmentation', CORE_NAMESPACE) @@ -362,7 +362,7 @@ def __init__(self, **kwargs): else: warnings.warn("The second dimension of data does not match the length of rois. Your data may be " "transposed.") - call_docval_func(super(RoiResponseSeries, self).__init__, kwargs) + super().__init__(**kwargs) self.rois = rois diff --git a/src/pynwb/retinotopy.py b/src/pynwb/retinotopy.py index 4671cf40a..47db2e04a 100644 --- a/src/pynwb/retinotopy.py +++ b/src/pynwb/retinotopy.py @@ -1,7 +1,7 @@ from collections.abc import Iterable import warnings -from hdmf.utils import docval, popargs, call_docval_func, get_docval +from hdmf.utils import docval, popargs, get_docval from . import register_class, CORE_NAMESPACE from .core import NWBDataInterface, NWBData @@ -27,7 +27,7 @@ class RetinotopyImage(NWBData): def __init__(self, **kwargs): bits_per_pixel, dimension, format, field_of_view = popargs( 'bits_per_pixel', 'dimension', 'format', 'field_of_view', kwargs) - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) self.bits_per_pixel = bits_per_pixel self.dimension = dimension self.format = format @@ -45,7 +45,7 @@ class FocalDepthImage(RetinotopyImage): {'name': 'focal_depth', 'type': 'float', 'doc': 'Focal depth offset, in meters.'}) def __init__(self, **kwargs): focal_depth = popargs('focal_depth', kwargs) - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) self.focal_depth = focal_depth @@ -63,7 +63,7 @@ class RetinotopyMap(NWBData): 'doc': 'Number of rows and columns in the image'}) def __init__(self, **kwargs): field_of_view, dimension = popargs('field_of_view', 'dimension', kwargs) - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) self.field_of_view = field_of_view self.dimension = dimension @@ -79,7 +79,7 @@ class AxisMap(RetinotopyMap): *get_docval(RetinotopyMap.__init__, 'dimension')) def __init__(self, **kwargs): unit = popargs('unit', kwargs) - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) self.unit = unit @@ -132,7 +132,7 @@ def __init__(self, **kwargs): focal_depth_image, sign_map, vasculature_image = popargs( 'axis_1_phase_map', 'axis_1_power_map', 'axis_2_phase_map', 'axis_2_power_map', 'axis_descriptions', 'focal_depth_image', 'sign_map', 'vasculature_image', kwargs) - call_docval_func(super().__init__, kwargs) + super().__init__(**kwargs) warnings.warn("The ImagingRetinotopy class currently cannot be written to or read from a file. " "This is a known bug and will be fixed in a future release of PyNWB.") self.axis_1_phase_map = axis_1_phase_map diff --git a/src/pynwb/spec.py b/src/pynwb/spec.py index d1e4160be..eaa70b6a1 100644 --- a/src/pynwb/spec.py +++ b/src/pynwb/spec.py @@ -3,7 +3,7 @@ from hdmf.spec import (LinkSpec, GroupSpec, DatasetSpec, SpecNamespace, NamespaceBuilder, AttributeSpec, DtypeSpec, RefSpec) from hdmf.spec.write import export_spec # noqa: F401 -from hdmf.utils import docval, get_docval, call_docval_func +from hdmf.utils import docval, get_docval from . import CORE_NAMESPACE @@ -32,7 +32,7 @@ class NWBRefSpec(RefSpec): @docval(*deepcopy(_ref_docval)) def __init__(self, **kwargs): - call_docval_func(super(NWBRefSpec, self).__init__, kwargs) + super().__init__(**kwargs) _attr_docval = __swap_inc_def(AttributeSpec) @@ -42,7 +42,7 @@ class NWBAttributeSpec(AttributeSpec): @docval(*deepcopy(_attr_docval)) def __init__(self, **kwargs): - call_docval_func(super(NWBAttributeSpec, self).__init__, kwargs) + super().__init__(**kwargs) _link_docval = __swap_inc_def(LinkSpec) @@ -52,7 +52,7 @@ class NWBLinkSpec(LinkSpec): @docval(*deepcopy(_link_docval)) def __init__(self, **kwargs): - call_docval_func(super(NWBLinkSpec, self).__init__, kwargs) + super().__init__(**kwargs) @property def neurodata_type_inc(self): @@ -60,7 +60,7 @@ def neurodata_type_inc(self): return self.data_type_inc -class BaseStorageOverride(object): +class BaseStorageOverride: ''' This class is used for the purpose of overriding BaseStorageSpec classmethods, without creating diamond inheritance hierarchies. @@ -97,7 +97,7 @@ def neurodata_type_def(self): def build_const_args(cls, spec_dict): """Extend base functionality to remap data_type_def and data_type_inc keys""" spec_dict = copy(spec_dict) - proxy = super(BaseStorageOverride, cls) + proxy = super() if proxy.inc_key() in spec_dict: spec_dict[cls.inc_key()] = spec_dict.pop(proxy.inc_key()) if proxy.def_key() in spec_dict: @@ -108,7 +108,7 @@ def build_const_args(cls, spec_dict): @classmethod def _translate_kwargs(cls, kwargs): """Swap neurodata_type_def and neurodata_type_inc for data_type_def and data_type_inc, respectively""" - proxy = super(BaseStorageOverride, cls) + proxy = super() kwargs[proxy.def_key()] = kwargs.pop(cls.def_key()) kwargs[proxy.inc_key()] = kwargs.pop(cls.inc_key()) return kwargs @@ -121,7 +121,7 @@ class NWBDtypeSpec(DtypeSpec): @docval(*deepcopy(_dtype_docval)) def __init__(self, **kwargs): - call_docval_func(super(NWBDtypeSpec, self).__init__, kwargs) + super().__init__(**kwargs) _dataset_docval = __swap_inc_def(DatasetSpec) @@ -139,7 +139,7 @@ def __init__(self, **kwargs): # set data_type_inc to NWBData only if it is not specified and the type is not an HDMF base type if kwargs['data_type_inc'] is None and kwargs['data_type_def'] not in (None, 'Data'): kwargs['data_type_inc'] = 'NWBData' - super(NWBDatasetSpec, self).__init__(**kwargs) + super().__init__(**kwargs) _group_docval = __swap_inc_def(GroupSpec) @@ -159,7 +159,7 @@ def __init__(self, **kwargs): # NWBContainer. This will be fixed in hdmf-common-schema 1.2.1. if kwargs['data_type_inc'] is None and kwargs['data_type_def'] not in (None, 'Container', 'CSRMatrix'): kwargs['data_type_inc'] = 'NWBContainer' - super(NWBGroupSpec, self).__init__(**kwargs) + super().__init__(**kwargs) @classmethod def dataset_spec_cls(cls): @@ -168,7 +168,7 @@ def dataset_spec_cls(cls): @docval({'name': 'neurodata_type', 'type': str, 'doc': 'the neurodata_type to retrieve'}) def get_neurodata_type(self, **kwargs): ''' Get a specification by "neurodata_type" ''' - return super(NWBGroupSpec, self).get_data_type(kwargs['neurodata_type']) + return super().get_data_type(kwargs['neurodata_type']) @docval(*deepcopy(_group_docval)) def add_group(self, **kwargs): @@ -215,5 +215,5 @@ class NWBNamespaceBuilder(NamespaceBuilder): def __init__(self, **kwargs): ''' Create a NWBNamespaceBuilder ''' kwargs['namespace_cls'] = NWBNamespace - call_docval_func(super(NWBNamespaceBuilder, self).__init__, kwargs) + super().__init__(**kwargs) self.include_namespace(CORE_NAMESPACE) diff --git a/tests/integration/hdf5/test_ecephys.py b/tests/integration/hdf5/test_ecephys.py index cc70ee9dc..c7babfb31 100644 --- a/tests/integration/hdf5/test_ecephys.py +++ b/tests/integration/hdf5/test_ecephys.py @@ -166,8 +166,8 @@ def setUpContainer(self): description='the first and third electrodes', table=self.table) sES = SpikeEventSeries(name='test_sES', - data=((1, 1, 1), (2, 2, 2)), - timestamps=[0., 1.], + data=((1, 1), (2, 2), (3, 3)), + timestamps=[0., 1., 2.], electrodes=region) ew = EventWaveform(sES) return ew diff --git a/tests/integration/hdf5/test_image.py b/tests/integration/hdf5/test_image.py index d10770f83..1183bb0a4 100644 --- a/tests/integration/hdf5/test_image.py +++ b/tests/integration/hdf5/test_image.py @@ -17,7 +17,7 @@ def setUpContainer(self): external_file=['external_file'], starting_frame=[1, 2, 3], format='tiff', - timestamps=list(), + timestamps=[1., 2., 3.], device=self.dev1, ) return iS diff --git a/tests/integration/hdf5/test_modular_storage.py b/tests/integration/hdf5/test_modular_storage.py index db1608865..6c86fc615 100644 --- a/tests/integration/hdf5/test_modular_storage.py +++ b/tests/integration/hdf5/test_modular_storage.py @@ -16,7 +16,7 @@ class TestTimeSeriesModular(TestCase): def setUp(self): self.start_time = datetime(1971, 1, 1, 12, tzinfo=tzutc()) - self.data = np.arange(2000).reshape((2, 1000)) + self.data = np.arange(2000).reshape((1000, 2)) self.timestamps = np.linspace(0, 1, 1000) self.container = TimeSeries( diff --git a/tests/integration/hdf5/test_nwbfile.py b/tests/integration/hdf5/test_nwbfile.py index 70909c029..90c02aac5 100644 --- a/tests/integration/hdf5/test_nwbfile.py +++ b/tests/integration/hdf5/test_nwbfile.py @@ -440,7 +440,7 @@ def setUpContainer(self): """ Return placeholder table for electrodes. Tested electrodes are added directly to the NWBFile in addContainer """ - return DynamicTable('electrodes', 'a placeholder table') + return DynamicTable(name='electrodes', description='a placeholder table') def addContainer(self, nwbfile): """ Add electrodes and related objects to the given NWBFile """ @@ -491,7 +491,7 @@ def setUpContainer(self): """ Return placeholder table for electrodes. Tested electrodes are added directly to the NWBFile in addContainer """ - return DynamicTable('electrodes', 'a placeholder table') + return DynamicTable(name='electrodes', description='a placeholder table') def addContainer(self, nwbfile): """ Add electrodes and related objects to the given NWBFile """ @@ -546,7 +546,7 @@ def setUpContainer(self): """ Return placeholder table for electrodes. Tested electrodes are added directly to the NWBFile in addContainer """ - return DynamicTable('electrodes', 'a placeholder table') + return DynamicTable(name='electrodes', description='a placeholder table') def addContainer(self, nwbfile): """ Add electrode table region and related objects to the given NWBFile """ diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py index 844f6989d..3f4e71454 100644 --- a/tests/unit/test_base.py +++ b/tests/unit/test_base.py @@ -1,274 +1,448 @@ -import warnings - import numpy as np -from pynwb.base import (ProcessingModule, TimeSeries, Images, Image, TimeSeriesReferenceVectorData, - TimeSeriesReference, ImageReferences) +from pynwb.base import ( + ProcessingModule, + TimeSeries, + Images, + Image, + TimeSeriesReferenceVectorData, + TimeSeriesReference, + ImageReferences +) from pynwb.testing import TestCase from hdmf.data_utils import DataChunkIterator from hdmf.backends.hdf5 import H5DataIO class TestProcessingModule(TestCase): - def setUp(self): - self.pm = ProcessingModule('test_procmod', 'a fake processing module') + self.pm = ProcessingModule( + name="test_procmod", description="a test processing module" + ) + + def _create_time_series(self): + ts = TimeSeries( + name="test_ts", + data=[0, 1, 2, 3, 4, 5], + unit="grams", + timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], + ) + return ts def test_init(self): - self.assertEqual(self.pm.name, 'test_procmod') - self.assertEqual(self.pm.description, 'a fake processing module') + """Test creating a ProcessingModule.""" + self.assertEqual(self.pm.name, "test_procmod") + self.assertEqual(self.pm.description, "a test processing module") def test_add_data_interface(self): - ts = TimeSeries('test_ts', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) + """Test adding a data interface to a ProcessingModule using add(...) and retrieving it.""" + ts = self._create_time_series() self.pm.add(ts) self.assertIn(ts.name, self.pm.containers) self.assertIs(ts, self.pm.containers[ts.name]) def test_deprecated_add_data_interface(self): - ts = TimeSeries('test_ts', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) - with self.assertWarnsWith(PendingDeprecationWarning, 'add_data_interface will be replaced by add'): + ts = self._create_time_series() + with self.assertWarnsWith( + PendingDeprecationWarning, "add_data_interface will be replaced by add" + ): self.pm.add_data_interface(ts) self.assertIn(ts.name, self.pm.containers) self.assertIs(ts, self.pm.containers[ts.name]) def test_deprecated_add_container(self): - ts = TimeSeries('test_ts', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) - with self.assertWarnsWith(PendingDeprecationWarning, 'add_container will be replaced by add'): + ts = self._create_time_series() + with self.assertWarnsWith( + PendingDeprecationWarning, "add_container will be replaced by add" + ): self.pm.add_container(ts) self.assertIn(ts.name, self.pm.containers) self.assertIs(ts, self.pm.containers[ts.name]) def test_get_data_interface(self): - ts = TimeSeries('test_ts', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) + """Test adding a data interface to a ProcessingModule and retrieving it using get(...).""" + ts = self._create_time_series() self.pm.add(ts) - tmp = self.pm.get('test_ts') + tmp = self.pm.get("test_ts") self.assertIs(tmp, ts) - self.assertIs(self.pm['test_ts'], self.pm.get('test_ts')) + self.assertIs(self.pm["test_ts"], self.pm.get("test_ts")) def test_deprecated_get_data_interface(self): - ts = TimeSeries('test_ts', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) + ts = self._create_time_series() self.pm.add(ts) - with self.assertWarnsWith(PendingDeprecationWarning, 'get_data_interface will be replaced by get'): - tmp = self.pm.get_data_interface('test_ts') + with self.assertWarnsWith( + PendingDeprecationWarning, "get_data_interface will be replaced by get" + ): + tmp = self.pm.get_data_interface("test_ts") self.assertIs(tmp, ts) def test_deprecated_get_container(self): - ts = TimeSeries('test_ts', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) + ts = self._create_time_series() self.pm.add(ts) - with self.assertWarnsWith(PendingDeprecationWarning, 'get_container will be replaced by get'): - tmp = self.pm.get_container('test_ts') + with self.assertWarnsWith( + PendingDeprecationWarning, "get_container will be replaced by get" + ): + tmp = self.pm.get_container("test_ts") self.assertIs(tmp, ts) def test_getitem(self): - ts = TimeSeries('test_ts', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) + """Test adding a data interface to a ProcessingModule and retrieving it using __getitem__(...).""" + ts = self._create_time_series() self.pm.add(ts) - tmp = self.pm['test_ts'] + tmp = self.pm["test_ts"] self.assertIs(tmp, ts) class TestTimeSeries(TestCase): - def test_init_no_parent(self): - ts = TimeSeries('test_ts', list(), 'unit', timestamps=list()) - self.assertEqual(ts.name, 'test_ts') + """Test creating an empty TimeSeries and that it has no parent.""" + ts = TimeSeries(name="test_ts", data=list(), unit="unit", timestamps=list()) + self.assertEqual(ts.name, "test_ts") self.assertIsNone(ts.parent) def test_init_datalink_set(self): - ts = TimeSeries('test_ts', list(), 'unit', timestamps=list()) + """Test creating a TimeSeries and that data_link is an empty set.""" + ts = TimeSeries(name="test_ts", data=list(), unit="unit", timestamps=list()) self.assertIsInstance(ts.data_link, set) self.assertEqual(len(ts.data_link), 0) def test_init_timestampslink_set(self): - ts = TimeSeries('test_ts', list(), 'unit', timestamps=list()) + """Test creating a TimeSeries and that timestamps_link is an empty set.""" + ts = TimeSeries(name="test_ts", data=list(), unit="unit", timestamps=list()) self.assertIsInstance(ts.timestamp_link, set) self.assertEqual(len(ts.timestamp_link), 0) - def test_init_data(self): - dat = [0, 1, 2, 3, 4] - ts = TimeSeries('test_ts', dat, 'volts', timestamps=[0.1, 0.2, 0.3, 0.4]) - self.assertIs(ts.data, dat) + def test_init_data_timestamps(self): + data = [0, 1, 2, 3, 4] + timestamps = [0.0, 0.1, 0.2, 0.3, 0.4] + ts = TimeSeries(name="test_ts", data=data, unit="volts", timestamps=timestamps) + self.assertIs(ts.data, data) + self.assertIs(ts.timestamps, timestamps) self.assertEqual(ts.conversion, 1.0) self.assertEqual(ts.offset, 0.0) self.assertEqual(ts.resolution, -1.0) - self.assertEqual(ts.unit, 'volts') - - def test_init_conversion(self): - dat = [0, 1, 2, 3, 4] + self.assertEqual(ts.unit, "volts") + self.assertEqual(ts.interval, 1) + self.assertEqual(ts.time_unit, "seconds") + self.assertEqual(ts.num_samples, 5) + self.assertIsNone(ts.continuity) + self.assertIsNone(ts.rate) + self.assertIsNone(ts.starting_time) + + def test_init_conversion_offset(self): + data = [0, 1, 2, 3, 4] + timestamps = [0.0, 0.1, 0.2, 0.3, 0.4] conversion = 2.1 - ts = TimeSeries('test_ts', dat, 'volts', timestamps=[0.1, 0.2, 0.3, 0.4], conversion=conversion) - self.assertIs(ts.data, dat) - self.assertEqual(ts.conversion, conversion) - - def test_init_offset(self): - dat = [0, 1, 2, 3, 4] offset = 1.2 - ts = TimeSeries('test_ts', dat, 'volts', timestamps=[0.1, 0.2, 0.3, 0.4], offset=offset) - self.assertIs(ts.data, dat) + ts = TimeSeries( + name="test_ts", + data=data, + unit="volts", + timestamps=timestamps, + conversion=conversion, + offset=offset, + ) + self.assertIs(ts.data, data) + self.assertEqual(ts.conversion, conversion) self.assertEqual(ts.offset, offset) - def test_init_timestamps(self): - dat = [0, 1, 2, 3, 4] - tstamps = [0.1, 0.2, 0.3, 0.4] - ts = TimeSeries('test_ts', dat, 'unit', timestamps=tstamps) - self.assertIs(ts.timestamps, tstamps) - self.assertEqual(ts.interval, 1) - self.assertEqual(ts.time_unit, "seconds") + def test_no_time(self): + with self.assertRaisesWith( + TypeError, "either 'timestamps' or 'rate' must be specified" + ): + TimeSeries(name="test_ts2", data=[10, 11, 12, 13, 14, 15], unit="grams") + + def test_no_starting_time(self): + """Test that if no starting_time is given, 0.0 is assumed.""" + ts1 = TimeSeries(name="test_ts1", data=[1, 2, 3], unit="unit", rate=0.1) + self.assertEqual(ts1.starting_time, 0.0) def test_init_rate(self): - ts = TimeSeries('test_ts', list(), 'unit', starting_time=0.0, rate=1.0) - self.assertEqual(ts.starting_time, 0.0) - self.assertEqual(ts.rate, 1.0) + ts = TimeSeries( + name="test_ts", + data=list(), + unit="volts", + starting_time=1.0, + rate=2.0, + ) + self.assertEqual(ts.starting_time, 1.0) + self.assertEqual(ts.starting_time_unit, "seconds") + self.assertEqual(ts.rate, 2.0) self.assertEqual(ts.time_unit, "seconds") + self.assertIsNone(ts.timestamps) def test_data_timeseries(self): - ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) - ts2 = TimeSeries('test_ts2', ts1, 'grams', timestamps=[1.0, 1.1, 1.2, - 1.3, 1.4, 1.5]) - self.assertEqual(ts2.data, [0, 1, 2, 3, 4, 5]) + """Test that setting a TimeSeries.data to another TimeSeries links the data correctly.""" + data = [0, 1, 2, 3] + timestamps1 = [0.0, 0.1, 0.2, 0.3] + timestamps2 = [1.0, 1.1, 1.2, 1.3] + ts1 = TimeSeries( + name="test_ts1", data=data, unit="grams", timestamps=timestamps1 + ) + ts2 = TimeSeries( + name="test_ts2", data=ts1, unit="grams", timestamps=timestamps2 + ) + self.assertEqual(ts2.data, data) self.assertEqual(ts1.num_samples, ts2.num_samples) + self.assertEqual(ts1.data_link, set([ts2])) def test_timestamps_timeseries(self): - ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) - ts2 = TimeSeries('test_ts2', [10, 11, 12, 13, 14, 15], - 'grams', timestamps=ts1) - self.assertEqual(ts2.timestamps, [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) + """Test that setting a TimeSeries.timestamps to another TimeSeries links the timestamps correctly.""" + data1 = [0, 1, 2, 3] + data2 = [10, 11, 12, 13] + timestamps = [0.0, 0.1, 0.2, 0.3] + ts1 = TimeSeries( + name="test_ts1", data=data1, unit="grams", timestamps=timestamps + ) + ts2 = TimeSeries(name="test_ts2", data=data2, unit="grams", timestamps=ts1) + self.assertEqual(ts2.timestamps, timestamps) + self.assertEqual(ts1.timestamp_link, set([ts2])) def test_good_continuity_timeseries(self): - ts1 = TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], - continuity='continuous') - self.assertEqual(ts1.continuity, 'continuous') + ts = TimeSeries( + name="test_ts1", + data=[0, 1, 2, 3, 4, 5], + unit="grams", + timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], + continuity="continuous", + ) + self.assertEqual(ts.continuity, "continuous") def test_bad_continuity_timeseries(self): - with self.assertRaises(ValueError): - TimeSeries('test_ts1', [0, 1, 2, 3, 4, 5], - 'grams', timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], - continuity='wrong') + msg = ( + "TimeSeries.__init__: forbidden value for 'continuity' (got 'wrong', " + "expected ['continuous', 'instantaneous', 'step'])" + ) + with self.assertRaisesWith(ValueError, msg): + TimeSeries( + name="test_ts1", + data=[0, 1, 2, 3, 4, 5], + unit="grams", + timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], + continuity="wrong", + ) + + def _create_time_series_with_data(self, data): + ts = TimeSeries(name="test_ts1", data=data, unit="grams", rate=0.1) + return ts def test_dataio_list_data(self): length = 100 data = list(range(length)) - ts1 = TimeSeries('test_ts1', H5DataIO(data), - 'grams', starting_time=0.0, rate=0.1) - self.assertEqual(ts1.num_samples, length) - assert data == list(ts1.data) + ts = self._create_time_series_with_data(data) + self.assertEqual(ts.num_samples, length) + assert data == list(ts.data) def test_dataio_dci_data(self): - def generator_factory(): return (i for i in range(100)) data = H5DataIO(DataChunkIterator(data=generator_factory())) - ts1 = TimeSeries('test_ts1', data, - 'grams', starting_time=0.0, rate=0.1) - with self.assertWarnsWith(UserWarning, 'The data attribute on this TimeSeries (named: test_ts1) has a ' - '__len__, but it cannot be read'): - self.assertIs(ts1.num_samples, None) + ts = self._create_time_series_with_data(data) + with self.assertWarnsWith( + UserWarning, + "The data attribute on this TimeSeries (named: test_ts1) has a " + "__len__, but it cannot be read", + ): + self.assertIsNone(ts.num_samples) for xi, yi in zip(data, generator_factory()): assert np.allclose(xi, yi) def test_dci_data(self): - def generator_factory(): return (i for i in range(100)) data = DataChunkIterator(data=generator_factory()) - ts1 = TimeSeries('test_ts1', data, - 'grams', starting_time=0.0, rate=0.1) - with self.assertWarnsWith(UserWarning, 'The data attribute on this TimeSeries (named: test_ts1) has no ' - '__len__'): - self.assertIs(ts1.num_samples, None) + ts = self._create_time_series_with_data(data) + with self.assertWarnsWith( + UserWarning, + "The data attribute on this TimeSeries (named: test_ts1) has no __len__", + ): + self.assertIsNone(ts.num_samples) for xi, yi in zip(data, generator_factory()): assert np.allclose(xi, yi) def test_dci_data_arr(self): - def generator_factory(): - return (np.array([i, i+1]) for i in range(100)) + return (np.array([i, i + 1]) for i in range(100)) data = DataChunkIterator(data=generator_factory()) - ts1 = TimeSeries('test_ts1', data, - 'grams', starting_time=0.0, rate=0.1) - # with self.assertWarnsRegex(UserWarning, r'.*name: \'test_ts1\'.*'): - with self.assertWarns(UserWarning): - self.assertIs(ts1.num_samples, None) + ts = self._create_time_series_with_data(data) + with self.assertWarnsWith( + UserWarning, + "The data attribute on this TimeSeries (named: test_ts1) has no __len__", + ): + self.assertIsNone(ts.num_samples) for xi, yi in zip(data, generator_factory()): assert np.allclose(xi, yi) - def test_no_time(self): - with self.assertRaisesWith(TypeError, "either 'timestamps' or 'rate' must be specified"): - TimeSeries('test_ts2', [10, 11, 12, 13, 14, 15], 'grams') + def test_dataio_list_timestamps(self): + length = 100 + data = list(range(length)) + ts = self._create_time_series_with_data(data) + self.assertEqual(ts.num_samples, length) + assert data == list(ts.data) - def test_no_starting_time(self): - # if no starting_time is given, 0.0 is assumed - ts1 = TimeSeries('test_ts1', data=[1, 2, 3], unit='unit', rate=0.1) - self.assertEqual(ts1.starting_time, 0.0) + def _create_time_series_with_timestamps(self, timestamps): + # data has no __len__ for these tests + def generator_factory(): + return (i for i in range(100)) + + ts = TimeSeries( + name="test_ts1", + data=DataChunkIterator(data=generator_factory()), + unit="grams", + timestamps=timestamps, + ) + return ts + + def test_dataio_dci_timestamps(self): + def generator_factory(): + return (i for i in range(100)) + + timestamps = H5DataIO(DataChunkIterator(data=generator_factory())) + ts = self._create_time_series_with_timestamps(timestamps) + with self.assertWarns(UserWarning) as record: + self.assertIsNone(ts.num_samples) + assert len(record.warnings) == 2 + assert record.warnings[0].message.args[0] == ( + "The data attribute on this TimeSeries (named: test_ts1) has no __len__" + ) + assert record.warnings[1].message.args[0] == ( + "The timestamps attribute on this TimeSeries (named: test_ts1) has a " + "__len__, but it cannot be read" + ) + for xi, yi in zip(timestamps, generator_factory()): + assert np.allclose(xi, yi) + + def test_dci_timestamps(self): + def generator_factory(): + return (i for i in range(100)) + + timestamps = DataChunkIterator(data=generator_factory()) + ts = self._create_time_series_with_timestamps(timestamps) + with self.assertWarns(UserWarning) as record: + self.assertIsNone(ts.num_samples) + assert len(record.warnings) == 2 + assert record.warnings[0].message.args[0] == ( + "The data attribute on this TimeSeries (named: test_ts1) has no __len__" + ) + assert record.warnings[1].message.args[0] == ( + "The timestamps attribute on this TimeSeries (named: test_ts1) has no __len__" + ) + for xi, yi in zip(timestamps, generator_factory()): + assert np.allclose(xi, yi) + + def test_dci_timestamps_arr(self): + def generator_factory(): + return np.array(np.arange(100)) + + timestamps = DataChunkIterator(data=generator_factory()) + ts = self._create_time_series_with_timestamps(timestamps) + with self.assertWarns(UserWarning) as record: + self.assertIsNone(ts.num_samples) + assert len(record.warnings) == 2 + assert record.warnings[0].message.args[0] == ( + "The data attribute on this TimeSeries (named: test_ts1) has no __len__" + ) + assert record.warnings[1].message.args[0] == ( + "The timestamps attribute on this TimeSeries (named: test_ts1) has no __len__" + ) + for xi, yi in zip(timestamps, generator_factory()): + assert np.allclose(xi, yi) def test_conflicting_time_args(self): - with self.assertRaises(ValueError): - TimeSeries('test_ts2', [10, 11, 12, 13, 14, 15], 'grams', rate=30., - timestamps=[.3, .4, .5, .6, .7, .8]) - with self.assertRaises(ValueError): - TimeSeries('test_ts2', [10, 11, 12, 13, 14, 15], 'grams', - starting_time=30., timestamps=[.3, .4, .5, .6, .7, .8]) + with self.assertRaisesWith( + ValueError, "Specifying rate and timestamps is not supported." + ): + TimeSeries( + name="test_ts2", + data=[10, 11, 12], + unit="grams", + rate=30.0, + timestamps=[0.3, 0.4, 0.5], + ) + with self.assertRaisesWith( + ValueError, "Specifying starting_time and timestamps is not supported." + ): + TimeSeries( + name="test_ts2", + data=[10, 11, 12], + unit="grams", + starting_time=30.0, + timestamps=[0.3, 0.4, 0.5], + ) def test_dimension_warning(self): - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") + msg = ( + "Length of data does not match length of timestamps. Your data may be " + "transposed. Time should be on the 0th dimension" + ) + with self.assertWarnsWith(UserWarning, msg): TimeSeries( - name='test_ts2', + name="test_ts2", data=[10, 11, 12], - unit='grams', - timestamps=[.3, .4, .5, .6, .7, .8], + unit="grams", + timestamps=[0.3, 0.4, 0.5, 0.6, 0.7, 0.8], ) - assert len(w) == 1 - assert ( - "Length of data does not match length of timestamps. Your data may be " - "transposed. Time should be on the 0th dimension" - ) in str(w[-1].message) class TestImage(TestCase): - - def test_image(self): - Image(name='test_image', data=np.ones((10, 10))) + def test_init(self): + im = Image(name="test_image", data=np.ones((10, 10))) + assert im.name == "test_image" + assert np.all(im.data == np.ones((10, 10))) class TestImages(TestCase): def test_images(self): - image1 = Image(name='test_image', data=np.ones((10, 10))) + image1 = Image(name='test_image1', data=np.ones((10, 10))) image2 = Image(name='test_image2', data=np.ones((10, 10))) image_references = ImageReferences(name='order_of_images', data=[image2, image1]) images = Images(name='images_name', images=[image1, image2], order_of_images=image_references) + assert images.name == "images_name" + assert images.images == dict(test_image1=image1, test_image2=image2) self.assertIs(images.order_of_images[0], image2) self.assertIs(images.order_of_images[1], image1) class TestTimeSeriesReferenceVectorData(TestCase): + def _create_time_series_with_rate(self): + ts = TimeSeries( + name="test", + description="test", + data=np.arange(10), + unit="unit", + starting_time=5.0, + rate=0.1, + ) + return ts + + def _create_time_series_with_timestamps(self): + ts = TimeSeries( + name="test", + description="test", + data=np.arange(10), + unit="unit", + timestamps=np.arange(10.0), + ) + return ts def test_init(self): temp = TimeSeriesReferenceVectorData() - self.assertEqual(temp.name, 'timeseries') - self.assertEqual(temp.description, - "Column storing references to a TimeSeries (rows). For each TimeSeries this " - "VectorData column stores the start_index and count to indicate the range in time " - "to be selected as well as an object reference to the TimeSeries.") + self.assertEqual(temp.name, "timeseries") + self.assertEqual( + temp.description, + "Column storing references to a TimeSeries (rows). For each TimeSeries this " + "VectorData column stores the start_index and count to indicate the range in time " + "to be selected as well as an object reference to the TimeSeries.", + ) self.assertListEqual(temp.data, []) - temp = TimeSeriesReferenceVectorData(name='test', description='test') - self.assertEqual(temp.name, 'test') - self.assertEqual(temp.description, 'test') + temp = TimeSeriesReferenceVectorData(name="test", description="test") + self.assertEqual(temp.name, "test") + self.assertEqual(temp.description, "test") def test_get_empty(self): """Get data from an empty TimeSeriesReferenceVectorData""" @@ -277,41 +451,73 @@ def test_get_empty(self): with self.assertRaises(IndexError): temp[0] - def test_get_length1_valid_data(self): + def test_append_get_length1_valid_data(self): """Get data from a TimeSeriesReferenceVectorData with one element and valid data""" temp = TimeSeriesReferenceVectorData() - value = TimeSeriesReference(0, 5, TimeSeries(name='test', description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) + value = TimeSeriesReference(0, 5, self._create_time_series_with_rate()) temp.append(value) self.assertTupleEqual(temp[0], value) - self.assertListEqual(temp[:], [TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*value), ]) + self.assertListEqual( + temp[:], + [ + TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*value), + ], + ) + + def test_add_row_get_length1_valid_data(self): + """Get data from a TimeSeriesReferenceVectorData with one element and valid data""" + temp = TimeSeriesReferenceVectorData() + value = TimeSeriesReference(0, 5, self._create_time_series_with_rate()) + temp.add_row(value) + self.assertTupleEqual(temp[0], value) + self.assertListEqual( + temp[:], + [ + TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*value), + ], + ) def test_get_length1_invalid_data(self): """Get data from a TimeSeriesReferenceVectorData with one element and invalid data""" temp = TimeSeriesReferenceVectorData() - value = TimeSeriesReference(-1, -1, TimeSeries(name='test', description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) + value = TimeSeriesReference(-1, -1, self._create_time_series_with_rate()) temp.append(value) # test index slicing re = temp[0] - self.assertTrue(isinstance(re, TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE)) - self.assertTupleEqual(re, TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE) + self.assertTrue( + isinstance(re, TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE) + ) + self.assertTupleEqual( + re, TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE + ) # test array slicing and list slicing - selection = [slice(None), [0, ]] + selection = [ + slice(None), + [ + 0, + ], + ] for s in selection: re = temp[s] self.assertTrue(isinstance(re, list)) self.assertTrue(len(re), 1) - self.assertTrue(isinstance(re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE)) - self.assertTupleEqual(re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE) + self.assertTrue( + isinstance( + re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE + ) + ) + self.assertTupleEqual( + re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE + ) def test_get_length5_valid_data(self): """Get data from a TimeSeriesReferenceVectorData with 5 elements""" temp = TimeSeriesReferenceVectorData() num_values = 5 - values = [TimeSeriesReference(0, 5, TimeSeries(name='test'+str(i), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - for i in range(num_values)] + values = [ + TimeSeriesReference(0, 5, self._create_time_series_with_rate()) + for i in range(num_values) + ] for v in values: temp.append(v) # Test single element selection @@ -320,26 +526,37 @@ def test_get_length5_valid_data(self): re = temp[i] self.assertTupleEqual(re, values[i]) # test slicing - re = temp[i:i+1] - self.assertTupleEqual(re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[i])) + re = temp[i : i + 1] + self.assertTupleEqual( + re[0], + TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[i]), + ) # Test multi element selection re = temp[0:2] - self.assertTupleEqual(re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[0])) - self.assertTupleEqual(re[1], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[1])) + self.assertTupleEqual( + re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[0]) + ) + self.assertTupleEqual( + re[1], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[1]) + ) def test_get_length5_with_invalid_data(self): """Get data from a TimeSeriesReferenceVectorData with 5 elements""" temp = TimeSeriesReferenceVectorData() num_values = 5 - values = [TimeSeriesReference(0, 5, TimeSeries(name='test'+str(i+1), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - for i in range(num_values-2)] - values = ([TimeSeriesReference(-1, -1, TimeSeries(name='test'+str(0), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, - rate=0.1)), ] - + values - + [TimeSeriesReference(-1, -1, TimeSeries(name='test'+str(5), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)), ]) + values = [ + TimeSeriesReference(0, 5, self._create_time_series_with_rate()) + for i in range(num_values - 2) + ] + values = ( + [ + TimeSeriesReference(-1, -1, self._create_time_series_with_rate()), + ] + + values + + [ + TimeSeriesReference(-1, -1, self._create_time_series_with_rate()), + ] + ) for v in values: temp.append(v) # Test single element selection @@ -347,21 +564,42 @@ def test_get_length5_with_invalid_data(self): # test index slicing re = temp[i] if i in [0, 4]: - self.assertTrue(isinstance(re, TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE)) - self.assertTupleEqual(re, TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE) + self.assertTrue( + isinstance( + re, TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE + ) + ) + self.assertTupleEqual( + re, TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE + ) else: self.assertTupleEqual(re, values[i]) # test slicing - re = temp[i:i+1] + re = temp[i : i + 1] if i in [0, 4]: - self.assertTrue(isinstance(re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE)) - self.assertTupleEqual(re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE) + self.assertTrue( + isinstance( + re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE + ) + ) + self.assertTupleEqual( + re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE + ) else: - self.assertTupleEqual(re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[i])) + self.assertTupleEqual( + re[0], + TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE( + *values[i] + ), + ) # Test multi element selection re = temp[0:2] - self.assertTupleEqual(re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE) - self.assertTupleEqual(re[1], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[1])) + self.assertTupleEqual( + re[0], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_NONE_TYPE + ) + self.assertTupleEqual( + re[1], TimeSeriesReferenceVectorData.TIME_SERIES_REFERENCE_TUPLE(*values[1]) + ) def test_add_row(self): v = TimeSeriesReferenceVectorData(name='a', description='a') @@ -385,9 +623,12 @@ def test_add_row_with_bad_tuple(self): v.add_row(val) def test_add_row_restricted_type(self): - v = TimeSeriesReferenceVectorData(name='a', description='a') - with self.assertRaisesWith(TypeError, "TimeSeriesReferenceVectorData.add_row: incorrect type for " - "'val' (got 'int', expected 'TimeSeriesReference or tuple')"): + v = TimeSeriesReferenceVectorData(name="a", description="a") + with self.assertRaisesWith( + TypeError, + "TimeSeriesReferenceVectorData.add_row: incorrect type for " + "'val' (got 'int', expected 'TimeSeriesReference or tuple')", + ): v.add_row(1) def test_append(self): @@ -412,111 +653,149 @@ def test_append_with_bad_tuple(self): v.append(val) def test_append_restricted_type(self): - v = TimeSeriesReferenceVectorData(name='a', description='a') - with self.assertRaisesWith(TypeError, "TimeSeriesReferenceVectorData.append: incorrect type for " - "'arg' (got 'float', expected 'TimeSeriesReference or tuple')"): + v = TimeSeriesReferenceVectorData(name="a", description="a") + with self.assertRaisesWith( + TypeError, + "TimeSeriesReferenceVectorData.append: incorrect type for " + "'arg' (got 'float', expected 'TimeSeriesReference or tuple')", + ): v.append(2.0) class TestTimeSeriesReference(TestCase): + def _create_time_series_with_rate(self): + ts = TimeSeries( + name="test", + description="test", + data=np.arange(10), + unit="unit", + starting_time=5.0, + rate=0.1, + ) + return ts + + def _create_time_series_with_timestamps(self): + ts = TimeSeries( + name="test", + description="test", + data=np.arange(10), + unit="unit", + timestamps=np.arange(10.0), + ) + return ts def test_check_types(self): # invalid selection but with correct types - tsr = TimeSeriesReference(-1, -1, TimeSeries(name='test'+str(0), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) + tsr = TimeSeriesReference(-1, -1, self._create_time_series_with_rate()) self.assertTrue(tsr.check_types()) # invalid types, use float instead of int for both idx_start and count - tsr = TimeSeriesReference(1.0, 5.0, TimeSeries(name='test'+str(0), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(TypeError, "idx_start must be an integer not "): + tsr = TimeSeriesReference(1.0, 5.0, self._create_time_series_with_rate()) + with self.assertRaisesWith( + TypeError, "idx_start must be an integer not " + ): tsr.check_types() # invalid types, use float instead of int for idx_start only - tsr = TimeSeriesReference(1.0, 5, TimeSeries(name='test'+str(0), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(TypeError, "idx_start must be an integer not "): + tsr = TimeSeriesReference(1.0, 5, self._create_time_series_with_rate()) + with self.assertRaisesWith( + TypeError, "idx_start must be an integer not " + ): tsr.check_types() # invalid types, use float instead of int for count only - tsr = TimeSeriesReference(1, 5.0, TimeSeries(name='test'+str(0), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(TypeError, "count must be an integer "): + tsr = TimeSeriesReference(1, 5.0, self._create_time_series_with_rate()) + with self.assertRaisesWith( + TypeError, "count must be an integer " + ): tsr.check_types() # invalid type for TimeSeries but valid idx_start and count tsr = TimeSeriesReference(1, 5, None) - with self.assertRaisesWith(TypeError, "timeseries must be of type TimeSeries. "): + with self.assertRaisesWith( + TypeError, "timeseries must be of type TimeSeries. " + ): tsr.check_types() def test_is_invalid(self): - tsr = TimeSeriesReference(-1, -1, TimeSeries(name='test'+str(0), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) + tsr = TimeSeriesReference(-1, -1, self._create_time_series_with_rate()) self.assertFalse(tsr.isvalid()) def test_is_valid(self): - tsr = TimeSeriesReference(0, 10, TimeSeries(name='test'+str(0), description='test', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) + tsr = TimeSeriesReference(0, 10, self._create_time_series_with_rate()) self.assertTrue(tsr.isvalid()) def test_is_valid_bad_index(self): # Error: negative start_index but positive count - tsr = TimeSeriesReference(-1, 10, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(IndexError, "'idx_start' -1 out of range for timeseries 'test0'"): + tsr = TimeSeriesReference(-1, 10, self._create_time_series_with_rate()) + with self.assertRaisesWith( + IndexError, "'idx_start' -1 out of range for timeseries 'test'" + ): tsr.isvalid() # Error: start_index too large - tsr = TimeSeriesReference(10, 0, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(IndexError, "'idx_start' 10 out of range for timeseries 'test0'"): + tsr = TimeSeriesReference(10, 0, self._create_time_series_with_rate()) + with self.assertRaisesWith( + IndexError, "'idx_start' 10 out of range for timeseries 'test'" + ): tsr.isvalid() # Error: positive start_index but negative count - tsr = TimeSeriesReference(0, -3, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(IndexError, "'count' -3 invalid. 'count' must be positive"): + tsr = TimeSeriesReference(0, -3, self._create_time_series_with_rate()) + with self.assertRaisesWith( + IndexError, "'count' -3 invalid. 'count' must be positive" + ): tsr.isvalid() # Error: start_index + count too large - tsr = TimeSeriesReference(3, 10, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(IndexError, "'idx_start + count' out of range for timeseries 'test0'"): + tsr = TimeSeriesReference(3, 10, self._create_time_series_with_rate()) + with self.assertRaisesWith( + IndexError, "'idx_start + count' out of range for timeseries 'test'" + ): tsr.isvalid() + def test_is_valid_no_num_samples(self): + def generator_factory(): + return (i for i in range(100)) + + data = DataChunkIterator(data=generator_factory()) + ts = TimeSeries(name="test_ts1", data=data, unit="grams", rate=0.1) + tsr = TimeSeriesReference(0, 10, ts) + with self.assertWarnsWith( + UserWarning, + "The data attribute on this TimeSeries (named: test_ts1) has no __len__", + ): + self.assertTrue(tsr.isvalid()) + def test_timestamps_property(self): # Timestamps from starting_time and rate - tsr = TimeSeriesReference(5, 4, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) + tsr = TimeSeriesReference(5, 4, self._create_time_series_with_rate()) np.testing.assert_array_equal(tsr.timestamps, np.array([5.5, 5.6, 5.7, 5.8])) # Timestamps from timestamps directly - tsr = TimeSeriesReference(5, 4, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', - timestamps=np.arange(10).astype(float))) - np.testing.assert_array_equal(tsr.timestamps, np.array([5., 6., 7., 8.])) + tsr = TimeSeriesReference(5, 4, self._create_time_series_with_timestamps()) + np.testing.assert_array_equal(tsr.timestamps, np.array([5.0, 6.0, 7.0, 8.0])) def test_timestamps_property_invalid_reference(self): # Timestamps from starting_time and rate - tsr = TimeSeriesReference(-1, -1, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) + tsr = TimeSeriesReference(-1, -1, self._create_time_series_with_rate()) self.assertIsNone(tsr.timestamps) def test_timestamps_property_bad_reference(self): - tsr = TimeSeriesReference(0, 12, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', - timestamps=np.arange(10).astype(float))) - with self.assertRaisesWith(IndexError, "'idx_start + count' out of range for timeseries 'test0'"): + tsr = TimeSeriesReference(0, 12, self._create_time_series_with_timestamps()) + with self.assertRaisesWith( + IndexError, "'idx_start + count' out of range for timeseries 'test'" + ): tsr.timestamps - tsr = TimeSeriesReference(0, 12, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(IndexError, "'idx_start + count' out of range for timeseries 'test0'"): + tsr = TimeSeriesReference(0, 12, self._create_time_series_with_rate()) + with self.assertRaisesWith( + IndexError, "'idx_start + count' out of range for timeseries 'test'" + ): tsr.timestamps def test_data_property(self): - tsr = TimeSeriesReference(5, 4, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - np.testing.assert_array_equal(tsr.data, np.array([5., 6., 7., 8.])) + tsr = TimeSeriesReference(5, 4, self._create_time_series_with_rate()) + np.testing.assert_array_equal(tsr.data, np.array([5.0, 6.0, 7.0, 8.0])) def test_data_property_invalid_reference(self): - tsr = TimeSeriesReference(-1, -1, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) + tsr = TimeSeriesReference(-1, -1, self._create_time_series_with_rate()) self.assertIsNone(tsr.data) def test_data_property_bad_reference(self): - tsr = TimeSeriesReference(0, 12, TimeSeries(name='test0', description='test0', - data=np.arange(10), unit='unit', starting_time=5.0, rate=0.1)) - with self.assertRaisesWith(IndexError, "'idx_start + count' out of range for timeseries 'test0'"): + tsr = TimeSeriesReference(0, 12, self._create_time_series_with_rate()) + with self.assertRaisesWith( + IndexError, "'idx_start + count' out of range for timeseries 'test'" + ): tsr.data diff --git a/tests/unit/test_behavior.py b/tests/unit/test_behavior.py index 664365070..0b7173da0 100644 --- a/tests/unit/test_behavior.py +++ b/tests/unit/test_behavior.py @@ -2,34 +2,49 @@ from pynwb import TimeSeries from pynwb.misc import IntervalSeries -from pynwb.behavior import SpatialSeries, BehavioralEpochs, BehavioralEvents, BehavioralTimeSeries, PupilTracking, \ - EyeTracking, CompassDirection, Position +from pynwb.behavior import (SpatialSeries, BehavioralEpochs, BehavioralEvents, BehavioralTimeSeries, PupilTracking, + EyeTracking, CompassDirection, Position) from pynwb.testing import TestCase class SpatialSeriesConstructor(TestCase): def test_init(self): - sS = SpatialSeries('test_sS', np.ones((2, 2)), 'reference_frame', timestamps=[1., 2., 3.]) + sS = SpatialSeries( + name='test_sS', + data=np.ones((3, 2)), + reference_frame='reference_frame', + timestamps=[1., 2., 3.] + ) self.assertEqual(sS.name, 'test_sS') self.assertEqual(sS.unit, 'meters') self.assertEqual(sS.reference_frame, 'reference_frame') def test_set_unit(self): - sS = SpatialSeries('test_sS', np.ones((2, 2)), 'reference_frame', 'degrees', - timestamps=[1., 2., 3.]) + sS = SpatialSeries( + name='test_sS', + data=np.ones((3, 2)), + reference_frame='reference_frame', + unit='degrees', + timestamps=[1., 2., 3.] + ) self.assertEqual(sS.unit, 'degrees') def test_gt_3_cols(self): msg = ("SpatialSeries 'test_sS' has data shape (5, 4) which is not compliant with NWB 2.5 and greater. " "The second dimension should have length <= 3 to represent at most x, y, z.") with self.assertWarnsWith(UserWarning, msg): - SpatialSeries("test_sS", np.ones((5, 4)), "reference_frame", "meters", rate=30.) + SpatialSeries( + name="test_sS", + data=np.ones((5, 4)), + reference_frame="reference_frame", + rate=30. + ) class BehavioralEpochsConstructor(TestCase): def test_init(self): - data = [0, 1, 0, 1] - iS = IntervalSeries('test_iS', data, timestamps=[1., 2., 3.]) + data = [0, 1, 0] + iS = IntervalSeries(name='test_iS', data=data, timestamps=[1., 2., 3.]) bE = BehavioralEpochs(iS) self.assertEqual(bE.interval_series['test_iS'], iS) @@ -37,7 +52,7 @@ def test_init(self): class BehavioralEventsConstructor(TestCase): def test_init(self): - ts = TimeSeries('test_ts', np.ones((2, 2)), 'unit', timestamps=[1., 2., 3.]) + ts = TimeSeries(name='test_ts', data=np.ones((3, 2)), unit='unit', timestamps=[1., 2., 3.]) bE = BehavioralEvents(ts) self.assertEqual(bE.time_series['test_ts'], ts) @@ -45,7 +60,7 @@ def test_init(self): class BehavioralTimeSeriesConstructor(TestCase): def test_init(self): - ts = TimeSeries('test_ts', np.ones((2, 2)), 'unit', timestamps=[1., 2., 3.]) + ts = TimeSeries(name='test_ts', data=np.ones((3, 2)), unit='unit', timestamps=[1., 2., 3.]) bts = BehavioralTimeSeries(ts) self.assertEqual(bts.time_series['test_ts'], ts) @@ -53,7 +68,7 @@ def test_init(self): class PupilTrackingConstructor(TestCase): def test_init(self): - ts = TimeSeries('test_ts', np.ones((2, 2)), 'unit', timestamps=[1., 2., 3.]) + ts = TimeSeries(name='test_ts', data=np.ones((3, 2)), unit='unit', timestamps=[1., 2., 3.]) pt = PupilTracking(ts) self.assertEqual(pt.time_series['test_ts'], ts) @@ -61,7 +76,12 @@ def test_init(self): class EyeTrackingConstructor(TestCase): def test_init(self): - sS = SpatialSeries('test_sS', np.ones((2, 2)), 'reference_frame', timestamps=[1., 2., 3.]) + sS = SpatialSeries( + name='test_sS', + data=np.ones((3, 2)), + reference_frame='reference_frame', + timestamps=[1., 2., 3.] + ) et = EyeTracking(sS) self.assertEqual(et.spatial_series['test_sS'], sS) @@ -69,7 +89,12 @@ def test_init(self): class CompassDirectionConstructor(TestCase): def test_init(self): - sS = SpatialSeries('test_sS', np.ones((2, 2)), 'reference_frame', timestamps=[1., 2., 3.]) + sS = SpatialSeries( + name='test_sS', + data=np.ones((3, 2)), + reference_frame='reference_frame', + timestamps=[1., 2., 3.] + ) cd = CompassDirection(sS) self.assertEqual(cd.spatial_series['test_sS'], sS) @@ -77,7 +102,12 @@ def test_init(self): class PositionConstructor(TestCase): def test_init(self): - sS = SpatialSeries('test_sS', np.ones((2, 2)), 'reference_frame', timestamps=[1., 2., 3.]) + sS = SpatialSeries( + name='test_sS', + data=np.ones((3, 2)), + reference_frame='reference_frame', + timestamps=[1., 2., 3.] + ) pc = Position(sS) self.assertEqual(pc.spatial_series.get('test_sS'), sS) diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index acdb317a4..d909a5be1 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -1,7 +1,7 @@ from datetime import datetime from dateutil.tz import tzlocal -from hdmf.utils import docval, call_docval_func +from hdmf.utils import docval from pynwb import NWBFile, TimeSeries, available_namespaces from pynwb.core import NWBContainer @@ -14,7 +14,7 @@ class MyTestClass(NWBContainer): @docval({'name': 'name', 'type': str, 'doc': 'The name of this container'}) def __init__(self, **kwargs): - call_docval_func(super(MyTestClass, self).__init__, kwargs) + super().__init__(**kwargs) self.prop1 = 'test1' @@ -43,26 +43,26 @@ def test_print_file(self): identifier='identifier', session_start_time=datetime.now(tzlocal())) ts1 = TimeSeries( name='name1', - data=[1., 2., 3.] * 1000, + data=[1000, 2000, 3000], unit='unit', - timestamps=[1, 2, 3] + timestamps=[1., 2., 3.] ) ts2 = TimeSeries( name='name2', - data=[1, 2, 3] * 1000, + data=[1000, 2000, 3000], unit='unit', - timestamps=[1, 2, 3] + timestamps=[1., 2., 3.] ) expected = """name1 pynwb.base.TimeSeries at 0x%d Fields: comments: no comments conversion: 1.0 - data: [1. 2. 3. ... 1. 2. 3.] + data: [1000 2000 3000] description: no description interval: 1 offset: 0.0 resolution: -1.0 - timestamps: [1 2 3] + timestamps: [1. 2. 3.] timestamps_unit: seconds unit: unit """ diff --git a/tests/unit/test_core_NWBContainer.py b/tests/unit/test_core_NWBContainer.py index b4fcd04f1..cbd5b0fe7 100644 --- a/tests/unit/test_core_NWBContainer.py +++ b/tests/unit/test_core_NWBContainer.py @@ -1,7 +1,7 @@ import unittest from pynwb.core import NWBContainer -from hdmf.utils import docval, call_docval_func +from hdmf.utils import docval class MyTestClass(NWBContainer): @@ -10,7 +10,7 @@ class MyTestClass(NWBContainer): @docval({'name': 'name', 'type': str, 'doc': 'The name of this container'}) def __init__(self, **kwargs): - call_docval_func(super(MyTestClass, self).__init__, kwargs) + super().__init__(**kwargs) self.prop1 = 'test1' diff --git a/tests/unit/test_ecephys.py b/tests/unit/test_ecephys.py index 1b934d9be..9ba88d89c 100644 --- a/tests/unit/test_ecephys.py +++ b/tests/unit/test_ecephys.py @@ -28,13 +28,22 @@ def make_electrode_table(): class ElectricalSeriesConstructor(TestCase): + def _create_table_and_region(self): + table = make_electrode_table() + region = DynamicTableRegion( + name='electrodes', + data=[0, 2], + description='the first and third electrodes', + table=table + ) + return table, region + def test_init(self): data = list(range(10)) ts = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] channel_conversion = [2., 6.3] filtering = 'Low-pass filter at 300 Hz' - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() eS = ElectricalSeries( name='test_eS', data=data, @@ -51,8 +60,7 @@ def test_init(self): self.assertEqual(eS.filtering, filtering) def test_link(self): - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() ts1 = ElectricalSeries('test_ts1', [0, 1, 2, 3, 4, 5], region, timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) ts2 = ElectricalSeries('test_ts2', ts1, region, timestamps=ts1) ts3 = ElectricalSeries('test_ts3', ts2, region, timestamps=ts2) @@ -62,15 +70,13 @@ def test_link(self): self.assertEqual(ts3.timestamps, [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) def test_invalid_data_shape(self): - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() with self.assertRaisesWith(ValueError, ("ElectricalSeries.__init__: incorrect shape for 'data' (got '(2, 2, 2, " "2)', expected '((None,), (None, None), (None, None, None))')")): ElectricalSeries('test_ts1', np.ones((2, 2, 2, 2)), region, timestamps=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]) def test_dimensions_warning(self): - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") @@ -103,23 +109,31 @@ def test_dimensions_warning(self): class SpikeEventSeriesConstructor(TestCase): - def test_init(self): + def _create_table_and_region(self): table = make_electrode_table() - region = DynamicTableRegion('electrodes', [1, 3], 'the second and fourth electrodes', table) - data = ((1, 1, 1), (2, 2, 2)) - timestamps = np.arange(2) - sES = SpikeEventSeries('test_sES', data, timestamps, region) + region = DynamicTableRegion( + name='electrodes', + data=[1, 3], + description='the second and fourth electrodes', + table=table + ) + return table, region + + def test_init(self): + table, region = self._create_table_and_region() + data = ((1, 1), (2, 2), (3, 3)) + timestamps = np.arange(3) + sES = SpikeEventSeries(name='test_sES', data=data, timestamps=timestamps, electrodes=region) self.assertEqual(sES.name, 'test_sES') # self.assertListEqual(sES.data, data) np.testing.assert_array_equal(sES.data, data) np.testing.assert_array_equal(sES.timestamps, timestamps) def test_no_rate(self): - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [1, 3], 'the second and fourth electrodes', table) + table, region = self._create_table_and_region() data = ((1, 1, 1), (2, 2, 2)) with self.assertRaises(TypeError): - SpikeEventSeries('test_sES', data, region, rate=1.) + SpikeEventSeries(name='test_sES', data=data, electrodes=region, rate=1.) class ElectrodeGroupConstructor(TestCase): @@ -150,11 +164,20 @@ def test_init_position_bad(self): class EventDetectionConstructor(TestCase): + def _create_table_and_region(self): + table = make_electrode_table() + region = DynamicTableRegion( + name='electrodes', + data=[0, 2], + description='the first and third electrodes', + table=table + ) + return table, region + def test_init(self): data = list(range(10)) ts = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() eS = ElectricalSeries('test_eS', data, region, timestamps=ts) eD = EventDetection('detection_method', eS, (1, 2, 3), (0.1, 0.2, 0.3)) self.assertEqual(eD.detection_method, 'detection_method') @@ -166,9 +189,18 @@ def test_init(self): class EventWaveformConstructor(TestCase): - def test_init(self): + def _create_table_and_region(self): table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + region = DynamicTableRegion( + name='electrodes', + data=[0, 2], + description='the first and third electrodes', + table=table + ) + return table, region + + def test_init(self): + table, region = self._create_table_and_region() sES = SpikeEventSeries('test_sES', list(range(10)), list(range(10)), region) ew = EventWaveform(sES) @@ -184,7 +216,7 @@ def test_init(self): peak_over_rms = [5.3, 6.3] with self.assertWarnsWith(DeprecationWarning, 'use pynwb.misc.Units or NWBFile.units instead'): - cc = Clustering('description', num, peak_over_rms, times) + cc = Clustering(description='description', num=num, peak_over_rms=peak_over_rms, times=times) self.assertEqual(cc.description, 'description') self.assertEqual(cc.num, num) self.assertEqual(cc.peak_over_rms, peak_over_rms) @@ -198,13 +230,18 @@ def test_init(self): num = [3, 4] peak_over_rms = [5.3, 6.3] with self.assertWarnsWith(DeprecationWarning, 'use pynwb.misc.Units or NWBFile.units instead'): - cc = Clustering('description', num, peak_over_rms, times) + cc = Clustering(description='description', num=num, peak_over_rms=peak_over_rms, times=times) means = [[7.3, 7.3]] stdevs = [[8.3, 8.3]] with self.assertWarnsWith(DeprecationWarning, 'use pynwb.misc.Units or NWBFile.units instead'): - cw = ClusterWaveforms(cc, 'filtering', means, stdevs) + cw = ClusterWaveforms( + clustering_interface=cc, + waveform_filtering='filtering', + waveform_mean=means, + waveform_sd=stdevs + ) self.assertEqual(cw.clustering_interface, cc) self.assertEqual(cw.waveform_filtering, 'filtering') self.assertEqual(cw.waveform_mean, means) @@ -213,10 +250,19 @@ def test_init(self): class LFPTest(TestCase): + def _create_table_and_region(self): + table = make_electrode_table() + region = DynamicTableRegion( + name='electrodes', + data=[0, 2], + description='the first and third electrodes', + table=table + ) + return table, region + def test_add_electrical_series(self): lfp = LFP() - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() eS = ElectricalSeries('test_eS', [0, 1, 2, 3], region, timestamps=[0.1, 0.2, 0.3, 0.4]) lfp.add_electrical_series(eS) self.assertEqual(lfp.electrical_series.get('test_eS'), eS) @@ -224,9 +270,18 @@ def test_add_electrical_series(self): class FilteredEphysTest(TestCase): - def test_init(self): + def _create_table_and_region(self): table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + region = DynamicTableRegion( + name='electrodes', + data=[0, 2], + description='the first and third electrodes', + table=table + ) + return table, region + + def test_init(self): + table, region = self._create_table_and_region() eS = ElectricalSeries('test_eS', [0, 1, 2, 3], region, timestamps=[0.1, 0.2, 0.3, 0.4]) fe = FilteredEphys(eS) self.assertEqual(fe.electrical_series.get('test_eS'), eS) @@ -234,8 +289,7 @@ def test_init(self): def test_add_electrical_series(self): fe = FilteredEphys() - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() eS = ElectricalSeries('test_eS', [0, 1, 2, 3], region, timestamps=[0.1, 0.2, 0.3, 0.4]) fe.add_electrical_series(eS) self.assertEqual(fe.electrical_series.get('test_eS'), eS) @@ -244,10 +298,19 @@ def test_add_electrical_series(self): class FeatureExtractionConstructor(TestCase): + def _create_table_and_region(self): + table = make_electrode_table() + region = DynamicTableRegion( + name='electrodes', + data=[0, 2], + description='the first and third electrodes', + table=table + ) + return table, region + def test_init(self): event_times = [1.9, 3.5] - table = make_electrode_table() - region = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() description = ['desc1', 'desc2', 'desc3'] features = [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]] fe = FeatureExtraction(region, description, event_times, features) @@ -257,32 +320,29 @@ def test_init(self): def test_invalid_init_mismatched_event_times(self): event_times = [] # Need 1 event time but give 0 - table = make_electrode_table() - electrodes = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() description = ['desc1', 'desc2', 'desc3'] features = [[[0, 1, 2], [3, 4, 5]]] - self.assertRaises(ValueError, FeatureExtraction, electrodes, description, event_times, features) + self.assertRaises(ValueError, FeatureExtraction, region, description, event_times, features) def test_invalid_init_mismatched_electrodes(self): event_times = [1] table = make_electrode_table() - electrodes = DynamicTableRegion('electrodes', [0], 'the first electrodes', table) + region = DynamicTableRegion(name='electrodes', data=[0], description='the first electrode', table=table) description = ['desc1', 'desc2', 'desc3'] features = [[[0, 1, 2], [3, 4, 5]]] - self.assertRaises(ValueError, FeatureExtraction, electrodes, description, event_times, features) + self.assertRaises(ValueError, FeatureExtraction, region, description, event_times, features) def test_invalid_init_mismatched_description(self): event_times = [1] - table = make_electrode_table() - electrodes = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() description = ['desc1', 'desc2', 'desc3', 'desc4'] # Need 3 descriptions but give 4 features = [[[0, 1, 2], [3, 4, 5]]] - self.assertRaises(ValueError, FeatureExtraction, electrodes, description, event_times, features) + self.assertRaises(ValueError, FeatureExtraction, region, description, event_times, features) def test_invalid_init_mismatched_description2(self): event_times = [1] - table = make_electrode_table() - electrodes = DynamicTableRegion('electrodes', [0, 2], 'the first and third electrodes', table) + table, region = self._create_table_and_region() description = ['desc1', 'desc2', 'desc3'] features = [[0, 1, 2], [3, 4, 5]] # Need 3D feature array but give only 2D array - self.assertRaises(ValueError, FeatureExtraction, electrodes, description, event_times, features) + self.assertRaises(ValueError, FeatureExtraction, region, description, event_times, features) diff --git a/tests/unit/test_extension.py b/tests/unit/test_extension.py index abfe96bf5..7664bbf22 100644 --- a/tests/unit/test_extension.py +++ b/tests/unit/test_extension.py @@ -105,7 +105,7 @@ class MyTestMetaData(LabMetaData): {'name': 'test_attr', 'type': float, 'doc': 'test attribute'}) def __init__(self, **kwargs): test_attr = popargs('test_attr', kwargs) - super(MyTestMetaData, self).__init__(**kwargs) + super().__init__(**kwargs) self.test_attr = test_attr nwbfile = NWBFile("a file with header data", "NB123A", datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())) diff --git a/tests/unit/test_file.py b/tests/unit/test_file.py index c68ff5589..0716924ac 100644 --- a/tests/unit/test_file.py +++ b/tests/unit/test_file.py @@ -19,9 +19,9 @@ def setUp(self): datetime(2017, 5, 2, 13, 0, 0, 1, tzinfo=tzutc()), datetime(2017, 5, 2, 14, tzinfo=tzutc())] self.path = 'nwbfile_test.h5' - self.nwbfile = NWBFile('a test session description for a test NWBFile', - 'FILE123', - self.start, + self.nwbfile = NWBFile(session_description='a test session description for a test NWBFile', + identifier='FILE123', + session_start_time=self.start, file_create_date=self.create, timestamps_reference_time=self.ref_time, experimenter='A test experimenter', diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index 8af641498..2497a8887 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -18,7 +18,7 @@ def test_init(self): external_file=['external_file'], starting_frame=[1, 2, 3], format='tiff', - timestamps=list(), + timestamps=[1., 2., 3.], device=dev, ) self.assertEqual(iS.name, 'test_iS') @@ -35,7 +35,7 @@ def test_no_data_no_file(self): ImageSeries( name='test_iS', unit='unit', - timestamps=list() + rate=2., ) def test_external_file_no_frame(self): @@ -43,7 +43,7 @@ def test_external_file_no_frame(self): name='test_iS', unit='unit', external_file=['external_file'], - timestamps=list() + timestamps=[1., 2., 3.] ) self.assertListEqual(iS.starting_frame, [0]) @@ -52,7 +52,7 @@ def test_data_no_frame(self): name='test_iS', unit='unit', data=np.ones((3, 3, 3)), - timestamps=list() + timestamps=[1., 2., 3.] ) self.assertIsNone(iS.starting_frame) diff --git a/tests/unit/test_misc.py b/tests/unit/test_misc.py index c8170d9fe..877849de5 100644 --- a/tests/unit/test_misc.py +++ b/tests/unit/test_misc.py @@ -11,7 +11,7 @@ class AnnotationSeriesConstructor(TestCase): def test_init(self): - aS = AnnotationSeries('test_aS', data=[1, 2, 3], timestamps=list()) + aS = AnnotationSeries('test_aS', data=[1, 2, 3], timestamps=[1., 2., 3.]) self.assertEqual(aS.name, 'test_aS') aS.add_annotation(2.0, 'comment') @@ -30,7 +30,7 @@ class DecompositionSeriesConstructor(TestCase): def test_init(self): timeseries = TimeSeries(name='dummy timeseries', description='desc', data=np.ones((3, 3)), unit='Volts', - timestamps=np.ones((3,))) + timestamps=[1., 2., 3.]) bands = DynamicTable(name='bands', description='band info for LFPSpectralAnalysis', columns=[ VectorData(name='band_name', description='name of bands', data=['alpha', 'beta', 'gamma']), VectorData(name='band_limits', description='low and high cutoffs in Hz', data=np.ones((3, 2))) @@ -38,7 +38,7 @@ def test_init(self): spec_anal = DecompositionSeries(name='LFPSpectralAnalysis', description='my description', data=np.ones((3, 3, 3)), - timestamps=np.ones((3,)), + timestamps=[1., 2., 3.], source_timeseries=timeseries, metric='amplitude', bands=bands) @@ -46,7 +46,7 @@ def test_init(self): self.assertEqual(spec_anal.name, 'LFPSpectralAnalysis') self.assertEqual(spec_anal.description, 'my description') np.testing.assert_equal(spec_anal.data, np.ones((3, 3, 3))) - np.testing.assert_equal(spec_anal.timestamps, np.ones((3,))) + np.testing.assert_equal(spec_anal.timestamps, [1., 2., 3.]) self.assertEqual(spec_anal.bands['band_name'].data, ['alpha', 'beta', 'gamma']) np.testing.assert_equal(spec_anal.bands['band_limits'].data, np.ones((3, 2))) self.assertEqual(spec_anal.source_timeseries, timeseries) @@ -59,7 +59,7 @@ def test_init_delayed_bands(self): spec_anal = DecompositionSeries(name='LFPSpectralAnalysis', description='my description', data=np.ones((3, 3, 3)), - timestamps=np.ones((3,)), + timestamps=[1., 2., 3.], source_timeseries=timeseries, metric='amplitude') for band_name in ['alpha', 'beta', 'gamma']: @@ -68,7 +68,7 @@ def test_init_delayed_bands(self): self.assertEqual(spec_anal.name, 'LFPSpectralAnalysis') self.assertEqual(spec_anal.description, 'my description') np.testing.assert_equal(spec_anal.data, np.ones((3, 3, 3))) - np.testing.assert_equal(spec_anal.timestamps, np.ones((3,))) + np.testing.assert_equal(spec_anal.timestamps, [1., 2., 3.]) self.assertEqual(spec_anal.bands['band_name'].data, ['alpha', 'beta', 'gamma']) np.testing.assert_equal(spec_anal.bands['band_limits'].data, np.ones((3, 2))) self.assertEqual(spec_anal.source_timeseries, timeseries) diff --git a/tests/unit/test_ophys.py b/tests/unit/test_ophys.py index 2cca3fab7..a74771023 100644 --- a/tests/unit/test_ophys.py +++ b/tests/unit/test_ophys.py @@ -252,6 +252,7 @@ def test_init(self): ) is2 = ImageSeries( name='is2', + data=np.ones((2, 2, 2)), unit='unit', external_file=['external_file'], starting_frame=[1, 2, 3], @@ -382,7 +383,7 @@ def set_up_dependencies(self): external_file=['external_file'], starting_frame=[1, 2, 3], format='tiff', - timestamps=list() + timestamps=[1., 2.] ) ip = create_imaging_plane()