From 36aa42705c25608ad5b781300da484e1800a0efc Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Mon, 19 Feb 2024 15:42:52 +0100 Subject: [PATCH] Unify batching in file drivers, remove read/write stubs on base reporter Since the reporter interface is so dynamic (i.e., file I/O needs vastly different arguments than DB I/O), we scrap the read/write APIs from the base class and add it as needed in the child classes. Adds batched (i.e., multi-record) read/write support and a `FileReporter` class that is just the sum of the `FileIO` and `BenchmarkReporter` interfaces. --- examples/mnist/mnist.py | 4 +- src/nnbench/reporter/__init__.py | 1 + src/nnbench/reporter/base.py | 19 +---- src/nnbench/reporter/duckdb_sql.py | 13 +-- src/nnbench/reporter/file.py | 124 ++++++++++++++++------------- 5 files changed, 81 insertions(+), 80 deletions(-) diff --git a/examples/mnist/mnist.py b/examples/mnist/mnist.py index d214c025..0c77789c 100644 --- a/examples/mnist/mnist.py +++ b/examples/mnist/mnist.py @@ -217,10 +217,10 @@ def mnist_jax(): # the nnbench portion. runner = nnbench.BenchmarkRunner() - reporter = nnbench.BenchmarkReporter() + reporter = nnbench.reporter.FileReporter() params = MNISTTestParameters(params=state.params, data=data) result = runner.run(HERE, params=params) - reporter.write(result) + reporter.write(result, "result.json") if __name__ == "__main__": diff --git a/src/nnbench/reporter/__init__.py b/src/nnbench/reporter/__init__.py index 859d113b..7b08b416 100644 --- a/src/nnbench/reporter/__init__.py +++ b/src/nnbench/reporter/__init__.py @@ -5,3 +5,4 @@ from .base import BenchmarkReporter from .duckdb_sql import DuckDBReporter +from .file import FileReporter diff --git a/src/nnbench/reporter/base.py b/src/nnbench/reporter/base.py index fb822013..d2f3567e 100644 --- a/src/nnbench/reporter/base.py +++ b/src/nnbench/reporter/base.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, Callable, Sequence +from typing import Any, Callable from tabulate import tabulate @@ -23,7 +23,8 @@ class BenchmarkReporter: the database in ``report_result()``, with preprocessing if necessary. """ - def __init__(self): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self._initialized = False def initialize(self): @@ -120,17 +121,3 @@ def display( filtered.append(filteredbm) print(tabulate(filtered, headers="keys", tablefmt=tablefmt)) - - def read(self, *args: Any, **kwargs: Any) -> BenchmarkRecord: - raise NotImplementedError - - def read_batched(self, *args: Any, **kwargs: Any) -> list[BenchmarkRecord]: - raise NotImplementedError - - def write(self, record: BenchmarkRecord, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError - - def write_batched(self, records: Sequence[BenchmarkRecord], *args: Any, **kwargs: Any) -> None: - # By default, just loop over the records and write() everything. - for record in records: - self.write(record, *args, **kwargs) diff --git a/src/nnbench/reporter/duckdb_sql.py b/src/nnbench/reporter/duckdb_sql.py index 31b50fe5..70180f0b 100644 --- a/src/nnbench/reporter/duckdb_sql.py +++ b/src/nnbench/reporter/duckdb_sql.py @@ -6,7 +6,6 @@ import tempfile import weakref from pathlib import Path -from typing import Any from nnbench.context import Context @@ -17,11 +16,11 @@ except ImportError: DUCKDB_INSTALLED = False -from nnbench.reporter.file import FileIO +from nnbench.reporter.file import FileReporter from nnbench.types import BenchmarkRecord -class DuckDBReporter(FileIO): +class DuckDBReporter(FileReporter): """ A reporter for streaming benchmark results to duckdb. @@ -77,14 +76,13 @@ def __init__( self.conn: duckdb.DuckDBPyConnection | None = None - self._initialized = False - @property def directory(self) -> os.PathLike[str]: return self._directory def initialize(self): self.conn = duckdb.connect(self.dbname, read_only=self.read_only) + self._initialized = True def finalize(self): if self.conn: @@ -97,7 +95,6 @@ def read_sql( self, file: str | os.PathLike[str], driver: str | None = None, - options: dict[str, Any] | None = None, include: tuple[str, ...] | None = None, alias: dict[str, str] | None = None, limit: int | None = None, @@ -132,3 +129,7 @@ def read_sql( context.update(bm.pop("context", {})) return BenchmarkRecord(context=context, benchmarks=benchmarks) + + def raw_sql(self, query: str) -> duckdb.DuckDBPyRelation: + rel = self.conn.sql(query=query) + return rel diff --git a/src/nnbench/reporter/file.py b/src/nnbench/reporter/file.py index 4f4d4dd1..b0dd7c30 100644 --- a/src/nnbench/reporter/file.py +++ b/src/nnbench/reporter/file.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import IO, Any, Callable, Literal, Sequence +from nnbench.reporter.base import BenchmarkReporter from nnbench.types import BenchmarkRecord @@ -19,8 +20,8 @@ class FileDriverOptions: _Options = dict[str, Any] SerDe = tuple[ - Callable[[BenchmarkRecord, IO, FileDriverOptions], None], - Callable[[IO, FileDriverOptions], BenchmarkRecord], + Callable[[Sequence[BenchmarkRecord], IO, FileDriverOptions], None], + Callable[[IO, FileDriverOptions], list[BenchmarkRecord]], ] @@ -30,51 +31,67 @@ class FileDriverOptions: _compression_lock = threading.Lock() -def yaml_save(record: BenchmarkRecord, fp: IO, fdoptions: FileDriverOptions) -> None: +def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None: try: import yaml except ImportError: raise ModuleNotFoundError("`pyyaml` is not installed") - bms = record.compact(mode=fdoptions.ctxmode) + bms = [] + for r in records: + bms += r.compact(mode=fdoptions.ctxmode) yaml.safe_dump(bms, fp, **fdoptions.options) -def yaml_load(fp: IO, fdoptions: FileDriverOptions) -> BenchmarkRecord: +def yaml_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: try: import yaml except ImportError: raise ModuleNotFoundError("`pyyaml` is not installed") + # TODO: Use expandmany() bms = yaml.safe_load(fp) - return BenchmarkRecord.expand(bms) + return [BenchmarkRecord.expand(bms)] -def json_save(record: BenchmarkRecord, fp: IO, fdoptions: FileDriverOptions) -> None: +def json_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None: import json - benchmarks = record.compact(mode=fdoptions.ctxmode) - json.dump(benchmarks, fp, **fdoptions.options) + newline: bool = fdoptions.options.pop("newline", False) + bm = [] + for r in records: + bm += r.compact(mode=fdoptions.ctxmode) + if newline: + fp.write("\n".join([json.dumps(b) for b in bm])) + else: + json.dump(bm, fp, **fdoptions.options) -def json_load(fp: IO, fdoptions: FileDriverOptions) -> BenchmarkRecord: +def json_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: import json - benchmarks: list[dict[str, Any]] = json.load(fp, **fdoptions.options) - return BenchmarkRecord.expand(benchmarks) + newline: bool = fdoptions.options.pop("newline", False) + benchmarks: list[dict[str, Any]] + if newline: + benchmarks = [json.loads(line, **fdoptions.options) for line in fp] + else: + benchmarks = json.load(fp, **fdoptions.options) + return [BenchmarkRecord.expand(benchmarks)] -def csv_save(record: BenchmarkRecord, fp: IO, fdoptions: FileDriverOptions) -> None: +def csv_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None: import csv - benchmarks = record.compact(mode=fdoptions.ctxmode) - writer = csv.DictWriter(fp, fieldnames=benchmarks[0].keys(), **fdoptions.options) + bm = [] + for r in records: + bm += r.compact(mode=fdoptions.ctxmode) + writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **fdoptions.options) - for bm in benchmarks: - writer.writerow(bm) + for b in bm: + writer.writerow(b) -def csv_load(fp: IO, fdoptions: FileDriverOptions) -> BenchmarkRecord: +def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: import csv reader = csv.DictReader(fp, **fdoptions.options) @@ -86,7 +103,7 @@ def csv_load(fp: IO, fdoptions: FileDriverOptions) -> BenchmarkRecord: for bm in reader: benchmarks.append(bm) - return BenchmarkRecord.expand(benchmarks) + return [BenchmarkRecord.expand(benchmarks)] with _file_driver_lock: @@ -102,6 +119,7 @@ def __init__( drivers: dict[str, SerDe] | None = None, compressions: dict[str, Callable] | None = None, ): + super().__init__() self.drivers = drivers or _file_drivers self.compressions = compressions or _compressions @@ -109,30 +127,40 @@ def read( self, file: str | os.PathLike[str], driver: str | None = None, - ctxmode: Literal["flatten", "inline", "omit"] = "inline", options: dict[str, Any] | None = None, ) -> BenchmarkRecord: """ - Writes a benchmark record to the given file path. + Greedy version of ``FileIO.read_batched()``, returning the first read record. + When reading a multi-record file, this uses as much memory as the batched version. + """ + records = self.read_batched(file=file, driver=driver, options=options) + return records[0] + + def read_batched( + self, + file: str | os.PathLike[str], + driver: str | None = None, + options: dict[str, Any] | None = None, + ) -> list[BenchmarkRecord]: + """ + Reads a set of benchmark records from the given file path. The file driver is chosen based on the extension found on the ``file`` path. Parameters ---------- file: str | os.PathLike[str] - The file name to write to. + The file name to read from. driver: str | None File driver implementation to use. If None, the file driver inferred from the given file path's extension will be used. - ctxmode: Literal["flatten", "inline", "omit"] - How to handle the benchmark context when writing the record data. options: dict[str, Any] | None Options to pass to the respective file driver implementation. Returns ------- - BenchmarkRecord - The benchmark record contained in the file. + list[BenchmarkRecord] + The benchmark records contained in the file. Raises ------ @@ -144,27 +172,28 @@ def read( try: _, de = self.drivers[driver] except KeyError: - raise ValueError(f"unsupported file format {driver!r}") + raise KeyError(f"unsupported file format {driver!r}") from None - fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {}) + # dummy value, since the context mode is unused in read ops. + fdoptions = FileDriverOptions(ctxmode="omit", options=options or {}) - with open(file, "w") as fp: + with open(file, "r") as fp: return de(fp, fdoptions) - def read_batched( + def write( self, - records: Sequence[BenchmarkRecord], + record: BenchmarkRecord, file: str | os.PathLike[str], driver: str | None = None, ctxmode: Literal["flatten", "inline", "omit"] = "inline", options: dict[str, Any] | None = None, ) -> None: - """A batched version of ``FileIO.read()``.""" - pass + """Greedy version of ``FileIO.write_batched()``""" + self.write_batched([record], file=file, driver=driver, ctxmode=ctxmode, options=options) - def write( + def write_batched( self, - record: BenchmarkRecord, + records: Sequence[BenchmarkRecord], file: str | os.PathLike[str], driver: str | None = None, ctxmode: Literal["flatten", "inline", "omit"] = "inline", @@ -177,7 +206,7 @@ def write( Parameters ---------- - record: BenchmarkRecord + records: Sequence[BenchmarkRecord] The record to write to the database. file: str | os.PathLike[str] The file name to write to. @@ -203,25 +232,8 @@ def write( fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {}) with open(file, "w") as fp: - ser(record, fp, fdoptions) + ser(records, fp, fdoptions) - def write_batched( - self, - records: Sequence[BenchmarkRecord], - file: str | os.PathLike[str], - driver: str | None = None, - ctxmode: Literal["flatten", "inline", "omit"] = "inline", - options: dict[str, Any] | None = None, - ) -> None: - """A batched version of ``FileIO.write()``.""" - driver = driver or Path(file).suffix.removeprefix(".") - try: - ser, _ = self.drivers[driver] - except KeyError: - raise KeyError(f"unsupported file format {driver!r}") from None - - fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {}) - with open(file, "a+") as fp: - for record in records: - ser(record, fp, fdoptions) +class FileReporter(FileIO, BenchmarkReporter): + pass