Skip to content

Commit

Permalink
refactor _save_split()
Browse files Browse the repository at this point in the history
[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
  • Loading branch information
dmalt committed Aug 17, 2023
1 parent f1c06e8 commit 0097cfb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
26 changes: 9 additions & 17 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

from .io.utils import _construct_bids_filename
from .io.utils import _make_split_fnames
from .io.write import (
start_and_end_file,
start_block,
Expand Down Expand Up @@ -119,30 +119,22 @@ def _pack_reject_params(epochs):
return reject_params


def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite):
def _save_split(epochs, split_fnames, part_idx, n_parts, fmt, overwrite):
"""Split epochs.
Anything new added to this function also needs to be added to
BaseEpochs.save to account for new file sizes.
"""
# insert index in filename
base, ext = op.splitext(fname)
if part_idx > 0 and split_naming == "neuromag":
fname = f"{base}-{part_idx:d}{ext}"
elif split_naming == "bids" and n_parts > 1:
fname = _construct_bids_filename(base, ext, part_idx + 1)
_check_fname(fname, overwrite=overwrite)
this_fname = split_fnames[part_idx]
_check_fname(this_fname, overwrite=overwrite)

next_fname, next_idx = None, None
if part_idx < n_parts - 1:
next_idx = part_idx + 1
if split_naming == "neuromag":
next_fname = f"{base}-{next_idx:d}{ext}"
else:
assert split_naming == "bids"
next_fname = _construct_bids_filename(base, ext, next_idx + 1)
next_fname = split_fnames[next_idx]

with start_and_end_file(fname) as fid:
with start_and_end_file(this_fname) as fid:
_save_part(fid, epochs, fmt, n_parts, next_fname, next_idx)


Expand Down Expand Up @@ -2143,13 +2135,13 @@ def save(

epoch_idxs = np.array_split(np.arange(n_epochs), n_parts)

split_fnames = _make_split_fnames(fname, n_parts, split_naming)
for part_idx, epoch_idx in enumerate(epoch_idxs):
this_epochs = self[epoch_idx] if n_parts > 1 else self
# avoid missing event_ids in splits
this_epochs.event_id = self.event_id
_save_split(
this_epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite
)

_save_split(this_epochs, split_fnames, part_idx, n_parts, fmt, overwrite)

@verbose
def export(self, fname, fmt="auto", *, overwrite=False, verbose=None):
Expand Down
16 changes: 16 additions & 0 deletions mne/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,19 @@ def _construct_bids_filename(base, ext, part_idx, validate=True):
if dirname:
use_fname = op.join(dirname, use_fname)
return use_fname


def _make_split_fnames(fname, n_splits, split_naming):
"""Make a list of split filenames."""
if n_splits == 1:
return [fname]
res = []
base, ext = op.splitext(fname)
for i in range(n_splits):
if split_naming == "neuromag":
res.append(f"{base}-{i:d}{ext}" if i else fname)
elif split_naming == "bids":
res.append(_construct_bids_filename(base, ext, i + 1))
else:
raise NotImplementedError(f"{split_naming=} is not supported")
return res

0 comments on commit 0097cfb

Please sign in to comment.