Skip to content

Commit

Permalink
make MotionCorrection test (#1374)
Browse files Browse the repository at this point in the history
* attempt to make MotionCorrection test. I don't understand why it's failing

* Fix CorrectedImageStack nwbfields

* Fix tests

* Fix test

* Fix flake8

* Fix test

Co-authored-by: Ryan Ly <[email protected]>
  • Loading branch information
bendichter and rly authored Jun 21, 2021
1 parent 928e071 commit 51cc8f7
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 27 deletions.
6 changes: 3 additions & 3 deletions src/pynwb/ophys.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ class CorrectedImageStack(NWBDataInterface):
assumed to be 2-D (has only x & y dimensions).
"""

__nwbfields__ = ({'name': 'corrected', 'child': True},
{'name': 'xy_translation', 'child': True},
__nwbfields__ = ({'name': 'corrected', 'child': True, 'required_name': 'corrected'},
{'name': 'xy_translation', 'child': True, 'required_name': 'xy_translation'},
'original')

@docval({'name': 'name', 'type': str,
Expand Down Expand Up @@ -193,7 +193,7 @@ class MotionCorrection(MultiContainerInterface):
'get': 'get_corrected_image_stack',
'create': 'create_corrected_image_stack',
'type': CorrectedImageStack,
'attr': 'corrected_images_stacks'
'attr': 'corrected_image_stacks'
}


Expand Down
105 changes: 85 additions & 20 deletions tests/integration/hdf5/test_ophys.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
from copy import deepcopy
from abc import ABCMeta

import numpy as np
from pynwb.ophys import (
ImagingPlane,
OpticalChannel,
PlaneSegmentation,
ImageSegmentation,
TwoPhotonSeries,
RoiResponseSeries
RoiResponseSeries,
MotionCorrection,
CorrectedImageStack,
)
from pynwb.base import TimeSeries
from pynwb.image import ImageSeries
from pynwb.device import Device
from pynwb.testing import NWBH5IOMixin, AcquisitionH5IOMixin, TestCase


def make_imaging_plane():
""" Make an ImagingPlane and related objects """
device = Device(name='dev1')
optical_channel = OpticalChannel(
name='optchan1',
description='a fake OpticalChannel',
emission_lambda=500.
)
imaging_plane = ImagingPlane(
name='imgpln1',
optical_channel=optical_channel,
description='a fake ImagingPlane',
device=device,
excitation_lambda=600.,
imaging_rate=300.,
indicator='GFP',
location='somewhere in the brain',
reference_frame='unknown'
)

return device, optical_channel, imaging_plane


class TestImagingPlaneIO(NWBH5IOMixin, TestCase):

def setUpContainer(self):
Expand Down Expand Up @@ -50,31 +77,69 @@ def getContainer(self, nwbfile):
return nwbfile.get_imaging_plane(self.container.name)


class TestTwoPhotonSeriesIO(AcquisitionH5IOMixin, TestCase):
class TestMotionCorrection(NWBH5IOMixin, TestCase):

def make_imaging_plane(self):
""" Make an ImagingPlane and related objects """
self.device = Device(name='dev1')
self.optical_channel = OpticalChannel(
name='optchan1',
description='a fake OpticalChannel',
emission_lambda=500.
def setUpContainer(self):
""" Return the test ImagingPlane to read/write """

self.device, self.optical_channel, self.imaging_plane = make_imaging_plane()

self.two_photon_series = TwoPhotonSeries(
name='TwoPhotonSeries',
data=np.ones((1000, 100, 100)),
imaging_plane=self.imaging_plane,
rate=1.0,
unit='normalized amplitude'
)
self.imaging_plane = ImagingPlane(
name='imgpln1',
optical_channel=self.optical_channel,
description='a fake ImagingPlane',
device=self.device,
excitation_lambda=600.,
imaging_rate=300.,
indicator='GFP',
location='somewhere in the brain',
reference_frame='unknown'

corrected = ImageSeries(
name='corrected',
data=np.ones((1000, 100, 100)),
unit='na',
format='raw',
starting_time=0.0,
rate=1.0
)

xy_translation = TimeSeries(
name='xy_translation',
data=np.ones((1000, 2)),
unit='pixels',
starting_time=0.0,
rate=1.0,
)

corrected_image_stack = CorrectedImageStack(
corrected=corrected,
original=self.two_photon_series,
xy_translation=xy_translation,
)

return MotionCorrection(corrected_image_stacks=[corrected_image_stack])

def addContainer(self, nwbfile):
""" Add the test ImagingPlane and Device to the given NWBFile """

nwbfile.add_device(self.device)
nwbfile.add_imaging_plane(self.imaging_plane)
nwbfile.add_acquisition(self.two_photon_series)

ophys_module = nwbfile.create_processing_module(
name='ophys',
description='optical physiology processed data'
)
ophys_module.add(self.container)

def getContainer(self, nwbfile):
""" Return the test ImagingPlane from the given NWBFile """
return nwbfile.processing['ophys'].data_interfaces['MotionCorrection']


class TestTwoPhotonSeriesIO(AcquisitionH5IOMixin, TestCase):

def setUpContainer(self):
""" Return the test TwoPhotonSeries to read/write """
self.make_imaging_plane()
self.device, self.optical_channel, self.imaging_plane = make_imaging_plane()
data = [[[1., 1.] * 2] * 2]
timestamps = list(map(lambda x: x/10, range(10)))
fov = [2.0, 2.0, 5.0]
Expand Down
41 changes: 37 additions & 4 deletions tests/unit/test_ophys.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,47 @@ def test_init(self):

class MotionCorrectionConstructor(TestCase):
def test_init(self):
MotionCorrection(list())
corrected = ImageSeries(
name='corrected',
data=np.ones((1000, 100, 100)),
unit='na',
format='raw',
starting_time=0.0,
rate=1.0
)

xy_translation = TimeSeries(
name='xy_translation',
data=np.ones((1000, 2)),
unit='pixels',
starting_time=0.0,
rate=1.0,
)

ip = create_imaging_plane()

image_series = TwoPhotonSeries(
name='TwoPhotonSeries1',
data=np.ones((1000, 100, 100)),
imaging_plane=ip,
rate=1.0,
unit='normalized amplitude'
)

corrected_image_stack = CorrectedImageStack(
corrected=corrected,
original=image_series,
xy_translation=xy_translation,
)

motion_correction = MotionCorrection(corrected_image_stacks=[corrected_image_stack])
self.assertEqual(motion_correction.corrected_image_stacks['CorrectedImageStack'], corrected_image_stack)


class CorrectedImageStackConstructor(TestCase):
def test_init(self):
is1 = ImageSeries(
name='is1',
name='corrected',
data=np.ones((2, 2, 2)),
unit='unit',
external_file=['external_file'],
Expand All @@ -216,7 +250,6 @@ def test_init(self):
)
is2 = ImageSeries(
name='is2',
data=np.ones((2, 2, 2)),
unit='unit',
external_file=['external_file'],
starting_frame=[1, 2, 3],
Expand All @@ -225,7 +258,7 @@ def test_init(self):
)
tstamps = np.arange(1.0, 100.0, 0.1, dtype=np.float)
ts = TimeSeries(
name="test_ts",
name='xy_translation',
data=list(range(len(tstamps))),
unit='unit',
timestamps=tstamps
Expand Down

0 comments on commit 51cc8f7

Please sign in to comment.