Skip to content

Commit

Permalink
PERF: Remove unnecessary computations and parallelize function.
Browse files Browse the repository at this point in the history
  • Loading branch information
ntustison committed Oct 9, 2022
1 parent 76dd332 commit 0ceb091
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter
using PointSetPointer = typename PointSetType::Pointer;
using PointDataType = typename PointSetType::PixelType;
using PointDataContainerType = typename PointSetType::PointDataContainer;
using PointDataContainerPointer = typename PointDataContainerType::Pointer;

/** Other type alias. */
using RealType = float;
Expand Down Expand Up @@ -318,10 +319,6 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter
void
RefineControlPointLattice();

/** Determine the residuals after fitting to one level. */
void
UpdatePointSet();

/** This function is not used as it requires an evaluation of all
* (SplineOrder+1)^ImageDimensions B-spline weights for each evaluation. */
void
Expand All @@ -335,6 +332,10 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter
void
ThreadedGenerateDataForReconstruction(const RegionType &, ThreadIdType);

/** Update the input point set values with the residuals after fitting to a level. */
void
ThreadedGenerateDataForUpdatePointSetValues(const RegionType &, ThreadIdType);

/** Sub-function used by GenerateOutputImageFast() to generate the sampled
* B-spline object quickly. */
void
Expand Down Expand Up @@ -368,8 +369,7 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter

vnl_matrix<RealType> m_RefinedLatticeCoefficients[ImageDimension];

typename PointDataContainerType::Pointer m_InputPointData;
typename PointDataContainerType::Pointer m_OutputPointData;
PointDataContainerPointer m_InputPointData;

typename KernelType::Pointer m_Kernel[ImageDimension];

Expand All @@ -383,6 +383,7 @@ class ITK_TEMPLATE_EXPORT BSplineScatteredDataPointSetToImageFilter

RealType m_BSplineEpsilon{ static_cast<RealType>(1e-3) };
bool m_IsFittingComplete{ false };
bool m_DoUpdatePointSetValues{ false };
};
} // end namespace itk

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::BSpline
}

this->m_InputPointData = PointDataContainerType::New();
this->m_OutputPointData = PointDataContainerType::New();

this->m_PointWeights = WeightsContainerType::New();
}
Expand Down Expand Up @@ -237,7 +236,6 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::Generat
}

this->m_InputPointData->Initialize();
this->m_OutputPointData->Initialize();
if (inputPointSet->GetNumberOfPoints() > 0)
{
const auto & pointData = inputPointSet->GetPointData()->CastToSTLConstContainer();
Expand All @@ -247,7 +245,6 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::Generat
m_PointWeights->CastToSTLContainer().assign(pointData.size(), 1);
}
m_InputPointData->CastToSTLContainer() = pointData;
m_OutputPointData->CastToSTLContainer() = pointData;
}

this->m_CurrentLevel = 0;
Expand All @@ -264,76 +261,53 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::Generat
multiThreader->SetNumberOfWorkUnits(this->GetNumberOfWorkUnits());
multiThreader->SetSingleMethod(this->ThreaderCallback, &str1);

// Multithread the generation of the control point lattice.
// Multithread the generation of the control point lattice for the first level.
this->BeforeThreadedGenerateData();
multiThreader->SingleMethodExecute();
this->AfterThreadedGenerateData();

this->UpdatePointSet();

if (this->m_DoMultilevel)
{
this->m_PsiLattice->SetRegions(this->m_PhiLattice->GetLargestPossibleRegion());
this->m_PsiLattice->Allocate();
PointDataType P{};
this->m_PsiLattice->FillBuffer(P);
}

for (this->m_CurrentLevel = 1; this->m_CurrentLevel < this->m_MaximumNumberOfLevels; this->m_CurrentLevel++)
{
ImageRegionIterator<PointDataImageType> ItPsi(this->m_PsiLattice, this->m_PsiLattice->GetLargestPossibleRegion());
ImageRegionIterator<PointDataImageType> ItPhi(this->m_PhiLattice, this->m_PhiLattice->GetLargestPossibleRegion());
for (ItPsi.GoToBegin(), ItPhi.GoToBegin(); !ItPsi.IsAtEnd(); ++ItPsi, ++ItPhi)
for (this->m_CurrentLevel = 1; this->m_CurrentLevel < this->m_MaximumNumberOfLevels; this->m_CurrentLevel++)
{
ItPsi.Set(ItPhi.Get() + ItPsi.Get());
}
this->RefineControlPointLattice();

for (unsigned int i = 0; i < ImageDimension; ++i)
{
if (this->m_CurrentLevel < this->m_NumberOfLevels[i])
// Multithread updating the point set values
this->m_DoUpdatePointSetValues = true;
// this->BeforeThreadedGenerateData();
multiThreader->SingleMethodExecute();
// this->AfterThreadedGenerateData();
this->m_DoUpdatePointSetValues = false;

ImageRegionIterator<PointDataImageType> ItPsi(this->m_PsiLattice, this->m_PsiLattice->GetLargestPossibleRegion());
ImageRegionIterator<PointDataImageType> ItPhi(this->m_PhiLattice, this->m_PhiLattice->GetLargestPossibleRegion());
for (ItPsi.GoToBegin(), ItPhi.GoToBegin(); !ItPsi.IsAtEnd(); ++ItPsi, ++ItPhi)
{
this->m_CurrentNumberOfControlPoints[i] = 2 * this->m_CurrentNumberOfControlPoints[i] - this->m_SplineOrder[i];
ItPsi.Set(ItPhi.Get() + ItPsi.Get());
}
}
this->RefineControlPointLattice();

itkDebugMacro("Current Level = " << this->m_CurrentLevel);
itkDebugMacro(" Current number of control points = " << this->m_CurrentNumberOfControlPoints);

RealType averageDifference = 0.0;
RealType totalWeight = 0.0;

typename PointDataContainerType::Iterator ItIn = this->m_InputPointData->Begin();
typename PointDataContainerType::Iterator ItOut = this->m_OutputPointData->Begin();
while (ItIn != this->m_InputPointData->End())
{
this->m_InputPointData->CastToSTLContainer()[ItIn.Index()] = ItIn.Value() - ItOut.Value();

if (this->GetDebug())
for (unsigned int i = 0; i < ImageDimension; ++i)
{
RealType weight = this->m_PointWeights->GetElement(ItIn.Index());
averageDifference += (ItIn.Value() - ItOut.Value()).GetNorm() * weight;
totalWeight += weight;
if (this->m_CurrentLevel < this->m_NumberOfLevels[i])
{
this->m_CurrentNumberOfControlPoints[i] =
2 * this->m_CurrentNumberOfControlPoints[i] - this->m_SplineOrder[i];
}
}

++ItIn;
++ItOut;
}
if (totalWeight > 0)
{
itkDebugMacro("The average weighted difference norm of the point set is " << averageDifference / totalWeight);
}
itkDebugMacro("Current Level = " << this->m_CurrentLevel);
itkDebugMacro(" Current number of control points = " << this->m_CurrentNumberOfControlPoints);

// Multithread the generation of the control point lattice.
this->BeforeThreadedGenerateData();
multiThreader->SingleMethodExecute();
this->AfterThreadedGenerateData();

this->UpdatePointSet();
}
// Multithread the generation of the control point lattice.
this->BeforeThreadedGenerateData();
multiThreader->SingleMethodExecute();
this->AfterThreadedGenerateData();
}

