From 0ceb091d0097db12fffd53bca1aeea4ce0d8366b Mon Sep 17 00:00:00 2001 From: Nick Tustison Date: Sun, 9 Oct 2022 13:47:27 -0700 Subject: [PATCH] PERF: Remove unnecessary computations and parallelize function. --- ...SplineScatteredDataPointSetToImageFilter.h | 13 +- ...lineScatteredDataPointSetToImageFilter.hxx | 119 ++++++++---------- 2 files changed, 62 insertions(+), 70 deletions(-) diff --git a/Modules/Filtering/ImageGrid/include/itkBSplineScatteredDataPointSetToImageFilter.h b/Modules/Filtering/ImageGrid/include/itkBSplineScatteredDataPointSetToImageFilter.h index b82f6ab6ed5..5b9958394f0 100644 --- a/Modules/Filtering/ImageGrid/include/itkBSplineScatteredDataPointSetToImageFilter.h +++ b/Modules/Filtering/ImageGrid/include/itkBSplineScatteredDataPointSetToImageFilter.h @@ -165,6 +165,7 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter using PointSetPointer = typename PointSetType::Pointer; using PointDataType = typename PointSetType::PixelType; using PointDataContainerType = typename PointSetType::PointDataContainer; + using PointDataContainerPointer = typename PointDataContainerType::Pointer; /** Other type alias. */ using RealType = float; @@ -318,10 +319,6 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter void RefineControlPointLattice(); - /** Determine the residuals after fitting to one level. */ - void - UpdatePointSet(); - /** This function is not used as it requires an evaluation of all * (SplineOrder+1)^ImageDimensions B-spline weights for each evaluation. */ void @@ -335,6 +332,10 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter void ThreadedGenerateDataForReconstruction(const RegionType &, ThreadIdType); + /** Update the input point set values with the residuals after fitting to a level. */ + void + ThreadedGenerateDataForUpdatePointSetValues(const RegionType &, ThreadIdType); + /** Sub-function used by GenerateOutputImageFast() to generate the sampled * B-spline object quickly. */ void @@ -368,8 +369,7 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter vnl_matrix m_RefinedLatticeCoefficients[ImageDimension]; - typename PointDataContainerType::Pointer m_InputPointData; - typename PointDataContainerType::Pointer m_OutputPointData; + PointDataContainerPointer m_InputPointData; typename KernelType::Pointer m_Kernel[ImageDimension]; @@ -383,6 +383,7 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter RealType m_BSplineEpsilon{ static_cast(1e-3) }; bool m_IsFittingComplete{ false }; + bool m_DoUpdatePointSetValues{ false }; }; } // end namespace itk diff --git a/Modules/Filtering/ImageGrid/include/itkBSplineScatteredDataPointSetToImageFilter.hxx b/Modules/Filtering/ImageGrid/include/itkBSplineScatteredDataPointSetToImageFilter.hxx index 77b03198df3..646dc295baa 100644 --- a/Modules/Filtering/ImageGrid/include/itkBSplineScatteredDataPointSetToImageFilter.hxx +++ b/Modules/Filtering/ImageGrid/include/itkBSplineScatteredDataPointSetToImageFilter.hxx @@ -65,7 +65,6 @@ BSplineScatteredDataPointSetToImageFilter::BSpline } this->m_InputPointData = PointDataContainerType::New(); - this->m_OutputPointData = PointDataContainerType::New(); this->m_PointWeights = WeightsContainerType::New(); } @@ -237,7 +236,6 @@ BSplineScatteredDataPointSetToImageFilter::Generat } this->m_InputPointData->Initialize(); - this->m_OutputPointData->Initialize(); if (inputPointSet->GetNumberOfPoints() > 0) { const auto & pointData = inputPointSet->GetPointData()->CastToSTLConstContainer(); @@ -247,7 +245,6 @@ BSplineScatteredDataPointSetToImageFilter::Generat m_PointWeights->CastToSTLContainer().assign(pointData.size(), 1); } m_InputPointData->CastToSTLContainer() = pointData; - m_OutputPointData->CastToSTLContainer() = pointData; } this->m_CurrentLevel = 0; @@ -264,76 +261,53 @@ BSplineScatteredDataPointSetToImageFilter::Generat multiThreader->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits()); multiThreader->SetSingleMethod(this->ThreaderCallback, &str1); - // Multithread the generation of the control point lattice. + // Multithread the generation of the control point lattice for the first level. this->BeforeThreadedGenerateData(); multiThreader->SingleMethodExecute(); this->AfterThreadedGenerateData(); - this->UpdatePointSet(); - if (this->m_DoMultilevel) { this->m_PsiLattice->SetRegions(this->m_PhiLattice->GetLargestPossibleRegion()); this->m_PsiLattice->Allocate(); PointDataType P{}; this->m_PsiLattice->FillBuffer(P); - } - for (this->m_CurrentLevel = 1; this->m_CurrentLevel < this->m_MaximumNumberOfLevels; this->m_CurrentLevel++) - { - ImageRegionIterator ItPsi(this->m_PsiLattice, this->m_PsiLattice->GetLargestPossibleRegion()); - ImageRegionIterator ItPhi(this->m_PhiLattice, this->m_PhiLattice->GetLargestPossibleRegion()); - for (ItPsi.GoToBegin(), ItPhi.GoToBegin(); !ItPsi.IsAtEnd(); ++ItPsi, ++ItPhi) + for (this->m_CurrentLevel = 1; this->m_CurrentLevel < this->m_MaximumNumberOfLevels; this->m_CurrentLevel++) { - ItPsi.Set(ItPhi.Get() + ItPsi.Get()); - } - this->RefineControlPointLattice(); - - for (unsigned int i = 0; i < ImageDimension; ++i) - { - if (this->m_CurrentLevel < this->m_NumberOfLevels[i]) + // Multithread updating the point set values + this->m_DoUpdatePointSetValues = true; + // this->BeforeThreadedGenerateData(); + multiThreader->SingleMethodExecute(); + // this->AfterThreadedGenerateData(); + this->m_DoUpdatePointSetValues = false; + + ImageRegionIterator ItPsi(this->m_PsiLattice, this->m_PsiLattice->GetLargestPossibleRegion()); + ImageRegionIterator ItPhi(this->m_PhiLattice, this->m_PhiLattice->GetLargestPossibleRegion()); + for (ItPsi.GoToBegin(), ItPhi.GoToBegin(); !ItPsi.IsAtEnd(); ++ItPsi, ++ItPhi) { - this->m_CurrentNumberOfControlPoints[i] = 2 * this->m_CurrentNumberOfControlPoints[i] - this->m_SplineOrder[i]; + ItPsi.Set(ItPhi.Get() + ItPsi.Get()); } - } + this->RefineControlPointLattice(); - itkDebugMacro("Current Level = " << this->m_CurrentLevel); - itkDebugMacro(" Current number of control points = " << this->m_CurrentNumberOfControlPoints); - - RealType averageDifference = 0.0; - RealType totalWeight = 0.0; - - typename PointDataContainerType::Iterator ItIn = this->m_InputPointData->Begin(); - typename PointDataContainerType::Iterator ItOut = this->m_OutputPointData->Begin(); - while (ItIn != this->m_InputPointData->End()) - { - this->m_InputPointData->CastToSTLContainer()[ItIn.Index()] = ItIn.Value() - ItOut.Value(); - - if (this->GetDebug()) + for (unsigned int i = 0; i < ImageDimension; ++i) { - RealType weight = this->m_PointWeights->GetElement(ItIn.Index()); - averageDifference += (ItIn.Value() - ItOut.Value()).GetNorm() * weight; - totalWeight += weight; + if (this->m_CurrentLevel < this->m_NumberOfLevels[i]) + { + this->m_CurrentNumberOfControlPoints[i] = + 2 * this->m_CurrentNumberOfControlPoints[i] - this->m_SplineOrder[i]; + } } - ++ItIn; - ++ItOut; - } - if (totalWeight > 0) - { - itkDebugMacro("The average weighted difference norm of the point set is " << averageDifference / totalWeight); - } + itkDebugMacro("Current Level = " << this->m_CurrentLevel); + itkDebugMacro(" Current number of control points = " << this->m_CurrentNumberOfControlPoints); - // Multithread the generation of the control point lattice. - this->BeforeThreadedGenerateData(); - multiThreader->SingleMethodExecute(); - this->AfterThreadedGenerateData(); - - this->UpdatePointSet(); - } + // Multithread the generation of the control point lattice. + this->BeforeThreadedGenerateData(); + multiThreader->SingleMethodExecute(); + this->AfterThreadedGenerateData(); + } - if (this->m_DoMultilevel) - { ImageRegionIterator ItPsi(this->m_PsiLattice, this->m_PsiLattice->GetLargestPossibleRegion()); ImageRegionIterator ItPhi(this->m_PhiLattice, this->m_PhiLattice->GetLargestPossibleRegion()); for (ItPsi.GoToBegin(), ItPhi.GoToBegin(); !ItPsi.IsAtEnd(); ++ItPsi, ++ItPhi) @@ -346,12 +320,11 @@ BSplineScatteredDataPointSetToImageFilter::Generat duplicator->SetInputImage(this->m_PsiLattice); duplicator->Update(); this->m_PhiLattice = duplicator->GetOutput(); - - this->UpdatePointSet(); } this->m_IsFittingComplete = true; + // Multithread the reconstruction of the sampled B-spline object if (this->m_GenerateOutputImage) { // this->BeforeThreadedGenerateData(); @@ -424,7 +397,14 @@ BSplineScatteredDataPointSetToImageFilter::Threade { if (!this->m_IsFittingComplete) { - this->ThreadedGenerateDataForFitting(region, threadId); + if (this->m_DoUpdatePointSetValues) + { + this->ThreadedGenerateDataForUpdatePointSetValues(region, threadId); + } + else + { + this->ThreadedGenerateDataForFitting(region, threadId); + } } else { @@ -907,7 +887,9 @@ BSplineScatteredDataPointSetToImageFilter::RefineC template void -BSplineScatteredDataPointSetToImageFilter::UpdatePointSet() +BSplineScatteredDataPointSetToImageFilter::ThreadedGenerateDataForUpdatePointSetValues( + const RegionType & itkNotUsed(region), + ThreadIdType threadId) { const TInputPointSet * input = this->GetInput(); PointDataImagePointer collapsedPhiLattices[ImageDimension + 1]; @@ -960,14 +942,24 @@ BSplineScatteredDataPointSetToImageFilter::UpdateP typename PointDataImageType::IndexType startPhiIndex = this->m_PhiLattice->GetLargestPossibleRegion().GetIndex(); - this->m_OutputPointData->CastToSTLContainer().resize(this->m_InputPointData->Size()); - typename PointDataContainerType::ConstIterator ItIn = this->m_InputPointData->Begin(); - while (ItIn != this->m_InputPointData->End()) + // Determine which points should be handled by this particular thread. + + ThreadIdType numberOfWorkUnits = this->GetNumberOfWorkUnits(); + auto numberOfPointsPerThread = static_cast(input->GetNumberOfPoints() / numberOfWorkUnits); + + unsigned int start = threadId * numberOfPointsPerThread; + unsigned int end = start + numberOfPointsPerThread; + if (threadId == this->GetNumberOfWorkUnits() - 1) + { + end = input->GetNumberOfPoints(); + } + + for (unsigned int n = start; n < end; ++n) { PointType point; point.Fill(0.0); - input->GetPoint(ItIn.Index(), &point); + input->GetPoint(n, &point); for (unsigned int i = 0; i < ImageDimension; ++i) { @@ -1002,11 +994,11 @@ BSplineScatteredDataPointSetToImageFilter::UpdateP break; } } - this->m_OutputPointData->CastToSTLContainer()[ItIn.Index()] = collapsedPhiLattices[0]->GetPixel(startPhiIndex); - ++ItIn; + this->m_InputPointData->CastToSTLContainer()[n] -= collapsedPhiLattices[0]->GetPixel(startPhiIndex); } } + template void BSplineScatteredDataPointSetToImageFilter::CollapsePhiLattice( @@ -1156,7 +1148,6 @@ BSplineScatteredDataPointSetToImageFilter::PrintSe } itkPrintSelfObjectMacro(InputPointData); - itkPrintSelfObjectMacro(OutputPointData); os << indent << "Kernel: " << std::endl; for (unsigned int i = 0; i < ImageDimension; ++i)