diff --git a/setup.py b/setup.py index 29bdfcd..a7dfb66 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def readfile(filename, split=False): # into the 'requirements.txt' file. requires = [ # minimal requirements listing - "opencmiss.math", + "opencmiss.maths", "opencmiss.utils >= 0.3", "opencmiss.zinc >= 3.4" ] diff --git a/src/scaffoldfitter/fitter.py b/src/scaffoldfitter/fitter.py index 52cb82d..10c0730 100644 --- a/src/scaffoldfitter/fitter.py +++ b/src/scaffoldfitter/fitter.py @@ -3,12 +3,12 @@ """ import json -import sys from opencmiss.maths.vectorops import sub -from opencmiss.utils.zinc.field import assignFieldParameters, createFieldFiniteElementClone, getGroupList, getManagedFieldNames, \ - findOrCreateFieldFiniteElement, findOrCreateFieldGroup, findOrCreateFieldNodeGroup, findOrCreateFieldStoredMeshLocation, \ +from opencmiss.utils.zinc.field import assignFieldParameters, createFieldFiniteElementClone, getGroupList, \ + findOrCreateFieldFiniteElement, findOrCreateFieldStoredMeshLocation, \ getUniqueFieldName, orphanFieldByName -from opencmiss.utils.zinc.finiteelement import evaluateFieldNodesetMean, evaluateFieldNodesetRange, findNodeWithName, getMaximumNodeIdentifier +from opencmiss.utils.zinc.finiteelement import evaluateFieldNodesetMean, evaluateFieldNodesetRange, findNodeWithName, \ + getMaximumNodeIdentifier from opencmiss.utils.zinc.general import ChangeManager from opencmiss.zinc.context import Context from opencmiss.zinc.element import Elementbasis, Elementfieldtemplate @@ -16,11 +16,12 @@ from opencmiss.zinc.result import RESULT_OK, RESULT_WARNING_PART_DONE from scaffoldfitter.fitterstep import FitterStep from scaffoldfitter.fitterstepconfig import FitterStepConfig +from scaffoldfitter.fitterstepfit import FitterStepFit class Fitter: - def __init__(self, zincModelFileName : str, zincDataFileName : str): + def __init__(self, zincModelFileName: str, zincDataFileName: str): """ :param zincModelFileName: Name of zinc file supplying model to fit. :param zincDataFileName: Name of zinc filed supplying data to fit to. @@ -42,7 +43,7 @@ def __init__(self, zincModelFileName : str, zincDataFileName : str): self._fibreField = None self._fibreFieldName = None self._mesh = [] # [dimension - 1] - self._dataHostLocationField = None # stored mesh location field in highest dimension mesh for all data and markers + self._dataHostLocationField = None # stored mesh location field in highest dimension mesh for all data, markers self._dataHostCoordinatesField = None # embedded field giving host coordinates at data location self._dataDeltaField = None # self._dataHostCoordinatesField - self._markerDataCoordinatesField self._dataErrorField = None # magnitude of _dataDeltaField @@ -68,13 +69,15 @@ def __init__(self, zincModelFileName : str, zincDataFileName : str): self._strainActiveMeshGroup = None # group owning active elements with strain penalties self._curvaturePenaltyField = None # field storing curvature penalty as per-element constant self._curvatureActiveMeshGroup = None # group owning active elements with curvature penalties + self._dataCentre = [0.0, 0.0, 0.0] + self._dataScale = 1.0 self._diagnosticLevel = 0 # must always have an initial FitterStepConfig - which can never be removed self._fitterSteps = [] fitterStep = FitterStepConfig() self.addFitterStep(fitterStep) - def decodeSettingsJSON(self, s : str, decoder): + def decodeSettingsJSON(self, s: str, decoder): """ Define Fitter from JSON serialisation output by encodeSettingsJSON. :param s: String of JSON encoded Fitter settings. @@ -83,16 +86,16 @@ def decodeSettingsJSON(self, s : str, decoder): # clear fitter steps and load from json. Later assert there is an initial config step oldFitterSteps = self._fitterSteps self._fitterSteps = [] - dct = json.loads(s, object_hook=lambda dct: decoder(self, dct)) + settings = json.loads(s, object_hook=lambda dct: decoder(self, dct)) # self._fitterSteps will already be populated by decoder # ensure there is a first config step: if (len(self._fitterSteps) > 0) and isinstance(self._fitterSteps[0], FitterStepConfig): # field names are read (default to None), fields are found on load - self._modelCoordinatesFieldName = dct.get("modelCoordinatesField") - self._dataCoordinatesFieldName = dct.get("dataCoordinatesField") - self._fibreFieldName = dct.get("fibreField") - self._markerGroupName = dct.get("markerGroup") - self._diagnosticLevel = dct["diagnosticLevel"] + self._modelCoordinatesFieldName = settings.get("modelCoordinatesField") + self._dataCoordinatesFieldName = settings.get("dataCoordinatesField") + self._fibreFieldName = settings.get("fibreField") + self._markerGroupName = settings.get("markerGroup") + self._diagnosticLevel = settings["diagnosticLevel"] else: self._fitterSteps = oldFitterSteps assert False, "Missing initial config step" @@ -102,12 +105,12 @@ def encodeSettingsJSON(self) -> str: :return: String JSON encoding of Fitter settings. """ dct = { - "modelCoordinatesField" : self._modelCoordinatesFieldName, - "dataCoordinatesField" : self._dataCoordinatesFieldName, - "fibreField" : self._fibreFieldName, - "markerGroup" : self._markerGroupName, - "diagnosticLevel" : self._diagnosticLevel, - "fitterSteps" : [ fitterStep.encodeSettingsJSONDict() for fitterStep in self._fitterSteps ] + "modelCoordinatesField": self._modelCoordinatesFieldName, + "dataCoordinatesField": self._dataCoordinatesFieldName, + "fibreField": self._fibreFieldName, + "markerGroup": self._markerGroupName, + "diagnosticLevel": self._diagnosticLevel, + "fitterSteps": [fitterStep.encodeSettingsJSONDict() for fitterStep in self._fitterSteps] } return json.dumps(dct, sort_keys=False, indent=4) @@ -117,7 +120,7 @@ def getInitialFitterStepConfig(self): """ return self._fitterSteps[0] - def getInheritFitterStep(self, refFitterStep : FitterStep): + def getInheritFitterStep(self, refFitterStep: FitterStep): """ Get last FitterStep of same type as refFitterStep or None if refFitterStep is the first. @@ -128,7 +131,7 @@ def getInheritFitterStep(self, refFitterStep : FitterStep): return self._fitterSteps[index] return None - def getInheritFitterStepConfig(self, refFitterStep : FitterStep): + def getInheritFitterStepConfig(self, refFitterStep: FitterStep): """ Get last FitterStepConfig applicable to refFitterStep or None if refFitterStep is the first. @@ -138,7 +141,7 @@ def getInheritFitterStepConfig(self, refFitterStep : FitterStep): return self._fitterSteps[index] return None - def getActiveFitterStepConfig(self, refFitterStep : FitterStep): + def getActiveFitterStepConfig(self, refFitterStep: FitterStep): """ Get latest FitterStepConfig applicable to refFitterStep. Can be itself. @@ -148,18 +151,19 @@ def getActiveFitterStepConfig(self, refFitterStep : FitterStep): return self._fitterSteps[index] assert False, "getActiveFitterStepConfig. Could not find config." - def addFitterStep(self, fitterStep : FitterStep, refFitterStep=None): + def addFitterStep(self, fitterStep: FitterStep, refFitterStep=None): """ + :param fitterStep: FitterStep to add. :param refFitterStep: FitterStep to insert after, or None to append. """ - assert fitterStep.getFitter() == None + assert fitterStep.getFitter() is None if refFitterStep: self._fitterSteps.insert(self._fitterSteps.index(refFitterStep) + 1, fitterStep) else: self._fitterSteps.append(fitterStep) - fitterStep._setFitter(self) + fitterStep.setFitter(self) - def removeFitterStep(self, fitterStep : FitterStep): + def removeFitterStep(self, fitterStep: FitterStep): """ Remove fitterStep from Fitter. :param fitterStep: FitterStep to remove. Must not be initial config. @@ -168,7 +172,7 @@ def removeFitterStep(self, fitterStep : FitterStep): assert fitterStep is not self.getInitialFitterStepConfig() index = self._fitterSteps.index(fitterStep) self._fitterSteps.remove(fitterStep) - fitterStep._setFitter(None) + fitterStep.setFitter(None) if index >= len(self._fitterSteps): index = -1 return self._fitterSteps[index] @@ -179,7 +183,7 @@ def _clearFields(self): self._dataCoordinatesField = None self._fibreField = None self._mesh = [] # [dimension - 1] - self._dataHostLocationField = None # stored mesh location field in highest dimension mesh for all data and markers + self._dataHostLocationField = None # stored mesh location field in highest dimension mesh for all data, markers self._dataHostCoordinatesField = None # embedded field giving host coordinates at data location self._dataDeltaField = None # self._dataHostCoordinatesField - self._markerDataCoordinatesField self._dataErrorField = None # magnitude of _dataDeltaField @@ -220,7 +224,7 @@ def load(self): # get centre and scale of data coordinates to manage fitting tolerances and steps datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) minimums, maximums = evaluateFieldNodesetRange(self._dataCoordinatesField, datapoints) - self._dataCentre = [ 0.5*(minimums[c] + maximums[c]) for c in range(3) ] + self._dataCentre = [0.5*(minimums[c] + maximums[c]) for c in range(3)] self._dataScale = max((maximums[c] - minimums[c]) for c in range(3)) if self._diagnosticLevel > 0: print("Load data: data coordinates centre ", self._dataCentre) @@ -253,11 +257,13 @@ def _defineCommonMeshFields(self): print("Scaffoldfitter: dimension < 2. Invalid model?") return with ChangeManager(self._fieldmodule): - self._strainPenaltyField = findOrCreateFieldFiniteElement(self._fieldmodule, "strain_penalty", components_count=(9 if (dimension == 3) else 4)) - self._curvaturePenaltyField = findOrCreateFieldFiniteElement(self._fieldmodule, "curvature_penalty", components_count=(27 if (dimension == 3) else 8)) + self._strainPenaltyField = findOrCreateFieldFiniteElement( + self._fieldmodule, "strain_penalty", components_count=(9 if (dimension == 3) else 4)) + self._curvaturePenaltyField = findOrCreateFieldFiniteElement( + self._fieldmodule, "curvature_penalty", components_count=(27 if (dimension == 3) else 8)) activeMeshGroups = [] for defname in ["deform", "strain", "curvature"]: - activeMeshName = defname + "_active_group." + mesh.getName() + activeMeshName = defname + "_active_group." + meshName activeElementGroup = self._fieldmodule.findFieldByName(activeMeshName).castElementGroup() if not activeElementGroup.isValid(): activeElementGroup = self._fieldmodule.createFieldElementGroup(mesh) @@ -276,7 +282,7 @@ def _defineCommonMeshFields(self): element = elemIter.next() zeroValues = [0.0]*27 while element.isValid(): - result = element.merge(elementtemplate) + element.merge(elementtemplate) fieldcache.setElement(element) self._strainPenaltyField.assignReal(fieldcache, zeroValues) self._curvaturePenaltyField.assignReal(fieldcache, zeroValues) @@ -293,7 +299,7 @@ def getCurvaturePenaltyField(self): def _loadModel(self): result = self._region.readFile(self._zincModelFileName) assert result == RESULT_OK, "Failed to load model file" + str(self._zincModelFileName) - self._mesh = [ self._fieldmodule.findMeshByDimension(d + 1) for d in range(3) ] + self._mesh = [self._fieldmodule.findMeshByDimension(d + 1) for d in range(3)] self._discoverModelCoordinatesField() self._discoverFibreField() self._defineCommonMeshFields() @@ -311,14 +317,16 @@ def _defineCommonDataFields(self): with ChangeManager(self._fieldmodule): mesh = self.getHighestDimensionMesh() datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) - self._dataHostLocationField = findOrCreateFieldStoredMeshLocation(self._fieldmodule, mesh, "data_location_" + mesh.getName(), managed=False) - self._dataHostCoordinatesField = self._fieldmodule.createFieldEmbedded(self._modelCoordinatesField, self._dataHostLocationField) + self._dataHostLocationField = findOrCreateFieldStoredMeshLocation( + self._fieldmodule, mesh, "data_location_" + mesh.getName(), managed=False) + self._dataHostCoordinatesField = self._fieldmodule.createFieldEmbedded( + self._modelCoordinatesField, self._dataHostLocationField) self._dataHostCoordinatesField.setName(getUniqueFieldName(self._fieldmodule, "data_host_coordinates")) self._dataDeltaField = self._dataHostCoordinatesField - self._dataCoordinatesField self._dataDeltaField.setName(getUniqueFieldName(self._fieldmodule, "data_delta")) self._dataErrorField = self._fieldmodule.createFieldMagnitude(self._dataDeltaField) self._dataErrorField.setName(getUniqueFieldName(self._fieldmodule, "data_error")) - # store weights per-point so can maintain variable weights by for marker and data by group, dimension of host + # store weights per-point so can maintain variable weights for marker and data by group, dimension of host self._dataWeightField = findOrCreateFieldFiniteElement(self._fieldmodule, "data_weight", components_count=1) activeDataName = "active_data.datapoints" activeDataGroup = self._fieldmodule.findFieldByName(activeDataName).castNodeGroup() @@ -339,7 +347,7 @@ def _loadData(self): with ChangeManager(fieldmodule): # rename data groups to match model # future: match with annotation terms - modelGroupNames = [ group.getName() for group in getGroupList(self._fieldmodule) ] + modelGroupNames = [group.getName() for group in getGroupList(self._fieldmodule)] writeDiagnostics = self.getDiagnosticLevel() > 0 for dataGroup in getGroupList(fieldmodule): dataGroupName = dataGroup.getName() @@ -353,9 +361,11 @@ def _loadData(self): result = dataGroup.setName(modelGroupName) if result == RESULT_OK: if writeDiagnostics: - print("Load data: Data group '" + dataGroupName + "' found in model as '" + modelGroupName + "'. Renaming to match.") + print("Load data: Data group '" + dataGroupName + "' found in model as '" + + modelGroupName + "'. Renaming to match.") else: - print("Error: Load data: Data group '" + dataGroupName + "' found in model as '" + modelGroupName + "'. Renaming to match FAILED.") + print("Error: Load data: Data group '" + dataGroupName + "' found in model as '" + + modelGroupName + "'. Renaming to match FAILED.") if fieldmodule.findFieldByName(modelGroupName).isValid(): print(" Reason: field of that name already exists.") break @@ -391,7 +401,7 @@ def _loadData(self): assert result == RESULT_OK, "Failed to write nodes" buffer = buffer.replace(bytes("!#nodeset nodes", "utf-8"), bytes("!#nodeset datapoints", "utf-8")) sir = self._region.createStreaminformationRegion() - srm = sir.createStreamresourceMemoryBuffer(buffer) + sir.createStreamresourceMemoryBuffer(buffer) result = self._region.read(sir) assert result == RESULT_OK, "Failed to load nodes as datapoints" # transfer datapoints to self._region @@ -402,7 +412,7 @@ def _loadData(self): result, buffer = srm.getBuffer() assert result == RESULT_OK, "Failed to write datapoints" sir = self._region.createStreaminformationRegion() - srm = sir.createStreamresourceMemoryBuffer(buffer) + sir.createStreamresourceMemoryBuffer(buffer) result = self._region.read(sir) assert result == RESULT_OK, "Failed to load datapoints" self._discoverDataCoordinatesField() @@ -412,7 +422,7 @@ def run(self, endStep=None, modelFileNameStem=None): """ Run either all remaining fitter steps or up to specified end step. :param endStep: Last fitter step to run, or None to run all. - :param modelFilename: Filename stem for writing intermediate model files. + :param modelFileNameStem: File name stem for writing intermediate model files. :return: True if reloaded (so scene changed), False if not. """ if not endStep: @@ -437,7 +447,7 @@ def run(self, endStep=None, modelFileNameStem=None): def getDataCoordinatesField(self): return self._dataCoordinatesField - def setDataCoordinatesField(self, dataCoordinatesField : Field): + def setDataCoordinatesField(self, dataCoordinatesField: Field): if dataCoordinatesField == self._dataCoordinatesField: return finiteElementField = dataCoordinatesField.castFiniteElement() @@ -458,7 +468,7 @@ def _discoverDataCoordinatesField(self): field = None if self._dataCoordinatesFieldName: field = self._fieldmodule.findFieldByName(self._dataCoordinatesFieldName) - if not ((field) and field.isValid()): + if not (field and field.isValid()): datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) datapoint = datapoints.createNodeiterator().next() if datapoint.isValid(): @@ -467,7 +477,8 @@ def _discoverDataCoordinatesField(self): fielditer = self._fieldmodule.createFielditerator() field = fielditer.next() while field.isValid(): - if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and (field.castFiniteElement().isValid()): + if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and \ + (field.castFiniteElement().isValid()): if field.isDefinedAtLocation(fieldcache): break field = fielditer.next() @@ -478,7 +489,7 @@ def _discoverDataCoordinatesField(self): def getMarkerGroup(self): return self._markerGroup - def setMarkerGroup(self, markerGroup : Field): + def setMarkerGroup(self, markerGroup: Field): self._markerGroup = None self._markerGroupName = None self._markerNodeGroup = None @@ -536,7 +547,7 @@ def setMarkerGroup(self, markerGroup : Field): self._markerDataGroup = None self._calculateMarkerDataLocations() - def assignDataWeights(self, fitterStepFit : FitterStep): + def assignDataWeights(self, fitterStepFit: FitterStepFit): """ Assign values of the weight field for all data and marker points. """ @@ -550,10 +561,10 @@ def assignDataWeights(self, fitterStepFit : FitterStep): dataGroup = self.getGroupDataProjectionNodesetGroup(group) if not dataGroup: continue - meshGroup = self.getGroupDataProjectionMeshGroup(group) - dimension = meshGroup.getDimension() + # meshGroup = self.getGroupDataProjectionMeshGroup(group) + # dimension = meshGroup.getDimension() dataWeight = fitterStepFit.getGroupDataWeight(groupName)[0] - #print("group", groupName, "dimension", dimension, "weight", dataWeight) + # print("group", groupName, "dimension", dimension, "weight", dataWeight) fieldassignment = self._dataWeightField.createFieldassignment( self._fieldmodule.createFieldConstant(dataWeight)) fieldassignment.setNodeset(dataGroup) @@ -562,7 +573,7 @@ def assignDataWeights(self, fitterStepFit : FitterStep): print("Incomplete assignment of data weight for group", groupName, "Result", result) if self._markerDataLocationGroup: markerWeight = fitterStepFit.getGroupDataWeight(self._markerGroupName)[0] - #print("marker weight", markerWeight) + # print("marker weight", markerWeight) fieldassignment = self._dataWeightField.createFieldassignment( self._fieldmodule.createFieldConstant(markerWeight)) fieldassignment.setNodeset(self._markerDataLocationGroup) @@ -571,7 +582,7 @@ def assignDataWeights(self, fitterStepFit : FitterStep): print('Incomplete assignment of marker data weight', result) del fieldassignment - def assignDeformationPenalties(self, fitterStepFit : FitterStep): + def assignDeformationPenalties(self, fitterStepFit: FitterStepFit): """ Assign per-element strain and curvature penalty values and build groups of elements for which they are non-zero. @@ -601,13 +612,16 @@ def assignDeformationPenalties(self, fitterStepFit : FitterStep): else: meshGroup = None groupName = None - groupStrainPenalty, setLocally, inheritable = fitterStepFit.getGroupStrainPenalty(groupName, strainComponents) + groupStrainPenalty, setLocally, inheritable = \ + fitterStepFit.getGroupStrainPenalty(groupName, strainComponents) groupStrainPenaltyNonZero = any((s > 0.0) for s in groupStrainPenalty) - groupStrainSet = setLocally or ((setLocally == False) and inheritable) - groupCurvaturePenalty, setLocally, inheritable = fitterStepFit.getGroupCurvaturePenalty(groupName, curvatureComponents) + groupStrainSet = setLocally or ((setLocally is False) and inheritable) + groupCurvaturePenalty, setLocally, inheritable = \ + fitterStepFit.getGroupCurvaturePenalty(groupName, curvatureComponents) groupCurvaturePenaltyNonZero = any((s > 0.0) for s in groupCurvaturePenalty) - groupCurvatureSet = setLocally or ((setLocally == False) and inheritable) - groups.append( (group, groupName, meshGroup, groupStrainPenalty, groupStrainPenaltyNonZero, groupStrainSet, groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet) ) + groupCurvatureSet = setLocally or ((setLocally is False) and inheritable) + groups.append((group, groupName, meshGroup, groupStrainPenalty, groupStrainPenaltyNonZero, groupStrainSet, + groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet)) with ChangeManager(self._fieldmodule): self._deformActiveMeshGroup.removeAllElements() self._strainActiveMeshGroup.removeAllElements() @@ -621,7 +635,8 @@ def assignDeformationPenalties(self, fitterStepFit : FitterStep): strainPenaltyNonZero = False curvaturePenalty = None curvaturePenaltyNonZero = False - for (group, groupName, meshGroup, groupStrainPenalty, groupStrainPenaltyNonZero, groupStrainSet, groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet) in groups: + for (group, groupName, meshGroup, groupStrainPenalty, groupStrainPenaltyNonZero, groupStrainSet, + groupCurvaturePenalty, groupCurvaturePenaltyNonZero, groupCurvatureSet) in groups: if (not group) or meshGroup.containsElement(element): if (not strainPenalty) and (groupStrainSet or (not group)): strainPenalty = groupStrainPenalty @@ -696,7 +711,8 @@ def _calculateMarkerDataLocations(self): """ self._markerDataLocationGroupField = None self._markerDataLocationGroup = None - if not (self._markerDataGroup and self._markerDataNameField and self._markerNodeGroup and self._markerLocationField and self._markerNameField): + if not (self._markerDataGroup and self._markerDataNameField and self._markerNodeGroup and + self._markerLocationField and self._markerNameField): return markerPrefix = self._markerGroupName @@ -707,7 +723,8 @@ def _calculateMarkerDataLocations(self): with ChangeManager(self._fieldmodule): fieldcache = self._fieldmodule.createFieldcache() self._markerDataLocationGroupField = self._fieldmodule.createFieldNodeGroup(datapoints) - self._markerDataLocationGroupField.setName(getUniqueFieldName(self._fieldmodule, markerPrefix + "_data_location_group")) + self._markerDataLocationGroupField.setName( + getUniqueFieldName(self._fieldmodule, markerPrefix + "_data_location_group")) self._markerDataLocationGroup = self._markerDataLocationGroupField.getNodesetGroup() nodetemplate = self._markerDataGroup.createNodetemplate() nodetemplate.defineField(self._dataHostLocationField) @@ -725,9 +742,11 @@ def _calculateMarkerDataLocations(self): fieldcache.setNode(datapoint) name = self._markerDataNameField.evaluateString(fieldcache) # if this is the only datapoint with name: - if name and findNodeWithName(self._markerDataGroup, self._markerDataNameField, name, ignore_case=True, strip_whitespace=True): + if name and findNodeWithName(self._markerDataGroup, self._markerDataNameField, name, ignore_case=True, + strip_whitespace=True): result, dataCoordinates = self._markerDataCoordinatesField.evaluateReal(fieldcache, componentsCount) - node = findNodeWithName(self._markerNodeGroup, self._markerNameField, name, ignore_case=True, strip_whitespace=True) + node = findNodeWithName(self._markerNodeGroup, self._markerNameField, name, ignore_case=True, + strip_whitespace=True) if (result == RESULT_OK) and node: fieldcache.setNode(node) element, xi = self._markerLocationField.evaluateMeshLocation(fieldcache, meshDimension) @@ -736,7 +755,7 @@ def _calculateMarkerDataLocations(self): fieldcache.setNode(datapoint) self._dataHostLocationField.assignMeshLocation(fieldcache, element, xi) if defineDataCoordinates: - result = self._dataCoordinatesField.assignReal(fieldcache, dataCoordinates) + self._dataCoordinatesField.assignReal(fieldcache, dataCoordinates) self._markerDataLocationGroup.addNode(datapoint) datapoint = datapointIter.next() del fieldcache @@ -746,9 +765,11 @@ def _calculateMarkerDataLocations(self): markerNodeGroupSize = self._markerNodeGroup.getSize() if self.getDiagnosticLevel() > 0: if markerDataLocationGroupSize < markerDataGroupSize: - print("Warning: Only " + str(markerDataLocationGroupSize) + " of " + str(markerDataGroupSize) + " marker data points have model locations") + print("Warning: Only " + str(markerDataLocationGroupSize) + + " of " + str(markerDataGroupSize) + " marker data points have model locations") if markerDataLocationGroupSize < markerNodeGroupSize: - print("Warning: Only " + str(markerDataLocationGroupSize) + " of " + str(markerNodeGroupSize) + " marker model locations used") + print("Warning: Only " + str(markerDataLocationGroupSize) + + " of " + str(markerNodeGroupSize) + " marker model locations used") def _discoverMarkerGroup(self): self._markerGroup = None @@ -756,7 +777,8 @@ def _discoverMarkerGroup(self): self._markerLocationField = None self._markerNameField = None self._markerCoordinatesField = None - markerGroup = self._fieldmodule.findFieldByName(self._markerGroupName if self._markerGroupName else "marker").castGroup() + markerGroupName = self._markerGroupName if self._markerGroupName else "marker" + markerGroup = self._fieldmodule.findFieldByName(markerGroupName).castGroup() if not markerGroup.isValid(): markerGroup = None self.setMarkerGroup(markerGroup) @@ -765,8 +787,10 @@ def _updateMarkerCoordinatesField(self): if self._modelCoordinatesField and self._markerLocationField: with ChangeManager(self._fieldmodule): markerPrefix = self._markerGroup.getName() - self._markerCoordinatesField = self._fieldmodule.createFieldEmbedded(self._modelCoordinatesField, self._markerLocationField) - self._markerCoordinatesField.setName(getUniqueFieldName(self._fieldmodule, markerPrefix + "_coordinates")) + self._markerCoordinatesField = \ + self._fieldmodule.createFieldEmbedded(self._modelCoordinatesField, self._markerLocationField) + self._markerCoordinatesField.setName( + getUniqueFieldName(self._fieldmodule, markerPrefix + "_coordinates")) else: self._markerCoordinatesField = None @@ -776,7 +800,7 @@ def getModelCoordinatesField(self): def getModelReferenceCoordinatesField(self): return self._modelReferenceCoordinatesField - def setModelCoordinatesField(self, modelCoordinatesField : Field): + def setModelCoordinatesField(self, modelCoordinatesField: Field): if modelCoordinatesField == self._modelCoordinatesField: return finiteElementField = modelCoordinatesField.castFiniteElement() @@ -785,7 +809,8 @@ def setModelCoordinatesField(self, modelCoordinatesField : Field): self._modelCoordinatesFieldName = modelCoordinatesField.getName() modelReferenceCoordinatesFieldName = "reference_" + self._modelCoordinatesField.getName() orphanFieldByName(self._fieldmodule, modelReferenceCoordinatesFieldName) - self._modelReferenceCoordinatesField = createFieldFiniteElementClone(self._modelCoordinatesField, modelReferenceCoordinatesFieldName) + self._modelReferenceCoordinatesField = \ + createFieldFiniteElementClone(self._modelCoordinatesField, modelReferenceCoordinatesFieldName) self._defineCommonDataFields() self._updateMarkerCoordinatesField() @@ -810,7 +835,8 @@ def _discoverModelCoordinatesField(self): fielditer = self._fieldmodule.createFielditerator() field = fielditer.next() while field.isValid(): - if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and (field.castFiniteElement().isValid()): + if field.isTypeCoordinate() and (field.getNumberOfComponents() == 3) and \ + (field.castFiniteElement().isValid()): if field.isDefinedAtLocation(fieldcache): break field = fielditer.next() @@ -822,16 +848,17 @@ def _discoverModelCoordinatesField(self): def getFibreField(self): return self._fibreField - def setFibreField(self, fibreField : Field): + def setFibreField(self, fibreField: Field): """ Set field used to orient strain and curvature penalties relative to element. :param fibreField: Fibre angles field available on elements, or None to use global x, y, z axes. """ - assert (fibreField is None) or ((fibreField.getValueType() == Field.VALUE_TYPE_REAL) and \ - (fibreField.getNumberOfComponents() <= 3)), "Scaffoldfitter: Invalid fibre field" + assert (fibreField is None) or \ + ((fibreField.getValueType() == Field.VALUE_TYPE_REAL) and (fibreField.getNumberOfComponents() <= 3)), \ + "Scaffoldfitter: Invalid fibre field" self._fibreField = fibreField - self._fibreFieldName = fibreField.getName() if (fibreField) else None + self._fibreFieldName = fibreField.getName() if fibreField else None def _discoverFibreField(self): """ @@ -865,12 +892,14 @@ def _defineDataProjectionFields(self): field.setName(getUniqueFieldName(self._fieldmodule, "data_projection_group_" + mesh.getName())) self._dataProjectionNodeGroupFields.append(field) self._dataProjectionNodesetGroups.append(field.getNodesetGroup()) - self._dataProjectionDirectionField = findOrCreateFieldFiniteElement(self._fieldmodule, "data_projection_direction", - components_count = 3, component_names = [ "x", "y", "z" ]) + self._dataProjectionDirectionField = findOrCreateFieldFiniteElement( + self._fieldmodule, "data_projection_direction", components_count=3, component_names=["x", "y", "z"]) - def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup, meshLocation, activeFitterStepConfig : FitterStepConfig): + def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup, meshLocation, + activeFitterStepConfig: FitterStepConfig): """ Project data points for group. Assumes called while ChangeManager is active for fieldmodule. + :param fieldcache: Fieldcache for zinc field evaluations in region. :param group: The FieldGroup being fitted (parent of dataGroup, meshGroup). :param dataGroup: Nodeset group containing data points to project. :param meshGroup: MeshGroup containing surfaces/lines to project onto. @@ -891,26 +920,30 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup, if result != RESULT_OK: print("Error: Centre Groups projection failed to get mean coordinates of data for group " + groupName) return - #print("Centre Groups dataCentre", dataCentre) + # print("Centre Groups dataCentre", dataCentre) # get geometric centre of meshGroup - meshGroupCoordinatesIntegral = self._fieldmodule.createFieldMeshIntegral(self._modelCoordinatesField, self._modelCoordinatesField, meshGroup) + meshGroupCoordinatesIntegral = self._fieldmodule.createFieldMeshIntegral( + self._modelCoordinatesField, self._modelCoordinatesField, meshGroup) meshGroupCoordinatesIntegral.setNumbersOfPoints([3]) - meshGroupArea = self._fieldmodule.createFieldMeshIntegral(self._fieldmodule.createFieldConstant([1.0]), self._modelCoordinatesField, meshGroup) + meshGroupArea = self._fieldmodule.createFieldMeshIntegral( + self._fieldmodule.createFieldConstant([1.0]), self._modelCoordinatesField, meshGroup) meshGroupArea.setNumbersOfPoints([3]) - result1, coordinatesIntegral = meshGroupCoordinatesIntegral.evaluateReal(fieldcache, self._modelCoordinatesField.getNumberOfComponents()) + result1, coordinatesIntegral = meshGroupCoordinatesIntegral.evaluateReal( + fieldcache, self._modelCoordinatesField.getNumberOfComponents()) result2, area = meshGroupArea.evaluateReal(fieldcache, 1) if (result1 != RESULT_OK) or (result2 != RESULT_OK) or (area <= 0.0): print("Error: Centre Groups projection failed to get mean coordinates of mesh for group " + groupName) return - meshCentre = [ s/area for s in coordinatesIntegral ] - #print("Centre Groups meshCentre", meshCentre) + meshCentre = [s/area for s in coordinatesIntegral] + # print("Centre Groups meshCentre", meshCentre) # offset dataCoordinates to make dataCentre coincide with meshCentre dataCoordinates = dataCoordinates + self._fieldmodule.createFieldConstant(sub(meshCentre, dataCentre)) # find nearest locations on 1-D or 2-D feature but store on highest dimension mesh highestDimensionMesh = self.getHighestDimensionMesh() highestDimension = highestDimensionMesh.getDimension() - findLocation = self._fieldmodule.createFieldFindMeshLocation(dataCoordinates, self._modelCoordinatesField, highestDimensionMesh) + findLocation = self._fieldmodule.createFieldFindMeshLocation(dataCoordinates, self._modelCoordinatesField, + highestDimensionMesh) assert RESULT_OK == findLocation.setSearchMesh(meshGroup) findLocation.setSearchMode(FieldFindMeshLocation.SEARCH_MODE_NEAREST) nodeIter = dataGroup.createNodeiterator() @@ -924,18 +957,20 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup, element, xi = findLocation.evaluateMeshLocation(fieldcache, highestDimension) if element.isValid(): result = meshLocation.assignMeshLocation(fieldcache, element, xi) - assert result == RESULT_OK, "Error: Failed to assign data projection mesh location for group " + groupName + assert result == RESULT_OK, \ + "Error: Failed to assign data projection mesh location for group " + groupName dataProjectionNodesetGroup.addNode(node) node = nodeIter.next() pointsProjected = dataProjectionNodesetGroup.getSize() - sizeBefore if pointsProjected < dataGroup.getSize(): if self.getDiagnosticLevel() > 0: - print("Warning: Only " + str(pointsProjected) + " of " + str(dataGroup.getSize()) + " data points projected for group " + groupName) + print("Warning: Only " + str(pointsProjected) + " of " + str(dataGroup.getSize()) + + " data points projected for group " + groupName) # add to active group self._activeDataNodesetGroup.addNodesConditional(self._dataProjectionNodeGroupFields[dimension - 1]) return - def getGroupDataProjectionNodesetGroup(self, group : FieldGroup): + def getGroupDataProjectionNodesetGroup(self, group: FieldGroup): """ :return: Data NodesetGroup containing points for projection of group, otherwise None. """ @@ -947,7 +982,7 @@ def getGroupDataProjectionNodesetGroup(self, group : FieldGroup): return dataGroup return None - def getGroupDataProjectionMeshGroup(self, group : FieldGroup): + def getGroupDataProjectionMeshGroup(self, group: FieldGroup): """ :return: 2D if not 1D meshGroup containing elements for projecting data in group, otherwise None. """ @@ -959,7 +994,7 @@ def getGroupDataProjectionMeshGroup(self, group : FieldGroup): return meshGroup return None - def calculateDataProjections(self, fitterStep : FitterStep): + def calculateDataProjections(self, fitterStep: FitterStep): """ Find projections of datapoints' coordinates onto model coordinates, by groups i.e. from datapoints group onto matching 2-D or 1-D mesh group. @@ -973,7 +1008,6 @@ def calculateDataProjections(self, fitterStep : FitterStep): if self._markerDataLocationGroupField: self._activeDataNodesetGroup.addNodesConditional(self._markerDataLocationGroupField) - findMeshLocation = None datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) fieldcache = self._fieldmodule.createFieldcache() for d in range(2): @@ -990,13 +1024,13 @@ def calculateDataProjections(self, fitterStep : FitterStep): if group != self._markerGroup: print("Warning: Cannot project data for group " + groupName + " as no matching mesh group") continue - dimension = meshGroup.getDimension() if groupName not in self._dataProjectionGroupNames: self._dataProjectionGroupNames.append(groupName) # so only define mesh location, or warn once fieldcache.setNode(dataGroup.createNodeiterator().next()) if not self._dataCoordinatesField.isDefinedAtLocation(fieldcache): if self.getDiagnosticLevel() > 0: - print("Warning: Cannot project data for group " + groupName + " as field " + self._dataCoordinatesField.getName() + " is not defined on data") + print("Warning: Cannot project data for group " + groupName + + " as field " + self._dataCoordinatesField.getName() + " is not defined on data") continue # define self._dataHostLocationField and self._dataProjectionDirectionField on data Group: nodetemplate = datapoints.createNodetemplate() @@ -1007,11 +1041,12 @@ def calculateDataProjections(self, fitterStep : FitterStep): nodeIter = dataGroup.createNodeiterator() node = nodeIter.next() while node.isValid(): - result = node.merge(nodetemplate) - #print("node",node.getIdentifier(),"result",result) + node.merge(nodetemplate) + # print("node",node.getIdentifier(),"result",result) node = nodeIter.next() del nodetemplate - self.calculateGroupDataProjections(fieldcache, group, dataGroup, meshGroup, self._dataHostLocationField, activeFitterStepConfig) + self.calculateGroupDataProjections(fieldcache, group, dataGroup, meshGroup, self._dataHostLocationField, + activeFitterStepConfig) # Store data projection directions for dimension in range(1, 3): @@ -1021,19 +1056,21 @@ def calculateDataProjections(self, fitterStep : FitterStep): self._fieldmodule.createFieldNormalise(self._dataDeltaField)) fieldassignment.setNodeset(nodesetGroup) result = fieldassignment.assign() - assert result in [ RESULT_OK, RESULT_WARNING_PART_DONE ], \ + assert result in [RESULT_OK, RESULT_WARNING_PART_DONE], \ "Error: Failed to assign data projection directions for dimension " + str(dimension) del fieldassignment if self.getDiagnosticLevel() > 0: # Warn about unprojected points unprojectedDatapoints = self._fieldmodule.createFieldNodeGroup(datapoints).getNodesetGroup() - unprojectedDatapoints.addNodesConditional(self._fieldmodule.createFieldIsDefined(self._dataCoordinatesField)) + unprojectedDatapoints.addNodesConditional( + self._fieldmodule.createFieldIsDefined(self._dataCoordinatesField)) for d in range(2): unprojectedDatapoints.removeNodesConditional(self._dataProjectionNodeGroupFields[d]) unprojectedCount = unprojectedDatapoints.getSize() if unprojectedCount > 0: - print("Warning: " + str(unprojectedCount) + " data points with data coordinates have not been projected") + print("Warning: " + str(unprojectedCount) + + " data points with data coordinates have not been projected") del unprojectedDatapoints # remove temporary objects before ChangeManager exits @@ -1085,7 +1122,10 @@ def getMarkerDataLocationNodesetGroup(self): return self._markerDataLocationGroup def getMarkerDataLocationField(self): - return self._markerDataLocationField + """ + Same as for all other data points. + """ + return self._dataHostLocationField def getContext(self): return self._context @@ -1119,10 +1159,11 @@ def getHighestDimensionMesh(self): return mesh return None - def evaluateNodeGroupMeanCoordinates(self, groupName, coordinatesFieldName, isData = False): + def evaluateNodeGroupMeanCoordinates(self, groupName, coordinatesFieldName, isData=False): group = self._fieldmodule.findFieldByName(groupName).castGroup() assert group.isValid() - nodeset = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS if isData else Field.DOMAIN_TYPE_NODES) + nodeset = self._fieldmodule.findNodesetByFieldDomainType( + Field.DOMAIN_TYPE_DATAPOINTS if isData else Field.DOMAIN_TYPE_NODES) nodesetGroup = group.getFieldNodeGroup(nodeset).getNodesetGroup() assert nodesetGroup.isValid() coordinates = self._fieldmodule.findFieldByName(coordinatesFieldName) @@ -1133,7 +1174,8 @@ def getDiagnosticLevel(self): def setDiagnosticLevel(self, diagnosticLevel): """ - :param diagnosticLevel: 0 = no diagnostic messages. 1 = Information and warning messages. 2 = Also optimisation reports. + :param diagnosticLevel: 0 = no diagnostic messages. 1 = Information and warning messages. + 2 = Also optimisation reports. """ assert diagnosticLevel >= 0 self._diagnosticLevel = diagnosticLevel @@ -1144,19 +1186,31 @@ def updateModelReferenceCoordinates(self): def writeModel(self, modelFileName): """ Write model nodes and elements with model coordinates field to file. + Note: Output field name is prefixed with "fitted ". """ - sir = self._region.createStreaminformationRegion() - sir.setRecursionMode(sir.RECURSION_MODE_OFF) - srf = sir.createStreamresourceFile(modelFileName) - sir.setResourceFieldNames(srf, [self._modelCoordinatesFieldName]) - sir.setResourceDomainTypes(srf, Field.DOMAIN_TYPE_NODES | Field.DOMAIN_TYPE_MESH1D | Field.DOMAIN_TYPE_MESH2D | Field.DOMAIN_TYPE_MESH3D) - result = self._region.write(sir) - #loggerMessageCount = self._logger.getNumberOfMessages() - #if loggerMessageCount > 0: - # for i in range(1, loggerMessageCount + 1): - # print(self._logger.getMessageTypeAtIndex(i), self._logger.getMessageTextAtIndex(i)) - # self._logger.removeAllMessages() - assert result == RESULT_OK + with ChangeManager(self._fieldmodule): + # temporarily rename model coordinates field to prefix with "fitted " + # so can be used along with original coordinates in later steps + outputCoordinatesFieldName = "fitted " + self._modelCoordinatesFieldName; + self._modelCoordinatesField.setName(outputCoordinatesFieldName) + + sir = self._region.createStreaminformationRegion() + sir.setRecursionMode(sir.RECURSION_MODE_OFF) + srf = sir.createStreamresourceFile(modelFileName) + sir.setResourceFieldNames(srf, [outputCoordinatesFieldName]) + sir.setResourceDomainTypes(srf, Field.DOMAIN_TYPE_NODES | + Field.DOMAIN_TYPE_MESH1D | Field.DOMAIN_TYPE_MESH2D | Field.DOMAIN_TYPE_MESH3D) + result = self._region.write(sir) + # loggerMessageCount = self._logger.getNumberOfMessages() + # if loggerMessageCount > 0: + # for i in range(1, loggerMessageCount + 1): + # print(self._logger.getMessageTypeAtIndex(i), self._logger.getMessageTextAtIndex(i)) + # self._logger.removeAllMessages() + + # restore original name + self._modelCoordinatesField.setName(self._modelCoordinatesFieldName) + + assert result == RESULT_OK def writeData(self, fileName): sir = self._region.createStreaminformationRegion() diff --git a/src/scaffoldfitter/fitterstep.py b/src/scaffoldfitter/fitterstep.py index 17e4aa6..4bd4a4f 100644 --- a/src/scaffoldfitter/fitterstep.py +++ b/src/scaffoldfitter/fitterstep.py @@ -1,6 +1,7 @@ """ Base class for fitter steps. """ +import abc class FitterStep: @@ -33,13 +34,18 @@ def getDefaultGroupName(cls): def getFitter(self): return self._fitter - def _setFitter(self, fitter): - ''' + def setFitter(self, fitter): + """ Should only be called by Fitter when adding or removing from it. - ''' + """ self._fitter = fitter - def decodeSettingsJSONDict(self, dctIn : dict): + @classmethod + @abc.abstractmethod + def getJsonTypeId(cls): + pass + + def decodeSettingsJSONDict(self, dctIn: dict): """ Decode definition of step from JSON dict. """ @@ -55,8 +61,8 @@ def encodeSettingsJSONDict(self) -> dict: :return: Settings in a dict ready for passing to json.dump. """ return { - self.getJsonTypeId() : True, - "groupSettings" : self._groupSettings + self.getJsonTypeId(): True, + "groupSettings": self._groupSettings } def getGroupSettingsNames(self): @@ -65,7 +71,7 @@ def getGroupSettingsNames(self): """ return list(self._groupSettings.keys()) - def clearGroupSetting(self, groupName : str, settingName : str): + def clearGroupSetting(self, groupName: str, settingName: str): """ Clear setting for group, removing group settings dict if empty. :param groupName: Exact model group name, or None for default group. @@ -79,7 +85,7 @@ def clearGroupSetting(self, groupName : str, settingName : str): if len(groupSettings) == 0: self._groupSettings.pop(groupName) - def _getInheritedGroupSetting(self, groupName : str, settingName : str): + def _getInheritedGroupSetting(self, groupName: str, settingName: str): """ :param groupName: Exact model group name, or None for default group. :param settingName: Exact setting name. @@ -98,7 +104,7 @@ def _getInheritedGroupSetting(self, groupName : str, settingName : str): if inheritedValue != "": return inheritedValue - def getGroupSetting(self, groupName : str, settingName : str, defaultValue): + def getGroupSetting(self, groupName: str, settingName: str, defaultValue): """ Get group setting of supplied name, with reset & inherit ability. :param groupName: Exact model group name, or None for default group. @@ -130,7 +136,7 @@ def getGroupSetting(self, groupName : str, settingName : str, defaultValue): value = defaultValue return value, setLocally, inheritable - def setGroupSetting(self, groupName : str, settingName : str, value): + def setGroupSetting(self, groupName: str, settingName: str, value): """ Set value of setting or None to reset to default. :param groupName: Exact model group name, or None for default group. diff --git a/src/scaffoldfitter/fitterstepalign.py b/src/scaffoldfitter/fitterstepalign.py index ee17220..1357f2c 100644 --- a/src/scaffoldfitter/fitterstepalign.py +++ b/src/scaffoldfitter/fitterstepalign.py @@ -5,7 +5,7 @@ import copy import math from opencmiss.maths.vectorops import div, euler_to_rotation_matrix, matrix_vector_mult, mult, sub -from opencmiss.utils.zinc.field import assignFieldParameters, get_group_list, create_field_euler_angles_rotation_matrix +from opencmiss.utils.zinc.field import get_group_list, create_field_euler_angles_rotation_matrix from opencmiss.utils.zinc.finiteelement import evaluate_field_nodeset_mean, getNodeNameCentres from opencmiss.utils.zinc.general import ChangeManager from opencmiss.zinc.element import Mesh @@ -15,8 +15,8 @@ from scaffoldfitter.fitterstep import FitterStep -def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_value=1.0, \ - translation_offsets=None, translation_scale_factor=1.0): +def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_value=1.0, + translation_offsets=None, translation_scale_factor=1.0): """ Create constant fields for rotation, scale and translation containing the supplied values, plus the transformed coordinates applying them in the supplied order. @@ -44,11 +44,13 @@ def createFieldsTransformations(coordinates: Field, rotation_angles=None, scale_ translation = fieldmodule.createFieldConstant(translation_offsets) rotation_matrix = create_field_euler_angles_rotation_matrix(fieldmodule, rotation) rotated_coordinates = fieldmodule.createFieldMatrixMultiply(components_count, rotation_matrix, coordinates) - transformed_coordinates = rotated_coordinates*scale + (translation if (translation_scale_factor == 1.0) else \ - translation*fieldmodule.createFieldConstant([ translation_scale_factor ]*components_count)) + transformed_coordinates = rotated_coordinates*scale + ( + translation if (translation_scale_factor == 1.0) else + translation*fieldmodule.createFieldConstant([translation_scale_factor]*components_count)) assert transformed_coordinates.isValid() return transformed_coordinates, rotation, scale, translation + class FitterStepAlign(FitterStep): _jsonTypeId = "_FitterStepAlign" @@ -57,15 +59,15 @@ def __init__(self): super(FitterStepAlign, self).__init__() self._alignGroups = False self._alignMarkers = False - self._rotation = [ 0.0, 0.0, 0.0 ] + self._rotation = [0.0, 0.0, 0.0] self._scale = 1.0 - self._translation = [ 0.0, 0.0, 0.0 ] + self._translation = [0.0, 0.0, 0.0] @classmethod def getJsonTypeId(cls): return cls._jsonTypeId - def decodeSettingsJSONDict(self, dctIn : dict): + def decodeSettingsJSONDict(self, dctIn: dict): """ Decode definition of step from JSON dict. """ @@ -86,11 +88,11 @@ def encodeSettingsJSONDict(self) -> dict: """ dct = super().encodeSettingsJSONDict() dct.update({ - "alignGroups" : self._alignGroups, - "alignMarkers" : self._alignMarkers, - "rotation" : self._rotation, - "scale" : self._scale, - "translation" : self._translation + "alignGroups": self._alignGroups, + "alignMarkers": self._alignMarkers, + "rotation": self._rotation, + "scale": self._scale, + "translation": self._translation }) return dct @@ -177,14 +179,14 @@ def run(self, modelFileNameStem=None): assert modelCoordinates, "Align: Missing model coordinates" if self._alignGroups or self._alignMarkers: self._doAutoAlign() - fieldmodule = self._fitter._fieldmodule + fieldmodule = self._fitter.getFieldmodule() with ChangeManager(fieldmodule): # rotate, scale and translate model modelCoordinatesTransformed = createFieldsTransformations( modelCoordinates, self._rotation, self._scale, self._translation)[0] - fieldassignment = self._fitter._modelCoordinatesField.createFieldassignment(modelCoordinatesTransformed) + fieldassignment = self._fitter.getModelCoordinatesField().createFieldassignment(modelCoordinatesTransformed) result = fieldassignment.assign() - assert result in [ RESULT_OK, RESULT_WARNING_PART_DONE ], "Align: Failed to transform model" + assert result in [RESULT_OK, RESULT_WARNING_PART_DONE], "Align: Failed to transform model" self._fitter.updateModelReferenceCoordinates() del fieldassignment del modelCoordinatesTransformed @@ -198,12 +200,10 @@ def _doAutoAlign(self): Perform auto alignment to groups and/or markers. """ modelCoordinates = self._fitter.getModelCoordinatesField() - componentsCount = modelCoordinates.getNumberOfComponents() pointMap = {} # dict group/marker name -> (modelCoordinates, dataCoordinates) if self._alignGroups: - fieldmodule = self._fitter._fieldmodule - datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + fieldmodule = self._fitter.getFieldmodule() dataCoordinates = self._fitter.getDataCoordinatesField() groups = get_group_list(fieldmodule) with ChangeManager(fieldmodule): @@ -220,20 +220,21 @@ def _doAutoAlign(self): coordinates_integral = evaluate_field_mesh_integral(modelCoordinates, modelCoordinates, meshGroup) mass = evaluate_field_mesh_integral(one, modelCoordinates, meshGroup) meanModelCoordinates = div(coordinates_integral, mass) - pointMap[groupName] = ( meanModelCoordinates, meanDataCoordinates ) + pointMap[groupName] = (meanModelCoordinates, meanDataCoordinates) del one if self._alignMarkers: markerGroup = self._fitter.getMarkerGroup() assert markerGroup, "Align: No marker group to align with" - markerPrefix = markerGroup.getName() markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields() - assert markerNodeGroup and markerCoordinates and markerName, "Align: No marker group, coordinates or name fields" + assert markerNodeGroup and markerCoordinates and markerName, \ + "Align: No marker group, coordinates or name fields" modelMarkers = getNodeNameCentres(markerNodeGroup, markerCoordinates, markerName) markerDataGroup, markerDataCoordinates, markerDataName = self._fitter.getMarkerDataFields() - assert markerDataGroup and markerDataCoordinates and markerDataName, "Align: No marker data group, coordinates or name fields" + assert markerDataGroup and markerDataCoordinates and markerDataName, \ + "Align: No marker data group, coordinates or name fields" dataMarkers = getNodeNameCentres(markerDataGroup, markerDataCoordinates, markerDataName) # match model and data markers, warn of unmatched markers @@ -243,9 +244,10 @@ def _doAutoAlign(self): matchName = modelName.strip().casefold() for dataName in dataMarkers: if dataName.strip().casefold() == matchName: - pointMap[modelName] = ( modelMarkers[modelName], dataMarkers[dataName] ) + pointMap[modelName] = (modelMarkers[modelName], dataMarkers[dataName]) if writeDiagnostics: - print("Align: Model marker '" + modelName + "' found in data" + (" as '" + dataName +"'" if (dataName != modelName) else "")) + print("Align: Model marker '" + modelName + "' found in data" + + (" as '" + dataName + "'" if (dataName != modelName) else "")) dataMarkers.pop(dataName) break else: @@ -265,7 +267,7 @@ def _optimiseAlignment(self, pointMap): :param pointMap: dict name -> (modelCoordinates, dataCoordinates) """ assert len(pointMap) >= 3, "Align: Only " + str(len(pointMap)) + " points - need at least 3" - region = self._fitter._context.createRegion() + region = self._fitter.getContext().createRegion() fieldmodule = region.getFieldmodule() with ChangeManager(fieldmodule): @@ -281,7 +283,7 @@ def _optimiseAlignment(self, pointMap): datasum = [0.0, 0.0, 0.0] modelMin = copy.deepcopy(list(pointMap.values())[0][0]) modelMax = copy.deepcopy(list(pointMap.values())[0][0]) - dataMin = copy.deepcopy(list(pointMap.values())[0][1]) + dataMin = copy.deepcopy(list(pointMap.values())[0][1]) dataMax = copy.deepcopy(list(pointMap.values())[0][1]) for name, positions in pointMap.items(): modelx = positions[0] @@ -292,7 +294,8 @@ def _optimiseAlignment(self, pointMap): fieldcache.setNode(node) result1 = modelCoordinates.assignReal(fieldcache, positions[0]) result2 = dataCoordinates.assignReal(fieldcache, positions[1]) - assert (result1 == RESULT_OK) and (result2 == RESULT_OK), "Align: Failed to set up data for alignment to markers optimisation" + assert (result1 == RESULT_OK) and (result2 == RESULT_OK), \ + "Align: Failed to set up data for alignment to markers optimisation" for c in range(3): modelMin[c] = min(modelx[c], modelMin[c]) modelMax[c] = max(modelx[c], modelMax[c]) @@ -314,8 +317,8 @@ def _optimiseAlignment(self, pointMap): translationScaleFactor = 1.0 first = True fieldcache = fieldmodule.createFieldcache() - modelCoordinatesTransformed, rotation, scale, translation = createFieldsTransformations(modelCoordinates, - scale_value=scaleFactor) + modelCoordinatesTransformed, rotation, scale, translation = \ + createFieldsTransformations(modelCoordinates, scale_value=scaleFactor) # create objective = sum of squares of vector from modelCoordinatesTransformed to dataCoordinates markerDiff = fieldmodule.createFieldSubtract(dataCoordinates, modelCoordinatesTransformed) @@ -345,41 +348,42 @@ def _optimiseAlignment(self, pointMap): rotation.assignReal(fieldcache, minRotationAngles) translation.assignReal(fieldcache, minTranslation) - assert objective.isValid(), "Align: Failed to set up objective function for alignment to markers optimisation" + assert objective.isValid(), \ + "Align: Failed to set up objective function for alignment to markers optimisation" optimisation = fieldmodule.createOptimisation() optimisation.setMethod(Optimisation.METHOD_LEAST_SQUARES_QUASI_NEWTON) - #optimisation.setMethod(Optimisation.METHOD_QUASI_NEWTON) + # optimisation.setMethod(Optimisation.METHOD_QUASI_NEWTON) optimisation.addObjectiveField(objective) optimisation.addDependentField(rotation) optimisation.addDependentField(scale) optimisation.addDependentField(translation) - #FunctionTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE) - #GradientTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE) - #StepTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE) - #MaximumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP) - #MinimumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP) - #LinesearchTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE) - #TrustRegionSize = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE) - - #tol_scale = dataScale*dataScale - #FunctionTolerance *= tol_scale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE, FunctionTolerance) - #GradientTolerance *= tol_scale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE, GradientTolerance) - #StepTolerance *= tol_scale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE, StepTolerance) - #MaximumStep *= tol_scale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP, MaximumStep) - #MinimumStep *= tol_scale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP, MinimumStep) - #LinesearchTolerance *= dataScale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE, LinesearchTolerance) - #TrustRegionSize *= dataScale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE, TrustRegionSize) - - #if self.getDiagnosticLevel() > 0: + # FunctionTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE) + # GradientTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE) + # StepTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE) + # MaximumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP) + # MinimumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP) + # LinesearchTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE) + # TrustRegionSize = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE) + + # tol_scale = dataScale*dataScale + # FunctionTolerance *= tol_scale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE, FunctionTolerance) + # GradientTolerance *= tol_scale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE, GradientTolerance) + # StepTolerance *= tol_scale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE, StepTolerance) + # MaximumStep *= tol_scale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP, MaximumStep) + # MinimumStep *= tol_scale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP, MinimumStep) + # LinesearchTolerance *= dataScale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE, LinesearchTolerance) + # TrustRegionSize *= dataScale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE, TrustRegionSize) + + # if self.getDiagnosticLevel() > 0: # print("Function Tolerance", FunctionTolerance) # print("Gradient Tolerance", GradientTolerance) # print("Step Tolerance", StepTolerance) @@ -397,10 +401,12 @@ def _optimiseAlignment(self, pointMap): result1, self._rotation = rotation.evaluateReal(fieldcache, 3) result2, self._scale = scale.evaluateReal(fieldcache, 1) result3, self._translation = translation.evaluateReal(fieldcache, 3) - self._translation = [ s*translationScaleFactor for s in self._translation ] - assert (result1 == RESULT_OK) and (result2 == RESULT_OK) and (result3 == RESULT_OK), "Align: Failed to evaluate transformation for alignment to markers" + self._translation = [s*translationScaleFactor for s in self._translation] + assert (result1 == RESULT_OK) and (result2 == RESULT_OK) and (result3 == RESULT_OK), \ + "Align: Failed to evaluate transformation for alignment to markers" + -def evaluate_field_mesh_integral(field : Field, coordinates : Field, mesh: Mesh, number_of_points = 4): +def evaluate_field_mesh_integral(field: Field, coordinates: Field, mesh: Mesh, number_of_points=4): """ Integrate value of a field over mesh using Gaussian Quadrature. :param field: Field to integrate over mesh. diff --git a/src/scaffoldfitter/fitterstepconfig.py b/src/scaffoldfitter/fitterstepconfig.py index 59b75a0..2a152f0 100644 --- a/src/scaffoldfitter/fitterstepconfig.py +++ b/src/scaffoldfitter/fitterstepconfig.py @@ -5,6 +5,7 @@ from scaffoldfitter.fitterstep import FitterStep import sys + class FitterStepConfig(FitterStep): _jsonTypeId = "_FitterStepConfig" @@ -19,7 +20,7 @@ def __init__(self): def getJsonTypeId(cls): return cls._jsonTypeId - def decodeSettingsJSONDict(self, dctIn : dict): + def decodeSettingsJSONDict(self, dctIn: dict): """ Decode definition of step from JSON dict. """ diff --git a/src/scaffoldfitter/fitterstepfit.py b/src/scaffoldfitter/fitterstepfit.py index 2aaa575..4ee5073 100644 --- a/src/scaffoldfitter/fitterstepfit.py +++ b/src/scaffoldfitter/fitterstepfit.py @@ -2,14 +2,13 @@ Fit step for gross alignment and scale. """ -from opencmiss.utils.zinc.field import assignFieldParameters, createFieldsDisplacementGradients from opencmiss.utils.zinc.general import ChangeManager -from opencmiss.zinc.field import Field, FieldFindMeshLocation from opencmiss.zinc.optimisation import Optimisation from opencmiss.zinc.result import RESULT_OK from scaffoldfitter.fitterstep import FitterStep import sys + class FitterStepFit(FitterStep): _jsonTypeId = "_FitterStepFit" @@ -27,7 +26,7 @@ def __init__(self): def getJsonTypeId(cls): return cls._jsonTypeId - def decodeSettingsJSONDict(self, dctIn : dict): + def decodeSettingsJSONDict(self, dctIn: dict): """ Decode definition of step from JSON dict. """ @@ -41,10 +40,10 @@ def decodeSettingsJSONDict(self, dctIn : dict): # migrate legacy settings lineWeight = dct.get("lineWeight") if lineWeight is not None: - self.setLineWeight(lineWeight) + print("Legacy lineWeight attribute ignored as feature removed", file=sys.stderr) markerWeight = dct.get("markerWeight") if markerWeight is not None: - self.setMarkerWeight(markerWeight) + print("Legacy markerWeight attribute ignored as feature removed", file=sys.stderr) # convert legacy single-valued strain and curvature penalty weights to list: strainPenaltyWeight = dct.get("strainPenaltyWeight") if strainPenaltyWeight is not None: @@ -60,20 +59,20 @@ def encodeSettingsJSONDict(self) -> dict: """ dct = super().encodeSettingsJSONDict() dct.update({ - "numberOfIterations" : self._numberOfIterations, - "maximumSubIterations" : self._maximumSubIterations, - "updateReferenceState" : self._updateReferenceState + "numberOfIterations": self._numberOfIterations, + "maximumSubIterations": self._maximumSubIterations, + "updateReferenceState": self._updateReferenceState }) return dct - def clearGroupDataWeight(self, groupName: str): + def clearGroupDataWeight(self, groupName): """ Clear group data weight so fall back to last fit or global default. :param groupName: Exact model group name, or None for default group. """ self.clearGroupSetting(groupName, self._dataWeightToken) - def getGroupDataWeight(self, groupName: str): + def getGroupDataWeight(self, groupName): """ Get group data weight to apply in fit, and associated flags. If not set or inherited, gets value from default group. @@ -86,10 +85,10 @@ def getGroupDataWeight(self, groupName: str): """ return self.getGroupSetting(groupName, self._dataWeightToken, 1.0) - def setGroupDataWeight(self, groupName: str, weight): + def setGroupDataWeight(self, groupName, weight): """ Set group data weight to apply in fit, or reset to use default. - :param groupName: Exact model group name, or default group name. + :param groupName: Exact model group name, or None for default group. :param weight: Float valued weight >= 0.0, or None to reset to global default. Function ensures value is valid. """ @@ -100,58 +99,6 @@ def setGroupDataWeight(self, groupName: str, weight): weight = 0.0 self.setGroupSetting(groupName, self._dataWeightToken, weight) - def getLineWeight(self): - """ - :deprecated: Use getGroupDataWeight(). - """ - print("Fit getLineWeight is deprecated", file=sys.stderr) - groupNames = self._fitter.getDataProjectionGroupNames() - fieldmodule = self._fitter.getFieldmodule() - for groupName in groupNames: - group = fieldmodule.findFieldByName(groupName).castGroup() - meshGroup = self.getGroupDataProjectionMeshGroup(group) - dimension = meshGroup.getDimension() - if dimension == 1: - return self.getGroupDataWeight(groupName, lineWeight)[0] - return 0.0 - - def setLineWeight(self, lineWeight): - """ - :deprecated: Use setGroupDataWeight(). - """ - print("Fit setLineWeight is deprecated", file=sys.stderr) - groupNames = self._fitter.getDataProjectionGroupNames() - fieldmodule = self._fitter.getFieldmodule() - for groupName in groupNames: - group = fieldmodule.findFieldByName(groupName).castGroup() - meshGroup = self.getGroupDataProjectionMeshGroup(group) - dimension = meshGroup.getDimension() - if dimension == 1: - self.setGroupDataWeight(groupName, lineWeight) - - def getMarkerWeight(self): - """ - :deprecated: Use getGroupDataWeight(). - """ - print("Fit getMarkerWeight is deprecated", file=sys.stderr) - markerGroup = self._fitter.getMarkerGroup() - if markerGroup: - return self.getGroupDataWeight(markerGroup.getName(), markerWeight)[0] - else: - print("Fit getMarkerWeight. Missing marker group", file=sys.stderr) - return 0.0 - - def setMarkerWeight(self, markerWeight): - """ - :deprecated: Use setGroupDataWeight(). - """ - print("Fit setMarkerWeight is deprecated", file=sys.stderr) - markerGroup = self._fitter.getMarkerGroup() - if markerGroup: - self.setGroupDataWeight(markerGroup.getName(), markerWeight) - else: - print("Fit setMarkerWeight. Missing marker group: need to set marker group data weight to ", markerWeight, file=sys.stderr) - def clearGroupStrainPenalty(self, groupName: str): """ Clear local group strain penalty so fall back to last fit or global default. @@ -159,11 +106,11 @@ def clearGroupStrainPenalty(self, groupName: str): """ self.clearGroupSetting(groupName, self._strainPenaltyToken) - def getGroupStrainPenalty(self, groupName: str, count=None): + def getGroupStrainPenalty(self, groupName, count=None): """ Get list of strain penalty factors used to scale first deformation gradient components in group. Up to 9 components possible in 3-D. - :param groupName: Exact model group name, or default group name. + :param groupName: Exact model group name, or None for default group. :param count: Optional number of factors to limit or enlarge list to. If enlarging, values are padded with the last stored value. If None, the number stored is requested. @@ -188,10 +135,10 @@ def getGroupStrainPenalty(self, groupName: str, count=None): strainPenalty = strainPenalty[:] # shallow copy return strainPenalty, setLocally, inheritable - def setGroupStrainPenalty(self, groupName: str, strainPenalty): + def setGroupStrainPenalty(self, groupName, strainPenalty): """ - :param groupName: Exact model group name, or default group name. - :param factors: List of 1-9 float-value strain penalty factors to scale + :param groupName: Exact model group name, or None for default group. + :param strainPenalty: List of 1-9 float-value strain penalty factors to scale first deformation gradient components, or None to reset to inherited or default value. If fewer than 9 values are supplied in the list, the last value is used for all remaining components. @@ -202,42 +149,24 @@ def setGroupStrainPenalty(self, groupName: str, strainPenalty): count = len(strainPenalty) assert count > 0, "FitterStepFit: setGroupStrainPenalty requires a list of at least 1 float" for i in range(count): - assert isinstance(strainPenalty[i], float), "FitterStepFit: setGroupStrainPenalty requires a list of float" + assert isinstance(strainPenalty[i], float), \ + "FitterStepFit: setGroupStrainPenalty requires a list of float" if strainPenalty[i] < 0.0: strainPenalty[i] = 0.0 self.setGroupSetting(groupName, self._strainPenaltyToken, strainPenalty) - def getStrainPenaltyWeight(self) -> float: - """ - :deprecated: use getGroupStrainPenalty[default group name] - :return: Single strain penalty weight. - """ - print("Fit getStrainPenaltyWeight is deprecated", file=sys.stderr) - strainPenalty = self.getGroupStrainPenalty(None)[0] - if len(strainPenalty) > 1: - print("Warning: Calling deprecated getStrainPenaltyWeight while multiple factors", file=sys.stderr) - return strainPenalty[0] - - def setStrainPenaltyWeight(self, weight : float): - """ - :deprecated: use setGroupStrainPenalty. - :param weight: penalty factor to apply to all first deformation gradient components. - """ - print("Fit setStrainPenaltyWeight is deprecated", file=sys.stderr) - self.setGroupStrainPenalty(None, [weight]) - - def clearGroupCurvaturePenalty(self, groupName: str): + def clearGroupCurvaturePenalty(self, groupName): """ Clear local group curvature penalty so fall back to last fit or global default. :param groupName: Exact model group name, or None for default group. """ self.clearGroupSetting(groupName, self._curvaturePenaltyToken) - def getGroupCurvaturePenalty(self, groupName: str, count=None): + def getGroupCurvaturePenalty(self, groupName, count=None): """ Get list of curvature penalty factors used to scale second deformation gradient components in group. Up to 27 components possible in 3-D. - :param groupName: Exact model group name, or default group name. + :param groupName: Exact model group name, or None for default group. :param count: Optional number of factors to limit or enlarge list to. If enlarging, values are padded with the last stored value. If None, the number stored is requested. @@ -261,9 +190,9 @@ def getGroupCurvaturePenalty(self, groupName: str, count=None): curvaturePenalty = curvaturePenalty[:] # shallow copy return curvaturePenalty, setLocally, inheritable - def setGroupCurvaturePenalty(self, groupName: str, curvaturePenalty): + def setGroupCurvaturePenalty(self, groupName, curvaturePenalty): """ - :param groupName: Exact model group name, or default group name. + :param groupName: Exact model group name, or None for default group. :param curvaturePenalty: List of 1-27 float-value curvature penalty factors to scale first deformation gradient components, or None to reset to inherited or default value. If fewer than 27 values are @@ -271,42 +200,18 @@ def setGroupCurvaturePenalty(self, groupName: str, curvaturePenalty): components. """ if curvaturePenalty is not None: - assert isinstance(curvaturePenalty, list), "FitterStepFit: setGroupCurvaturePenalty requires a list of float" + assert isinstance(curvaturePenalty, list), \ + "FitterStepFit: setGroupCurvaturePenalty requires a list of float" curvaturePenalty = curvaturePenalty[:27] # shallow copy, limiting size count = len(curvaturePenalty) assert count > 0, "FitterStepFit: setGroupCurvaturePenalty requires a list of at least 1 float" for i in range(count): - assert isinstance(curvaturePenalty[i], float), "FitterStepFit: setGroupCurvaturePenalty requires a list of float" + assert isinstance(curvaturePenalty[i], float), \ + "FitterStepFit: setGroupCurvaturePenalty requires a list of float" if curvaturePenalty[i] < 0.0: curvaturePenalty[i] = 0.0 self.setGroupSetting(groupName, self._curvaturePenaltyToken, curvaturePenalty) - def getCurvaturePenaltyWeight(self) -> float: - """ - :deprecated: use getGroupCurvaturePenalty[default group name] - :return: Single curvature penalty weight. - """ - print("Fit getCurvaturePenaltyWeight is deprecated", file=sys.stderr) - curvaturePenalty = self.getGroupCurvaturePenalty(None)[0] - if len(curvaturePenalty) > 1: - print("Warning: Calling deprecated getCurvaturePenaltyWeight while multiple factors", file=sys.stderr) - return curvaturePenalty[0] - - def setCurvaturePenaltyWeight(self, weight : float): - """ - :deprecated: use setCurvaturePenaltyFactors. - :param weight: penalty factor to apply to all first deformation gradient components. - """ - print("Fit setCurvaturePenaltyWeight is deprecated", file=sys.stderr) - self.setGroupCurvaturePenalty(None, [weight]) - - def getEdgeDiscontinuityPenaltyWeight(self): - print("Fit getEdgeDiscontinuityPenaltyWeight: feature removed", file=sys.stderr) - return 0.0 - - def setEdgeDiscontinuityPenaltyWeight(self, weight): - print("Fit setEdgeDiscontinuityPenaltyWeight: feature removed", file=sys.stderr) - def getNumberOfIterations(self): return self._numberOfIterations @@ -342,40 +247,41 @@ def run(self, modelFileNameStem=None): :param modelFileNameStem: Optional name stem of intermediate output file to write. """ self._fitter.assignDataWeights(self) - deformActiveMeshGroup, strainActiveMeshGroup, curvatureActiveMeshGroup = self._fitter.assignDeformationPenalties(self) + deformActiveMeshGroup, strainActiveMeshGroup, curvatureActiveMeshGroup = \ + self._fitter.assignDeformationPenalties(self) - fieldmodule = self._fitter._region.getFieldmodule() + fieldmodule = self._fitter.getFieldmodule() optimisation = fieldmodule.createOptimisation() optimisation.setMethod(Optimisation.METHOD_NEWTON) optimisation.addDependentField(self._fitter.getModelCoordinatesField()) optimisation.setAttributeInteger(Optimisation.ATTRIBUTE_MAXIMUM_ITERATIONS, self._maximumSubIterations) - #FunctionTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE) - #GradientTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE) - #StepTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE) + # FunctionTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE) + # GradientTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE) + # StepTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE) MaximumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP) MinimumStep = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP) - #LinesearchTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE) - #TrustRegionSize = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE) + # LinesearchTolerance = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE) + # TrustRegionSize = optimisation.getAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE) dataScale = self._fitter.getDataScale() - #tol_scale = dataScale # *dataScale - #FunctionTolerance *= tol_scale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE, FunctionTolerance) - #GradientTolerance /= tol_scale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE, GradientTolerance) - #StepTolerance *= tol_scale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE, StepTolerance) + # tol_scale = dataScale # *dataScale + # FunctionTolerance *= tol_scale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_FUNCTION_TOLERANCE, FunctionTolerance) + # GradientTolerance /= tol_scale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_GRADIENT_TOLERANCE, GradientTolerance) + # StepTolerance *= tol_scale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_STEP_TOLERANCE, StepTolerance) MaximumStep *= dataScale optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MAXIMUM_STEP, MaximumStep) MinimumStep *= dataScale optimisation.setAttributeReal(Optimisation.ATTRIBUTE_MINIMUM_STEP, MinimumStep) - #LinesearchTolerance *= dataScale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE, LinesearchTolerance) - #TrustRegionSize *= dataScale - #optimisation.setAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE, TrustRegionSize) + # LinesearchTolerance *= dataScale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_LINESEARCH_TOLERANCE, LinesearchTolerance) + # TrustRegionSize *= dataScale + # optimisation.setAttributeReal(Optimisation.ATTRIBUTE_TRUST_REGION_SIZE, TrustRegionSize) - #if self.getDiagnosticLevel() > 0: + # if self.getDiagnosticLevel() > 0: # print("Function Tolerance", FunctionTolerance) # print("Gradient Tolerance", GradientTolerance) # print("Step Tolerance", StepTolerance) @@ -384,34 +290,29 @@ def run(self, modelFileNameStem=None): # print("Linesearch Tolerance", LinesearchTolerance) # print("Trust Region Size", TrustRegionSize) - dataObjective = None deformationPenaltyObjective = None - edgeDiscontinuityPenaltyObjective = None with ChangeManager(fieldmodule): dataObjective = self.createDataObjectiveField() result = optimisation.addObjectiveField(dataObjective) assert result == RESULT_OK, "Fit Geometry: Could not add data objective field" if deformActiveMeshGroup.getSize() > 0: - deformationPenaltyObjective = self.createDeformationPenaltyObjectiveField(deformActiveMeshGroup, strainActiveMeshGroup, curvatureActiveMeshGroup) + deformationPenaltyObjective = self.createDeformationPenaltyObjectiveField( + deformActiveMeshGroup, strainActiveMeshGroup, curvatureActiveMeshGroup) result = optimisation.addObjectiveField(deformationPenaltyObjective) assert result == RESULT_OK, "Fit Geometry: Could not add strain/curvature penalty objective field" - #if self._edgeDiscontinuityPenaltyWeight > 0.0: - # print("WARNING! Edge discontinuity penalty is not supported by NEWTON solver - skipping") - # #edgeDiscontinuityPenaltyObjective = self.createEdgeDiscontinuityPenaltyObjectiveField() - # #result = optimisation.addObjectiveField(edgeDiscontinuityPenaltyObjective) - # #assert result == RESULT_OK, "Fit Geometry: Could not add edge discontinuity penalty objective field" fieldcache = fieldmodule.createFieldcache() objectiveFormat = "{:12e}" - for iter in range(self._numberOfIterations): - iterName = str(iter + 1) + for iterationIndex in range(self._numberOfIterations): + iterName = str(iterationIndex + 1) if self.getDiagnosticLevel() > 0: print("-------- Iteration " + iterName) if self.getDiagnosticLevel() > 0: result, objective = dataObjective.evaluateReal(fieldcache, 1) print(" Data objective", objectiveFormat.format(objective)) if deformationPenaltyObjective: - result, objective = deformationPenaltyObjective.evaluateReal(fieldcache, deformationPenaltyObjective.getNumberOfComponents()) + result, objective = deformationPenaltyObjective.evaluateReal( + fieldcache, deformationPenaltyObjective.getNumberOfComponents()) print(" Deformation penalty objective", objectiveFormat.format(objective)) result = optimisation.optimise() if self.getDiagnosticLevel() > 1: @@ -427,7 +328,8 @@ def run(self, modelFileNameStem=None): result, objective = dataObjective.evaluateReal(fieldcache, 1) print(" END Data objective", objectiveFormat.format(objective)) if deformationPenaltyObjective: - result, objective = deformationPenaltyObjective.evaluateReal(fieldcache, deformationPenaltyObjective.getNumberOfComponents()) + result, objective = deformationPenaltyObjective.evaluateReal( + fieldcache, deformationPenaltyObjective.getNumberOfComponents()) print(" END Deformation penalty objective", objectiveFormat.format(objective)) if self._updateReferenceState: @@ -445,58 +347,67 @@ def createDataObjectiveField(self): delta = self._fitter.getDataDeltaField() weight = self._fitter.getDataWeightField() deltaSq = fieldmodule.createFieldDotProduct(delta, delta) - #dataProjectionInDirection = fieldmodule.createFieldDotProduct(dataProjectionDelta, self._fitter.getDataProjectionDirectionField()) - #dataProjectionInDirection = fieldmodule.createFieldMagnitude(dataProjectionDelta) - #dataProjectionInDirection = dataProjectionDelta - #dataProjectionInDirection = fieldmodule.createFieldConstant([ weight/dataScale ]*dataProjectionDelta.getNumberOfComponents()) * dataProjectionDelta - dataProjectionObjective = fieldmodule.createFieldNodesetSum(weight*deltaSq, self._fitter.getActiveDataNodesetGroup()) + # dataProjectionInDirection = fieldmodule.createFieldDotProduct( + # dataProjectionDelta, self._fitter.getDataProjectionDirectionField()) + # dataProjectionInDirection = fieldmodule.createFieldMagnitude(dataProjectionDelta) + # dataProjectionInDirection = dataProjectionDelta + # dataProjectionInDirection = fieldmodule.createFieldConstant( + # [ weight/dataScale ]*dataProjectionDelta.getNumberOfComponents()) * dataProjectionDelta + dataProjectionObjective = fieldmodule.createFieldNodesetSum( + weight*deltaSq, self._fitter.getActiveDataNodesetGroup()) dataProjectionObjective.setElementMapField(self._fitter.getDataHostLocationField()) return dataProjectionObjective - def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainActiveMeshGroup, curvatureActiveMeshGroup): + def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainActiveMeshGroup, + curvatureActiveMeshGroup): """ Only call for non-zero strain or curvature penalty values. - :param deformActiveMeshGroup, strainActiveMeshGroup, curvatureActiveMeshGroup: - Mesh groups over which to apply combined, strain or curvature penalties. + :param deformActiveMeshGroup: Mesh group over which either penalties is applied. + :param strainActiveMeshGroup: Mesh group over which strain penalty is applied. + :param curvatureActiveMeshGroup: Mesh group over which curvature penalty is applied. :return: Zinc field, or None if not weighted. Assumes ChangeManager(fieldmodule) is in effect. """ numberOfGaussPoints = 3 fieldmodule = self._fitter.getFieldmodule() mesh = self._fitter.getHighestDimensionMesh() - dataScale = 1.0 # future: eliminate effect of model scale - #dimension = mesh.getDimension() - #linearDataScale = self._fitter.getDataScale() - #for d in range(dimension): + # dataScale = 1.0 + # dimension = mesh.getDimension() + # linearDataScale = self._fitter.getDataScale() + # for d in range(dimension): # dataScale /= linearDataScale modelCoordinates = self._fitter.getModelCoordinatesField() modelReferenceCoordinates = self._fitter.getModelReferenceCoordinatesField() fibreField = self._fitter.getFibreField() dimension = mesh.getDimension() coordinatesCount = modelCoordinates.getNumberOfComponents() - zincVersion = self._fitter.getZincVersion() - #zincVersion34 = (zincVersion[0] > 3) or ((zincVersion[0] == 3) and (zincVersion[1] >= 4)) + # zincVersion = self._fitter.getZincVersion() + # zincVersion34 = (zincVersion[0] > 3) or ((zincVersion[0] == 3) and (zincVersion[1] >= 4)) assert coordinatesCount == dimension, \ "Fit strain/curvature penalties cannot be applied as element dimension < coordinate components. " displacement = modelCoordinates - modelReferenceCoordinates - displacementGradient1 = displacementGradient1raw = fieldmodule.createFieldGradient(displacement, modelReferenceCoordinates) + displacementGradient1 = displacementGradient1raw =\ + fieldmodule.createFieldGradient(displacement, modelReferenceCoordinates) + fibreAxesT = None if fibreField: # convert to local fibre directions, with possible dimension reduction for 2D, 1D fibreAxes = fieldmodule.createFieldFibreAxes(fibreField, modelReferenceCoordinates) if dimension == 3: fibreAxesT = fieldmodule.createFieldTranspose(3, fibreAxes) elif dimension == 2: - fibreAxesT = fieldmodule.createFieldComponent(fibreAxes, \ - [1, 4, 2, 5, 3, 6] if (coordinatesCount == 3) else [1, 4, 2, 5]) + fibreAxesT = fieldmodule.createFieldComponent( + fibreAxes, [1, 4, 2, 5, 3, 6] if (coordinatesCount == 3) else [1, 4, 2, 5]) else: # dimension == 1 - fibreAxesT = fieldmodule.createFieldComponent(fibreAxes, \ - [1, 2, 3] if (coordinatesCount == 3) else [1, 2] if (coordinatesCount == 2) else [1]) - displacementGradient1 = fieldmodule.createFieldMatrixMultiply(coordinatesCount, displacementGradient1, fibreAxesT) + fibreAxesT = fieldmodule.createFieldComponent( + fibreAxes, [1, 2, 3] if (coordinatesCount == 3) else [1, 2] if (coordinatesCount == 2) else [1]) + displacementGradient1 = \ + fieldmodule.createFieldMatrixMultiply(coordinatesCount, displacementGradient1, fibreAxesT) deformationTerm = None if strainActiveMeshGroup.getSize() > 0: alpha = self._fitter.getStrainPenaltyField() - wtSqDeformationGradient1 = fieldmodule.createFieldDotProduct(alpha, displacementGradient1*displacementGradient1) + wtSqDeformationGradient1 = \ + fieldmodule.createFieldDotProduct(alpha, displacementGradient1*displacementGradient1) assert wtSqDeformationGradient1.isValid() deformationTerm = wtSqDeformationGradient1 if curvatureActiveMeshGroup.getSize() > 0: @@ -504,41 +415,34 @@ def createDeformationPenaltyObjectiveField(self, deformActiveMeshGroup, strainAc displacementGradient2 = fieldmodule.createFieldGradient(displacementGradient1raw, modelReferenceCoordinates) if fibreField: # convert to local fibre directions - displacementGradient2a = fieldmodule.createFieldMatrixMultiply(coordinatesCount*coordinatesCount, displacementGradient2, fibreAxesT) + displacementGradient2a = fieldmodule.createFieldMatrixMultiply(coordinatesCount*coordinatesCount, + displacementGradient2, fibreAxesT) # transpose each displacement component of displacementGradient2a to remultiply by fibreAxesT if dimension == 1: displacementGradient2aT = displacementGradient2a else: + transposeComponents = None if coordinatesCount == 3: if dimension == 3: - transposeComponents = [1, 4, 7, 2, 5, 8, 3, 6, 9, 10, 13, 16, 11, 14, 17, 12, 15, 18, 19, 22, 25, 20, 23, 26, 21, 24, 27] + transposeComponents = [1, 4, 7, 2, 5, 8, 3, 6, 9, + 10, 13, 16, 11, 14, 17, 12, 15, 18, + 19, 22, 25, 20, 23, 26, 21, 24, 27] elif dimension == 2: transposeComponents = [1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12, 13, 15, 17, 14, 16, 18] elif coordinatesCount == 2: transposeComponents = [1, 3, 2, 4, 5, 7, 6, 8] - displacementGradient2aT = fieldmodule.createFieldComponent(displacementGradient2a, transposeComponents) - displacementGradient2 = fieldmodule.createFieldMatrixMultiply(dimension*coordinatesCount, displacementGradient2aT, fibreAxesT) + displacementGradient2aT = \ + fieldmodule.createFieldComponent(displacementGradient2a, transposeComponents) + displacementGradient2 = fieldmodule.createFieldMatrixMultiply(dimension*coordinatesCount, + displacementGradient2aT, fibreAxesT) beta = self._fitter.getCurvaturePenaltyField() - wtSqDeformationGradient2 = fieldmodule.createFieldDotProduct(beta, displacementGradient2*displacementGradient2) + wtSqDeformationGradient2 = \ + fieldmodule.createFieldDotProduct(beta, displacementGradient2*displacementGradient2) assert wtSqDeformationGradient2.isValid() - deformationTerm = (deformationTerm + wtSqDeformationGradient2) if deformationTerm else wtSqDeformationGradient2 + deformationTerm = (deformationTerm + wtSqDeformationGradient2) if deformationTerm \ + else wtSqDeformationGradient2 - deformationPenaltyObjective = fieldmodule.createFieldMeshIntegral(deformationTerm, self._fitter.getModelReferenceCoordinatesField(), deformActiveMeshGroup) + deformationPenaltyObjective = fieldmodule.createFieldMeshIntegral( + deformationTerm, self._fitter.getModelReferenceCoordinatesField(), deformActiveMeshGroup) deformationPenaltyObjective.setNumbersOfPoints(numberOfGaussPoints) return deformationPenaltyObjective - - #def createEdgeDiscontinuityPenaltyObjectiveField(self): - # """ - # Only call if self._edgeDiscontinuityPenaltyWeight > 0.0 - # Assumes ChangeManager(fieldmodule) is in effect. - # :return: Zinc FieldMeshIntegralSquares, or None if not weighted. - # """ - # numberOfGaussPoints = 3 - # fieldmodule = self._fitter.getFieldmodule() - # lineMesh = fieldmodule.findMeshByDimension(1) - # edgeDiscontinuity = fieldmodule.createFieldEdgeDiscontinuity(self._fitter.getModelCoordinatesField()) - # dataScale = self._fitter.getDataScale() - # weightedEdgeDiscontinuity = edgeDiscontinuity*fieldmodule.createFieldConstant(self._edgeDiscontinuityPenaltyWeight/dataScale) - # edgeDiscontinuityPenaltyObjective = fieldmodule.createFieldMeshIntegralSquares(weightedEdgeDiscontinuity, self._fitter.getModelReferenceCoordinatesField(), lineMesh) - # edgeDiscontinuityPenaltyObjective.setNumbersOfPoints(numberOfGaussPoints) - # return edgeDiscontinuityPenaltyObjective diff --git a/tests/test_fitcube.py b/tests/test_fitcube.py index 2c3cf1d..f15af72 100644 --- a/tests/test_fitcube.py +++ b/tests/test_fitcube.py @@ -13,11 +13,13 @@ here = os.path.abspath(os.path.dirname(__file__)) + def assertAlmostEqualList(testcase, actualList, expectedList, delta): assert len(actualList) == len(expectedList) for actual, expected in zip(actualList, expectedList): testcase.assertAlmostEqual(actual, expected, delta=delta) + def getRotationMatrix(eulerAngles): """ From OpenCMISS-Zinc graphics_library.cpp, transposed. @@ -42,10 +44,11 @@ def getRotationMatrix(eulerAngles): sin_azimuth*sin_elevation*cos_roll - cos_azimuth*sin_roll, -sin_elevation, cos_elevation*sin_roll, - cos_elevation*cos_roll, + cos_elevation*cos_roll ] -def transformCoordinatesList(xIn : list, transformationMatrix, translation): + +def transformCoordinatesList(xIn: list, transformationMatrix, translation): """ Transforms coordinates by multiplying by 9-component transformationMatrix then offsetting by translation. @@ -64,7 +67,8 @@ def transformCoordinatesList(xIn : list, transformationMatrix, translation): xOut.append(x2) return xOut -def getNodesetConditionalSize(nodeset : Nodeset, conditionalField : Field): + +def getNodesetConditionalSize(nodeset: Nodeset, conditionalField: Field): """ :return: Number of objects in nodeset for which conditionalField is True. """ @@ -82,6 +86,7 @@ def getNodesetConditionalSize(nodeset : Nodeset, conditionalField : Field): node = nodeiterator.next() return size + class FitCubeToSphereTestCase(unittest.TestCase): def test_alignFixedRandomData(self): @@ -99,29 +104,29 @@ def test_alignFixedRandomData(self): self.assertEqual(fitter.getModelCoordinatesField().getName(), "coordinates") self.assertEqual(fitter.getDataCoordinatesField().getName(), "data_coordinates") self.assertEqual(fitter.getMarkerGroup().getName(), "marker") - bottomCentre1 = fitter.evaluateNodeGroupMeanCoordinates("bottom", "coordinates", isData = False) - sidesCentre1 = fitter.evaluateNodeGroupMeanCoordinates("sides", "coordinates", isData = False) - topCentre1 = fitter.evaluateNodeGroupMeanCoordinates("top", "coordinates", isData = False) - assertAlmostEqualList(self, bottomCentre1, [ 0.5, 0.5, 0.0 ], delta=1.0E-7) - assertAlmostEqualList(self, sidesCentre1, [ 0.5, 0.5, 0.5 ], delta=1.0E-7) - assertAlmostEqualList(self, topCentre1, [ 0.5, 0.5, 1.0 ], delta=1.0E-7) + bottomCentre1 = fitter.evaluateNodeGroupMeanCoordinates("bottom", "coordinates", isData=False) + sidesCentre1 = fitter.evaluateNodeGroupMeanCoordinates("sides", "coordinates", isData=False) + topCentre1 = fitter.evaluateNodeGroupMeanCoordinates("top", "coordinates", isData=False) + assertAlmostEqualList(self, bottomCentre1, [0.5, 0.5, 0.0], delta=1.0E-7) + assertAlmostEqualList(self, sidesCentre1, [0.5, 0.5, 0.5], delta=1.0E-7) + assertAlmostEqualList(self, topCentre1, [0.5, 0.5, 1.0], delta=1.0E-7) align = FitterStepAlign() fitter.addFitterStep(align) align.setScale(1.1) - align.setTranslation([ 0.1, -0.2, 0.3 ]) - align.setRotation([ math.pi/4.0, math.pi/8.0, math.pi/2.0 ]) + align.setTranslation([0.1, -0.2, 0.3]) + align.setRotation([math.pi/4.0, math.pi/8.0, math.pi/2.0]) self.assertFalse(align.isAlignMarkers()) align.run() rotation = align.getRotation() scale = align.getScale() translation = align.getTranslation() rotationMatrix = getRotationMatrix(rotation) - transformationMatrix = [ v*scale for v in rotationMatrix ] + transformationMatrix = [v*scale for v in rotationMatrix] bottomCentre2Expected, sidesCentre2Expected, topCentre2Expected = transformCoordinatesList( - [ bottomCentre1, sidesCentre1, topCentre1 ], transformationMatrix, translation) - bottomCentre2 = fitter.evaluateNodeGroupMeanCoordinates("bottom", "coordinates", isData = False) - sidesCentre2 = fitter.evaluateNodeGroupMeanCoordinates("sides", "coordinates", isData = False) - topCentre2 = fitter.evaluateNodeGroupMeanCoordinates("top", "coordinates", isData = False) + [bottomCentre1, sidesCentre1, topCentre1], transformationMatrix, translation) + bottomCentre2 = fitter.evaluateNodeGroupMeanCoordinates("bottom", "coordinates", isData=False) + sidesCentre2 = fitter.evaluateNodeGroupMeanCoordinates("sides", "coordinates", isData=False) + topCentre2 = fitter.evaluateNodeGroupMeanCoordinates("top", "coordinates", isData=False) assertAlmostEqualList(self, bottomCentre2, bottomCentre2Expected, delta=1.0E-7) assertAlmostEqualList(self, sidesCentre2, sidesCentre2Expected, delta=1.0E-7) assertAlmostEqualList(self, topCentre2, topCentre2Expected, delta=1.0E-7) @@ -143,7 +148,7 @@ def test_alignMarkersFitRegularData(self): self.assertEqual(coordinates.getName(), "coordinates") self.assertEqual(fitter.getDataCoordinatesField().getName(), "data_coordinates") self.assertEqual(fitter.getMarkerGroup().getName(), "marker") - #fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry1.exf")) + # fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry1.exf")) fieldmodule = fitter.getFieldmodule() surfaceAreaField = createFieldMeshIntegral(coordinates, fitter.getMesh(2), number_of_points=4) volumeField = createFieldMeshIntegral(coordinates, fitter.getMesh(3), number_of_points=3) @@ -156,10 +161,10 @@ def test_alignMarkersFitRegularData(self): self.assertAlmostEqual(volume, 1.0, delta=1.0E-7) activeNodeset = fitter.getActiveDataNodesetGroup() self.assertEqual(292, activeNodeset.getSize()) - self.assertEqual(72, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("bottom"))) - self.assertEqual(144, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("sides"))) - self.assertEqual(72, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("top"))) - self.assertEqual(4, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("marker"))) + groupSizes = {"bottom": 72, "sides": 144, "top": 72, "marker": 4} + for groupName, count in groupSizes.items(): + self.assertEqual(count, getNodesetConditionalSize( + activeNodeset, fitter.getFieldmodule().findFieldByName(groupName))) align = FitterStepAlign() fitter.addFitterStep(align) @@ -167,13 +172,14 @@ def test_alignMarkersFitRegularData(self): self.assertTrue(align.setAlignMarkers(True)) self.assertTrue(align.isAlignMarkers()) align.run() - #fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry2.exf")) + # fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry2.exf")) rotation = align.getRotation() scale = align.getScale() translation = align.getTranslation() - assertAlmostEqualList(self, rotation, [ -0.25*math.pi, 0.0, 0.0 ], delta=1.0E-4) + assertAlmostEqualList(self, rotation, [-0.25*math.pi, 0.0, 0.0], delta=1.0E-4) self.assertAlmostEqual(scale, 0.8047378476539072, places=5) - assertAlmostEqualList(self, translation, [ -0.5690355950594247, 1.1068454682130484e-05, -0.4023689233125251 ], delta=1.0E-6) + assertAlmostEqualList(self, translation, + [-0.5690355950594247, 1.1068454682130484e-05, -0.4023689233125251], delta=1.0E-6) result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) self.assertAlmostEqual(surfaceArea, 3.885618020657802, delta=1.0E-6) @@ -184,12 +190,12 @@ def test_alignMarkersFitRegularData(self): fit1 = FitterStepFit() fitter.addFitterStep(fit1) self.assertEqual(3, len(fitter.getFitterSteps())) - fit1.setMarkerWeight(1.0) - fit1.setCurvaturePenaltyWeight(0.01) + fit1.setGroupDataWeight("marker", 1.0) + fit1.setGroupCurvaturePenalty(None, [0.01]) fit1.setNumberOfIterations(3) fit1.setUpdateReferenceState(True) fit1.run() - #fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry3.exf")) + # fitter.getRegion().writeFile(os.path.join(here, "resources", "km_fitgeometry3.exf")) result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) @@ -207,8 +213,8 @@ def test_alignMarkersFitRegularData(self): self.assertTrue(isinstance(fitterSteps[0], FitterStepConfig)) self.assertTrue(isinstance(fitterSteps[1], FitterStepAlign)) self.assertTrue(isinstance(fitterSteps[2], FitterStepFit)) - #fitter2.load() - #for fitterStep in fitterSteps: + # fitter2.load() + # for fitterStep in fitterSteps: # fitterStep.run() s2 = fitter.encodeSettingsJSON() self.assertEqual(s, s2) @@ -237,7 +243,6 @@ def test_alignGroupsFitEllipsoidRegularData(self): result, volume = volumeField.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) self.assertAlmostEqual(volume, 2.0, delta=1.0E-6) - activeNodeset = fitter.getActiveDataNodesetGroup() align = FitterStepAlign() fitter.addFitterStep(align) @@ -250,7 +255,8 @@ def test_alignGroupsFitEllipsoidRegularData(self): translation = align.getTranslation() assertAlmostEqualList(self, rotation, [0.0, 0.0, 0.0], delta=1.0E-5) self.assertAlmostEqual(scale, 1.040599599095245, places=5) - assertAlmostEqualList(self, translation, [-1.0405995643008867, -0.5202997843515198, -0.5202997827678563], delta=1.0E-6) + assertAlmostEqualList(self, translation, [-1.0405995643008867, -0.5202997843515198, -0.5202997827678563], + delta=1.0E-6) result, surfaceArea = surfaceAreaField.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) self.assertAlmostEqual(surfaceArea, 11.0*scale*scale, delta=1.0E-6) @@ -382,7 +388,6 @@ def test_alignGroupsFitEllipsoidRegularData(self): s2 = fitter.encodeSettingsJSON() self.assertEqual(s, s2) - def test_fitRegularDataGroupWeight(self): """ Test automatic alignment of model and data using fiducial markers. @@ -416,14 +421,13 @@ def test_fitRegularDataGroupWeight(self): self.assertEqual(2, len(groupNames)) self.assertEqual((0.5, True, False), fit1.getGroupDataWeight("bottom")) self.assertEqual((0.1, True, False), fit1.getGroupDataWeight("sides")) - fit1.setCurvaturePenaltyWeight(0.01) + fit1.setGroupCurvaturePenalty(None, [0.01]) fit1.setNumberOfIterations(3) fit1.setUpdateReferenceState(True) fit1.run() dataWeightField = fieldmodule.findFieldByName("data_weight").castFiniteElement() self.assertTrue(dataWeightField.isValid()) - groupData = { "bottom" : ( 72, 0.5 ), "sides" : ( 144, 0.1 ), "top" : ( 72, 1.0 ) } - mesh2d = fitter.getMesh(2) + groupData = {"bottom": (72, 0.5), "sides": (144, 0.1), "top": (72, 1.0)} for groupName in groupData.keys(): expectedSize, expectedWeight = groupData[groupName] group = fieldmodule.findFieldByName(groupName).castGroup() @@ -494,10 +498,10 @@ def test_groupSettings(self): fitter.load() activeNodeset = fitter.getActiveDataNodesetGroup() self.assertEqual(141, activeNodeset.getSize()) - self.assertEqual(72, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("bottom"))) - self.assertEqual(36, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("sides"))) - self.assertEqual(29, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("top"))) - self.assertEqual(4, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("marker"))) + groupSizes = {"bottom": 72, "sides": 36, "top": 29, "marker": 4} + for groupName, count in groupSizes.items(): + self.assertEqual(count, getNodesetConditionalSize( + activeNodeset, fitter.getFieldmodule().findFieldByName(groupName))) # test override and inherit config2 = FitterStepConfig() fitter.addFitterStep(config2) @@ -511,10 +515,10 @@ def test_groupSettings(self): config2.run() activeNodeset = fitter.getActiveDataNodesetGroup() self.assertEqual(184, activeNodeset.getSize()) - self.assertEqual(72, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("bottom"))) - self.assertEqual(36, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("sides"))) - self.assertEqual(72, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("top"))) - self.assertEqual(4, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("marker"))) + groupSizes = {"bottom": 72, "sides": 36, "top": 72, "marker": 4} + for groupName, count in groupSizes.items(): + self.assertEqual(count, getNodesetConditionalSize( + activeNodeset, fitter.getFieldmodule().findFieldByName(groupName))) # test inherit through 2 previous configs and cancel/None in config2 config3 = FitterStepConfig() fitter.addFitterStep(config3) @@ -525,10 +529,9 @@ def test_groupSettings(self): config3.run() activeNodeset = fitter.getActiveDataNodesetGroup() self.assertEqual(184, activeNodeset.getSize()) - self.assertEqual(72, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("bottom"))) - self.assertEqual(36, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("sides"))) - self.assertEqual(72, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("top"))) - self.assertEqual(4, getNodesetConditionalSize(activeNodeset, fitter.getFieldmodule().findFieldByName("marker"))) + for groupName, count in groupSizes.items(): + self.assertEqual(count, getNodesetConditionalSize( + activeNodeset, fitter.getFieldmodule().findFieldByName(groupName))) del config1 del config2 del config3 @@ -568,17 +571,18 @@ def test_preAlignment(self): fitter.setDiagnosticLevel(1) # Rotation, scale, translation - transformationList = [[[ 0.0, 0.0, 0.0 ], 1.0, [ 0.0, 0.0, 0.0 ]], - [[ math.pi * 20/180, 0.0, 0.0 ], 1.0, [ 0.0, 0.0, 0.0 ]], - [[ math.pi * 135/180, 0.0, 0.0 ], 1.0, [ 0.0, 0.0, 0.0 ]], - [[ math.pi * 250/180, math.pi * -45/180, 0.0 ], 1.0, [ 0.0, 0.0, 0.0 ]], - [[ math.pi * 45/180, math.pi * 45/180, math.pi * 45/180 ], 1.0, [ 0.0, 0.0, 0.0 ]], - [[ 0.0, 0.0, 0.0 ], 0.05, [ 0.0, 0.0, 0.0]], - [[ math.pi * 70/180, math.pi * 10/180, math.pi * -300/180 ], 0.2, [ 0.0, 0.0, 0.0 ]], - [[ 0.0, 0.0, 0.0 ], 1.0, [ 15.0, 15.0, 15.0 ]], - [[ 0.0, 0.0, 0.0 ], 20.0, [ 50.0, 0.0, 10.0 ]], - [[ math.pi * 90/180, math.pi * 200/180, math.pi * 5/180 ], 1.0, [ -10.0, -20.0, 100.0 ]], - [[ math.pi * -45/180, math.pi * 120/180, math.pi * 10/180 ], 500.0, [ 100.0, 100.0, 100.0 ]]] + transformationList = [ + [[0.0, 0.0, 0.0], 1.0, [0.0, 0.0, 0.0]], + [[math.pi * 20/180, 0.0, 0.0], 1.0, [0.0, 0.0, 0.0]], + [[math.pi * 135/180, 0.0, 0.0], 1.0, [0.0, 0.0, 0.0]], + [[math.pi * 250/180, math.pi * -45/180, 0.0], 1.0, [0.0, 0.0, 0.0]], + [[math.pi * 45/180, math.pi * 45/180, math.pi * 45/180], 1.0, [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], 0.05, [0.0, 0.0, 0.0]], + [[math.pi * 70/180, math.pi * 10/180, math.pi * -300/180], 0.2, [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], 1.0, [15.0, 15.0, 15.0]], + [[0.0, 0.0, 0.0], 20.0, [50.0, 0.0, 10.0]], + [[math.pi * 90/180, math.pi * 200/180, math.pi * 5/180], 1.0, [-10.0, -20.0, 100.0]], + [[math.pi * -45/180, math.pi * 120/180, math.pi * 10/180], 500.0, [100.0, 100.0, 100.0]]] expectedAlignedNodes = [[-0.5690355951820659, 1.1070979208244695e-05, -0.40236892417087866], [-1.1077595833408616e-05, -0.5690355904946871, -0.4023689227447479], @@ -617,5 +621,6 @@ def test_preAlignment(self): result, x = modelCoordinates.getNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, 3) assertAlmostEqualList(self, x, expectedAlignedNodes[nodeIdentifier - 1], delta=1.0E-3) + if __name__ == "__main__": unittest.main()