Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept: add support for zipfile.Path when loading raw data #11924

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions mne/_fiff/open.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
#
# License: BSD-3-Clause

import os.path as op
from gzip import GzipFile
from io import SEEK_SET, BytesIO

import numpy as np
from scipy.sparse import issparse

from ..utils import _file_like, logger, verbose, warn
from pathlib import Path
from .constants import FIFF
from .tag import Tag, _call_dict_names, _matrix_info, read_tag, read_tag_info
from .tree import dir_tree_find, make_dir_tree
Expand Down Expand Up @@ -44,20 +44,23 @@ def _fiff_get_fid(fname):
fid = _NoCloseRead(fname)
fid.seek(0)
else:
fname = str(fname)
if op.splitext(fname)[1].lower() == ".gz":
if isinstance(fname, str):
fname = Path(fname)
if fname.suffix.lower() == ".gz":
logger.debug("Using gzip")
fid = GzipFile(fname, "rb") # Open in binary mode
else:
logger.debug("Using normal I/O")
fid = open(fname, "rb") # Open in binary mode
fid = fname.open("rb") # Open in binary mode
return fid


def _get_next_fname(fid, fname, tree):
"""Get the next filename in split files."""
nodes_list = dir_tree_find(tree, FIFF.FIFFB_REF)
next_fname = None
if fname == "File-like":
fname = Path()
for nodes in nodes_list:
next_fname = None
for ent in nodes["directory"]:
Expand All @@ -69,14 +72,14 @@ def _get_next_fname(fid, fname, tree):
break
if ent.kind == FIFF.FIFF_REF_FILE_NAME:
tag = read_tag(fid, ent.pos)
next_fname = op.join(op.dirname(fname), tag.data)
next_fname = fname.parent / tag.data
if ent.kind == FIFF.FIFF_REF_FILE_NUM:
# Some files don't have the name, just the number. So
# we construct the name from the current name.
if next_fname is not None:
continue
next_num = read_tag(fid, ent.pos).data.item()
path, base = op.split(fname)
path, base = fname.parent, fname.name
idx = base.find(".")
idx2 = base.rfind("-")
num_str = base[idx2 + 1 : idx]
Expand All @@ -85,14 +88,10 @@ def _get_next_fname(fid, fname, tree):

if idx2 < 0 and next_num == 1:
# this is the first file, which may not be numbered
next_fname = op.join(
path, "%s-%d.%s" % (base[:idx], next_num, base[idx + 1 :])
)
next_fname = path / f"{base[:idx]}-{next_num:d}.{base[idx + 1 :]}"
continue

next_fname = op.join(
path, "%s-%d.%s" % (base[:idx2], next_num, base[idx + 1 :])
)
next_fname = path / f"{base[:idx2]}-{next_num:d}.{base[idx + 1 :]}"
if next_fname is not None:
break
return next_fname
Expand Down
37 changes: 31 additions & 6 deletions mne/_fiff/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#
# License: BSD-3-Clause

import os.path as op
from contextlib import contextmanager
from gzip import GzipFile
import re
import time
import uuid
from contextlib import contextmanager
from gzip import GzipFile
from pathlib import Path
import zipfile
import io

import numpy as np
from scipy.sparse import csc_matrix, csr_matrix
Expand Down Expand Up @@ -275,6 +277,26 @@ def end_block(fid, kind):
write_int(fid, FIFF.FIFF_BLOCK_END, kind)


class SeekableZipWriteFile(io.BufferedIOBase):
def __init__(self, fid):
self._fid = fid
self._seek = 0

def write(self, data):
n_bytes = self._fid.write(data)
self._seek += n_bytes
return n_bytes

def close(self):
self._fid.close()

def writable(self):
return self._fid.writable()

def tell(self):
return self._seek


