Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEP 8 fixes, rename output fitted coordinates #20

Merged
merged 2 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def readfile(filename, split=False):
# into the 'requirements.txt' file.
requires = [
# minimal requirements listing
"opencmiss.math",
"opencmiss.maths",
"opencmiss.utils >= 0.3",
"opencmiss.zinc >= 3.4"
]
Expand Down
286 changes: 170 additions & 116 deletions src/scaffoldfitter/fitter.py

Large diffs are not rendered by default.

26 changes: 16 additions & 10 deletions src/scaffoldfitter/fitterstep.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Base class for fitter steps.
"""
import abc


class FitterStep:
Expand Down Expand Up @@ -33,13 +34,18 @@ def getDefaultGroupName(cls):
def getFitter(self):
return self._fitter

def _setFitter(self, fitter):
'''
def setFitter(self, fitter):
"""
Should only be called by Fitter when adding or removing from it.
'''
"""
self._fitter = fitter

def decodeSettingsJSONDict(self, dctIn : dict):
@classmethod
@abc.abstractmethod
def getJsonTypeId(cls):
pass

def decodeSettingsJSONDict(self, dctIn: dict):
"""
Decode definition of step from JSON dict.
"""
Expand All @@ -55,8 +61,8 @@ def encodeSettingsJSONDict(self) -> dict:
:return: Settings in a dict ready for passing to json.dump.
"""
return {
self.getJsonTypeId() : True,
"groupSettings" : self._groupSettings
self.getJsonTypeId(): True,
"groupSettings": self._groupSettings
}

def getGroupSettingsNames(self):
Expand All @@ -65,7 +71,7 @@ def getGroupSettingsNames(self):
"""
return list(self._groupSettings.keys())

def clearGroupSetting(self, groupName : str, settingName : str):
def clearGroupSetting(self, groupName: str, settingName: str):
"""
Clear setting for group, removing group settings dict if empty.
:param groupName: Exact model group name, or None for default group.
Expand All @@ -79,7 +85,7 @@ def clearGroupSetting(self, groupName : str, settingName : str):
if len(groupSettings) == 0:
self._groupSettings.pop(groupName)

def _getInheritedGroupSetting(self, groupName : str, settingName : str):
def _getInheritedGroupSetting(self, groupName: str, settingName: str):
"""
:param groupName: Exact model group name, or None for default group.
:param settingName: Exact setting name.
Expand All @@ -98,7 +104,7 @@ def _getInheritedGroupSetting(self, groupName : str, settingName : str):
if inheritedValue != "<not set>":
return inheritedValue

def getGroupSetting(self, groupName : str, settingName : str, defaultValue):
def getGroupSetting(self, groupName: str, settingName: str, defaultValue):
"""
Get group setting of supplied name, with reset & inherit ability.
:param groupName: Exact model group name, or None for default group.
Expand Down Expand Up @@ -130,7 +136,7 @@ def getGroupSetting(self, groupName : str, settingName : str, defaultValue):
value = defaultValue
return value, setLocally, inheritable

def setGroupSetting(self, groupName : str, settingName : str, value):
def setGroupSetting(self, groupName: str, settingName: str, value):
"""
Set value of setting or None to reset to default.
:param groupName: Exact model group name, or None for default group.
Expand Down
126 changes: 66 additions & 60 deletions src/scaffoldfitter/fitterstepalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import math
from opencmiss.maths.vectorops import div, euler_to_rotation_matrix, matrix_vector_mult, mult, sub
from opencmiss.utils.zinc.field import assignFieldParameters, get_group_list, create_field_euler_angles_rotation_matrix
from opencmiss.utils.zinc.field import get_group_list, create_field_euler_angles_rotation_matrix
from opencmiss.utils.zinc.finiteelement import evaluate_field_nodeset_mean, getNodeNameCentres
from opencmiss.utils.zinc.general import ChangeManager
from opencmiss.zinc.element import Mesh
Expand All @@ -15,8 +15,8 @@
from scaffoldfitter.fitterstep import FitterStep


def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_value=1.0, \
translation_offsets=None, translation_scale_factor=1.0):
def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_value=1.0,
translation_offsets=None, translation_scale_factor=1.0):
"""
Create constant fields for rotation, scale and translation containing the supplied
values, plus the transformed coordinates applying them in the supplied order.
Expand Down Expand Up @@ -44,11 +44,13 @@ def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_
translation = fieldmodule.createFieldConstant(translation_offsets)
rotation_matrix = create_field_euler_angles_rotation_matrix(fieldmodule, rotation)
rotated_coordinates = fieldmodule.createFieldMatrixMultiply(components_count, rotation_matrix, coordinates)
transformed_coordinates = rotated_coordinates*scale + (translation if (translation_scale_factor == 1.0) else \
translation*fieldmodule.createFieldConstant([ translation_scale_factor ]*components_count))
transformed_coordinates = rotated_coordinates*scale + (
translation if (translation_scale_factor == 1.0) else
translation*fieldmodule.createFieldConstant([translation_scale_factor]*components_count))
assert transformed_coordinates.isValid()
return transformed_coordinates, rotation, scale, translation


class FitterStepAlign(FitterStep):

_jsonTypeId = "_FitterStepAlign"
Expand All @@ -57,15 +59,15 @@ def __init__(self):
super(FitterStepAlign, self).__init__()
self._alignGroups = False
self._alignMarkers = False
self._rotation = [ 0.0, 0.0, 0.0 ]
self._rotation = [0.0, 0.0, 0.0]
self._scale = 1.0
self._translation = [ 0.0, 0.0, 0.0 ]
self._translation = [0.0, 0.0, 0.0]

@classmethod
def getJsonTypeId(cls):
return cls._jsonTypeId

def decodeSettingsJSONDict(self, dctIn : dict):
def decodeSettingsJSONDict(self, dctIn: dict):
"""
Decode definition of step from JSON dict.
"""
Expand All @@ -86,11 +88,11 @@ def encodeSettingsJSONDict(self) -> dict:
"""
dct = super().encodeSettingsJSONDict()
dct.update({
"alignGroups" : self._alignGroups,
"alignMarkers" : self._alignMarkers,
"rotation" : self._rotation,
"scale" : self._scale,
"translation" : self._translation
"alignGroups": self._alignGroups,
"alignMarkers": self._alignMarkers,
"rotation": self._rotation,
"scale": self._scale,
"translation": self._translation
})
return dct

Expand Down Expand Up @@ -177,14 +179,14 @@ def run(self, modelFileNameStem=None):
assert modelCoordinates, "Align: Missing model coordinates"
if self._alignGroups or self._alignMarkers:
self._doAutoAlign()
fieldmodule = self._fitter._fieldmodule
fieldmodule = self._fitter.getFieldmodule()
with ChangeManager(fieldmodule):
# rotate, scale and translate model
modelCoordinatesTransformed = createFieldsTransformations(
modelCoordinates, self._rotation, self._scale, self._translation)[0]
fieldassignment = self._fitter._modelCoordinatesField.createFieldassignment(modelCoordinatesTransformed)
fieldassignment = self._fitter.getModelCoordinatesField().createFieldassignment(modelCoordinatesTransformed)
result = fieldassignment.assign()
assert result in [ RESULT_OK, RESULT_WARNING_PART_DONE ], "Align: Failed to transform model"
assert result in [RESULT_OK, RESULT_WARNING_PART_DONE], "Align: Failed to transform model"
self._fitter.updateModelReferenceCoordinates()
del fieldassignment
del modelCoordinatesTransformed
Expand All @@ -198,12 +200,10 @@ def _doAutoAlign(self):
Perform auto alignment to groups and/or markers.
"""
modelCoordinates = self._fitter.getModelCoordinatesField()
componentsCount = modelCoordinates.getNumberOfComponents()
pointMap = {} # dict group/marker name -> (modelCoordinates, dataCoordinates)

if self._alignGroups:
fieldmodule = self._fitter._fieldmodule
datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
fieldmodule = self._fitter.getFieldmodule()
dataCoordinates = self._fitter.getDataCoordinatesField()
groups = get_group_list(fieldmodule)
with ChangeManager(fieldmodule):
Expand All @@ -220,20 +220,21 @@ def _doAutoAlign(self):
coordinates_integral = evaluate_field_mesh_integral(modelCoordinates, modelCoordinates, meshGroup)
mass = evaluate_field_mesh_integral(one, modelCoordinates, meshGroup)
meanModelCoordinates = div(coordinates_integral, mass)
pointMap[groupName] = ( meanModelCoordinates, meanDataCoordinates )
pointMap[groupName] = (meanModelCoordinates, meanDataCoordinates)
del one

if self._alignMarkers:
markerGroup = self._fitter.getMarkerGroup()
assert markerGroup, "Align: No marker group to align with"
markerPrefix = markerGroup.getName()

markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields()
assert markerNodeGroup and markerCoordinates and markerName, "Align: No marker group, coordinates or name fields"
assert markerNodeGroup and markerCoordinates and markerName, \
"Align: No marker group, coordinates or name fields"
modelMarkers = getNodeNameCentres(markerNodeGroup, markerCoordinates, markerName)

markerDataGroup, markerDataCoordinates, markerDataName = self._fitter.getMarkerDataFields()
assert markerDataGroup and markerDataCoordinates and markerDataName, "Align: No marker data group, coordinates or name fields"
assert markerDataGroup and markerDataCoordinates and markerDataName, \
"Align: No marker data group, coordinates or name fields"
dataMarkers = getNodeNameCentres(markerDataGroup, markerDataCoordinates, markerDataName)

# match model and data markers, warn of unmatched markers
Expand All @@ -243,9 +244,10 @@ def _doAutoAlign(self):
matchName = modelName.strip().casefold()
for dataName in dataMarkers:
if dataName.strip().casefold() == matchName:
pointMap[modelName] = ( modelMarkers[modelName], dataMarkers[dataName] )
pointMap[modelName] = (modelMarkers[modelName], dataMarkers[dataName])
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' found in data" + (" as '" + dataName +"'" if (dataName != modelName) else ""))
print("Align: Model marker '" + modelName + "' found in data" +
(" as '" + dataName + "'" if (dataName != modelName) else ""))
dataMarkers.pop(dataName)
break
else:
Expand All @@ -265,7 +267,7 @@ def _optimiseAlignment(self, pointMap):
:param pointMap: dict name -> (modelCoordinates, dataCoordinates)
"""
assert len(pointMap) >= 3, "Align: Only " + str(len(pointMap)) + " points - need at least 3"
region = self._fitter._context.createRegion()
region = self._fitter.getContext().createRegion()
fieldmodule = region.getFieldmodule()

with ChangeManager(fieldmodule):
Expand All @@ -281,7 +283,7 @@ def _optimiseAlignment(self, pointMap):
datasum = [0.0, 0.0, 0.0]
modelMin = copy.deepcopy(list(pointMap.values())[0][0])
modelMax = copy.deepcopy(list(pointMap.values())[0][0])
dataMin = copy.deepcopy(list(pointMap.values())[0][1])
dataMin = copy.deepcopy(list(pointMap.values())[0][1])
dataMax = copy.deepcopy(list(pointMap.values())[0][1])
for name, positions in pointMap.items():
modelx = positions[0]
Expand All @@ -292,7 +294,8 @@ def _optimiseAlignment(self, pointMap):
fieldcache.setNode(node)
result1 = modelCoordinates.assignReal(fieldcache, positions[0])
result2 = dataCoordinates.assignReal(fieldcache, positions[1])
assert (result1 == RESULT_OK) and (result2 == RESULT_OK), "Align: Failed to set up data for alignment to markers optimisation"
assert (result1 == RESULT_OK) and (result2 == RESULT_OK), \
"Align: Failed to set up data for alignment to markers optimisation"
for c in range(3):
modelMin[c] = min(modelx[c], modelMin[c])
modelMax[c] = max(modelx[c], modelMax[c])
Expand All @@ -314,8 +317,8 @@ def _optimiseAlignment(self, pointMap):
translationScaleFactor = 1.0
first = True
fieldcache = fieldmodule.createFieldcache()
modelCoordinatesTransformed, rotation, scale, translation = createFieldsTransformations(modelCoordinates,
scale_value=scaleFactor)
modelCoordinatesTransformed, rotation, scale, translation = \
createFieldsTransformations(modelCoordinates, scale_value=scaleFactor)

# create objective = sum of squares of vector from modelCoordinatesTransformed to dataCoordinates
markerDiff = fieldmodule.createFieldSubtract(dataCoordinates, modelCoordinatesTransformed)
Expand Down Expand Up @@ -345,41 +348,42 @@ def _optimiseAlignment(self, pointMap):
rotation.assignReal(fieldcache, minRotationAngles)
translation.assignReal(fieldcache, minTranslation)

assert objective.isValid(), "Align: Failed to set up objective function for alignment to markers optimisation"
assert objective.isValid(), \
"Align: Failed to set up objective function for alignment to markers optimisation"

optimisation = fieldmodule.createOptimisation()
optimisation.setMethod(Optimisation.METHOD_LEAST_SQUARES_QUASI_NEWTON)
#optimisation.setMethod(Optimisation.METHOD_QUASI_NEWTON)
# optimisation.setMethod(Optimisation.METHOD_QUASI_NEWTON)
optimisation.addObjectiveField(objective)
optimisation.addDependentField(rotation)
optimisation.addDependentField(scale)
optimisation.addDependentField(translation)

#FunctionTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE)
#GradientTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE)
#StepTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE)
#MaximumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP)
#MinimumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP)
#LinesearchTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE)
#TrustRegionSize = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE)

#tol_scale = dataScale*dataScale
#FunctionTolerance *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE, FunctionTolerance)
#GradientTolerance *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE, GradientTolerance)
#StepTolerance *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE, StepTolerance)
#MaximumStep *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP, MaximumStep)
#MinimumStep *= tol_scale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP, MinimumStep)
#LinesearchTolerance *= dataScale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE, LinesearchTolerance)
#TrustRegionSize *= dataScale
#optimisation.setAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE, TrustRegionSize)

#if self.getDiagnosticLevel() > 0:
# FunctionTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE)
# GradientTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE)
# StepTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE)
# MaximumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP)
# MinimumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP)
# LinesearchTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE)
# TrustRegionSize = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE)

# tol_scale = dataScale*dataScale
# FunctionTolerance *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE, FunctionTolerance)
# GradientTolerance *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE, GradientTolerance)
# StepTolerance *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE, StepTolerance)
# MaximumStep *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP, MaximumStep)
# MinimumStep *= tol_scale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP, MinimumStep)
# LinesearchTolerance *= dataScale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE, LinesearchTolerance)
# TrustRegionSize *= dataScale
# optimisation.setAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE, TrustRegionSize)

# if self.getDiagnosticLevel() > 0:
# print("Function Tolerance", FunctionTolerance)
# print("Gradient Tolerance", GradientTolerance)
# print("Step Tolerance", StepTolerance)
Expand All @@ -397,10 +401,12 @@ def _optimiseAlignment(self, pointMap):
result1, self._rotation = rotation.evaluateReal(fieldcache, 3)
result2, self._scale = scale.evaluateReal(fieldcache, 1)
result3, self._translation = translation.evaluateReal(fieldcache, 3)
self._translation = [ s*translationScaleFactor for s in self._translation ]
assert (result1 == RESULT_OK) and (result2 == RESULT_OK) and (result3 == RESULT_OK), "Align: Failed to evaluate transformation for alignment to markers"
self._translation = [s*translationScaleFactor for s in self._translation]
assert (result1 == RESULT_OK) and (result2 == RESULT_OK) and (result3 == RESULT_OK), \
"Align: Failed to evaluate transformation for alignment to markers"


def evaluate_field_mesh_integral(field : Field, coordinates : Field, mesh: Mesh, number_of_points = 4):
def evaluate_field_mesh_integral(field: Field, coordinates: Field, mesh: Mesh, number_of_points=4):
"""
Integrate value of a field over mesh using Gaussian Quadrature.
:param field: Field to integrate over mesh.
Expand Down
3 changes: 2 additions & 1 deletion src/scaffoldfitter/fitterstepconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scaffoldfitter.fitterstep import FitterStep
import sys


class FitterStepConfig(FitterStep):

_jsonTypeId = "_FitterStepConfig"
Expand All @@ -19,7 +20,7 @@ def __init__(self):
def getJsonTypeId(cls):
return cls._jsonTypeId

def decodeSettingsJSONDict(self, dctIn : dict):
def decodeSettingsJSONDict(self, dctIn: dict):
"""
Decode definition of step from JSON dict.
"""
Expand Down
Loading