if (this->m_DoMultilevel)
{
ImageRegionIterator<PointDataImageType> ItPsi(this->m_PsiLattice, this->m_PsiLattice->GetLargestPossibleRegion());
ImageRegionIterator<PointDataImageType> ItPhi(this->m_PhiLattice, this->m_PhiLattice->GetLargestPossibleRegion());
for (ItPsi.GoToBegin(), ItPhi.GoToBegin(); !ItPsi.IsAtEnd(); ++ItPsi, ++ItPhi)
Expand All @@ -346,12 +320,11 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::Generat
duplicator->SetInputImage(this->m_PsiLattice);
duplicator->Update();
this->m_PhiLattice = duplicator->GetOutput();

this->UpdatePointSet();
}

this->m_IsFittingComplete = true;

// Multithread the reconstruction of the sampled B-spline object
if (this->m_GenerateOutputImage)
{
// this->BeforeThreadedGenerateData();
Expand Down Expand Up @@ -424,7 +397,14 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::Threade
{
if (!this->m_IsFittingComplete)
{
this->ThreadedGenerateDataForFitting(region, threadId);
if (this->m_DoUpdatePointSetValues)
{
this->ThreadedGenerateDataForUpdatePointSetValues(region, threadId);
}
else
{
this->ThreadedGenerateDataForFitting(region, threadId);
}
}
else
{
Expand Down Expand Up @@ -907,7 +887,9 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::RefineC

template <typename TInputPointSet, typename TOutputImage>
void
BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::UpdatePointSet()
BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::ThreadedGenerateDataForUpdatePointSetValues(
const RegionType & itkNotUsed(region),
ThreadIdType threadId)
{
const TInputPointSet * input = this->GetInput();
PointDataImagePointer collapsedPhiLattices[ImageDimension + 1];
Expand Down Expand Up @@ -960,14 +942,24 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::UpdateP

typename PointDataImageType::IndexType startPhiIndex = this->m_PhiLattice->GetLargestPossibleRegion().GetIndex();

this->m_OutputPointData->CastToSTLContainer().resize(this->m_InputPointData->Size());
typename PointDataContainerType::ConstIterator ItIn = this->m_InputPointData->Begin();
while (ItIn != this->m_InputPointData->End())
// Determine which points should be handled by this particular thread.

ThreadIdType numberOfWorkUnits = this->GetNumberOfWorkUnits();
auto numberOfPointsPerThread = static_cast<SizeValueType>(input->GetNumberOfPoints() / numberOfWorkUnits);

unsigned int start = threadId * numberOfPointsPerThread;
unsigned int end = start + numberOfPointsPerThread;
if (threadId == this->GetNumberOfWorkUnits() - 1)
{
end = input->GetNumberOfPoints();
}

for (unsigned int n = start; n < end; ++n)
{
PointType point;
point.Fill(0.0);

input->GetPoint(ItIn.Index(), &point);
input->GetPoint(n, &point);

for (unsigned int i = 0; i < ImageDimension; ++i)
{
Expand Down Expand Up @@ -1002,11 +994,11 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::UpdateP
break;
}
}
this->m_OutputPointData->CastToSTLContainer()[ItIn.Index()] = collapsedPhiLattices[0]->GetPixel(startPhiIndex);
++ItIn;
this->m_InputPointData->CastToSTLContainer()[n] -= collapsedPhiLattices[0]->GetPixel(startPhiIndex);
}
}


template <typename TInputPointSet, typename TOutputImage>
void
BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::CollapsePhiLattice(
Expand Down Expand Up @@ -1156,7 +1148,6 @@ BSplineScatteredDataPointSetToImageFilter<TInputPointSet, TOutputImage>::PrintSe
}

itkPrintSelfObjectMacro(InputPointData);
itkPrintSelfObjectMacro(OutputPointData);

os << indent << "Kernel: " << std::endl;
for (unsigned int i = 0; i < ImageDimension; ++i)
Expand Down

0 comments on commit 0ceb091

Please sign in to comment.