Skip to content

Commit

Permalink
Merge pull request #638 from ANTsX/ants_read_warp_transform
Browse files Browse the repository at this point in the history
Ants read warp transform
  • Loading branch information
Nicholas Cullen, PhD authored May 18, 2024
2 parents 1c81298 + dab3df0 commit fe4fbbb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
12 changes: 10 additions & 2 deletions ants/core/ants_transform_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def transform_from_displacement_field(field):
"""
if not isinstance(field, iio.ANTsImage):
raise ValueError("field must be ANTsImage type")
if field.dimension < 2 or field.dimension > 3:
raise ValueError("Unsupported displacement field dimension: %i" % field.dimension)
if field.components != field.dimension:
raise ValueError("Displacement field must have same number of components as the image dimension")
libfn = utils.get_lib_fn("antsTransformFromDisplacementField")
field = field.clone("float")
txptr = libfn(field.pointer)
Expand All @@ -263,6 +267,7 @@ def transform_from_displacement_field(field):
pointer=txptr,
)


def transform_to_displacement_field(xfrm, ref):
"""
Convert displacement field ANTsTransform to displacement field
Expand Down Expand Up @@ -299,6 +304,7 @@ def transform_to_displacement_field(xfrm, ref):
field_ptr = libfn(xfrm.pointer, ref.pointer)
return iio2.from_pointer(field_ptr)


def read_transform(filename, precision="float"):
"""
Read a transform from file
Expand Down Expand Up @@ -329,7 +335,9 @@ def read_transform(filename, precision="float"):
if not os.path.exists(filename):
raise ValueError("filename does not exist!")

# intentionally ignore dimension
if filename.endswith('.nii') or filename.endswith('.nii.gz'):
return transform_from_displacement_field(iio2.image_read(filename))

libfn1 = utils.get_lib_fn("getTransformDimensionFromFile")
dimensionUse = libfn1(filename)

Expand Down Expand Up @@ -377,7 +385,7 @@ def write_transform(transform, filename):
"""
if not isinstance(transform, tio.ANTsTransform):
raise Exception('Only ANTsTransform instances can be written to file. Check that you are not passing in a filepath to a saved transform.')

filename = os.path.expanduser(filename)
libfn = utils.get_lib_fn("writeTransform")
libfn(transform.pointer, filename)
31 changes: 21 additions & 10 deletions tests/test_core_ants_transform_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ def setUp(self):
self.txs = [tx2d, tx3d]
self.pixeltypes = ['unsigned char', 'unsigned int', 'float']

self.matrix_offset_types = ['AffineTransform',
self.matrix_offset_types = ['AffineTransform',
'CenteredAffineTransform',
'Euler2DTransform',
'Euler3DTransform',
'Rigid2DTransform',
'QuaternionRigidTransform',
'Similarity2DTransform',
'CenteredSimilarity2DTransform',
'Similarity3DTransform',
'Similarity3DTransform',
'CenteredRigid2DTransform',
'CenteredEuler3DTransform',
'CenteredEuler3DTransform',
'Rigid3DTransform']

def tearDown(self):
Expand Down Expand Up @@ -95,16 +95,27 @@ def test_read_write_transform(self):
# file doesnt exist
with self.assertRaises(Exception):
ants.read_transform('blah-blah.mat')


def test_from_displacement_components(self):
vec_np = np.ndarray((2,2,3), dtype=np.float32)
vec = ants.from_numpy(vec_np, origin=(0,0), spacing=(1,1), has_components=True)
# should get ValueError here because the 2D vector field has 3 components
with self.assertRaises(ValueError):
ants.transform_from_displacement_field(vec)
vec_np = np.ndarray((2,2,2,3), dtype=np.float32)
vec = ants.from_numpy(vec_np, origin=(0,0,0), spacing=(1,1,1), has_components=True)
# should work here because the 3D vector field has 3 components
tx = ants.transform_from_displacement_field(vec)

def test_from_displacement(self):
fi = ants.image_read(ants.get_ants_data('r16') )
mi = ants.image_read(ants.get_ants_data('r64') )
fi = ants.resample_image(fi,(60,60),1,0)
mi = ants.resample_image(mi,(60,60),1,0) # speed up
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = ('SyN') )
vec = ants.image_read( mytx['fwdtransforms'][0] )
atx = ants.transform_from_displacement_field( vec )
# read transform, which calls transform_from_displacement_field
atx = ants.read_transform( mytx['fwdtransforms'][0] )

def test_to_displacement(self):
fi = ants.image_read(ants.get_ants_data('r16') )
mi = ants.image_read(ants.get_ants_data('r64') )
Expand All @@ -113,11 +124,11 @@ def test_to_displacement(self):
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = ('SyN') )
vec = ants.image_read( mytx['fwdtransforms'][0] )
atx = ants.transform_from_displacement_field( vec )
field = ants.transform_to_displacement_field( atx, fi )
field = ants.transform_to_displacement_field( atx, fi )

def test_catch_error(self):
with self.assertRaises(Exception):
ants.write_transform(123, 'test.mat')
ants.write_transform(123, 'test.mat')


if __name__ == '__main__':
Expand Down

0 comments on commit fe4fbbb

Please sign in to comment.