From fc981bd93b7a485480d0d074fe29fbfb23443609 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 14 Apr 2023 10:59:52 -0400 Subject: [PATCH] ENH: Allow gradient compensated data in maxwell_filter (#10554) --- doc/changes/latest.inc | 1 + mne/epochs.py | 11 ++++-- mne/io/base.py | 37 ++++++++---------- mne/preprocessing/maxwell.py | 51 ++++++++++++++++++------- mne/preprocessing/tests/test_maxwell.py | 48 ++++++++++++++++------- 5 files changed, 97 insertions(+), 51 deletions(-) diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 67b9e61f0ff..b1b3f840a7b 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -41,6 +41,7 @@ Enhancements - Allow an image with intracranial electrode contacts (e.g. computed tomography) to be used without the freesurfer recon-all surfaces to locate contacts so that it doesn't have to be downsampled to freesurfer dimensions (for microelectrodes) and show an example :ref:`ex-ieeg-micro` with :func:`mne.transforms.apply_volume_registration_points` added to aid this transform (:gh:`11567` by `Alex Rockhill`_) - Use new :meth:`dipy.workflows.align.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_) - Improve performance of raw data browsing with many annotations (:gh:`11614` by `Eric Larson`_) +- Add support for :func:`mne.preprocessing.maxwell_filter` with gradient-compensated CTF data, e.g., for tSSS-only mode (:gh:`10554` by `Eric Larson`_) - Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_) Bugs diff --git a/mne/epochs.py b/mne/epochs.py index ae0f6736564..099d4651009 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3671,7 +3671,8 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, _check_usable, _col_norm_pinv, _get_n_moments, _get_mf_picks_fix_mags, _prep_mf_coils, _check_destination, - _remove_meg_projs, _get_coil_scale) + _remove_meg_projs_comps, + _get_coil_scale, _get_sensor_operator) if head_pos is None: raise TypeError('head_pos must be provided and cannot be None') from .chpi import head_pos_to_trans_rot_t @@ -3684,7 +3685,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, head_pos = head_pos_to_trans_rot_t(head_pos) trn, rot, t = head_pos del head_pos - _check_usable(epochs) + _check_usable(epochs, ignore_ref) origin = _check_origin(origin, epochs.info, 'head') recon_trans = _check_destination(destination, epochs.info, True) @@ -3697,6 +3698,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, _get_mf_picks_fix_mags(info_to, int_order, ext_order, ignore_ref) coil_scale, mag_scale = _get_coil_scale( meg_picks, mag_picks, grad_picks, mag_scale, info_to) + mult = _get_sensor_operator(epochs, meg_picks) n_channels, n_times = len(epochs.ch_names), len(epochs.times) other_picks = np.setdiff1d(np.arange(n_channels), meg_picks) data = np.zeros((n_channels, n_times)) @@ -3761,6 +3763,9 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, # (We would need to include external here for regularization to work) exp['ext_order'] = 0 S_recon = _trans_sss_basis(exp, all_coils_recon, recon_trans) + if mult is not None: + S_decomp = mult @ S_decomp + S_recon = mult @ S_recon exp['ext_order'] = ext_order # We could determine regularization on basis of destination basis # matrix, restricted to good channels, as regularizing individual @@ -3779,7 +3784,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, evoked = epochs._evoked_from_epoch_data(data, info_to, picks, n_events=count, kind='average', comment=epochs._name) - _remove_meg_projs(evoked) # remove MEG projectors, they won't apply now + _remove_meg_projs_comps(evoked, ignore_ref) logger.info('Created Evoked dataset from %s epochs' % (count,)) return (evoked, mapping) if return_mapping else evoked diff --git a/mne/io/base.py b/mne/io/base.py index 05645de8cf3..96e4e0f2549 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -281,8 +281,8 @@ def _dtype(self): return self._dtype_ @verbose - def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, - projector=None, verbose=None): + def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, *, + verbose=None): """Read a chunk of raw data. Parameters @@ -344,26 +344,22 @@ def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, # set up cals and mult (cals, compensation, and projector) n_out = len(np.arange(len(self.ch_names))[idx]) - cals = self._cals.ravel()[np.newaxis, :] - if projector is not None: - assert projector.shape[0] == projector.shape[1] == cals.shape[1] - if self._comp is not None: + cals = self._cals.ravel() + projector, comp = self._projector, self._comp + if comp is not None: + mult = comp if projector is not None: - mult = self._comp * cals - mult = np.dot(projector[idx], mult) - else: - mult = self._comp[idx] * cals - elif projector is not None: - mult = projector[idx] * cals + mult = projector @ mult else: - mult = None - del projector + mult = projector + del projector, comp if mult is None: - cals = cals.T[idx] + cals = cals[idx, np.newaxis] assert cals.shape == (n_out, 1) need_idx = idx # sufficient just to read the given channels else: + mult = mult[idx] * cals cals = None # shouldn't be used assert mult.shape == (n_out, len(self.ch_names)) # read all necessary for proj @@ -504,8 +500,7 @@ def _preload_data(self, preload): data_buffer = None logger.info('Reading %d ... %d = %9.3f ... %9.3f secs...' % (0, len(self.times) - 1, 0., self.times[-1])) - self._data = self._read_segment( - data_buffer=data_buffer, projector=self._projector) + self._data = self._read_segment(data_buffer=data_buffer) assert len(self._data) == self.info['nchan'] self.preload = True self._comp = None # no longer needed @@ -752,8 +747,7 @@ def _getitem(self, item, return_times=True): if self.preload: data = self._data[sel, start:stop] else: - data = self._read_segment(start=start, stop=stop, sel=sel, - projector=self._projector) + data = self._read_segment(start=start, stop=stop, sel=sel) if return_times: # Rather than compute the entire thing just compute the subset @@ -1669,7 +1663,7 @@ def append(self, raws, preload=None): nsamp = c_ns[-1] if not self.preload: - this_data = self._read_segment(projector=self._projector) + this_data = self._read_segment() else: this_data = self._data @@ -1681,8 +1675,7 @@ def append(self, raws, preload=None): if not raws[ri].preload: # read the data directly into the buffer data_buffer = _data[:, c_ns[ri]:c_ns[ri + 1]] - raws[ri]._read_segment(data_buffer=data_buffer, - projector=self._projector) + raws[ri]._read_segment(data_buffer=data_buffer) else: _data[:, c_ns[ri]:c_ns[ri + 1]] = raws[ri]._data self._data = _data diff --git a/mne/preprocessing/maxwell.py b/mne/preprocessing/maxwell.py index d270d716e5c..b6dba1fc21a 100644 --- a/mne/preprocessing/maxwell.py +++ b/mne/preprocessing/maxwell.py @@ -23,6 +23,7 @@ quat_to_rot, rot_to_quat) from ..forward import _concatenate_coils, _prep_meg_channels, _create_meg_coils from ..surface import _normalize_vectors +from ..io.compensator import make_compensator from ..io.constants import FIFF, FWD from ..io.meas_info import _simplify_info, Info from ..io.proc_history import _read_ctc @@ -376,7 +377,7 @@ def _prep_maxwell_filter( # triage inputs ASAP to avoid late-thrown errors _validate_type(raw, BaseRaw, 'raw') - _check_usable(raw) + _check_usable(raw, ignore_ref) _check_regularize(regularize) st_correlation = float(st_correlation) if st_correlation <= 0. or st_correlation > 1.: @@ -478,7 +479,6 @@ def _prep_maxwell_filter( exp['extended_proj'] = extended_proj del extended_proj # Reconstruct data from internal space only (Eq. 38), and rescale S_recon - S_recon /= coil_scale if recon_trans is not None: # warn if we have translated too far diff = 1000 * (info['dev_head_t']['trans'][:3, 3] - @@ -520,13 +520,20 @@ def _prep_maxwell_filter( np.zeros(3)]) else: this_pos_quat = None + + # Figure out our linear operator + mult = _get_sensor_operator(raw, meg_picks) + if mult is not None: + S_recon = mult @ S_recon + S_recon /= coil_scale + _get_this_decomp_trans = partial( _get_decomp, all_coils=all_coils, cal=calibration, regularize=regularize, exp=exp, ignore_ref=ignore_ref, coil_scale=coil_scale, grad_picks=grad_picks, mag_picks=mag_picks, good_mask=good_mask, mag_or_fine=mag_or_fine, bad_condition=bad_condition, - mag_scale=mag_scale) + mag_scale=mag_scale, mult=mult) update_kwargs.update( nchan=good_mask.sum(), st_only=st_only, recon_trans=recon_trans) params = dict( @@ -536,7 +543,7 @@ def _prep_maxwell_filter( this_pos_quat=this_pos_quat, meg_picks=meg_picks, good_mask=good_mask, grad_picks=grad_picks, head_pos=head_pos, info=info, _get_this_decomp_trans=_get_this_decomp_trans, - S_recon=S_recon, update_kwargs=update_kwargs) + S_recon=S_recon, update_kwargs=update_kwargs, ignore_ref=ignore_ref) return params @@ -544,7 +551,7 @@ def _run_maxwell_filter( raw, skip_by_annotation, st_duration, st_correlation, st_only, st_when, ctc, coil_scale, this_pos_quat, meg_picks, good_mask, grad_picks, head_pos, info, _get_this_decomp_trans, S_recon, - update_kwargs, + update_kwargs, *, ignore_ref=False, reconstruct='in', copy=True): # Eventually find_bad_channels_maxwell could be sped up by moving this # outside the loop (e.g., in the prep function) but regularization depends @@ -564,7 +571,7 @@ def _run_maxwell_filter( del raw if not st_only: # remove MEG projectors, they won't apply now - _remove_meg_projs(raw_sss) + _remove_meg_projs_comps(raw_sss, ignore_ref) # Figure out which segments of data we can use onsets, ends = _annotations_starts_stops( raw_sss, skip_by_annotation, invert=True) @@ -745,7 +752,19 @@ def _get_coil_scale(meg_picks, mag_picks, grad_picks, mag_scale, info): return coil_scale, mag_scale -def _remove_meg_projs(inst): +def _get_sensor_operator(raw, meg_picks): + comp = raw.compensation_grade + if comp not in (0, None): + mult = make_compensator(raw.info, 0, comp) + logger.info(f' Accounting for compensation grade {comp}') + assert mult.shape[0] == mult.shape[1] == len(raw.ch_names) + mult = mult[np.ix_(meg_picks, meg_picks)] + else: + mult = None + return mult + + +def _remove_meg_projs_comps(inst, ignore_ref): """Remove inplace existing MEG projectors (assumes inactive).""" meg_picks = pick_types(inst.info, meg=True, exclude=[]) meg_channels = [inst.ch_names[pi] for pi in meg_picks] @@ -754,6 +773,10 @@ def _remove_meg_projs(inst): if not any(c in meg_channels for c in proj['data']['col_names']): non_meg_proj.append(proj) inst.add_proj(non_meg_proj, remove_existing=True, verbose=False) + if ignore_ref and inst.info['comps']: + assert inst.compensation_grade in (None, 0) + with inst.info._unlock(): + inst.info['comps'] = [] def _check_destination(destination, info, head_frame): @@ -959,9 +982,9 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq): return pos -def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref, +def _get_decomp(trans, *, all_coils, cal, regularize, exp, ignore_ref, coil_scale, grad_picks, mag_picks, good_mask, mag_or_fine, - bad_condition, t, mag_scale): + bad_condition, t, mag_scale, mult): """Get a decomposition matrix and pseudoinverse matrices.""" from scipy import linalg # @@ -970,6 +993,8 @@ def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref, S_decomp_full = _get_s_decomp( exp, all_coils, trans, coil_scale, cal, ignore_ref, grad_picks, mag_picks, mag_scale) + if mult is not None: + S_decomp_full = mult @ S_decomp_full S_decomp = S_decomp_full[good_mask] # # Extended SSS basis (eSSS) @@ -1143,16 +1168,16 @@ def _check_regularize(regularize): raise ValueError('regularize must be None or "in"') -def _check_usable(inst): +def _check_usable(inst, ignore_ref): """Ensure our data are clean.""" if inst.proj: raise RuntimeError('Projectors cannot be applied to data during ' 'Maxwell filtering.') current_comp = inst.compensation_grade - if current_comp not in (0, None): + if current_comp not in (0, None) and ignore_ref: raise RuntimeError('Maxwell filter cannot be done on compensated ' - 'channels, but data have been compensated with ' - 'grade %s.' % current_comp) + 'channels (data have been compensated with ' + 'grade {current_comp}) when ignore_ref=True') def _col_norm_pinv(x): diff --git a/mne/preprocessing/tests/test_maxwell.py b/mne/preprocessing/tests/test_maxwell.py index 06d4c193c70..233879e173a 100644 --- a/mne/preprocessing/tests/test_maxwell.py +++ b/mne/preprocessing/tests/test_maxwell.py @@ -303,27 +303,49 @@ def test_other_systems(): _assert_shielding(raw_sss_auto, power, 0.7) # CTF - raw_ctf = read_crop(fname_ctf_raw) - assert raw_ctf.compensation_grade == 3 - with pytest.raises(RuntimeError, match='compensated'): - maxwell_filter(raw_ctf) - raw_ctf.apply_gradient_compensation(0) + raw_ctf_3 = read_crop(fname_ctf_raw) + assert raw_ctf_3.compensation_grade == 3 + raw_ctf_0 = raw_ctf_3.copy().apply_gradient_compensation(0) + assert raw_ctf_0.compensation_grade == 0 + # 3rd-order gradient compensation works really well (better than MF here) + _assert_shielding(raw_ctf_3, raw_ctf_0, 20, 21) + origin = (0., 0., 0.04) + raw_sss_3 = maxwell_filter(raw_ctf_3, origin=origin, verbose=True) + _assert_n_free(raw_sss_3, 70) + _assert_shielding(raw_sss_3, raw_ctf_3, 0.12, 0.14) + _assert_shielding(raw_sss_3, raw_ctf_0, 2.63, 2.66) + assert raw_sss_3.compensation_grade == 3 + raw_sss_3.apply_gradient_compensation(0) + assert raw_sss_3.compensation_grade == 0 + _assert_shielding(raw_sss_3, raw_ctf_3, 0.15, 0.17) + _assert_shielding(raw_sss_3, raw_ctf_0, 3.18, 3.20) with pytest.raises(ValueError, match='digitization points'): - maxwell_filter(raw_ctf) - raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04)) - _assert_n_free(raw_sss, 68) - _assert_shielding(raw_sss, raw_ctf, 1.8) + maxwell_filter(raw_ctf_0) + raw_sss_0 = maxwell_filter(raw_ctf_0, origin=origin, verbose=True) + _assert_n_free(raw_sss_0, 68) + _assert_shielding(raw_sss_0, raw_ctf_3, 0.07, 0.09) + _assert_shielding(raw_sss_0, raw_ctf_0, 1.8, 1.9) + raw_sss_0.apply_gradient_compensation(3) + _assert_shielding(raw_sss_0, raw_ctf_3, 0.07, 0.09) + _assert_shielding(raw_sss_0, raw_ctf_0, 1.63, 1.67) + with pytest.raises(RuntimeError, match='ignore_ref'): + maxwell_filter(raw_ctf_3, ignore_ref=True) + # ignoring ref outperforms including it in maxwell filtering with catch_logging() as log: - raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04), + raw_sss = maxwell_filter(raw_ctf_0, origin=origin, ignore_ref=True, verbose=True) assert ', 12/15 out' in log.getvalue() # homogeneous fields removed _assert_n_free(raw_sss, 70) - _assert_shielding(raw_sss, raw_ctf, 12) - raw_sss_auto = maxwell_filter(raw_ctf, origin=(0., 0., 0.04), + _assert_shielding(raw_sss, raw_ctf_0, 12, 13) + # if ignore_ref=True, we remove compensators because they will not + # work the way people expect (it puts noise back in the data!) + with pytest.raises(ValueError, match='Desired compensation.*not found'): + raw_sss.copy().apply_gradient_compensation(3) + raw_sss_auto = maxwell_filter(raw_ctf_0, origin=origin, ignore_ref=True, mag_scale='auto') assert_allclose(raw_sss._data, raw_sss_auto._data) with catch_logging() as log: - maxwell_filter(raw_ctf, origin=(0., 0., 0.04), regularize=None, + maxwell_filter(raw_ctf_0, origin=origin, regularize=None, ignore_ref=True, verbose=True) assert '80/80 in, 12/15 out' in log.getvalue() # homogeneous fields