Skip to content

Commit

Permalink
Refactor test_epochs.py::test_split_saving (1 out of 2) (#11880)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2023
1 parent 2e357a6 commit 039122a
Showing 1 changed file with 111 additions and 64 deletions.
175 changes: 111 additions & 64 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#
# License: BSD-3-Clause

import os
import pickle
from copy import deepcopy
from datetime import timedelta
Expand Down Expand Up @@ -1476,85 +1475,133 @@ def test_epochs_io_preload(tmp_path, preload):
assert_equal(epochs.get_data().shape[-1], 1)


@pytest.mark.parametrize(
"split_size, n_epochs, n_files, size",
[
("1.5MB", 9, 6, 1572864),
("3MB", 18, 3, 3 * 1024 * 1024),
],
)
@pytest.mark.parametrize("metadata", [False, True])
@pytest.mark.parametrize("concat", (False, True))
def test_split_saving(tmp_path, split_size, n_epochs, n_files, size, metadata, concat):
"""Test saving split epochs."""
if metadata:
pytest.importorskip("pandas")
# See gh-5102
fs = 1000.0
n_times = int(round(fs * (n_epochs + 1)))
raw = mne.io.RawArray(
np.random.RandomState(0).randn(100, n_times), mne.create_info(100, 1000.0)
)
events = mne.make_fixed_length_events(raw, 1)
epochs = mne.Epochs(raw, events)
if split_size == "2MB" and (metadata or concat):
n_files += 1
if metadata:
from pandas import DataFrame

junk = ["*" * 10000 for _ in range(len(events))]
metadata = DataFrame(
{
"event_time": events[:, 0] / raw.info["sfreq"],
"trial_number": range(len(events)),
"junk": junk,
}
@pytest.fixture(scope="session")
def epochs_factory():
"""Function to create fake Epochs object.""" # noqa: D401 (imperative mood)

def factory(n_epochs, metadata=False, concat=False):
if metadata:
pytest.importorskip("pandas")
# See gh-5102
fs = 1000.0
n_times = int(round(fs * (n_epochs + 1)))
raw = mne.io.RawArray(
np.random.RandomState(0).randn(100, n_times), mne.create_info(100, 1000.0)
)
epochs.metadata = metadata
if concat:
events = mne.make_fixed_length_events(raw, 1)
epochs = mne.Epochs(raw, events)
if metadata:
from pandas import DataFrame

junk = ["*" * 10000 for _ in range(len(events))]
metadata = DataFrame(
{
"event_time": events[:, 0] / raw.info["sfreq"],
"trial_number": range(len(events)),
"junk": junk,
}
)
epochs.metadata = metadata
epochs.drop_bad()
epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))])
if concat:
epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))])
assert len(epochs) == n_epochs
return epochs

return factory


@pytest.fixture(
params=[
("1.5MB", 9, True, True, 6),
("1.5MB", 9, True, False, 6),
("1.5MB", 9, False, True, 6),
("1.5MB", 9, False, False, 6),
("3MB", 18, True, True, 3),
("3MB", 18, True, False, 3),
("3MB", 18, False, True, 3),
("3MB", 18, False, False, 3),
]
)
def epochs_to_split(request, epochs_factory):
"""Epochs tailored to produce specific number of splits when saving.
We're specifically interested in boundary cases, when a small size
excess triggers creation of a new split: gh-7897
"""
split_size, n_epochs, metadata, concat, n_files = request.param
epochs = epochs_factory(n_epochs, metadata, concat)
return epochs, split_size, n_files


@pytest.mark.parametrize("preload", [True, False], ids=["preload", "no_preload"])
def test_split_saving(tmp_path, epochs_to_split, preload):
"""Test saving split epochs."""
epochs, split_size, n_files = epochs_to_split
epochs_data = epochs.get_data()
assert len(epochs) == n_epochs
fname = tmp_path / "test-epo.fif"
epochs.save(fname, split_size=split_size, overwrite=True)
got_size = _get_split_size(split_size)
assert got_size == size
_assert_splits(fname, n_files, size)

epochs.save(fname, split_size=split_size, overwrite=True)
epochs2 = mne.read_epochs(fname, preload=preload)

_assert_splits(fname, n_files, got_size)
assert not fname.with_name(f"{fname.stem}-{n_files + 1}{fname.suffix}").is_file()
for preload in (True, False):
epochs2 = mne.read_epochs(fname, preload=preload)
assert_allclose(epochs2.get_data(), epochs_data)
assert_array_equal(epochs.events, epochs2.events)
assert_allclose(epochs2.get_data(), epochs_data)
assert_array_equal(epochs.events, epochs2.events)


@pytest.mark.parametrize(
"split_naming, split_fname, split_fname_part1",
[
("neuromag", "test_epo.fif", lambda n: f"test_epo-{n + 1}.fif"),
("bids", "test_epo.fif", lambda n: f"test_split-{n + 1:02d}_epo.fif"),
],
)
def test_split_naming(
tmp_path, epochs_to_split, split_naming, split_fname, split_fname_part1
):
"""Test naming of the split files."""
epochs, _, n_files = epochs_to_split
split_fpath = tmp_path / split_fname
# we don't test for reserved files as it's not implemented here

epochs.save(
split_fpath, split_size="1.4MB", split_naming=split_naming, verbose=True
)

# check that the filenames match the intended pattern
assert split_fpath.is_file()
assert (tmp_path / split_fname_part1(n_files)).is_file()


def test_saved_fname_no_splitting(tmp_path, epochs_to_split):
"""Test saved fname doesn't get split suffix when splitting not needed."""
# Check that if BIDS is used and no split is needed it defaults to
# simple writing without _split- entity.
epochs, _, n_files = epochs_to_split
split_fname = tmp_path / "test_epo.fif"
split_fname_neuromag_part1 = tmp_path / f"test_epo-{n_files + 1}.fif"
split_fname_bids_part1 = tmp_path / f"test_split-{n_files + 1:02d}_epo.fif"

epochs.save(split_fname, split_naming="bids", verbose=True)

assert split_fname.is_file()
assert not split_fname_bids_part1.is_file()
for split_naming in ("neuromag", "bids"):
with pytest.raises(FileExistsError, match="Destination file"):
epochs.save(split_fname, split_naming=split_naming, verbose=True)
os.remove(split_fname)
# we don't test for reserved files as it's not implemented here

epochs.save(split_fname, split_size="1.4MB", verbose=True)
# check that the filenames match the intended pattern
assert split_fname.is_file()
assert split_fname_neuromag_part1.is_file()
# check that filenames are being formatted correctly for BIDS
epochs.save(
split_fname,
split_size="1.4MB",
split_naming="bids",
overwrite=True,
verbose=True,
)
assert split_fname_bids_part1.is_file()

@pytest.mark.parametrize("split_naming", ["neuromag", "bids"])
def test_saving_fails_with_not_permitted_overwrite(
tmp_path, epochs_factory, split_naming
):
"""Check exception is raised when overwriting without explicit flag."""
dst_fpath = tmp_path / "test-epo.fif"
epochs = epochs_factory(n_epochs=5)

epochs.save(dst_fpath, split_naming=split_naming, verbose=True)

with pytest.raises(FileExistsError, match="Destination file"):
epochs.save(dst_fpath, split_naming=split_naming, verbose=True)


@pytest.mark.slowtest
Expand Down

0 comments on commit 039122a

Please sign in to comment.