Skip to content

Commit

Permalink
Merge pull request #4635 from thewtex/transform-dict-repr
Browse files Browse the repository at this point in the history
BUG: Make dict_from_transform more consistent with other dict representations
  • Loading branch information
thewtex authored May 3, 2024
2 parents 0ea4ae9 + 5397f81 commit afc4879
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,13 @@

keys_to_test1 = [
"name",
"parametersValueType",
"transformType",
"inputDimension",
"outputDimension",
"inputSpaceName",
"outputSpaceName",
"numberOfParameters",
"numberOfFixedParameters",
]
keys_to_test2 = ["parameters", "fixedParameters"]
keys_to_test3 = ["transformParameterization", "parametersValueType", "inputDimension", "outputDimension"]

transform_object_list = []
for i, transform_type in enumerate(transforms_to_test):
Expand All @@ -60,6 +57,8 @@
# Test all the parameters
for k in keys_to_test2:
assert np.array_equal(serialize_deserialize[k], transform[k])
for k in keys_to_test3:
assert serialize_deserialize["transformType"][k], transform["transformType"][k]
transform_object_list.append(transform)

print("Individual Transforms Test Done")
Expand Down Expand Up @@ -93,6 +92,9 @@
for k in keys_to_test2:
assert np.array_equal(transform_obj[k], transform_object_list[i][k])

for k in keys_to_test3:
assert transform_object_list[i]["transformType"][k], transform["transformType"][k]


