Skip to content

Commit

Permalink
MRG, FIX: Fix epochs split size (#7740)
Browse files Browse the repository at this point in the history
* FIX: Fix epochs split size

* FIX: Fix for sklearn dep

* FIX: Test
  • Loading branch information
larsoner authored May 6, 2020
1 parent ef16a2d commit d0dc8f3
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 20 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ Bug

- Fix bug with :func:`mne.read_epochs` when loading data in complex format with ``preload=False`` by `Eric Larson`_

- Fix bug with :meth:`mne.Epochs.save` where the file splitting calculations did not account for the sizes of non-data writes by `Eric Larson`_

- Fix bug with :class:`mne.Report` where the BEM section could not be toggled by `Eric Larson`_

- Fix bug when using :meth:`mne.Epochs.crop` to exclude the baseline period would break :func:`mne.Epochs.save` / :func:`mne.read_epochs` round-trip by `Eric Larson`_
Expand Down
6 changes: 4 additions & 2 deletions mne/decoding/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ def __init__(self, info=None, scalings=None, with_mean=True,
self._scaler = _ConstantScaler(info, scalings, self.with_std)
elif scalings == 'mean':
from sklearn.preprocessing import StandardScaler
self._scaler = StandardScaler(self.with_mean, self.with_std)
self._scaler = StandardScaler(
with_mean=self.with_mean, with_std=self.with_std)
else: # scalings == 'median':
from sklearn.preprocessing import RobustScaler
self._scaler = RobustScaler(self.with_mean, self.with_std)
self._scaler = RobustScaler(
with_centering=self.with_mean, with_scaling=self.with_std)

def fit(self, epochs_data, y=None):
"""Standardize data across channels.
Expand Down
48 changes: 40 additions & 8 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
write_int, write_float, write_float_matrix,
write_double_matrix, write_complex_float_matrix,
write_complex_double_matrix, write_id, write_string,
_get_split_size)
_get_split_size, _NEXT_FILE_BUFFER)
from .io.meas_info import read_meas_info, write_meas_info, _merge_info
from .io.open import fiff_open, _get_next_fname
from .io.tree import dir_tree_find
Expand Down Expand Up @@ -57,7 +57,7 @@
_check_event_id, _gen_events, _check_option,
_check_combine, ShiftTimeMixin, _build_data_frame,
_check_pandas_index_arguments, _convert_times,
_scale_dataframe_data, _check_time_format)
_scale_dataframe_data, _check_time_format, object_size)
from .utils.docs import fill_doc


Expand All @@ -71,7 +71,11 @@ def _pack_reject_params(epochs):


def _save_split(epochs, fname, part_idx, n_parts, fmt):
"""Split epochs."""
"""Split epochs.
Anything new added to this function also needs to be added to
BaseEpochs.save to account for new file sizes.
"""
# insert index in filename
path, base = op.split(fname)
idx = base.find('.')
Expand Down Expand Up @@ -120,10 +124,7 @@ def _save_split(epochs, fname, part_idx, n_parts, fmt):

start_block(fid, FIFF.FIFFB_MNE_EVENTS)
write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, epochs.events.T)
mapping_ = ';'.join([k + ':' + str(v) for k, v in
epochs.event_id.items()])

write_string(fid, FIFF.FIFF_DESCRIPTION, mapping_)
write_string(fid, FIFF.FIFF_DESCRIPTION, _event_id_string(epochs.event_id))
end_block(fid, FIFF.FIFFB_MNE_EVENTS)

# Metadata
Expand Down Expand Up @@ -187,6 +188,10 @@ def _save_split(epochs, fname, part_idx, n_parts, fmt):
end_file(fid)


def _event_id_string(event_id):
return ';'.join([k + ':' + str(v) for k, v in event_id.items()])


def _merge_events(events, event_id, selection):
"""Merge repeated events."""
event_id = event_id.copy()
Expand Down Expand Up @@ -1595,7 +1600,34 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False,
self._check_consistency()
if fmt == "single":
total_size //= 2 # 64bit data converted to 32bit before writing.
n_parts = max(int(np.ceil(total_size / float(split_size))), 1)
total_size += 32 # FIF tags
# Account for all the other things we write, too
# 1. meas_id block plus main epochs block
total_size += 132
# 2. measurement info (likely slight overestimate, but okay)
total_size += object_size(self.info)
# 3. events and event_id in its own block
total_size += (self.events.size * 4 +
len(_event_id_string(self.event_id)) + 72)
# 4. Metadata in a block of its own
if self.metadata is not None:
total_size += len(_prepare_write_metadata(self.metadata)) + 56
# 5. first sample, last sample, baseline
total_size += 40 + 40 * (self.baseline is not None)
# 6. drop log
total_size += len(json.dumps(self.drop_log)) + 16
# 7. reject params
reject_params = _pack_reject_params(self)
if reject_params:
total_size += len(json.dumps(reject_params)) + 16
# 8. selection
total_size += self.selection.size * 4 + 16
# 9. end of file tags
total_size += _NEXT_FILE_BUFFER

# This is like max(int(ceil(total_size / split_size)), 1) but cleaner
n_parts = (total_size - 1) // split_size + 1
assert n_parts >= 1
epoch_idxs = np.array_split(np.arange(len(self)), n_parts)

for part_idx, epoch_idx in enumerate(epoch_idxs):
Expand Down
9 changes: 4 additions & 5 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .write import (start_file, end_file, start_block, end_block,
write_dau_pack16, write_float, write_double,
write_complex64, write_complex128, write_int,
write_id, write_string, _get_split_size)
write_id, write_string, _get_split_size, _NEXT_FILE_BUFFER)

from ..annotations import (_annotations_starts_stops, _write_annotations,
_handle_meas_date)
Expand Down Expand Up @@ -1871,7 +1871,6 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,
'measurement information, you must use a larger '
'value for split size: %s plus enough bytes for '
'the chosen buffer_size' % pos_prev)
next_file_buffer = 2 ** 20 # extra cushion for last few post-data tags

# Check to see if this has acquisition skips and, if so, if we can
# write out empty buffers instead of zeroes
Expand Down Expand Up @@ -1920,7 +1919,7 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,

pos = fid.tell()
this_buff_size_bytes = pos - pos_prev
overage = pos - split_size + next_file_buffer
overage = pos - split_size + _NEXT_FILE_BUFFER
if overage > 0:
# This should occur on the first buffer write of the file, so
# we should mention the space required for the meas info
Expand All @@ -1930,12 +1929,12 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,
'by %s bytes after writing info (%s) and leaving enough space '
'for end tags (%s): decrease "buffer_size_sec" or increase '
'"split_size".' % (this_buff_size_bytes, split_size, overage,
pos_prev, next_file_buffer))
pos_prev, _NEXT_FILE_BUFFER))

# Split files if necessary, leave some space for next file info
# make sure we check to make sure we actually *need* another buffer
# with the "and" check
if pos >= split_size - this_buff_size_bytes - next_file_buffer and \
if pos >= split_size - this_buff_size_bytes - _NEXT_FILE_BUFFER and \
first + buffer_size < stop:
next_fname, next_idx = _write_raw(
fname, raw, info, picks, fmt,
Expand Down
3 changes: 3 additions & 0 deletions mne/io/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def _get_split_size(split_size):
return split_size


_NEXT_FILE_BUFFER = 1048576 # 2 ** 20 extra cushion for last post-data tags


def write_nop(fid, last=False):
"""Write a FIFF_NOP."""
fid.write(np.array(FIFF.FIFF_NOP, dtype='>i4').tobytes())
Expand Down
39 changes: 34 additions & 5 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from mne.io import RawArray, read_raw_fif
from mne.io.proj import _has_eeg_average_ref_proj
from mne.io.write import _get_split_size
from mne.event import merge_events
from mne.io.constants import FIFF
from mne.datasets import testing
Expand Down Expand Up @@ -989,15 +990,35 @@ def test_epochs_io_preload(tmpdir, preload):
epochs = Epochs(raw, events, dict(foo=1, bar=999), tmin, tmax,
picks=picks, on_missing='ignore')
epochs.save(temp_fname, overwrite=True)
split_fname_1 = temp_fname[:-4] + '-1.fif'
split_fname_2 = temp_fname[:-4] + '-2.fif'
assert op.isfile(temp_fname)
assert not op.isfile(split_fname_1)
epochs_read = read_epochs(temp_fname, preload=preload)
assert_allclose(epochs.get_data(), epochs_read.get_data(), **tols)
assert_array_equal(epochs.events, epochs_read.events)
assert_equal(set(epochs.event_id.keys()),
{str(x) for x in epochs_read.event_id.keys()})

# test saving split epoch files
epochs.save(temp_fname, split_size='7MB', overwrite=True)
split_size = '7MB'
# ensure that we're in a position where just the data itself could fit
# if that were all that we saved ...
split_size_bytes = _get_split_size(split_size)
assert epochs.get_data().nbytes // 2 < split_size_bytes
epochs.save(temp_fname, split_size=split_size, overwrite=True)
# ... but we correctly account for the other stuff we need to write,
# so end up with two files ...
assert op.isfile(temp_fname)
assert op.isfile(split_fname_1)
assert not op.isfile(split_fname_2)
epochs_read = read_epochs(temp_fname, preload=preload)
# ... and none of the files exceed our limit.
for fname in (temp_fname, split_fname_1):
with open(fname, 'r') as fid:
fid.seek(0, 2)
fsize = fid.tell()
assert fsize <= split_size_bytes
assert_allclose(epochs.get_data(), epochs_read.get_data(), **tols)
assert_array_equal(epochs.events, epochs_read.events)
assert_array_equal(epochs.selection, epochs_read.selection)
Expand Down Expand Up @@ -1025,12 +1046,20 @@ def test_split_saving(tmpdir):
events = mne.make_fixed_length_events(raw, 1)
epochs = mne.Epochs(raw, events)
epochs_data = epochs.get_data()
assert len(epochs) == 9
fname = op.join(tempdir, 'test-epo.fif')
epochs.save(fname, split_size='1MB', overwrite=True)
assert op.isfile(fname)
assert op.isfile(fname[:-4] + '-1.fif')
assert op.isfile(fname[:-4] + '-2.fif')
assert not op.isfile(fname[:-4] + '-3.fif')
size = _get_split_size('1MB')
assert size == 1048576 == 1024 * 1024
written_fnames = [fname] + [
fname[:-4] + '-%d.fif' % ii for ii in range(1, 4)]
for this_fname in written_fnames:
assert op.isfile(this_fname)
with open(this_fname, 'r') as fid:
fid.seek(0, 2)
file_size = fid.tell()
assert size * 0.5 < file_size <= size
assert not op.isfile(fname[:-4] + '-4.fif')
for preload in (True, False):
epochs2 = mne.read_epochs(fname, preload=preload)
assert_allclose(epochs2.get_data(), epochs_data)
Expand Down

0 comments on commit d0dc8f3

Please sign in to comment.