Skip to content

Commit

Permalink
enh: refactor the distortion estimation workflow
Browse files Browse the repository at this point in the history
Close #16.
  • Loading branch information
oesteban committed Nov 20, 2019
1 parent 9c520a6 commit c2f48fd
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 195 deletions.
248 changes: 122 additions & 126 deletions sdcflows/workflows/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""SDC workflows coordination."""
from collections import defaultdict

from nipype.pipeline import engine as pe
from nipype.interfaces import utility as niu
from nipype import logging

from niworkflows.engine.workflows import LiterateWorkflow as Workflow

# Fieldmap workflows
from .pepolar import init_pepolar_unwarp_wf

LOGGER = logging.getLogger('nipype.workflow')
FMAP_PRIORITY = {
Expand All @@ -23,12 +19,12 @@
DEFAULT_MEMORY_MIN_GB = 0.01


def init_sdc_wf(distorted_ref, omp_nthreads=1, debug=False, ignore=None):
def init_sdc_estimate_wf(fmaps, epi_meta, omp_nthreads=1, debug=False, ignore=None):
"""
Build a :abbr:`SDC (susceptibility distortion correction)` workflow.
This workflow implements the heuristics to choose a
:abbr:`SDC (susceptibility distortion correction)` strategy.
This workflow implements the heuristics to choose an estimation
methodology for :abbr:`SDC (susceptibility distortion correction)`.
When no field map information is present within the BIDS inputs,
the EXPERIMENTAL "fieldmap-less SyN" can be performed, using
the ``--use-syn`` argument. When ``--force-syn`` is specified,
Expand All @@ -41,41 +37,40 @@ def init_sdc_wf(distorted_ref, omp_nthreads=1, debug=False, ignore=None):
Parameters
----------
distorted_ref : pybids.BIDSFile
A BIDSFile object with suffix ``bold``, ``sbref`` or ``dwi``.
fmaps : list of pybids dicts
A list of dictionaries with the available fieldmaps
(and their metadata using the key ``'metadata'`` for the
case of :abbr:`PEPOLAR (Phase-Encoding POLARity)` fieldmaps).
epi_meta : dict
BIDS metadata dictionary corresponding to the
:abbr:`EPI (echo-planar imaging)` run (i.e., suffix ``bold``,
``sbref``, or ``dwi``) for which the fieldmap is being estimated.
omp_nthreads : int
Maximum number of threads an individual process may use
debug : bool
Enable debugging outputs
Inputs
------
distorted_ref
epi_file
A reference image calculated at a previous stage
ref_brain
epi_brain
Same as above, but brain-masked
ref_mask
epi_mask
Brain mask for the run
t1_brain
t1w_brain
T1w image, brain-masked, for the fieldmap-less SyN method
std2anat_xfm
List of standard-to-T1w transforms generated during spatial
Standard-to-T1w transform generated during spatial
normalization (only for the fieldmap-less SyN method).
template : str
Name of template from which prior knowledge will be mapped
into the subject's T1w reference
(only for the fieldmap-less SyN method)
templates : str
Name of templates that index the ``std2anat_xfm`` input list
(only for the fieldmap-less SyN method).
Outputs
-------
distorted_ref
epi_file
An unwarped BOLD reference
bold_mask
epi_mask
The corresponding new mask after unwarping
bold_ref_brain
epi_brain
Brain-extracted, unwarped BOLD reference
out_warp
The deformation field to unwarp the susceptibility distortions
Expand All @@ -90,19 +85,16 @@ def init_sdc_wf(distorted_ref, omp_nthreads=1, debug=False, ignore=None):
if not isinstance(ignore, (list, tuple)):
ignore = tuple(ignore)

fmaps = defaultdict(list, [])
for associated in distorted_ref.get_associations(kind='InformedBy'):
if associated.suffix in list(FMAP_PRIORITY.keys()):
fmaps[associated.suffix].append(associated)
# TODO: To be removed (filter out unsupported fieldmaps):
fmaps = [fmap for fmap in fmaps if fmap['suffix'] in FMAP_PRIORITY]

workflow = Workflow(name='sdc_wf' if distorted_ref else 'sdc_bypass_wf')
workflow = Workflow(name='sdc_estimate_wf' if fmaps else 'sdc_bypass_wf')
inputnode = pe.Node(niu.IdentityInterface(
fields=['distorted_ref', 'ref_brain', 'ref_mask',
't1_brain', 'std2anat_xfm', 'template', 'templates']),
fields=['epi_file', 'epi_brain', 'epi_mask', 't1w_brain', 'std2anat_xfm']),
name='inputnode')

