Skip to content

Commit

Permalink
Merge pull request #57 from oesteban/enh/plotting-transparent
Browse files Browse the repository at this point in the history
ENH: Transparency on fieldmap plots!
  • Loading branch information
oesteban authored Nov 22, 2019
2 parents 065a2c4 + 930bf35 commit 4ac16d5
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 20 deletions.
56 changes: 43 additions & 13 deletions sdcflows/interfaces/reportlets.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Interfaces to generate speciality reportlets."""
import numpy as np
from nilearn.image import threshold_img, load_img
from niworkflows import NIWORKFLOWS_LOG
from niworkflows.viz.utils import cuts_from_bbox, compose_view
from nipype.interfaces.base import File, isdefined
from nipype.interfaces.base import File, isdefined, traits
from nipype.interfaces.mixins import reporting

from ..viz.utils import plot_registration, coolwarm_transparent


class _FieldmapReportletInputSpec(reporting.ReportCapableInputSpec):
reference = File(exists=True, mandatory=True, desc='input reference')
moving = File(exists=True, desc='input moving')
fieldmap = File(exists=True, mandatory=True, desc='input fieldmap')
max_alpha = traits.Float(0.7, usedefault=True, desc='maximum alpha channel')
mask = File(exists=True, desc='brain mask')
out_report = File('report.svg', usedefault=True,
desc='filename for the visual report')
show = traits.Enum(1, 0, 'both', usedefault=True,
desc='where the fieldmap should be shown')
reference_label = traits.Str('Reference', usedefault=True,
desc='a label name for the reference mosaic')
moving_label = traits.Str('Fieldmap (Hz)', usedefault=True,
desc='a label name for the reference mosaic')


class FieldmapReportlet(reporting.ReportCapableInterface):
Expand All @@ -37,30 +46,51 @@ def _generate_report(self):
"""Generate a reportlet."""
NIWORKFLOWS_LOG.info('Generating visual report')

refnii = load_img(self.inputs.reference)
movnii = refnii = load_img(self.inputs.reference)
fmapnii = load_img(self.inputs.fieldmap)
contour_nii = load_img(self.inputs.mask) if isdefined(self.inputs.mask) else None
mask_nii = threshold_img(refnii, 1e-3)

if isdefined(self.inputs.moving):
movnii = load_img(self.inputs.moving)

contour_nii = mask_nii = None
if isdefined(self.inputs.mask):
contour_nii = load_img(self.inputs.mask)
maskdata = contour_nii.get_fdata() > 0
else:
mask_nii = threshold_img(refnii, 1e-3)
maskdata = mask_nii.get_fdata() > 0
cuts = cuts_from_bbox(contour_nii or mask_nii, cuts=self._n_cuts)
fmapdata = fmapnii.get_fdata()
vmax = max(fmapdata.max(), abs(fmapdata.min()))
vmax = max(abs(np.percentile(fmapdata[maskdata], 99.8)),
abs(np.percentile(fmapdata[maskdata], 0.2)))

fmap_overlay = [{
'overlay': fmapnii,
'overlay_params': {
'cmap': coolwarm_transparent(max_alpha=self.inputs.max_alpha),
'vmax': vmax,
'vmin': -vmax,
}
}] * 2

if self.inputs.show != 'both':
fmap_overlay[not self.inputs.show] = {}

# Call composer
compose_view(
plot_registration(refnii, 'fixed-image',
plot_registration(movnii, 'moving-image',
estimate_brightness=True,
cuts=cuts,
label='reference',
label=self.inputs.moving_label,
contour=contour_nii,
compress=False),
plot_registration(fmapnii, 'moving-image',
compress=False,
**fmap_overlay[1]),
plot_registration(refnii, 'fixed-image',
estimate_brightness=True,
cuts=cuts,
label='fieldmap (Hz)',
label=self.inputs.reference_label,
contour=contour_nii,
compress=False,
plot_params={'cmap': coolwarm_transparent(),
'vmax': vmax,
'vmin': -vmax}),
**fmap_overlay[0]),
out_file=self._out_report
)
26 changes: 21 additions & 5 deletions sdcflows/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def plot_registration(anat_nii, div_id, plot_params=None,
order=('z', 'x', 'y'), cuts=None,
estimate_brightness=False, label=None, contour=None,
compress='auto'):
compress='auto', overlay=None, overlay_params=None):
"""
Plot the foreground and background views.
Expand All @@ -15,6 +15,7 @@ def plot_registration(anat_nii, div_id, plot_params=None,
from uuid import uuid4

from lxml import etree
import matplotlib.pyplot as plt
from nilearn.plotting import plot_anat
from svgutils.transform import SVGFigure
from niworkflows.viz.utils import robust_set_limits, extract_svg, SVGNS
Expand All @@ -41,6 +42,15 @@ def plot_registration(anat_nii, div_id, plot_params=None,

# Generate nilearn figure
display = plot_anat(anat_nii, **plot_params)
if overlay is not None:
_overlay_params = {
'vmin': overlay.get_fdata().min(),
'vmax': overlay.get_fdata().max(),
'cmap': plt.cm.gray,
'interpolation': 'nearest',
}
_overlay_params.update(overlay_params)
display.add_overlay(overlay, **_overlay_params)
if contour is not None:
display.add_contours(contour, colors='g', levels=[0.5],
linewidths=0.5)
Expand All @@ -60,7 +70,7 @@ def plot_registration(anat_nii, div_id, plot_params=None,
return out_files


def coolwarm_transparent():
def coolwarm_transparent(max_alpha=0.7, opaque_perc=30, transparent_perc=8):
"""Modify the coolwarm color scale to have full transparency around the middle."""
import numpy as np
import matplotlib.pylab as pl
Expand All @@ -72,9 +82,15 @@ def coolwarm_transparent():
# Get the colormap colors
my_cmap = cmap(np.arange(cmap.N))

_20perc = (cmap.N * opaque_perc) // 100
midpoint = cmap.N // 2 + 1
_10perc = (cmap.N * transparent_perc) // 100
# Set alpha
alpha = np.ones(cmap.N)
alpha[128:160] = np.linspace(0, 1, len(alpha[128:160]))
alpha[96:128] = np.linspace(1, 0, len(alpha[96:128]))
alpha = np.ones(cmap.N) * max_alpha
alpha[midpoint - _10perc:midpoint + _10perc] = 0
alpha[_20perc:midpoint - _10perc - 1] = np.linspace(
max_alpha, 0, len(alpha[_20perc:midpoint - _10perc - 1]))
alpha[midpoint + _10perc:-_20perc] = np.linspace(
0, max_alpha, len(alpha[midpoint + _10perc:-_20perc]))
my_cmap[:, -1] = alpha
return ListedColormap(my_cmap)
2 changes: 1 addition & 1 deletion sdcflows/workflows/tests/test_pepolar.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_pepolar_wf1(bids_layouts, output_path, dataset, workdir):
wf.inputs.inputnode.in_reference_brain = boldref.path
wf.inputs.inputnode.in_reference = boldref.path

rep = pe.Node(FieldmapReportlet(), 'simple_report')
rep = pe.Node(FieldmapReportlet(reference_label='EPI Reference'), 'simple_report')
rep.interface._always_run = True
dsink = pe.Node(DerivativesDataSink(
base_directory=str(output_path), keep_dtype=True,
Expand Down
2 changes: 1 addition & 1 deletion sdcflows/workflows/tests/test_phdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_workflow(bids_layouts, tmpdir, output_path, dataset, workdir):

if output_path:
from ...interfaces.reportlets import FieldmapReportlet
rep = pe.Node(FieldmapReportlet(), 'simple_report')
rep = pe.Node(FieldmapReportlet(reference_label='Magnitude'), 'simple_report')
rep.interface._always_run = True

dsink = pe.Node(DerivativesDataSink(
Expand Down

0 comments on commit 4ac16d5

Please sign in to comment.