Skip to content

Commit

Permalink
ENH: Add masked/bspline fitting variant of Nyul histogram matching.
Browse files Browse the repository at this point in the history
  • Loading branch information
ntustison committed Apr 12, 2024
1 parent f79e9f3 commit d20897f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ants/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .get_mask import get_mask
from .get_neighborhood import (get_neighborhood_in_mask,
get_neighborhood_at_voxel)
from .histogram_match_image import histogram_match_image
from .histogram_match_image import histogram_match_image, histogram_match_image2
from .histogram_equalize_image import histogram_equalize_image
from .hausdorff_distance import hausdorff_distance
from .image_similarity import image_similarity
Expand Down
108 changes: 105 additions & 3 deletions ants/utils/histogram_match_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@

__all__ = ['histogram_match_image']
__all__ = ['histogram_match_image',
'histogram_match_image2']

import math
import numpy as np

from ..core import ants_image as iio
from ..core import ants_image_io as iio
from .. import utils

from ..utils import fit_bspline_object_to_scattered_data


def histogram_match_image(source_image, reference_image, number_of_histogram_bins=255, number_of_match_points=64, use_threshold_at_mean_intensity=False):
"""
Expand Down Expand Up @@ -51,3 +54,102 @@ def histogram_match_image(source_image, reference_image, number_of_histogram_bin
return new_image


def histogram_match_image2(source_image, reference_image,
source_mask=None, reference_mask=None,
match_points=64,
transform_domain_size=255):
"""
Transform image intensities based on histogram mapping.
Apply B-spline 1-D maps to an input image for intensity warping.
Arguments
---------
source_image : ANTsImage
source image
reference_image : ANTsImage
reference image
source_mask : ANTsImage
source mask
reference_mask : ANTsImage
reference mask
match_points : integer or tuple
Parametric points at which the intensity transform displacements are
specified between [0, 1], i.e. quantiles. Alternatively, a single number
can be given and the sequence is linearly spaced in [0, 1].
transform_domain_size : integer
Defines the sampling resolution of the B-spline warping.
Returns
-------
ANTs image
Example
-------
>>> import ants
>>> src_img = ants.image_read(ants.get_data('r16'))
>>> ref_img = ants.image_read(ants.get_data('r64'))
>>> src_ref = ants.histogram_match_image(src_img, ref_img)
"""

if not isinstance(match_points, int):
if any(b < 0 for b in match_points) and any(b > 1 for b in match_points):
raise ValueError("If specifying match_points as a vector, values must be in the range [0, 1]")

# Use entire image if mask isn't specified
if source_mask is None:
source_mask = source_image * 0 + 1
if reference_mask is None:
reference_mask = reference_image * 0 + 1

source_array = source_image.numpy()
source_mask_array = source_mask.numpy()
source_masked_min = source_image[source_mask != 0].min()
source_masked_max = source_image[source_mask != 0].max()

reference_array = reference_image.numpy()
reference_mask_array = reference_mask.numpy()

parametric_points = None
if not isinstance(match_points, int):
parametric_points = match_points
else:
parametric_points = np.linspace(0, 1, match_points)

source_intensity_quantiles = np.quantile(source_array[source_mask_array != 0], parametric_points)
reference_intensity_quantiles = np.quantile(reference_array[reference_mask_array != 0], parametric_points)
displacements = reference_intensity_quantiles - source_intensity_quantiles

scattered_data = np.reshape(displacements, (len(displacements), 1))
parametric_data = np.reshape(parametric_points * (source_masked_max - source_masked_min) + source_masked_min, (len(parametric_points), 1))

transform_domain_origin = source_masked_min
transform_domain_spacing = (source_masked_max - transform_domain_origin) / (transform_domain_size - 1)

bspline_histogram_transform = fit_bspline_object_to_scattered_data(scattered_data,
parametric_data, [transform_domain_origin], [transform_domain_spacing], [transform_domain_size],
data_weights=None, is_parametric_dimension_closed=None, number_of_fitting_levels=8,
mesh_size=1, spline_order=3)

transform_domain = np.linspace(source_masked_min, source_masked_max, transform_domain_size)

transformed_source_array = source_image.numpy()
for i in range(len(transform_domain) - 1):
indices = np.where((source_array >= transform_domain[i]) & (source_array < transform_domain[i+1]))
intensities = source_array[indices]

alpha = (intensities - transform_domain[i])/(transform_domain[i+1] - transform_domain[i])
xfrm = alpha * (bspline_histogram_transform[i+1] - bspline_histogram_transform[i]) + bspline_histogram_transform[i]
transformed_source_array[indices] = intensities + xfrm

transformed_source_image = iio.from_numpy(transformed_source_array, origin=source_image.origin,
spacing=source_image.spacing, direction=source_image.direction)
transformed_source_image[source_mask == 0] = source_image[source_mask == 0]

return(transformed_source_image)

0 comments on commit d20897f

Please sign in to comment.