def start_file(fname, id_=None):
"""Open a fif file for writing and writes the compulsory header tags.

Expand All @@ -292,15 +314,18 @@ def start_file(fname, id_=None):
fid = fname
fid.seek(0)
else:
fname = str(fname)
if op.splitext(fname)[1].lower() == ".gz":
if isinstance(fname, str):
fname = Path(fname)
if str(fname).lower().endswith(".gz"):
logger.debug("Writing using gzip")
# defaults to compression level 9, which is barely smaller but much
# slower. 2 offers a good compromise.
fid = GzipFile(fname, "wb", compresslevel=2)
else:
logger.debug("Writing using normal I/O")
fid = open(fname, "wb")
fid = fname.open("wb") # Open in binary mode
if isinstance(fname, zipfile.Path):
fid = SeekableZipWriteFile(fid)
# Write the compulsory items
write_id(fid, FIFF.FIFF_FILE_ID, id_)
write_int(fid, FIFF.FIFF_DIR_POINTER, -1)
Expand Down
2 changes: 1 addition & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3778,7 +3778,7 @@ def __init__(self, fname, proj=True, preload=True, verbose=None): # noqa: D102
filetype="epochs",
endings=("-epo.fif", "-epo.fif.gz", "_epo.fif", "_epo.fif.gz"),
)
fname = str(_check_fname(fname=fname, must_exist=True, overwrite="read"))
fname = _check_fname(fname=fname, must_exist=True, overwrite="read")
elif not preload:
raise ValueError("preload must be used with file-like objects")

Expand Down
2 changes: 1 addition & 1 deletion mne/export/_eeglab.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _export_raw(fname, raw):
# remove extra epoc and STI channels
drop_chs = ["epoc"]
# filenames attribute of RawArray is filled with None
if raw.filenames[0] and not (raw.filenames[0].endswith(".fif")):
if raw.filenames[0] and not (str(raw.filenames[0]).endswith(".fif")):
drop_chs.append("STI 014")

ch_names = [ch for ch in raw.ch_names if ch not in drop_chs]
Expand Down
2 changes: 1 addition & 1 deletion mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,7 +2043,7 @@ def copy(self):

def __repr__(self): # noqa: D105
name = self.filenames[0]
name = "" if name is None else op.basename(name) + ", "
name = "" if name is None else op.basename(str(name)) + ", "
size_str = str(sizeof_fmt(self._size)) # str in case it fails -> None
size_str += ", data%s loaded" % ("" if self.preload else " not")
s = "%s%s x %s (%0.1f s), ~%s" % (
Expand Down
11 changes: 5 additions & 6 deletions mne/io/fiff/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
# License: BSD-3-Clause

import copy
import os
import os.path as op
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -99,15 +98,15 @@ def __init__(
): # noqa: D102
raws = []
do_check_ext = not _file_like(fname)
next_fname = fname
next_fname = Path(fname) if isinstance(fname, str) else fname
while next_fname is not None:
raw, next_fname, buffer_size_sec = self._read_raw_file(
next_fname, allow_maxshield, preload, do_check_ext
)
do_check_ext = False
raws.append(raw)
if next_fname is not None:
if not op.exists(next_fname):
if not next_fname.exists():
msg = (
f"Split raw file detected but next file {next_fname} "
"does not exist. Ensure all files were transferred "
Expand Down Expand Up @@ -183,8 +182,8 @@ def _read_raw_file(
endings += tuple([f"{e}.gz" for e in endings])
check_fname(fname, "raw", endings)
# filename
fname = str(_check_fname(fname, "read", True, "fname"))
ext = os.path.splitext(fname)[1].lower()
fname = _check_fname(fname, "read", True, "fname")
ext = fname.suffix.lower()
whole_file = preload if ".gz" in ext else False
del ext
else:
Expand Down
41 changes: 41 additions & 0 deletions mne/io/fiff/tests/test_raw_fiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import partial
from io import BytesIO
from pathlib import Path
import zipfile

import numpy as np
import pytest
Expand Down Expand Up @@ -2093,3 +2094,43 @@ def test_expand_user(tmp_path, monkeypatch):

raw = read_raw_fif(fname=path_home, preload=True)
raw.save(fname=path_home, overwrite=True)


@testing.requires_testing_data
@pytest.mark.parametrize("split_size", ["2GB", "5MB"])
def test_zip_io(tmp_path_factory, split_size):
"""Test writin to zip and reading back preserves data."""
fname = fif_fname.name
zip_fname = tmp_path_factory.mktemp("zipfile_reading") / (fname + ".zip")
saved_raw = read_raw_fif(fif_fname).crop(0, 1)

with zipfile.ZipFile(zip_fname, "w") as zip_:
saved_raw.save(zipfile.Path(zip_, fname), split_size=split_size)

with zipfile.ZipFile(zip_fname) as zip_:
loaded_raw = read_raw_fif(zipfile.Path(zip_, fname))

assert_object_equal(saved_raw.get_data(), loaded_raw.get_data())
assert saved_raw.info["ch_names"] == loaded_raw.info["ch_names"]
assert_array_equal(saved_raw.times, loaded_raw.times)


@testing.requires_testing_data
@pytest.mark.parametrize("split_size", ["5MB"])
def test_zip_splits_number(tmp_path_factory, split_size):
"""Test save to zip produces the same number of splits as regular save."""
dst_dir_reg = tmp_path_factory.mktemp("zipfile_splits_reg")
dst_dir_zip = tmp_path_factory.mktemp("zipfile_splits_zip")
fname = fif_fname.name
zip_fname = dst_dir_zip / (fname + ".zip")
saved_raw = read_raw_fif(fif_fname).crop(0, 3)

saved_raw.save(dst_dir_reg / fname, split_size=split_size, buffer_size_sec=1)
with zipfile.ZipFile(zip_fname, "w") as zip_:
saved_raw.save(
zipfile.Path(zip_, fname), split_size=split_size, buffer_size_sec=1
)

assert len(list(dst_dir_reg.iterdir())) > 1
with zipfile.ZipFile(zip_fname, "r") as zip_:
assert len(list(dst_dir_reg.iterdir())) == len(zip_.namelist())
16 changes: 7 additions & 9 deletions mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2765,14 +2765,12 @@ def parse_folder(
# iterate through the possible patterns
fnames = list()
for p in pattern:
data_path = str(
_check_fname(
fname=self.data_path,
overwrite="read",
must_exist=True,
name="Directory or folder",
need_dir=True,
)
data_path = _check_fname(
fname=self.data_path,
overwrite="read",
must_exist=True,
name="Directory or folder",
need_dir=True,
)
fnames.extend(sorted(_recursive_search(data_path, p)))

Expand All @@ -2789,7 +2787,7 @@ def parse_folder(
inst = read_raw(**kwargs)

if len(inst.filenames) > 1:
fnames_to_remove.extend(inst.filenames[1:])
fnames_to_remove.extend([str(f) for f in inst.filenames[1:]])
# For STCs, only keep one hemisphere
elif fname.endswith("-lh.stc") or fname.endswith("-rh.stc"):
first_hemi_fname = fname
Expand Down
10 changes: 7 additions & 3 deletions mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from importlib import import_module
from importlib.metadata import version
from pathlib import Path
import re
import numbers
import zipfile

import numpy as np
from packaging.version import parse
Expand Down Expand Up @@ -236,7 +239,8 @@ def _check_fname(
):
"""Check for file existence, and return its absolute path."""
_validate_type(fname, "path-like", name)
fname = Path(fname).expanduser().absolute()
if not isinstance(fname, zipfile.Path):
fname = Path(fname).expanduser().absolute()

if fname.exists():
if not overwrite:
Expand All @@ -257,7 +261,7 @@ def _check_fname(
raise OSError(
f"Need a file for {name} but found a directory " f"at {fname}"
)
if not os.access(fname, os.R_OK):
if not isinstance(fname, zipfile.Path) and not os.access(fname, os.R_OK):
raise PermissionError(f"{name} does not have read permissions: {fname}")
elif must_exist:
raise FileNotFoundError(f'{name} does not exist: "{fname}"')
Expand Down Expand Up @@ -527,7 +531,7 @@ def __instancecheck__(cls, other):


int_like = _IntLike()
path_like = (str, Path, os.PathLike)
path_like = (str, Path, os.PathLike, zipfile.Path)


class _Callable:
Expand Down