diff --git a/src/nnbench/reporter/file.py b/src/nnbench/reporter/file.py index b0dd7c30..9d9e74f1 100644 --- a/src/nnbench/reporter/file.py +++ b/src/nnbench/reporter/file.py @@ -4,7 +4,7 @@ import threading from dataclasses import dataclass, field from pathlib import Path -from typing import IO, Any, Callable, Literal, Sequence +from typing import IO, Any, Callable, Literal, Sequence, cast from nnbench.reporter.base import BenchmarkReporter from nnbench.types import BenchmarkRecord @@ -113,33 +113,63 @@ def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: # TODO: Add parquet support -class FileIO: - def __init__( - self, - drivers: dict[str, SerDe] | None = None, - compressions: dict[str, Callable] | None = None, - ): - super().__init__() - self.drivers = drivers or _file_drivers - self.compressions = compressions or _compressions +def get_driver_implementation(name: str) -> SerDe: + try: + return _file_drivers[name] + except KeyError: + raise KeyError(f"unsupported file format {name!r}") from None + + +def gzip_compression(filename: str | os.PathLike[str], mode: Literal["r", "w"] = "r") -> IO: + import gzip + + # gzip.GzipFile does not inherit from IO[bytes], + # but it has all required methods, so we allow it. + return cast(IO[bytes], gzip.GzipFile(filename=filename, mode=mode)) + + +def bz2_compression(filename: str | os.PathLike[str], mode: Literal["r", "w"] = "r") -> IO: + import bz2 + + return bz2.BZ2File(filename=filename, mode=mode) + +with _compression_lock: + _compressions["gzip"] = gzip_compression + _compressions["bz2"] = bz2_compression + + +def get_compression_algorithm(name: str) -> Callable: + try: + return _compressions[name] + except KeyError: + raise KeyError(f"unsupported compression algorithm {name!r}") from None + + +class FileIO: def read( self, file: str | os.PathLike[str], + mode: str = "r", driver: str | None = None, + compression: str | None = None, options: dict[str, Any] | None = None, ) -> BenchmarkRecord: """ Greedy version of ``FileIO.read_batched()``, returning the first read record. When reading a multi-record file, this uses as much memory as the batched version. """ - records = self.read_batched(file=file, driver=driver, options=options) + records = self.read_batched( + file=file, mode=mode, driver=driver, compression=compression, options=options + ) return records[0] def read_batched( self, file: str | os.PathLike[str], + mode: str = "r", driver: str | None = None, + compression: str | None = None, options: dict[str, Any] | None = None, ) -> list[BenchmarkRecord]: """ @@ -151,9 +181,14 @@ def read_batched( ---------- file: str | os.PathLike[str] The file name to read from. + mode: str + File mode to use. Can be any of the modes used in builtin ``open()``. driver: str | None File driver implementation to use. If None, the file driver inferred from the given file path's extension will be used. + compression: str | None + Compression engine to use. If None, the compression inferred from the given + file path's extension will be used. options: dict[str, Any] | None Options to pass to the respective file driver implementation. @@ -161,41 +196,65 @@ def read_batched( ------- list[BenchmarkRecord] The benchmark records contained in the file. - - Raises - ------ - KeyError - If the given file does not have a driver implementation available. """ - driver = driver or Path(file).suffix.removeprefix(".") - - try: - _, de = self.drivers[driver] - except KeyError: - raise KeyError(f"unsupported file format {driver!r}") from None + fileext = Path(file).suffix.removeprefix(".") + # if the extension looks like FORMAT.COMPRESSION, we split. + if fileext.count(".") == 1: + # TODO: Are there file extensions with more than one meaningful part? + ext_driver, ext_compression = fileext.rsplit(".", 1) + else: + ext_driver, ext_compression = fileext, None + + driver = driver or ext_driver + compression = compression or ext_compression + + _, de = get_driver_implementation(driver) + + # canonicalize extension to make sure the file gets it correctly + # regardless of where driver and compression came from. + fullext = "." + driver + if compression is not None: + fullext += "." + compression + file = Path(file).with_suffix(fullext) + fd = get_compression_algorithm(compression)(file, mode) + else: + file = Path(file).with_suffix(fullext) + fd = open(file, mode) # dummy value, since the context mode is unused in read ops. fdoptions = FileDriverOptions(ctxmode="omit", options=options or {}) - with open(file, "r") as fp: + with fd as fp: return de(fp, fdoptions) def write( self, record: BenchmarkRecord, file: str | os.PathLike[str], + mode: str = "r", driver: str | None = None, + compression: str | None = None, ctxmode: Literal["flatten", "inline", "omit"] = "inline", options: dict[str, Any] | None = None, ) -> None: """Greedy version of ``FileIO.write_batched()``""" - self.write_batched([record], file=file, driver=driver, ctxmode=ctxmode, options=options) + self.write_batched( + [record], + file=file, + mode=mode, + driver=driver, + compression=compression, + ctxmode=ctxmode, + options=options, + ) def write_batched( self, records: Sequence[BenchmarkRecord], file: str | os.PathLike[str], + mode: str = "r", driver: str | None = None, + compression: str | None = None, ctxmode: Literal["flatten", "inline", "omit"] = "inline", options: dict[str, Any] | None = None, ) -> None: @@ -210,28 +269,44 @@ def write_batched( The record to write to the database. file: str | os.PathLike[str] The file name to write to. + mode: str + File mode to use. Can be any of the modes used in builtin ``open()``. driver: str | None File driver implementation to use. If None, the file driver inferred from the given file path's extension will be used. + compression: str | None + Compression engine to use. If None, the compression inferred from the given + file path's extension will be used. ctxmode: Literal["flatten", "inline", "omit"] How to handle the benchmark context when writing the record data. options: dict[str, Any] | None Options to pass to the respective file driver implementation. - - Raises - ------ - KeyError - If the given file does not have a driver implementation available. """ - driver = driver or Path(file).suffix.removeprefix(".") - - try: - ser, _ = self.drivers[driver] - except KeyError: - raise KeyError(f"unsupported file format {driver!r}") from None + fileext = Path(file).suffix.removeprefix(".") + # if the extension looks like FORMAT.COMPRESSION, we split. + if fileext.count(".") == 1: + ext_driver, ext_compression = fileext.rsplit(".", 1) + else: + ext_driver, ext_compression = fileext, None + + driver = driver or ext_driver + compression = compression or ext_compression + + ser, _ = get_driver_implementation(driver) + + # canonicalize extension to make sure the file gets it correctly + # regardless of where driver and compression came from. + fullext = "." + driver + if compression is not None: + fullext += "." + compression + file = Path(file).with_suffix(fullext) + fd = get_compression_algorithm(compression)(file, mode) + else: + file = Path(file).with_suffix(fullext) + fd = open(file, mode) fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {}) - with open(file, "w") as fp: + with fd as fp: ser(records, fp, fdoptions)