Skip to content

Commit

Permalink
ENH: Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cookpa committed May 17, 2024
1 parent a54ed05 commit dab3df0
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions tests/test_core_ants_transform_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ def setUp(self):
self.txs = [tx2d, tx3d]
self.pixeltypes = ['unsigned char', 'unsigned int', 'float']

self.matrix_offset_types = ['AffineTransform',
self.matrix_offset_types = ['AffineTransform',
'CenteredAffineTransform',
'Euler2DTransform',
'Euler3DTransform',
'Rigid2DTransform',
'QuaternionRigidTransform',
'Similarity2DTransform',
'CenteredSimilarity2DTransform',
'Similarity3DTransform',
'Similarity3DTransform',
'CenteredRigid2DTransform',
'CenteredEuler3DTransform',
'CenteredEuler3DTransform',
'Rigid3DTransform']

def tearDown(self):
Expand Down Expand Up @@ -95,16 +95,27 @@ def test_read_write_transform(self):
# file doesnt exist
with self.assertRaises(Exception):
ants.read_transform('blah-blah.mat')


def test_from_displacement_components(self):
vec_np = np.ndarray((2,2,3), dtype=np.float32)
vec = ants.from_numpy(vec_np, origin=(0,0), spacing=(1,1), has_components=True)
# should get ValueError here because the 2D vector field has 3 components
with self.assertRaises(ValueError):
ants.transform_from_displacement_field(vec)
vec_np = np.ndarray((2,2,2,3), dtype=np.float32)
vec = ants.from_numpy(vec_np, origin=(0,0,0), spacing=(1,1,1), has_components=True)
# should work here because the 3D vector field has 3 components
tx = ants.transform_from_displacement_field(vec)

def test_from_displacement(self):
fi = ants.image_read(ants.get_ants_data('r16') )
mi = ants.image_read(ants.get_ants_data('r64') )
fi = ants.resample_image(fi,(60,60),1,0)
mi = ants.resample_image(mi,(60,60),1,0) # speed up
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = ('SyN') )
vec = ants.image_read( mytx['fwdtransforms'][0] )
atx = ants.transform_from_displacement_field( vec )
# read transform, which calls transform_from_displacement_field
atx = ants.read_transform( mytx['fwdtransforms'][0] )

def test_to_displacement(self):
fi = ants.image_read(ants.get_ants_data('r16') )
mi = ants.image_read(ants.get_ants_data('r64') )
Expand All @@ -113,11 +124,11 @@ def test_to_displacement(self):
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = ('SyN') )
vec = ants.image_read( mytx['fwdtransforms'][0] )
atx = ants.transform_from_displacement_field( vec )
field = ants.transform_to_displacement_field( atx, fi )
field = ants.transform_to_displacement_field( atx, fi )

def test_catch_error(self):
with self.assertRaises(Exception):
ants.write_transform(123, 'test.mat')
ants.write_transform(123, 'test.mat')


if __name__ == '__main__':
Expand Down

0 comments on commit dab3df0

Please sign in to comment.