Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Allow gradient compensated data in maxwell_filter #10554

Merged
merged 5 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Enhancements
- Allow an image with intracranial electrode contacts (e.g. computed tomography) to be used without the freesurfer recon-all surfaces to locate contacts so that it doesn't have to be downsampled to freesurfer dimensions (for microelectrodes) and show an example :ref:`ex-ieeg-micro` with :func:`mne.transforms.apply_volume_registration_points` added to aid this transform (:gh:`11567` by `Alex Rockhill`_)
- Use new :meth:`dipy.workflows.align.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_)
- Improve performance of raw data browsing with many annotations (:gh:`11614` by `Eric Larson`_)
- Add support for :func:`mne.preprocessing.maxwell_filter` with gradient-compensated CTF data, e.g., for tSSS-only mode (:gh:`10554` by `Eric Larson`_)
- Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_)

Bugs
Expand Down
11 changes: 8 additions & 3 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3671,7 +3671,8 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
_check_usable, _col_norm_pinv,
_get_n_moments, _get_mf_picks_fix_mags,
_prep_mf_coils, _check_destination,
_remove_meg_projs, _get_coil_scale)
_remove_meg_projs_comps,
_get_coil_scale, _get_sensor_operator)
if head_pos is None:
raise TypeError('head_pos must be provided and cannot be None')
from .chpi import head_pos_to_trans_rot_t
Expand All @@ -3684,7 +3685,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
head_pos = head_pos_to_trans_rot_t(head_pos)
trn, rot, t = head_pos
del head_pos
_check_usable(epochs)
_check_usable(epochs, ignore_ref)
origin = _check_origin(origin, epochs.info, 'head')
recon_trans = _check_destination(destination, epochs.info, True)

Expand All @@ -3697,6 +3698,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
_get_mf_picks_fix_mags(info_to, int_order, ext_order, ignore_ref)
coil_scale, mag_scale = _get_coil_scale(
meg_picks, mag_picks, grad_picks, mag_scale, info_to)
mult = _get_sensor_operator(epochs, meg_picks)
n_channels, n_times = len(epochs.ch_names), len(epochs.times)
other_picks = np.setdiff1d(np.arange(n_channels), meg_picks)
data = np.zeros((n_channels, n_times))
Expand Down Expand Up @@ -3761,6 +3763,9 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
# (We would need to include external here for regularization to work)
exp['ext_order'] = 0
S_recon = _trans_sss_basis(exp, all_coils_recon, recon_trans)
if mult is not None:
S_decomp = mult @ S_decomp
S_recon = mult @ S_recon
exp['ext_order'] = ext_order
# We could determine regularization on basis of destination basis
# matrix, restricted to good channels, as regularizing individual
Expand All @@ -3779,7 +3784,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
evoked = epochs._evoked_from_epoch_data(data, info_to, picks,
n_events=count, kind='average',
comment=epochs._name)
_remove_meg_projs(evoked) # remove MEG projectors, they won't apply now
_remove_meg_projs_comps(evoked, ignore_ref)
logger.info('Created Evoked dataset from %s epochs' % (count,))
return (evoked, mapping) if return_mapping else evoked

Expand Down
37 changes: 15 additions & 22 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ def _dtype(self):
return self._dtype_

@verbose
def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None,
projector=None, verbose=None):
def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, *,
verbose=None):
"""Read a chunk of raw data.

Parameters
Expand Down Expand Up @@ -344,26 +344,22 @@ def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None,

# set up cals and mult (cals, compensation, and projector)
n_out = len(np.arange(len(self.ch_names))[idx])
cals = self._cals.ravel()[np.newaxis, :]
if projector is not None:
assert projector.shape[0] == projector.shape[1] == cals.shape[1]
if self._comp is not None:
cals = self._cals.ravel()
projector, comp = self._projector, self._comp
if comp is not None:
mult = comp
if projector is not None:
mult = self._comp * cals
mult = np.dot(projector[idx], mult)
else:
mult = self._comp[idx] * cals
elif projector is not None:
mult = projector[idx] * cals
mult = projector @ mult
else:
mult = None
del projector
mult = projector
del projector, comp

if mult is None:
cals = cals.T[idx]
cals = cals[idx, np.newaxis]
assert cals.shape == (n_out, 1)
need_idx = idx # sufficient just to read the given channels
else:
mult = mult[idx] * cals
cals = None # shouldn't be used
assert mult.shape == (n_out, len(self.ch_names))
# read all necessary for proj
Expand Down Expand Up @@ -504,8 +500,7 @@ def _preload_data(self, preload):
data_buffer = None
logger.info('Reading %d ... %d = %9.3f ... %9.3f secs...' %
(0, len(self.times) - 1, 0., self.times[-1]))
self._data = self._read_segment(
data_buffer=data_buffer, projector=self._projector)
self._data = self._read_segment(data_buffer=data_buffer)
assert len(self._data) == self.info['nchan']
self.preload = True
self._comp = None # no longer needed
Expand Down Expand Up @@ -752,8 +747,7 @@ def _getitem(self, item, return_times=True):
if self.preload:
data = self._data[sel, start:stop]
else:
data = self._read_segment(start=start, stop=stop, sel=sel,
projector=self._projector)
data = self._read_segment(start=start, stop=stop, sel=sel)

if return_times:
# Rather than compute the entire thing just compute the subset
Expand Down Expand Up @@ -1669,7 +1663,7 @@ def append(self, raws, preload=None):
nsamp = c_ns[-1]

if not self.preload:
this_data = self._read_segment(projector=self._projector)
this_data = self._read_segment()
else:
this_data = self._data

Expand All @@ -1681,8 +1675,7 @@ def append(self, raws, preload=None):
if not raws[ri].preload:
# read the data directly into the buffer
data_buffer = _data[:, c_ns[ri]:c_ns[ri + 1]]
raws[ri]._read_segment(data_buffer=data_buffer,
projector=self._projector)
raws[ri]._read_segment(data_buffer=data_buffer)
else:
_data[:, c_ns[ri]:c_ns[ri + 1]] = raws[ri]._data
self._data = _data
Expand Down
51 changes: 38 additions & 13 deletions mne/preprocessing/maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
quat_to_rot, rot_to_quat)
from ..forward import _concatenate_coils, _prep_meg_channels, _create_meg_coils
from ..surface import _normalize_vectors
from ..io.compensator import make_compensator
from ..io.constants import FIFF, FWD
from ..io.meas_info import _simplify_info, Info
from ..io.proc_history import _read_ctc
Expand Down Expand Up @@ -376,7 +377,7 @@ def _prep_maxwell_filter(

# triage inputs ASAP to avoid late-thrown errors
_validate_type(raw, BaseRaw, 'raw')
_check_usable(raw)
_check_usable(raw, ignore_ref)
_check_regularize(regularize)
st_correlation = float(st_correlation)
if st_correlation <= 0. or st_correlation > 1.:
Expand Down Expand Up @@ -478,7 +479,6 @@ def _prep_maxwell_filter(
exp['extended_proj'] = extended_proj
del extended_proj
# Reconstruct data from internal space only (Eq. 38), and rescale S_recon
S_recon /= coil_scale
if recon_trans is not None:
# warn if we have translated too far
diff = 1000 * (info['dev_head_t']['trans'][:3, 3] -
Expand Down Expand Up @@ -520,13 +520,20 @@ def _prep_maxwell_filter(
np.zeros(3)])
else:
this_pos_quat = None

# Figure out our linear operator
mult = _get_sensor_operator(raw, meg_picks)
if mult is not None:
S_recon = mult @ S_recon
S_recon /= coil_scale

_get_this_decomp_trans = partial(
_get_decomp, all_coils=all_coils,
cal=calibration, regularize=regularize,
exp=exp, ignore_ref=ignore_ref, coil_scale=coil_scale,
grad_picks=grad_picks, mag_picks=mag_picks, good_mask=good_mask,
mag_or_fine=mag_or_fine, bad_condition=bad_condition,
mag_scale=mag_scale)
mag_scale=mag_scale, mult=mult)
update_kwargs.update(
nchan=good_mask.sum(), st_only=st_only, recon_trans=recon_trans)
params = dict(
Expand All @@ -536,15 +543,15 @@ def _prep_maxwell_filter(
this_pos_quat=this_pos_quat, meg_picks=meg_picks,
good_mask=good_mask, grad_picks=grad_picks, head_pos=head_pos,
info=info, _get_this_decomp_trans=_get_this_decomp_trans,
S_recon=S_recon, update_kwargs=update_kwargs)
S_recon=S_recon, update_kwargs=update_kwargs, ignore_ref=ignore_ref)
return params


def _run_maxwell_filter(
raw, skip_by_annotation, st_duration, st_correlation, st_only,
st_when, ctc, coil_scale, this_pos_quat, meg_picks, good_mask,
grad_picks, head_pos, info, _get_this_decomp_trans, S_recon,
update_kwargs,
update_kwargs, *, ignore_ref=False,
reconstruct='in', copy=True):
# Eventually find_bad_channels_maxwell could be sped up by moving this
# outside the loop (e.g., in the prep function) but regularization depends
Expand All @@ -564,7 +571,7 @@ def _run_maxwell_filter(
del raw
if not st_only:
# remove MEG projectors, they won't apply now
_remove_meg_projs(raw_sss)
_remove_meg_projs_comps(raw_sss, ignore_ref)
# Figure out which segments of data we can use
onsets, ends = _annotations_starts_stops(
raw_sss, skip_by_annotation, invert=True)
Expand Down Expand Up @@ -745,7 +752,19 @@ def _get_coil_scale(meg_picks, mag_picks, grad_picks, mag_scale, info):
return coil_scale, mag_scale


def _remove_meg_projs(inst):
def _get_sensor_operator(raw, meg_picks):
comp = raw.compensation_grade
if comp not in (0, None):
mult = make_compensator(raw.info, 0, comp)
logger.info(f' Accounting for compensation grade {comp}')
assert mult.shape[0] == mult.shape[1] == len(raw.ch_names)
mult = mult[np.ix_(meg_picks, meg_picks)]
else:
mult = None
return mult


def _remove_meg_projs_comps(inst, ignore_ref):
"""Remove inplace existing MEG projectors (assumes inactive)."""
meg_picks = pick_types(inst.info, meg=True, exclude=[])
meg_channels = [inst.ch_names[pi] for pi in meg_picks]
Expand All @@ -754,6 +773,10 @@ def _remove_meg_projs(inst):
if not any(c in meg_channels for c in proj['data']['col_names']):
non_meg_proj.append(proj)
inst.add_proj(non_meg_proj, remove_existing=True, verbose=False)
if ignore_ref and inst.info['comps']:
assert inst.compensation_grade in (None, 0)
with inst.info._unlock():
inst.info['comps'] = []


def _check_destination(destination, info, head_frame):
Expand Down Expand Up @@ -959,9 +982,9 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq):
return pos


def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref,
def _get_decomp(trans, *, all_coils, cal, regularize, exp, ignore_ref,
coil_scale, grad_picks, mag_picks, good_mask, mag_or_fine,
bad_condition, t, mag_scale):
bad_condition, t, mag_scale, mult):
"""Get a decomposition matrix and pseudoinverse matrices."""
from scipy import linalg
#
Expand All @@ -970,6 +993,8 @@ def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref,
S_decomp_full = _get_s_decomp(
exp, all_coils, trans, coil_scale, cal, ignore_ref, grad_picks,
mag_picks, mag_scale)
if mult is not None:
S_decomp_full = mult @ S_decomp_full
S_decomp = S_decomp_full[good_mask]
#
# Extended SSS basis (eSSS)
Expand Down Expand Up @@ -1143,16 +1168,16 @@ def _check_regularize(regularize):
raise ValueError('regularize must be None or "in"')


def _check_usable(inst):
def _check_usable(inst, ignore_ref):
"""Ensure our data are clean."""
if inst.proj:
raise RuntimeError('Projectors cannot be applied to data during '
'Maxwell filtering.')
current_comp = inst.compensation_grade
if current_comp not in (0, None):
if current_comp not in (0, None) and ignore_ref:
raise RuntimeError('Maxwell filter cannot be done on compensated '
'channels, but data have been compensated with '
'grade %s.' % current_comp)
'channels (data have been compensated with '
'grade {current_comp}) when ignore_ref=True')


def _col_norm_pinv(x):
Expand Down
48 changes: 35 additions & 13 deletions mne/preprocessing/tests/test_maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,27 +303,49 @@ def test_other_systems():
_assert_shielding(raw_sss_auto, power, 0.7)

# CTF
raw_ctf = read_crop(fname_ctf_raw)
assert raw_ctf.compensation_grade == 3
with pytest.raises(RuntimeError, match='compensated'):
maxwell_filter(raw_ctf)
raw_ctf.apply_gradient_compensation(0)
raw_ctf_3 = read_crop(fname_ctf_raw)
assert raw_ctf_3.compensation_grade == 3
raw_ctf_0 = raw_ctf_3.copy().apply_gradient_compensation(0)
assert raw_ctf_0.compensation_grade == 0
# 3rd-order gradient compensation works really well (better than MF here)
_assert_shielding(raw_ctf_3, raw_ctf_0, 20, 21)
origin = (0., 0., 0.04)
raw_sss_3 = maxwell_filter(raw_ctf_3, origin=origin, verbose=True)
_assert_n_free(raw_sss_3, 70)
_assert_shielding(raw_sss_3, raw_ctf_3, 0.12, 0.14)
_assert_shielding(raw_sss_3, raw_ctf_0, 2.63, 2.66)
assert raw_sss_3.compensation_grade == 3
raw_sss_3.apply_gradient_compensation(0)
assert raw_sss_3.compensation_grade == 0
_assert_shielding(raw_sss_3, raw_ctf_3, 0.15, 0.17)
_assert_shielding(raw_sss_3, raw_ctf_0, 3.18, 3.20)
with pytest.raises(ValueError, match='digitization points'):
maxwell_filter(raw_ctf)
raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04))
_assert_n_free(raw_sss, 68)
_assert_shielding(raw_sss, raw_ctf, 1.8)
maxwell_filter(raw_ctf_0)
raw_sss_0 = maxwell_filter(raw_ctf_0, origin=origin, verbose=True)
_assert_n_free(raw_sss_0, 68)
_assert_shielding(raw_sss_0, raw_ctf_3, 0.07, 0.09)
_assert_shielding(raw_sss_0, raw_ctf_0, 1.8, 1.9)
raw_sss_0.apply_gradient_compensation(3)
_assert_shielding(raw_sss_0, raw_ctf_3, 0.07, 0.09)
_assert_shielding(raw_sss_0, raw_ctf_0, 1.63, 1.67)
with pytest.raises(RuntimeError, match='ignore_ref'):
maxwell_filter(raw_ctf_3, ignore_ref=True)
# ignoring ref outperforms including it in maxwell filtering
with catch_logging() as log:
raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04),
raw_sss = maxwell_filter(raw_ctf_0, origin=origin,
ignore_ref=True, verbose=True)
assert ', 12/15 out' in log.getvalue() # homogeneous fields removed
_assert_n_free(raw_sss, 70)
_assert_shielding(raw_sss, raw_ctf, 12)
raw_sss_auto = maxwell_filter(raw_ctf, origin=(0., 0., 0.04),
_assert_shielding(raw_sss, raw_ctf_0, 12, 13)
# if ignore_ref=True, we remove compensators because they will not
# work the way people expect (it puts noise back in the data!)
with pytest.raises(ValueError, match='Desired compensation.*not found'):
raw_sss.copy().apply_gradient_compensation(3)
raw_sss_auto = maxwell_filter(raw_ctf_0, origin=origin,
ignore_ref=True, mag_scale='auto')
assert_allclose(raw_sss._data, raw_sss_auto._data)
with catch_logging() as log:
maxwell_filter(raw_ctf, origin=(0., 0., 0.04), regularize=None,
maxwell_filter(raw_ctf_0, origin=origin, regularize=None,
ignore_ref=True, verbose=True)
assert '80/80 in, 12/15 out' in log.getvalue() # homogeneous fields

Expand Down