diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index ce7e957a127..dbfb4861731 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -4,7 +4,6 @@ # # License: BSD-3-Clause -import os import pickle from copy import deepcopy from datetime import timedelta @@ -1476,85 +1475,133 @@ def test_epochs_io_preload(tmp_path, preload): assert_equal(epochs.get_data().shape[-1], 1) -@pytest.mark.parametrize( - "split_size, n_epochs, n_files, size", - [ - ("1.5MB", 9, 6, 1572864), - ("3MB", 18, 3, 3 * 1024 * 1024), - ], -) -@pytest.mark.parametrize("metadata", [False, True]) -@pytest.mark.parametrize("concat", (False, True)) -def test_split_saving(tmp_path, split_size, n_epochs, n_files, size, metadata, concat): - """Test saving split epochs.""" - if metadata: - pytest.importorskip("pandas") - # See gh-5102 - fs = 1000.0 - n_times = int(round(fs * (n_epochs + 1))) - raw = mne.io.RawArray( - np.random.RandomState(0).randn(100, n_times), mne.create_info(100, 1000.0) - ) - events = mne.make_fixed_length_events(raw, 1) - epochs = mne.Epochs(raw, events) - if split_size == "2MB" and (metadata or concat): - n_files += 1 - if metadata: - from pandas import DataFrame - - junk = ["*" * 10000 for _ in range(len(events))] - metadata = DataFrame( - { - "event_time": events[:, 0] / raw.info["sfreq"], - "trial_number": range(len(events)), - "junk": junk, - } +@pytest.fixture(scope="session") +def epochs_factory(): + """Function to create fake Epochs object.""" # noqa: D401 (imperative mood) + + def factory(n_epochs, metadata=False, concat=False): + if metadata: + pytest.importorskip("pandas") + # See gh-5102 + fs = 1000.0 + n_times = int(round(fs * (n_epochs + 1))) + raw = mne.io.RawArray( + np.random.RandomState(0).randn(100, n_times), mne.create_info(100, 1000.0) ) - epochs.metadata = metadata - if concat: + events = mne.make_fixed_length_events(raw, 1) + epochs = mne.Epochs(raw, events) + if metadata: + from pandas import DataFrame + + junk = ["*" * 10000 for _ in range(len(events))] + metadata = DataFrame( + { + "event_time": events[:, 0] / raw.info["sfreq"], + "trial_number": range(len(events)), + "junk": junk, + } + ) + epochs.metadata = metadata epochs.drop_bad() - epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))]) + if concat: + epochs = concatenate_epochs([epochs[ii] for ii in range(len(epochs))]) + assert len(epochs) == n_epochs + return epochs + + return factory + + +@pytest.fixture( + params=[ + ("1.5MB", 9, True, True, 6), + ("1.5MB", 9, True, False, 6), + ("1.5MB", 9, False, True, 6), + ("1.5MB", 9, False, False, 6), + ("3MB", 18, True, True, 3), + ("3MB", 18, True, False, 3), + ("3MB", 18, False, True, 3), + ("3MB", 18, False, False, 3), + ] +) +def epochs_to_split(request, epochs_factory): + """Epochs tailored to produce specific number of splits when saving. + + We're specifically interested in boundary cases, when a small size + excess triggers creation of a new split: gh-7897 + + """ + split_size, n_epochs, metadata, concat, n_files = request.param + epochs = epochs_factory(n_epochs, metadata, concat) + return epochs, split_size, n_files + + +@pytest.mark.parametrize("preload", [True, False], ids=["preload", "no_preload"]) +def test_split_saving(tmp_path, epochs_to_split, preload): + """Test saving split epochs.""" + epochs, split_size, n_files = epochs_to_split epochs_data = epochs.get_data() - assert len(epochs) == n_epochs fname = tmp_path / "test-epo.fif" - epochs.save(fname, split_size=split_size, overwrite=True) got_size = _get_split_size(split_size) - assert got_size == size - _assert_splits(fname, n_files, size) + + epochs.save(fname, split_size=split_size, overwrite=True) + epochs2 = mne.read_epochs(fname, preload=preload) + + _assert_splits(fname, n_files, got_size) assert not fname.with_name(f"{fname.stem}-{n_files + 1}{fname.suffix}").is_file() - for preload in (True, False): - epochs2 = mne.read_epochs(fname, preload=preload) - assert_allclose(epochs2.get_data(), epochs_data) - assert_array_equal(epochs.events, epochs2.events) + assert_allclose(epochs2.get_data(), epochs_data) + assert_array_equal(epochs.events, epochs2.events) + +@pytest.mark.parametrize( + "split_naming, split_fname, split_fname_part1", + [ + ("neuromag", "test_epo.fif", lambda n: f"test_epo-{n + 1}.fif"), + ("bids", "test_epo.fif", lambda n: f"test_split-{n + 1:02d}_epo.fif"), + ], +) +def test_split_naming( + tmp_path, epochs_to_split, split_naming, split_fname, split_fname_part1 +): + """Test naming of the split files.""" + epochs, _, n_files = epochs_to_split + split_fpath = tmp_path / split_fname + # we don't test for reserved files as it's not implemented here + + epochs.save( + split_fpath, split_size="1.4MB", split_naming=split_naming, verbose=True + ) + + # check that the filenames match the intended pattern + assert split_fpath.is_file() + assert (tmp_path / split_fname_part1(n_files)).is_file() + + +def test_saved_fname_no_splitting(tmp_path, epochs_to_split): + """Test saved fname doesn't get split suffix when splitting not needed.""" # Check that if BIDS is used and no split is needed it defaults to # simple writing without _split- entity. + epochs, _, n_files = epochs_to_split split_fname = tmp_path / "test_epo.fif" - split_fname_neuromag_part1 = tmp_path / f"test_epo-{n_files + 1}.fif" split_fname_bids_part1 = tmp_path / f"test_split-{n_files + 1:02d}_epo.fif" epochs.save(split_fname, split_naming="bids", verbose=True) + assert split_fname.is_file() assert not split_fname_bids_part1.is_file() - for split_naming in ("neuromag", "bids"): - with pytest.raises(FileExistsError, match="Destination file"): - epochs.save(split_fname, split_naming=split_naming, verbose=True) - os.remove(split_fname) - # we don't test for reserved files as it's not implemented here - epochs.save(split_fname, split_size="1.4MB", verbose=True) - # check that the filenames match the intended pattern - assert split_fname.is_file() - assert split_fname_neuromag_part1.is_file() - # check that filenames are being formatted correctly for BIDS - epochs.save( - split_fname, - split_size="1.4MB", - split_naming="bids", - overwrite=True, - verbose=True, - ) - assert split_fname_bids_part1.is_file() + +@pytest.mark.parametrize("split_naming", ["neuromag", "bids"]) +def test_saving_fails_with_not_permitted_overwrite( + tmp_path, epochs_factory, split_naming +): + """Check exception is raised when overwriting without explicit flag.""" + dst_fpath = tmp_path / "test-epo.fif" + epochs = epochs_factory(n_epochs=5) + + epochs.save(dst_fpath, split_naming=split_naming, verbose=True) + + with pytest.raises(FileExistsError, match="Destination file"): + epochs.save(dst_fpath, split_naming=split_naming, verbose=True) @pytest.mark.slowtest