diff --git a/sdcflows/interfaces/reportlets.py b/sdcflows/interfaces/reportlets.py index 90b7b99fa4..caeabc355b 100644 --- a/sdcflows/interfaces/reportlets.py +++ b/sdcflows/interfaces/reportlets.py @@ -1,10 +1,11 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Interfaces to generate speciality reportlets.""" +import numpy as np from nilearn.image import threshold_img, load_img from niworkflows import NIWORKFLOWS_LOG from niworkflows.viz.utils import cuts_from_bbox, compose_view -from nipype.interfaces.base import File, isdefined +from nipype.interfaces.base import File, isdefined, traits from nipype.interfaces.mixins import reporting from ..viz.utils import plot_registration, coolwarm_transparent @@ -12,10 +13,14 @@ class _FieldmapReportletInputSpec(reporting.ReportCapableInputSpec): reference = File(exists=True, mandatory=True, desc='input reference') + moving = File(exists=True, desc='input moving') fieldmap = File(exists=True, mandatory=True, desc='input fieldmap') + max_alpha = traits.Float(0.7, usedefault=True, desc='maximum alpha channel') mask = File(exists=True, desc='brain mask') out_report = File('report.svg', usedefault=True, desc='filename for the visual report') + show = traits.Enum(1, 0, 'both', usedefault=True, + desc='where the fieldmap should be shown') class FieldmapReportlet(reporting.ReportCapableInterface): @@ -39,28 +44,46 @@ def _generate_report(self): refnii = load_img(self.inputs.reference) fmapnii = load_img(self.inputs.fieldmap) - contour_nii = load_img(self.inputs.mask) if isdefined(self.inputs.mask) else None - mask_nii = threshold_img(refnii, 1e-3) + + contour_nii = mask_nii = None + if isdefined(self.inputs.mask): + contour_nii = load_img(self.inputs.mask) + maskdata = contour_nii.get_fdata() > 0 + else: + mask_nii = threshold_img(refnii, 1e-3) + maskdata = mask_nii.get_fdata() > 0 cuts = cuts_from_bbox(contour_nii or mask_nii, cuts=self._n_cuts) fmapdata = fmapnii.get_fdata() - vmax = max(fmapdata.max(), abs(fmapdata.min())) + vmax = max(abs(np.percentile(fmapdata[maskdata], 99.8)), + abs(np.percentile(fmapdata[maskdata], 0.2))) + + fmap_overlay = [{ + 'overlay': fmapnii, + 'overlay_params': { + 'cmap': coolwarm_transparent(max_alpha=self.inputs.max_alpha), + 'vmax': vmax, + 'vmin': -vmax, + } + }] * 2 + + if self.inputs.show != 'both': + fmap_overlay[not self.inputs.show] = {} # Call composer compose_view( - plot_registration(refnii, 'fixed-image', + plot_registration(refnii, 'moving-image', estimate_brightness=True, cuts=cuts, - label='reference', + label='fieldmap (Hz)', contour=contour_nii, - compress=False), - plot_registration(fmapnii, 'moving-image', + compress=False, + **fmap_overlay[1]), + plot_registration(refnii, 'fixed-image', estimate_brightness=True, cuts=cuts, - label='fieldmap (Hz)', + label='reference', contour=contour_nii, compress=False, - plot_params={'cmap': coolwarm_transparent(), - 'vmax': vmax, - 'vmin': -vmax}), + **fmap_overlay[0]), out_file=self._out_report ) diff --git a/sdcflows/viz/utils.py b/sdcflows/viz/utils.py index d17cb38e36..af8cffb648 100644 --- a/sdcflows/viz/utils.py +++ b/sdcflows/viz/utils.py @@ -6,7 +6,7 @@ def plot_registration(anat_nii, div_id, plot_params=None, order=('z', 'x', 'y'), cuts=None, estimate_brightness=False, label=None, contour=None, - compress='auto'): + compress='auto', overlay=None, overlay_params=None): """ Plot the foreground and background views. @@ -15,6 +15,7 @@ def plot_registration(anat_nii, div_id, plot_params=None, from uuid import uuid4 from lxml import etree + import matplotlib.pyplot as plt from nilearn.plotting import plot_anat from svgutils.transform import SVGFigure from niworkflows.viz.utils import robust_set_limits, extract_svg, SVGNS @@ -41,6 +42,15 @@ def plot_registration(anat_nii, div_id, plot_params=None, # Generate nilearn figure display = plot_anat(anat_nii, **plot_params) + if overlay is not None: + _overlay_params = { + 'vmin': overlay.get_fdata().min(), + 'vmax': overlay.get_fdata().max(), + 'cmap': plt.cm.gray, + 'interpolation': 'nearest', + } + _overlay_params.update(overlay_params) + display.add_overlay(overlay, **_overlay_params) if contour is not None: display.add_contours(contour, colors='g', levels=[0.5], linewidths=0.5) @@ -60,7 +70,7 @@ def plot_registration(anat_nii, div_id, plot_params=None, return out_files -def coolwarm_transparent(): +def coolwarm_transparent(max_alpha=0.7, opaque_perc=30, transparent_perc=8): """Modify the coolwarm color scale to have full transparency around the middle.""" import numpy as np import matplotlib.pylab as pl @@ -72,9 +82,15 @@ def coolwarm_transparent(): # Get the colormap colors my_cmap = cmap(np.arange(cmap.N)) + _20perc = (cmap.N * opaque_perc) // 100 + midpoint = cmap.N // 2 + 1 + _10perc = (cmap.N * transparent_perc) // 100 # Set alpha - alpha = np.ones(cmap.N) - alpha[128:160] = np.linspace(0, 1, len(alpha[128:160])) - alpha[96:128] = np.linspace(1, 0, len(alpha[96:128])) + alpha = np.ones(cmap.N) * max_alpha + alpha[midpoint - _10perc:midpoint + _10perc] = 0 + alpha[_20perc:midpoint - _10perc - 1] = np.linspace( + max_alpha, 0, len(alpha[_20perc:midpoint - _10perc - 1])) + alpha[midpoint + _10perc:-_20perc] = np.linspace( + 0, max_alpha, len(alpha[midpoint + _10perc:-_20perc])) my_cmap[:, -1] = alpha return ListedColormap(my_cmap)