Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG, FIX: Fix epochs split size #7740

Merged
merged 3 commits into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ Bug

- Fix bug with :func:`mne.read_epochs` when loading data in complex format with ``preload=False`` by `Eric Larson`_

- Fix bug with :meth:`mne.Epochs.save` where the file splitting calculations did not account for the sizes of non-data writes by `Eric Larson`_

- Fix bug with :class:`mne.Report` where the BEM section could not be toggled by `Eric Larson`_

- Fix bug when using :meth:`mne.Epochs.crop` to exclude the baseline period would break :func:`mne.Epochs.save` / :func:`mne.read_epochs` round-trip by `Eric Larson`_
Expand Down
6 changes: 4 additions & 2 deletions mne/decoding/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ def __init__(self, info=None, scalings=None, with_mean=True,
self._scaler = _ConstantScaler(info, scalings, self.with_std)
elif scalings == 'mean':
from sklearn.preprocessing import StandardScaler
self._scaler = StandardScaler(self.with_mean, self.with_std)
self._scaler = StandardScaler(
with_mean=self.with_mean, with_std=self.with_std)
else: # scalings == 'median':
from sklearn.preprocessing import RobustScaler
self._scaler = RobustScaler(self.with_mean, self.with_std)
self._scaler = RobustScaler(
with_centering=self.with_mean, with_scaling=self.with_std)

def fit(self, epochs_data, y=None):
"""Standardize data across channels.
Expand Down
48 changes: 40 additions & 8 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
write_int, write_float, write_float_matrix,
write_double_matrix, write_complex_float_matrix,
write_complex_double_matrix, write_id, write_string,
_get_split_size)
_get_split_size, _NEXT_FILE_BUFFER)
from .io.meas_info import read_meas_info, write_meas_info, _merge_info
from .io.open import fiff_open, _get_next_fname
from .io.tree import dir_tree_find
Expand Down Expand Up @@ -57,7 +57,7 @@
_check_event_id, _gen_events, _check_option,
_check_combine, ShiftTimeMixin, _build_data_frame,
_check_pandas_index_arguments, _convert_times,
_scale_dataframe_data, _check_time_format)
_scale_dataframe_data, _check_time_format, object_size)
from .utils.docs import fill_doc


Expand All @@ -71,7 +71,11 @@ def _pack_reject_params(epochs):


def _save_split(epochs, fname, part_idx, n_parts, fmt):
"""Split epochs."""
"""Split epochs.

Anything new added to this function also needs to be added to
BaseEpochs.save to account for new file sizes.
"""
# insert index in filename
path, base = op.split(fname)
idx = base.find('.')
Expand Down Expand Up @@ -120,10 +124,7 @@ def _save_split(epochs, fname, part_idx, n_parts, fmt):

start_block(fid, FIFF.FIFFB_MNE_EVENTS)
write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, epochs.events.T)
mapping_ = ';'.join([k + ':' + str(v) for k, v in
epochs.event_id.items()])

write_string(fid, FIFF.FIFF_DESCRIPTION, mapping_)
write_string(fid, FIFF.FIFF_DESCRIPTION, _event_id_string(epochs.event_id))
end_block(fid, FIFF.FIFFB_MNE_EVENTS)

# Metadata
Expand Down Expand Up @@ -187,6 +188,10 @@ def _save_split(epochs, fname, part_idx, n_parts, fmt):
end_file(fid)


def _event_id_string(event_id):
return ';'.join([k + ':' + str(v) for k, v in event_id.items()])


def _merge_events(events, event_id, selection):
"""Merge repeated events."""
event_id = event_id.copy()
Expand Down Expand Up @@ -1595,7 +1600,34 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False,
self._check_consistency()
if fmt == "single":
total_size //= 2 # 64bit data converted to 32bit before writing.
n_parts = max(int(np.ceil(total_size / float(split_size))), 1)
total_size += 32 # FIF tags
# Account for all the other things we write, too
# 1. meas_id block plus main epochs block
total_size += 132
# 2. measurement info (likely slight overestimate, but okay)
total_size += object_size(self.info)
# 3. events and event_id in its own block
total_size += (self.events.size * 4 +
len(_event_id_string(self.event_id)) + 72)
# 4. Metadata in a block of its own
if self.metadata is not None:
total_size += len(_prepare_write_metadata(self.metadata)) + 56
# 5. first sample, last sample, baseline
total_size += 40 + 40 * (self.baseline is not None)
# 6. drop log
total_size += len(json.dumps(self.drop_log)) + 16
# 7. reject params
reject_params = _pack_reject_params(self)
if reject_params:
total_size += len(json.dumps(reject_params)) + 16
# 8. selection
total_size += self.selection.size * 4 + 16
# 9. end of file tags
total_size += _NEXT_FILE_BUFFER

# This is like max(int(ceil(total_size / split_size)), 1) but cleaner
n_parts = (total_size - 1) // split_size + 1
assert n_parts >= 1
epoch_idxs = np.array_split(np.arange(len(self)), n_parts)

for part_idx, epoch_idx in enumerate(epoch_idxs):
Expand Down
9 changes: 4 additions & 5 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .write import (start_file, end_file, start_block, end_block,
write_dau_pack16, write_float, write_double,
write_complex64, write_complex128, write_int,
write_id, write_string, _get_split_size)
write_id, write_string, _get_split_size, _NEXT_FILE_BUFFER)

from ..annotations import (_annotations_starts_stops, _write_annotations,
_handle_meas_date)
Expand Down Expand Up @@ -1871,7 +1871,6 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,
'measurement information, you must use a larger '
'value for split size: %s plus enough bytes for '
'the chosen buffer_size' % pos_prev)
next_file_buffer = 2 ** 20 # extra cushion for last few post-data tags

# Check to see if this has acquisition skips and, if so, if we can
# write out empty buffers instead of zeroes
Expand Down Expand Up @@ -1920,7 +1919,7 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,

pos = fid.tell()
this_buff_size_bytes = pos - pos_prev
overage = pos - split_size + next_file_buffer
overage = pos - split_size + _NEXT_FILE_BUFFER
if overage > 0:
# This should occur on the first buffer write of the file, so
# we should mention the space required for the meas info
Expand All @@ -1930,12 +1929,12 @@ def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start,
'by %s bytes after writing info (%s) and leaving enough space '
'for end tags (%s): decrease "buffer_size_sec" or increase '
'"split_size".' % (this_buff_size_bytes, split_size, overage,
pos_prev, next_file_buffer))
pos_prev, _NEXT_FILE_BUFFER))

# Split files if necessary, leave some space for next file info
# make sure we check to make sure we actually *need* another buffer
# with the "and" check
if pos >= split_size - this_buff_size_bytes - next_file_buffer and \
if pos >= split_size - this_buff_size_bytes - _NEXT_FILE_BUFFER and \
first + buffer_size < stop:
next_fname, next_idx = _write_raw(
fname, raw, info, picks, fmt,
Expand Down
3 changes: 3 additions & 0 deletions mne/io/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def _get_split_size(split_size):
return split_size


_NEXT_FILE_BUFFER = 1048576 # 2 ** 20 extra cushion for last post-data tags


def write_nop(fid, last=False):
"""Write a FIFF_NOP."""
fid.write(np.array(FIFF.FIFF_NOP, dtype='>i4').tobytes())
Expand Down
39 changes: 34 additions & 5 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from mne.io import RawArray, read_raw_fif
from mne.io.proj import _has_eeg_average_ref_proj
from mne.io.write import _get_split_size
from mne.event import merge_events
from mne.io.constants import FIFF
from mne.datasets import testing
Expand Down Expand Up @@ -989,15 +990,35 @@ def test_epochs_io_preload(tmpdir, preload):
epochs = Epochs(raw, events, dict(foo=1, bar=999), tmin, tmax,
picks=picks, on_missing='ignore')
epochs.save(temp_fname, overwrite=True)
split_fname_1 = temp_fname[:-4] + '-1.fif'
split_fname_2 = temp_fname[:-4] + '-2.fif'
assert op.isfile(temp_fname)
assert not op.isfile(split_fname_1)
epochs_read = read_epochs(temp_fname, preload=preload)
assert_allclose(epochs.get_data(), epochs_read.get_data(), **tols)
assert_array_equal(epochs.events, epochs_read.events)
assert_equal(set(epochs.event_id.keys()),
{str(x) for x in epochs_read.event_id.keys()})

# test saving split epoch files
epochs.save(temp_fname, split_size='7MB', overwrite=True)
split_size = '7MB'
# ensure that we're in a position where just the data itself could fit
# if that were all that we saved ...
split_size_bytes = _get_split_size(split_size)
assert epochs.get_data().nbytes // 2 < split_size_bytes
epochs.save(temp_fname, split_size=split_size, overwrite=True)
# ... but we correctly account for the other stuff we need to write,
# so end up with two files ...
assert op.isfile(temp_fname)
assert op.isfile(split_fname_1)
assert not op.isfile(split_fname_2)
epochs_read = read_epochs(temp_fname, preload=preload)
# ... and none of the files exceed our limit.
for fname in (temp_fname, split_fname_1):
with open(fname, 'r') as fid:
fid.seek(0, 2)
fsize = fid.tell()
assert fsize <= split_size_bytes
assert_allclose(epochs.get_data(), epochs_read.get_data(), **tols)
assert_array_equal(epochs.events, epochs_read.events)
assert_array_equal(epochs.selection, epochs_read.selection)
Expand Down Expand Up @@ -1025,12 +1046,20 @@ def test_split_saving(tmpdir):
events = mne.make_fixed_length_events(raw, 1)
epochs = mne.Epochs(raw, events)
epochs_data = epochs.get_data()
assert len(epochs) == 9
fname = op.join(tempdir, 'test-epo.fif')
epochs.save(fname, split_size='1MB', overwrite=True)
assert op.isfile(fname)
assert op.isfile(fname[:-4] + '-1.fif')
assert op.isfile(fname[:-4] + '-2.fif')
assert not op.isfile(fname[:-4] + '-3.fif')
size = _get_split_size('1MB')
assert size == 1048576 == 1024 * 1024
written_fnames = [fname] + [
fname[:-4] + '-%d.fif' % ii for ii in range(1, 4)]
for this_fname in written_fnames:
assert op.isfile(this_fname)
with open(this_fname, 'r') as fid:
fid.seek(0, 2)
file_size = fid.tell()
assert size * 0.5 < file_size <= size
assert not op.isfile(fname[:-4] + '-4.fif')
for preload in (True, False):
epochs2 = mne.read_epochs(fname, preload=preload)
assert_allclose(epochs2.get_data(), epochs_data)
Expand Down