# Test for transformation using de-serialized BSpline Transform
ImageDimension = 2
Expand Down
5 changes: 2 additions & 3 deletions Wrapping/Generators/Python/PyBase/pyBase.i
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,15 @@ str = str
Return keys related to the transform's metadata.
These keys are used in the dictionary resulting from dict(transform).
"""
result = ['name', 'inputDimension', 'outputDimension', 'inputSpaceName', 'outputSpaceName', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
result = ['transformType', 'name', 'inputSpaceName', 'outputSpaceName', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
return result
def __getitem__(self, key):
"""Access metadata keys, see help(transform.keys), for string keys."""
import itk
if isinstance(key, str):
state = itk.dict_from_transform(self)
return state[0][key]
return state[key]
def __setitem__(self, key, value):
if isinstance(key, str):
Expand Down Expand Up @@ -474,7 +474,6 @@ str = str
def __setstate__(self, state):
"""Set object state, necessary for serialization with pickle."""
import itk
import numpy as np
deserialized = itk.transform_from_dict(state)
self.__dict__['this'] = deserialized
%}
Expand Down
5 changes: 5 additions & 0 deletions Wrapping/Generators/Python/Tests/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ def custom_callback(name, progress):
parameters = np.asarray(transforms[0].GetParameters())
assert np.allclose(parameters, np.array(baseline_additional_transform_params))

transform_dict = itk.dict_from_transform(transforms[0])
transform_back = itk.transform_from_dict(transform_dict)
transform_dict = itk.dict_from_transform(transforms)
transform_back = itk.transform_from_dict(transform_dict)

# pipeline, auto_pipeline and templated class are tested in other files

# BridgeNumPy
Expand Down
104 changes: 67 additions & 37 deletions Wrapping/Generators/Python/itk/support/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,57 +981,81 @@ def dict_from_pointset(pointset: "itkt.PointSet") -> Dict:
)


def dict_from_transform(transform: "itkt.TransformBase") -> Dict:
def dict_from_transform(transform: Union["itkt.TransformBase", List["itkt.TransformBase"]]) -> Union[List[Dict], Dict]:
"""Serialize a Python itk.Transform object to a pickable Python dictionary.
If the transform is a list of transforms, then a list of dictionaries is returned.
If the transform is a single, non-Composite transform, then a single dictionary is returned.
Composite transforms and nested composite transforms are flattened into a list of dictionaries.
"""
import itk
datatype_dict = {"double": itk.D, "float": itk.F}

def update_transform_dict(current_transform):
current_transform_type = current_transform.GetTransformTypeAsString()
current_transform_type_split = current_transform_type.split("_")
component = itk.template(current_transform)

in_transform_dict = dict()
in_transform_dict["name"] = current_transform.GetObjectName()
transform_type = dict()
transform_parameterization = current_transform_type_split[0].replace("Transform", "")
transform_type["transformParameterization"] = transform_parameterization

datatype_dict = {"double": itk.D, "float": itk.F}
in_transform_dict["parametersValueType"] = python_to_js(
transform_type["parametersValueType"] = python_to_js(
datatype_dict[current_transform_type_split[1]]
)
in_transform_dict["inputDimension"] = int(current_transform_type_split[2])
in_transform_dict["outputDimension"] = int(current_transform_type_split[3])
in_transform_dict["transformType"] = current_transform_type_split[0]
transform_type["inputDimension"] = int(current_transform_type_split[2])
transform_type["outputDimension"] = int(current_transform_type_split[3])

in_transform_dict["inputSpaceName"] = current_transform.GetInputSpaceName()
in_transform_dict["outputSpaceName"] = current_transform.GetOutputSpaceName()
transform_dict = dict()
transform_dict['transformType'] = transform_type
transform_dict["name"] = current_transform.GetObjectName()

transform_dict["inputSpaceName"] = current_transform.GetInputSpaceName()
transform_dict["outputSpaceName"] = current_transform.GetOutputSpaceName()

# To avoid copying the parameters for the Composite Transform
# as it is a copy of child transforms.
if "Composite" not in current_transform_type_split[0]:
p = np.array(current_transform.GetParameters())
in_transform_dict["parameters"] = p
transform_dict["parameters"] = p

fp = np.array(current_transform.GetFixedParameters())
in_transform_dict["fixedParameters"] = fp
transform_dict["fixedParameters"] = fp

in_transform_dict["numberOfParameters"] = p.shape[0]
in_transform_dict["numberOfFixedParameters"] = fp.shape[0]
transform_dict["numberOfParameters"] = p.shape[0]
transform_dict["numberOfFixedParameters"] = fp.shape[0]

return in_transform_dict
return transform_dict

dict_array = []
transform_type = transform.GetTransformTypeAsString()
if "CompositeTransform" in transform_type:
# Add the transforms inside the composite transform
# range is over-ridden so using this hack to create a list
for i, _ in enumerate([0] * transform.GetNumberOfTransforms()):
current_transform = transform.GetNthTransform(i)
dict_array.append(update_transform_dict(current_transform))
multi = False
def add_transform_dict(transform):
transform_type = transform.GetTransformTypeAsString()
if "CompositeTransform" in transform_type:
# Add the transforms inside the composite transform
# range is over-ridden so using this hack to create a list
for i, _ in enumerate([0] * transform.GetNumberOfTransforms()):
current_transform = transform.GetNthTransform(i)
dict_array.append(update_transform_dict(current_transform))
return True
else:
dict_array.append(update_transform_dict(transform))
return False
if isinstance(transform, list):
multi = True
for t in transform:
add_transform_dict(t)
else:
dict_array.append(update_transform_dict(transform))
multi = add_transform_dict(transform)

return dict_array
if multi:
return dict_array
else:
return dict_array[0]

def transform_from_dict(transform_dict: Union[Dict, List[Dict]]) -> "itkt.TransformBase":
"""Deserialize a dictionary representing an itk.Transform object.
def transform_from_dict(transform_dict: Dict) -> "itkt.TransformBase":
If the dictionary represents a list of transforms, then a Composite Transform is returned."""
import itk

def set_parameters(transform, transform_parameters, transform_fixed_parameters, data_type):
Expand All @@ -1055,35 +1079,41 @@ def special_transform_check(transform_name):

parametersValueType_dict = {"float32": itk.F, "float64": itk.D}

if not isinstance(transform_dict, list):
transform_dict = [transform_dict]

# Loop over all the transforms in the dictionary
transforms_list = []
for i, _ in enumerate(transform_dict):
data_type = parametersValueType_dict[transform_dict[i]["parametersValueType"]]
transform_type = transform_dict[i]["transformType"]
data_type = parametersValueType_dict[transform_type["parametersValueType"]]

transform_parameterization = transform_type["transformParameterization"] + 'Transform'

# No template parameter needed for transforms having 2D or 3D name
# Also for some selected transforms
if special_transform_check(transform_dict[i]["transformType"]):
transform_template = getattr(itk, transform_dict[i]["transformType"])
if special_transform_check(transform_parameterization):
transform_template = getattr(itk, transform_parameterization)
transform = transform_template[data_type].New()
# Currently only BSpline Transform has 3 template parameters
# For future extensions the information will have to be encoded in
# the transformType variable. The transform object once added in a
# composite transform lose the information for other template parameters ex. BSpline.
# The Spline order is fixed as 3 here.
elif transform_dict[i]["transformType"] == "BSplineTransform":
transform_template = getattr(itk, transform_dict[i]["transformType"])
elif transform_parameterization == "BSplineTransform":
transform_template = getattr(itk, transform_parameterization)
transform = transform_template[
data_type, transform_dict[i]["inputDimension"], 3
data_type, transform_type["inputDimension"], 3
].New()
else:
transform_template = getattr(itk, transform_dict[i]["transformType"])
transform_template = getattr(itk, transform_parameterization)
if len(transform_template.items()[0][0]) > 2:
transform = transform_template[
data_type, transform_dict[i]["inputDimension"], transform_dict[i]["outputDimension"]
data_type, transform_type["inputDimension"], transform_type["outputDimension"]
].New()
else:
transform = transform_template[
data_type, transform_dict[i]["inputDimension"]
data_type, transform_type["inputDimension"]
].New()

transform.SetObjectName(transform_dict[i]["name"])
Expand All @@ -1102,8 +1132,8 @@ def special_transform_check(transform_name):
if len(transforms_list) > 1:
# Create a Composite Transform object
# and add all the transforms in it.
data_type = parametersValueType_dict[transform_dict[0]["parametersValueType"]]
transform = itk.CompositeTransform[data_type, transforms_list[0]['inputDimension']].New()
data_type = parametersValueType_dict[transform_dict[0]["transformType"]["parametersValueType"]]
transform = itk.CompositeTransform[data_type, transforms_list[0]["transformType"]['inputDimension']].New()
for current_transform in transforms_list:
transform.AddTransform(current_transform)
else:
Expand Down

0 comments on commit afc4879

Please sign in to comment.