outputnode = pe.Node(niu.IdentityInterface(
fields=['output_ref', 'ref_mask', 'ref_brain',
fields=['output_ref', 'epi_mask', 'epi_brain',
'out_warp', 'syn_ref', 'method']),
name='outputnode')

Expand All @@ -115,121 +107,125 @@ def init_sdc_wf(distorted_ref, omp_nthreads=1, debug=False, ignore=None):
"""
outputnode.inputs.method = 'None'
workflow.connect([
(inputnode, outputnode, [('distorted_ref', 'output_ref'),
('ref_mask', 'ref_mask'),
('ref_brain', 'ref_brain')]),
(inputnode, outputnode, [('epi_file', 'output_ref'),
('epi_mask', 'epi_mask'),
('epi_brain', 'epi_brain')]),
])
return workflow

workflow.__postdesc__ = """\
Based on the estimated susceptibility distortion, an
unwarped BOLD reference was calculated for a more accurate
co-registration with the anatomical reference.
Based on the estimated susceptibility distortion, an unwarped
EPI (echo-planar imaging) reference was calculated for a more
accurate co-registration with the anatomical reference.
"""

# In case there are multiple fieldmaps prefer EPI
fmaps.sort(key=lambda fmap: FMAP_PRIORITY[fmap['suffix']])
fmap = fmaps[0]

# PEPOLAR path
if 'epi' in fmaps:
from .pepolar import init_pepolar_unwarp_wf, check_pes
outputnode.inputs.method = 'PEB/PEPOLAR (phase-encoding based / PE-POLARity)'

# Filter out EPI fieldmaps to be used
fmaps_epi = [(epi.path, epi.get_metadata()['PhaseEncodingDirection'])
for epi in fmaps['epi']]

# Find matched PE directions
matched_pe = check_pes(fmaps_epi, epi_meta['PhaseEncodingDirection'])

# Get EPI polarities and their metadata
sdc_unwarp_wf = init_pepolar_unwarp_wf(
bold_meta=distorted_ref.get_metadata(),
epi_fmaps=[(fmap, fmap.get_metadata()["PhaseEncodingDirection"])
for fmap in fmaps['epi']],
omp_nthreads=omp_nthreads,
name='pepolar_unwarp_wf')
matched_pe=matched_pe,
omp_nthreads=omp_nthreads)
sdc_unwarp_wf.inputs.inputnode.epi_pe_dir = epi_meta['PhaseEncodingDirection']
sdc_unwarp_wf.inputs.inputnode.fmaps_epi = fmaps_epi

workflow.connect([
(inputnode, sdc_unwarp_wf, [
('distorted_ref', 'inputnode.in_reference'),
('bold_mask', 'inputnode.in_mask'),
('bold_ref_brain', 'inputnode.in_reference_brain')]),
('epi_file', 'inputnode.in_reference'),
('epi_brain', 'inputnode.in_reference_brain'),
('epi_mask', 'inputnode.in_mask')]),
])

# FIELDMAP path
# elif 'fieldmap' in fmaps:
# # Import specific workflows here, so we don't break everything with one
# # unused workflow.
# suffices = {f.suffix for f in fmaps['fieldmap']}
# if 'fieldmap' in suffices:
# from .fmap import init_fmap_wf
# outputnode.inputs.method = 'FMB (fieldmap-based)'
# fmap_estimator_wf = init_fmap_wf(
# omp_nthreads=omp_nthreads,
# fmap_bspline=False)
# # set inputs
# fmap_estimator_wf.inputs.inputnode.fieldmap = fmap['fieldmap']
# fmap_estimator_wf.inputs.inputnode.magnitude = fmap['magnitude']

# if fmap['suffix'] == 'phasediff':
# from .phdiff import init_phdiff_wf
# fmap_estimator_wf = init_phdiff_wf(omp_nthreads=omp_nthreads)
# # set inputs
# fmap_estimator_wf.inputs.inputnode.phasediff = fmap['phasediff']
# fmap_estimator_wf.inputs.inputnode.magnitude = [
# fmap_ for key, fmap_ in sorted(fmap.items())
# if key.startswith("magnitude")
# ]

# sdc_unwarp_wf = init_sdc_unwarp_wf(
# omp_nthreads=omp_nthreads,
# fmap_demean=fmap_demean,
# debug=debug,
# name='sdc_unwarp_wf')
# sdc_unwarp_wf.inputs.inputnode.metadata = bold_meta

# workflow.connect([
# (inputnode, sdc_unwarp_wf, [
# ('distorted_ref', 'inputnode.in_reference'),
# ('bold_ref_brain', 'inputnode.in_reference_brain'),
# ('bold_mask', 'inputnode.in_mask')]),
# (fmap_estimator_wf, sdc_unwarp_wf, [
# ('outputnode.fmap', 'inputnode.fmap'),
# ('outputnode.fmap_ref', 'inputnode.fmap_ref'),
# ('outputnode.fmap_mask', 'inputnode.fmap_mask')]),
# ])

# # FIELDMAP-less path
# if any(fm['suffix'] == 'syn' for fm in fmaps):
# # Select template
# sdc_select_std = pe.Node(KeySelect(
# fields=['std2anat_xfm']),
# name='sdc_select_std', run_without_submitting=True)

# syn_sdc_wf = init_syn_sdc_wf(
# bold_pe=bold_meta.get('PhaseEncodingDirection', None),
# omp_nthreads=omp_nthreads)

# workflow.connect([
# (inputnode, sdc_select_std, [
# ('template', 'key'),
# ('templates', 'keys'),
# ('std2anat_xfm', 'std2anat_xfm')]),
# (sdc_select_std, syn_sdc_wf, [
# ('std2anat_xfm', 'inputnode.std2anat_xfm')]),
# (inputnode, syn_sdc_wf, [
# ('t1_brain', 'inputnode.t1_brain'),
# ('distorted_ref', 'inputnode.distorted_ref'),
# ('bold_ref_brain', 'inputnode.bold_ref_brain'),
# ('template', 'inputnode.template')]),
# ])

# # XXX Eliminate branch when forcing isn't an option
# if fmap['suffix'] == 'syn': # No fieldmaps, but --use-syn
# outputnode.inputs.method = 'FLB ("fieldmap-less", SyN-based)'
# sdc_unwarp_wf = syn_sdc_wf
# else: # --force-syn was called when other fieldmap was present
# sdc_unwarp_wf.__desc__ = None
# workflow.connect([
# (syn_sdc_wf, outputnode, [
# ('outputnode.out_reference', 'syn_bold_ref')]),
# ])
elif 'fieldmap' in fmaps:
from .unwarp import init_sdc_unwarp_wf
# Import specific workflows here, so we don't break everything with one
# unused workflow.
suffices = {f.suffix for f in fmaps['fieldmap']}
if 'fieldmap' in suffices:
from .fmap import init_fmap_wf
outputnode.inputs.method = 'FMB (fieldmap-based)'
fmap_wf = init_fmap_wf(
omp_nthreads=omp_nthreads,
fmap_bspline=False)
# set inputs
fmap_wf.inputs.inputnode.magnitude = fmap['magnitude']
fmap_wf.inputs.inputnode.fieldmap = fmap['fieldmap']
elif 'phasediff' in suffices:
from .phdiff import init_phdiff_wf
fmap_wf = init_phdiff_wf(omp_nthreads=omp_nthreads)
# set inputs
fmap_wf.inputs.inputnode.phasediff = fmap['phasediff']
fmap_wf.inputs.inputnode.magnitude = [
fmap_ for key, fmap_ in sorted(fmap.items())
if key.startswith("magnitude")
]
else:
raise ValueError('Fieldmaps of types %s are not supported' %
', '.join(['"%s"' % f for f in suffices]))

sdc_unwarp_wf = init_sdc_unwarp_wf(
omp_nthreads=omp_nthreads,
debug=debug,
name='sdc_unwarp_wf')
sdc_unwarp_wf.inputs.inputnode.metadata = epi_meta

workflow.connect([
(inputnode, sdc_unwarp_wf, [
('epi_file', 'inputnode.in_reference'),
('epi_brain', 'inputnode.in_reference_brain'),
('epi_mask', 'inputnode.in_mask')]),
(fmap_wf, sdc_unwarp_wf, [
('outputnode.fmap', 'inputnode.fmap'),
('outputnode.fmap_ref', 'inputnode.fmap_ref'),
('outputnode.fmap_mask', 'inputnode.fmap_mask')]),
])

# FIELDMAP-less path
if any(fm['suffix'] == 'syn' for fm in fmaps):
from .syn import init_syn_sdc_wf
syn_sdc_wf = init_syn_sdc_wf(
bold_pe=epi_meta.get('PhaseEncodingDirection', None),
omp_nthreads=omp_nthreads)

workflow.connect([
(inputnode, syn_sdc_wf, [
('t1w_brain', 'inputnode.t1w_brain'),
('epi_file', 'inputnode.epi_file'),
('epi_brain', 'inputnode.epi_brain'),
('std2anat_xfm', 'inputnode.std2anat_xfm')]),
])

# XXX Eliminate branch when forcing isn't an option
if fmap['suffix'] == 'syn': # No fieldmaps, but --use-syn
outputnode.inputs.method = 'FLB ("fieldmap-less", SyN-based)'
sdc_unwarp_wf = syn_sdc_wf
else: # --force-syn was called when other fieldmap was present
sdc_unwarp_wf.__desc__ = None
workflow.connect([
(syn_sdc_wf, outputnode, [
('outputnode.out_reference', 'syn_bold_ref')]),
])

workflow.connect([
(sdc_unwarp_wf, outputnode, [
('outputnode.out_warp', 'out_warp'),
('outputnode.out_reference', 'distorted_ref'),
('outputnode.out_reference_brain', 'bold_ref_brain'),
('outputnode.out_mask', 'bold_mask')]),
('outputnode.out_reference', 'epi_file'),
('outputnode.out_reference_brain', 'epi_brain'),
('outputnode.out_mask', 'epi_mask')]),
])

return workflow
12 changes: 6 additions & 6 deletions sdcflows/workflows/pepolar.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ def init_pepolar_unwarp_wf(omp_nthreads=1, matched_pe=False,
fmaps_epi : list of tuple(pathlike, str)
The list of EPI images that will be used in PE-Polar correction, and
their corresponding ``PhaseEncodingDirection`` metadata.
The workflow will use the ``bold_pe_dir`` input to separate out those
The workflow will use the ``epi_pe_dir`` input to separate out those
EPI acquisitions with opposed PE blips and those with matched PE blips
(the latter could be none, and ``in_reference_brain`` would then be
used). The workflow raises a ``ValueError`` when no images with
opposed PE blips are found.
bold_pe_dir : str
epi_pe_dir : str
The baseline PE direction.
in_reference : pathlike
The baseline reference image (must correspond to ``bold_pe_dir``).
The baseline reference image (must correspond to ``epi_pe_dir``).
in_reference_brain : pathlike
The reference image above, but skullstripped.
in_mask : pathlike
Expand Down Expand Up @@ -110,7 +110,7 @@ def init_pepolar_unwarp_wf(omp_nthreads=1, matched_pe=False,

inputnode = pe.Node(niu.IdentityInterface(
fields=['fmaps_epi', 'in_reference', 'in_reference_brain',
'in_mask', 'bold_pe_dir']), name='inputnode')
'in_mask', 'epi_pe_dir']), name='inputnode')

outputnode = pe.Node(niu.IdentityInterface(
fields=['out_reference', 'out_reference_brain', 'out_warp', 'out_mask']),
Expand Down Expand Up @@ -140,11 +140,11 @@ def init_pepolar_unwarp_wf(omp_nthreads=1, matched_pe=False,
omp_nthreads=omp_nthreads)

workflow.connect([
(inputnode, qwarp, [(('bold_pe_dir', _qwarp_args), 'args')]),
(inputnode, qwarp, [(('epi_pe_dir', _qwarp_args), 'args')]),
(inputnode, cphdr_warp, [('in_reference', 'hdr_file')]),
(inputnode, prepare_epi_wf, [
('fmaps_epi', 'inputnode.maps_pe'),
('bold_pe_dir', 'inputnode.epi_pe'),
('epi_pe_dir', 'inputnode.epi_pe'),
('in_reference_brain', 'inputnode.ref_brain')]),
(prepare_epi_wf, qwarp, [('outputnode.opposed_pe', 'base_file'),
('outputnode.matched_pe', 'in_file')]),
Expand Down
Loading

0 comments on commit c2f48fd

Please sign in to comment.