Skip to content

Commit

Permalink
Merge pull request #20 from rchristie/fitted_coordinates
Browse files Browse the repository at this point in the history
PEP 8 fixes, rename output fitted coordinates
  • Loading branch information
hsorby authored Nov 4, 2021
2 parents d6d54a8 + 94c596b commit 528182c
Show file tree
Hide file tree
Showing 7 changed files with 419 additions and 443 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def readfile(filename, split=False):
# into the 'requirements.txt' file.
requires = [
# minimal requirements listing
"opencmiss.math",
"opencmiss.maths",
"opencmiss.utils >= 0.3",
"opencmiss.zinc >= 3.4"
]
Expand Down
286 changes: 170 additions & 116 deletions src/scaffoldfitter/fitter.py

Large diffs are not rendered by default.

26 changes: 16 additions & 10 deletions src/scaffoldfitter/fitterstep.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Base class for fitter steps.
"""
import abc


class FitterStep:
Expand Down Expand Up @@ -33,13 +34,18 @@ def getDefaultGroupName(cls):
def getFitter(self):
return self._fitter

def _setFitter(self, fitter):
'''
def setFitter(self, fitter):
"""
Should only be called by Fitter when adding or removing from it.
'''
"""
self._fitter = fitter

def decodeSettingsJSONDict(self, dctIn : dict):
@classmethod
@abc.abstractmethod
def getJsonTypeId(cls):
pass

def decodeSettingsJSONDict(self, dctIn: dict):
"""
Decode definition of step from JSON dict.
"""
Expand All @@ -55,8 +61,8 @@ def encodeSettingsJSONDict(self) -> dict:
:return: Settings in a dict ready for passing to json.dump.
"""
return {
self.getJsonTypeId() : True,
"groupSettings" : self._groupSettings
self.getJsonTypeId(): True,
"groupSettings": self._groupSettings
}

def getGroupSettingsNames(self):
Expand All @@ -65,7 +71,7 @@ def getGroupSettingsNames(self):
"""
return list(self._groupSettings.keys())

def clearGroupSetting(self, groupName : str, settingName : str):
def clearGroupSetting(self, groupName: str, settingName: str):
"""
Clear setting for group, removing group settings dict if empty.
:param groupName: Exact model group name, or None for default group.
Expand All @@ -79,7 +85,7 @@ def clearGroupSetting(self, groupName : str, settingName : str):
if len(groupSettings) == 0:
self._groupSettings.pop(groupName)

def _getInheritedGroupSetting(self, groupName : str, settingName : str):
def _getInheritedGroupSetting(self, groupName: str, settingName: str):
"""
:param groupName: Exact model group name, or None for default group.
:param settingName: Exact setting name.
Expand All @@ -98,7 +104,7 @@ def _getInheritedGroupSetting(self, groupName : str, settingName : str):
if inheritedValue != "<not set>":
return inheritedValue

def getGroupSetting(self, groupName : str, settingName : str, defaultValue):
def getGroupSetting(self, groupName: str, settingName: str, defaultValue):
"""
Get group setting of supplied name, with reset & inherit ability.
:param groupName: Exact model group name, or None for default group.
Expand Down Expand Up @@ -130,7 +136,7 @@ def getGroupSetting(self, groupName : str, settingName : str, defaultValue):
value = defaultValue
return value, setLocally, inheritable

def setGroupSetting(self, groupName : str, settingName : str, value):
def setGroupSetting(self, groupName: str, settingName: str, value):
"""
Set value of setting or None to reset to default.
:param groupName: Exact model group name, or None for default group.
Expand Down
126 changes: 66 additions & 60 deletions src/scaffoldfitter/fitterstepalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import math
from opencmiss.maths.vectorops import div, euler_to_rotation_matrix, matrix_vector_mult, mult, sub
from opencmiss.utils.zinc.field import assignFieldParameters, get_group_list, create_field_euler_angles_rotation_matrix
from opencmiss.utils.zinc.field import get_group_list, create_field_euler_angles_rotation_matrix
from opencmiss.utils.zinc.finiteelement import evaluate_field_nodeset_mean, getNodeNameCentres
from opencmiss.utils.zinc.general import ChangeManager
from opencmiss.zinc.element import Mesh
Expand All @@ -15,8 +15,8 @@
from scaffoldfitter.fitterstep import FitterStep


def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_value=1.0, \
translation_offsets=None, translation_scale_factor=1.0):
def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_value=1.0,
translation_offsets=None, translation_scale_factor=1.0):
"""
Create constant fields for rotation, scale and translation containing the supplied
values, plus the transformed coordinates applying them in the supplied order.
Expand Down Expand Up @@ -44,11 +44,13 @@ def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_
translation = fieldmodule.createFieldConstant(translation_offsets)
rotation_matrix = create_field_euler_angles_rotation_matrix(fieldmodule, rotation)
rotated_coordinates = fieldmodule.createFieldMatrixMultiply(components_count, rotation_matrix, coordinates)
transformed_coordinates = rotated_coordinates*scale + (translation if (translation_scale_factor == 1.0) else \
translation*fieldmodule.createFieldConstant([ translation_scale_factor ]*components_count))
transformed_coordinates = rotated_coordinates*scale + (
translation if (translation_scale_factor == 1.0) else
translation*fieldmodule.createFieldConstant([translation_scale_factor]*components_count))
assert transformed_coordinates.isValid()
return transformed_coordinates, rotation, scale, translation


class FitterStepAlign(FitterStep):

_jsonTypeId = "_FitterStepAlign"
Expand All @@ -57,15 +59,15 @@ def __init__(self):
super(FitterStepAlign, self).__init__()
self._alignGroups = False
self._alignMarkers = False
self._rotation = [ 0.0, 0.0, 0.0 ]
self._rotation = [0.0, 0.0, 0.0]
self._scale = 1.0
self._translation = [ 0.0, 0.0, 0.0 ]
self._translation = [0.0, 0.0, 0.0]

@classmethod
def getJsonTypeId(cls):
return cls._jsonTypeId

def decodeSettingsJSONDict(self, dctIn : dict):
def decodeSettingsJSONDict(self, dctIn: dict):
"""
Decode definition of step from JSON dict.
"""
Expand All @@ -86,11 +88,11 @@ def encodeSettingsJSONDict(self) -> dict:
"""
dct = super().encodeSettingsJSONDict()
dct.update({
"alignGroups" : self._alignGroups,
"alignMarkers" : self._alignMarkers,
"rotation" : self._rotation,
"scale" : self._scale,
"translation" : self._translation
"alignGroups": self._alignGroups,
"alignMarkers": self._alignMarkers,
"rotation": self._rotation,
"scale": self._scale,
"translation": self._translation
})
return dct

Expand Down Expand Up @@ -177,14 +179,14 @@ def run(self, modelFileNameStem=None):
assert modelCoordinates, "Align: Missing model coordinates"
if self._alignGroups or self._alignMarkers:
self._doAutoAlign()
fieldmodule = self._fitter._fieldmodule
fieldmodule = self._fitter.getFieldmodule()
with ChangeManager(fieldmodule):
# rotate, scale and translate model
modelCoordinatesTransformed = createFieldsTransformations(
modelCoordinates, self._rotation, self._scale, self._translation)[0]
fieldassignment = self._fitter._modelCoordinatesField.createFieldassignment(modelCoordinatesTransformed)
fieldassignment = self._fitter.getModelCoordinatesField().createFieldassignment(modelCoordinatesTransformed)
result = fieldassignment.assign()
assert result in [ RESULT_OK, RESULT_WARNING_PART_DONE ], "Align: Failed to transform model"
assert result in [RESULT_OK, RESULT_WARNING_PART_DONE], "Align: Failed to transform model"
self._fitter.updateModelReferenceCoordinates()
del fieldassignment
del modelCoordinatesTransformed
Expand All @@ -198,12 +200,10 @@ def _doAutoAlign(self):
Perform auto alignment to groups and/or markers.
"""
modelCoordinates = self._fitter.getModelCoordinatesField()
componentsCount = modelCoordinates.getNumberOfComponents()
pointMap = {} # dict group/marker name -> (modelCoordinates, dataCoordinates)

if self._alignGroups:
fieldmodule = self._fitter._fieldmodule
datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
fieldmodule = self._fitter.getFieldmodule()
dataCoordinates = self._fitter.getDataCoordinatesField()
groups = get_group_list(fieldmodule)
with ChangeManager(fieldmodule):
Expand All @@ -220,20 +220,21 @@ def _doAutoAlign(self):
coordinates_integral = evaluate_field_mesh_integral(modelCoordinates, modelCoordinates, meshGroup)
mass = evaluate_field_mesh_integral(one, modelCoordinates, meshGroup)
meanModelCoordinates = div(coordinates_integral, mass)
pointMap[groupName] = ( meanModelCoordinates, meanDataCoordinates )
pointMap[groupName] = (meanModelCoordinates, meanDataCoordinates)
del one

if self._alignMarkers:
markerGroup = self._fitter.getMarkerGroup()
assert markerGroup, "Align: No marker group to align with"
markerPrefix = markerGroup.getName()

markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields()
assert markerNodeGroup and markerCoordinates and markerName, "Align: No marker group, coordinates or name fields"
assert markerNodeGroup and markerCoordinates and markerName, \
"Align: No marker group, coordinates or name fields"
modelMarkers = getNodeNameCentres(markerNodeGroup, markerCoordinates, markerName)

markerDataGroup, markerDataCoordinates, markerDataName = self._fitter.getMarkerDataFields()
assert markerDataGroup and markerDataCoordinates and markerDataName, "Align: No marker data group, coordinates or name fields"
assert markerDataGroup and markerDataCoordinates and markerDataName, \
"Align: No marker data group, coordinates or name fields"
dataMarkers = getNodeNameCentres(markerDataGroup, markerDataCoordinates, markerDataName)

# match model and data markers, warn of unmatched markers
Expand All @@ -243,9 +244,10 @@ def _doAutoAlign(self):
matchName = modelName.strip().casefold()
for dataName in dataMarkers:
if dataName.strip().casefold() == matchName:
pointMap[modelName] = ( modelMarkers[modelName], dataMarkers[dataName] )
pointMap[modelName] = (modelMarkers[modelName], dataMarkers[dataName])
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' found in data" + (" as '" + dataName +"'" if (dataName != modelName) else ""))
print("Align: Model marker '" + modelName + "' found in data" +
(" as '" + dataName + "'" if (dataName != modelName) else ""))
dataMarkers.pop(dataName)
break
else:
Expand All @@ -265,7 +267,7 @@ def _optimiseAlignment(self, pointMap):
:param pointMap: dict name -> (modelCoordinates, dataCoordinates)
"""
assert len(pointMap) >= 3, "Align: Only " + str(len(pointMap)) + " points - need at least 3"
region = self._fitter._context.createRegion()
region = self._fitter.getContext().createRegion()
fieldmodule = region.getFieldmodule()

with ChangeManager(fieldmodule):
Expand All @@ -281,7 +283,7 @@ def _optimiseAlignment(self, pointMap):
datasum = [0.0, 0.0, 0.0]
modelMin = copy.deepcopy(list(pointMap.values())[0][0])
modelMax = copy.deepcopy(list(pointMap.values())[0][0])
dataMin = copy.deepcopy(list(pointMap.values())[0][1])
dataMin = copy.deepcopy(list(pointMap.values())[0][1])
dataMax = copy.deepcopy(list(pointMap.values())[0][1])
for name, positions in pointMap.items():
modelx = positions[0]
Expand All @@ -292,7 +294,8 @@ def _optimiseAlignment(self, pointMap):
fieldcache.setNode(node)
result1 = modelCoordinates.assignReal(fieldcache, positions[0])
result2 = dataCoordinates.assignReal(fieldcache, positions[1])
assert (result1 == RESULT_OK) and (result2 == RESULT_OK), "Align: Failed to set up data for alignment to markers optimisation"
assert (result1 == RESULT_OK) and (result2 == RESULT_OK), \
"Align: Failed to set up data for alignment to markers optimisation"
for c in range(3):
modelMin[c] = min(modelx[c], modelMin[c])
modelMax[c] = max(modelx[c], modelMax[c])
Expand All @@ -314,8 +317,8 @@ def _optimiseAlignment(self, pointMap):
translationScaleFactor = 1.0
first = True
fieldcache = fieldmodule.createFieldcache()
modelCoordinatesTransformed, rotation, scale, translation = createFieldsTransformations(modelCoordinates,
scale_value=scaleFactor)
modelCoordinatesTransformed, rotation, scale, translation = \
createFieldsTransformations(modelCoordinates, scale_value=scaleFactor)

# create objective = sum of squares of vector from modelCoordinatesTransformed to dataCoordinates
markerDiff = fieldmodule.createFieldSubtract(dataCoordinates, modelCoordinatesTransformed)
Expand Down Expand Up @@ -345,41 +348,42 @@ def _optimiseAlignment(self, pointMap):
rotation.assignReal(fieldcache, minRotationAngles)
translation.assignReal(fieldcache, minTranslation)

assert objective.isValid(), "Align: Failed to set up objective function for alignment to markers optimisation"
assert objective.isValid(), \
"Align: Failed to set up objective function for alignment to markers optimisation"

optimisation = fieldmodule.createOptimisation()
optimisation.setMethod(Optimisation.METHOD_LEAST_SQUARES_QUASI_NEWTON)
#optimisation.setMethod(Optimisation.METHOD_QUASI_NEWTON)
# optimisation.setMethod(Optimisation.METHOD_QUASI_NEWTON)
optimisation.addObjectiveField(objective)
optimisation.addDependentField(rotation)
optimisation.addDependentField(scale)
optimisation.addDependentField(translation)

#FunctionTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE)
#GradientTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE)
#StepTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE)
#MaximumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP)
#MinimumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP)
#LinesearchTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE)
#TrustRegionSize = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE)

#tol_scale = dataScale*dataScale
#FunctionTolerance *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE, FunctionTolerance)
#GradientTolerance *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE, GradientTolerance)
#StepTolerance *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE, StepTolerance)
#MaximumStep *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP, MaximumStep)
#MinimumStep *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP, MinimumStep)
#LinesearchTolerance *= dataScale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE, LinesearchTolerance)
#TrustRegionSize *= dataScale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE, TrustRegionSize)

#if self.getDiagnosticLevel() > 0:
# FunctionTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE)
# GradientTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE)
# StepTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE)
# MaximumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP)
# MinimumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP)
# LinesearchTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE)
# TrustRegionSize = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE)

# tol_scale = dataScale*dataScale
# FunctionTolerance *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE, FunctionTolerance)
# GradientTolerance *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE, GradientTolerance)
# StepTolerance *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE, StepTolerance)
# MaximumStep *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP, MaximumStep)
# MinimumStep *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP, MinimumStep)
# LinesearchTolerance *= dataScale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE, LinesearchTolerance)
# TrustRegionSize *= dataScale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE, TrustRegionSize)

# if self.getDiagnosticLevel() > 0:
# print("Function Tolerance", FunctionTolerance)
# print("Gradient Tolerance", GradientTolerance)
# print("Step Tolerance", StepTolerance)
Expand All @@ -397,10 +401,12 @@ def _optimiseAlignment(self, pointMap):
result1, self._rotation = rotation.evaluateReal(fieldcache, 3)
result2, self._scale = scale.evaluateReal(fieldcache, 1)
result3, self._translation = translation.evaluateReal(fieldcache, 3)
self._translation = [ s*translationScaleFactor for s in self._translation ]
assert (result1 == RESULT_OK) and (result2 == RESULT_OK) and (result3 == RESULT_OK), "Align: Failed to evaluate transformation for alignment to markers"
self._translation = [s*translationScaleFactor for s in self._translation]
assert (result1 == RESULT_OK) and (result2 == RESULT_OK) and (result3 == RESULT_OK), \
"Align: Failed to evaluate transformation for alignment to markers"


def evaluate_field_mesh_integral(field : Field, coordinates : Field, mesh: Mesh, number_of_points = 4):
def evaluate_field_mesh_integral(field: Field, coordinates: Field, mesh: Mesh, number_of_points=4):
"""
Integrate value of a field over mesh using Gaussian Quadrature.
:param field: Field to integrate over mesh.
Expand Down
3 changes: 2 additions & 1 deletion src/scaffoldfitter/fitterstepconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scaffoldfitter.fitterstep import FitterStep
import sys


class FitterStepConfig(FitterStep):

_jsonTypeId = "_FitterStepConfig"
Expand All @@ -19,7 +20,7 @@ def __init__(self):
def getJsonTypeId(cls):
return cls._jsonTypeId

def decodeSettingsJSONDict(self, dctIn : dict):
def decodeSettingsJSONDict(self, dctIn: dict):
"""
Decode definition of step from JSON dict.
"""
Expand Down
Loading

0 comments on commit 528182c

Please sign in to comment.