Skip to content

Commit

Permalink
Add compression algorithms, selection in FileIO sink
Browse files Browse the repository at this point in the history
Works similarly to the file driver lookup, except the compression can
also be None, in which case a normal file descriptor is used.

Canonicalizes the extension of the resulting file regardless of which
file name is used based on driver and compression algorithm. This might
be surprising behavior, but simplifies driver and compression inference
significantly on roundtrips.

Scraps the custom registries in FileIO constructor, because we intend to
expose reg/dereg hooks for the default later on.
  • Loading branch information
nicholasjng committed Feb 22, 2024
1 parent e2890e0 commit 57429fd
Showing 1 changed file with 111 additions and 36 deletions.
147 changes: 111 additions & 36 deletions src/nnbench/reporter/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import IO, Any, Callable, Literal, Sequence
from typing import IO, Any, Callable, Literal, Sequence, cast

from nnbench.reporter.base import BenchmarkReporter
from nnbench.types import BenchmarkRecord
Expand Down Expand Up @@ -113,33 +113,63 @@ def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
# TODO: Add parquet support


class FileIO:
def __init__(
self,
drivers: dict[str, SerDe] | None = None,
compressions: dict[str, Callable] | None = None,
):
super().__init__()
self.drivers = drivers or _file_drivers
self.compressions = compressions or _compressions
def get_driver_implementation(name: str) -> SerDe:
try:
return _file_drivers[name]
except KeyError:
raise KeyError(f"unsupported file format {name!r}") from None


def gzip_compression(filename: str | os.PathLike[str], mode: Literal["r", "w"] = "r") -> IO:
import gzip

# gzip.GzipFile does not inherit from IO[bytes],
# but it has all required methods, so we allow it.
return cast(IO[bytes], gzip.GzipFile(filename=filename, mode=mode))


def bz2_compression(filename: str | os.PathLike[str], mode: Literal["r", "w"] = "r") -> IO:
import bz2

return bz2.BZ2File(filename=filename, mode=mode)


with _compression_lock:
_compressions["gzip"] = gzip_compression
_compressions["bz2"] = bz2_compression


def get_compression_algorithm(name: str) -> Callable:
try:
return _compressions[name]
except KeyError:
raise KeyError(f"unsupported compression algorithm {name!r}") from None


class FileIO:
def read(
self,
file: str | os.PathLike[str],
mode: str = "r",
driver: str | None = None,
compression: str | None = None,
options: dict[str, Any] | None = None,
) -> BenchmarkRecord:
"""
Greedy version of ``FileIO.read_batched()``, returning the first read record.
When reading a multi-record file, this uses as much memory as the batched version.
"""
records = self.read_batched(file=file, driver=driver, options=options)
records = self.read_batched(
file=file, mode=mode, driver=driver, compression=compression, options=options
)
return records[0]

def read_batched(
self,
file: str | os.PathLike[str],
mode: str = "r",
driver: str | None = None,
compression: str | None = None,
options: dict[str, Any] | None = None,
) -> list[BenchmarkRecord]:
"""
Expand All @@ -151,51 +181,80 @@ def read_batched(
----------
file: str | os.PathLike[str]
The file name to read from.
mode: str
File mode to use. Can be any of the modes used in builtin ``open()``.
driver: str | None
File driver implementation to use. If None, the file driver inferred from the
given file path's extension will be used.
compression: str | None
Compression engine to use. If None, the compression inferred from the given
file path's extension will be used.
options: dict[str, Any] | None
Options to pass to the respective file driver implementation.
Returns
-------
list[BenchmarkRecord]
The benchmark records contained in the file.
Raises
------
KeyError
If the given file does not have a driver implementation available.
"""
driver = driver or Path(file).suffix.removeprefix(".")

try:
_, de = self.drivers[driver]
except KeyError:
raise KeyError(f"unsupported file format {driver!r}") from None
fileext = Path(file).suffix.removeprefix(".")
# if the extension looks like FORMAT.COMPRESSION, we split.
if fileext.count(".") == 1:
# TODO: Are there file extensions with more than one meaningful part?
ext_driver, ext_compression = fileext.rsplit(".", 1)
else:
ext_driver, ext_compression = fileext, None

driver = driver or ext_driver
compression = compression or ext_compression

_, de = get_driver_implementation(driver)

# canonicalize extension to make sure the file gets it correctly
# regardless of where driver and compression came from.
fullext = "." + driver
if compression is not None:
fullext += "." + compression
file = Path(file).with_suffix(fullext)
fd = get_compression_algorithm(compression)(file, mode)
else:
file = Path(file).with_suffix(fullext)
fd = open(file, mode)

# dummy value, since the context mode is unused in read ops.
fdoptions = FileDriverOptions(ctxmode="omit", options=options or {})

with open(file, "r") as fp:
with fd as fp:
return de(fp, fdoptions)

def write(
self,
record: BenchmarkRecord,
file: str | os.PathLike[str],
mode: str = "r",
driver: str | None = None,
compression: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> None:
"""Greedy version of ``FileIO.write_batched()``"""
self.write_batched([record], file=file, driver=driver, ctxmode=ctxmode, options=options)
self.write_batched(
[record],
file=file,
mode=mode,
driver=driver,
compression=compression,
ctxmode=ctxmode,
options=options,
)

def write_batched(
self,
records: Sequence[BenchmarkRecord],
file: str | os.PathLike[str],
mode: str = "r",
driver: str | None = None,
compression: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> None:
Expand All @@ -210,28 +269,44 @@ def write_batched(
The record to write to the database.
file: str | os.PathLike[str]
The file name to write to.
mode: str
File mode to use. Can be any of the modes used in builtin ``open()``.
driver: str | None
File driver implementation to use. If None, the file driver inferred from the
given file path's extension will be used.
compression: str | None
Compression engine to use. If None, the compression inferred from the given
file path's extension will be used.
ctxmode: Literal["flatten", "inline", "omit"]
How to handle the benchmark context when writing the record data.
options: dict[str, Any] | None
Options to pass to the respective file driver implementation.
Raises
------
KeyError
If the given file does not have a driver implementation available.
"""
driver = driver or Path(file).suffix.removeprefix(".")

try:
ser, _ = self.drivers[driver]
except KeyError:
raise KeyError(f"unsupported file format {driver!r}") from None
fileext = Path(file).suffix.removeprefix(".")
# if the extension looks like FORMAT.COMPRESSION, we split.
if fileext.count(".") == 1:
ext_driver, ext_compression = fileext.rsplit(".", 1)
else:
ext_driver, ext_compression = fileext, None

driver = driver or ext_driver
compression = compression or ext_compression

ser, _ = get_driver_implementation(driver)

# canonicalize extension to make sure the file gets it correctly
# regardless of where driver and compression came from.
fullext = "." + driver
if compression is not None:
fullext += "." + compression
file = Path(file).with_suffix(fullext)
fd = get_compression_algorithm(compression)(file, mode)
else:
file = Path(file).with_suffix(fullext)
fd = open(file, mode)

fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {})
with open(file, "w") as fp:
with fd as fp:
ser(records, fp, fdoptions)


Expand Down

0 comments on commit 57429fd

Please sign in to comment.