From 02a3b468319c1c25f05e9f5fad8f2df94c9725fa Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 13 Nov 2020 10:55:26 +0100 Subject: [PATCH] ENH: A 3D tensor B-Spline approximator and extrapolator This PR finally adds an implementation for B-Spline smoothing and extrapolation of fieldmaps. References: #71, #22. Resolves: #72. Resolves: #14. --- sdcflows/interfaces/bspline.py | 188 +++++++++++++++++++++++++++++++++ setup.cfg | 1 + 2 files changed, 189 insertions(+) create mode 100644 sdcflows/interfaces/bspline.py diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py new file mode 100644 index 0000000000..bf2e344b27 --- /dev/null +++ b/sdcflows/interfaces/bspline.py @@ -0,0 +1,188 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +""" +B-Spline filtering. + + .. testsetup:: + + >>> tmpdir = getfixture('tmpdir') + >>> tmp = tmpdir.chdir() # changing to a temporary directory + >>> nb.Nifti1Image(np.zeros((90, 90, 60)), None, None).to_filename( + ... tmpdir.join('epi.nii.gz').strpath) + +""" +from pathlib import Path +import numpy as np +import nibabel as nb + +from nipype.utils.filemanip import fname_presuffix +from nipype.interfaces.base import ( + BaseInterfaceInputSpec, + TraitedSpec, + File, + traits, + SimpleInterface, + InputMultiObject, + OutputMultiObject, +) + + +DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm +DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm +DEFAULT_HF_ZOOMS_MM = (16.0, 16.0, 10.0) # For human adults (high-frequency), in mm + + +class _BSplineApproxInputSpec(BaseInterfaceInputSpec): + in_data = File(exists=True, mandatory=True, desc="path to a fieldmap") + in_mask = File(exists=True, mandatory=True, desc="path to a brain mask") + bs_spacing = InputMultiObject( + [DEFAULT_ZOOMS_MM], + traits.Tuple(traits.Float, traits.Float, traits.Float), + usedefault=True, + desc="spacing between B-Spline control points", + ) + ridge_alpha = traits.Float( + 1e-4, usedefault=True, desc="controls the regularization" + ) + + +class _BSplineApproxOutputSpec(TraitedSpec): + out_field = File(exists=True) + out_coeff = OutputMultiObject(File(exists=True)) + + +class BSplineApprox(SimpleInterface): + """ + Approximate the field to smooth it removing spikes and extrapolating beyond the brain mask. + + Examples + -------- + + """ + + input_spec = _BSplineApproxInputSpec + output_spec = _BSplineApproxOutputSpec + + def _run_interface(self, runtime): + from gridbspline.maths import cubic + from sklearn import linear_model as lm + + _vbspl = np.vectorize(cubic) + + # Load in the fieldmap + fmapnii = nb.load(self.inputs.in_data) + data = fmapnii.get_fdata() + mask = nb.load(self.inputs.in_mask).get_fdata() > 0 + bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing] + + # Calculate B-Splines grid(s) + bs_levels = [] + for sp in bs_spacing: + bs_levels.append(bspline_grid(fmapnii, control_zooms_mm=sp)) + + # Calculate spatial location of voxels, and normalize per B-Spline grid + fmap_points = grid_coords(fmapnii) + sample_points = [] + for sp in bs_spacing: + sample_points.append((fmap_points / sp).astype("float32")) + + # Calculate the spatial location of control points + bs_x = [] + ncoeff = [] + for sp, level, points in zip(bs_spacing, bs_levels, sample_points): + ncoeff.append(level.dataobj.size) + control_points = grid_coords(level, control_zooms_mm=sp) + bs_x.append(control_points[:, np.newaxis, :] - points[np.newaxis, ...]) + + # Calculate the cubic spline weights per dimension and tensor-product + dist = np.vstack(bs_x) + dist_support = (np.abs(dist) < 2).all(axis=-1) + weights = _vbspl(dist[dist_support]).prod(axis=-1) + + # Compose the interpolation matrix + interp_mat = np.zeros(dist.shape[:2]) + interp_mat[dist_support] = weights + + # Fit the model + model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False) + model.fit( + interp_mat[..., mask.reshape(-1)].T, # Regress only within brainmask + data[mask], + ) + + # Store outputs + out_name = str( + Path( + fname_presuffix( + self.inputs.in_data, suffix="_field", newpath=runtime.cwd + ) + ).absolute() + ) + hdr = fmapnii.header.copy() + hdr.set_data_dtype("float32") + nb.Nifti1Image( + (model.intercept_ + np.array(model.coef_) @ interp_mat) + .astype("float32") # Interpolation + .reshape(data.shape), + fmapnii.affine, + hdr, + ).to_filename(out_name) + self._results["out_field"] = out_name + + index = 0 + self._results["out_coeff"] = [] + for i, (n, bsl) in enumerate(zip(ncoeff, bs_levels)): + out_level = out_name.replace("_field.", f"_coeff{i:03}.") + nb.Nifti1Image( + np.array(model.coef_, dtype="float32")[index : index + n].reshape( + bsl.shape + ), + bsl.affine, + bsl.header, + ).to_filename(out_level) + index += n + self._results["out_coeff"].append(out_level) + return runtime + + +def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM): + """Calculate a Nifti1Image object encoding the location of control points.""" + if isinstance(img, (str, Path)): + img = nb.load(img) + + im_zooms = np.array(img.header.get_zooms()) + im_shape = np.array(img.shape[:3]) + + # Calculate the direction cosines of the target image + dir_cos = img.affine[:3, :3] / im_zooms + + # Initialize the affine of the B-Spline grid + bs_affine = np.diag(np.hstack((np.array(control_zooms_mm) @ dir_cos, 1))) + bs_zooms = nb.affines.voxel_sizes(bs_affine) + + # Calculate the shape of the B-Spline grid + im_extent = im_zooms * (im_shape - 1) + bs_shape = (im_extent // bs_zooms + 3).astype(int) + + # Center both images + im_center = img.affine @ np.hstack((0.5 * (im_shape - 1), 1)) + bs_center = bs_affine @ np.hstack((0.5 * (bs_shape - 1), 1)) + bs_affine[:3, 3] = im_center[:3] - bs_center[:3] + + return nb.Nifti1Image(np.zeros(bs_shape, dtype="float32"), bs_affine) + + +def grid_coords(img, control_zooms_mm=None, dtype="float32"): + """Create a linear space of physical coordinates.""" + if isinstance(img, (str, Path)): + img = nb.load(img) + + grid = np.array( + np.meshgrid(*[range(s) for s in img.shape[:3]]), dtype=dtype + ).reshape(3, -1) + coords = (img.affine @ np.vstack((grid, np.ones(grid.shape[-1])))).T[..., :3] + + if control_zooms_mm is not None: + coords /= np.array(control_zooms_mm) + + return coords.astype(dtype) diff --git a/setup.cfg b/setup.cfg index f7ea5a792e..91f9e96443 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ setup_requires = setuptools_scm >= 3.4 toml install_requires = + gridbspline nibabel >=3.0.1 niflow-nipype1-workflows ~= 0.0.1 nipype >=1.3.1,<2.0