Skip to content

Commit

Permalink
Merge pull request #83 from aai-institute/reporter-refactor
Browse files Browse the repository at this point in the history
Unify batching in file drivers, remove read/write stubs on base reporter
  • Loading branch information
nicholasjng authored Feb 19, 2024
2 parents 7eee103 + 36aa427 commit e2890e0
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 80 deletions.
4 changes: 2 additions & 2 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def mnist_jax():

# the nnbench portion.
runner = nnbench.BenchmarkRunner()
reporter = nnbench.BenchmarkReporter()
reporter = nnbench.reporter.FileReporter()
params = MNISTTestParameters(params=state.params, data=data)
result = runner.run(HERE, params=params)
reporter.write(result)
reporter.write(result, "result.json")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions src/nnbench/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .base import BenchmarkReporter
from .duckdb_sql import DuckDBReporter
from .file import FileReporter
19 changes: 3 additions & 16 deletions src/nnbench/reporter/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import Any, Callable, Sequence
from typing import Any, Callable

from tabulate import tabulate

Expand All @@ -23,7 +23,8 @@ class BenchmarkReporter:
the database in ``report_result()``, with preprocessing if necessary.
"""

def __init__(self):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._initialized = False

def initialize(self):
Expand Down Expand Up @@ -120,17 +121,3 @@ def display(
filtered.append(filteredbm)

print(tabulate(filtered, headers="keys", tablefmt=tablefmt))

def read(self, *args: Any, **kwargs: Any) -> BenchmarkRecord:
raise NotImplementedError

def read_batched(self, *args: Any, **kwargs: Any) -> list[BenchmarkRecord]:
raise NotImplementedError

def write(self, record: BenchmarkRecord, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError

def write_batched(self, records: Sequence[BenchmarkRecord], *args: Any, **kwargs: Any) -> None:
# By default, just loop over the records and write() everything.
for record in records:
self.write(record, *args, **kwargs)
13 changes: 7 additions & 6 deletions src/nnbench/reporter/duckdb_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import tempfile
import weakref
from pathlib import Path
from typing import Any

from nnbench.context import Context

Expand All @@ -17,11 +16,11 @@
except ImportError:
DUCKDB_INSTALLED = False

from nnbench.reporter.file import FileIO
from nnbench.reporter.file import FileReporter
from nnbench.types import BenchmarkRecord


class DuckDBReporter(FileIO):
class DuckDBReporter(FileReporter):
"""
A reporter for streaming benchmark results to duckdb.
Expand Down Expand Up @@ -77,14 +76,13 @@ def __init__(

self.conn: duckdb.DuckDBPyConnection | None = None

self._initialized = False

@property
def directory(self) -> os.PathLike[str]:
return self._directory

def initialize(self):
self.conn = duckdb.connect(self.dbname, read_only=self.read_only)
self._initialized = True

def finalize(self):
if self.conn:
Expand All @@ -97,7 +95,6 @@ def read_sql(
self,
file: str | os.PathLike[str],
driver: str | None = None,
options: dict[str, Any] | None = None,
include: tuple[str, ...] | None = None,
alias: dict[str, str] | None = None,
limit: int | None = None,
Expand Down Expand Up @@ -132,3 +129,7 @@ def read_sql(
context.update(bm.pop("context", {}))

return BenchmarkRecord(context=context, benchmarks=benchmarks)

def raw_sql(self, query: str) -> duckdb.DuckDBPyRelation:
rel = self.conn.sql(query=query)
return rel
124 changes: 68 additions & 56 deletions src/nnbench/reporter/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import IO, Any, Callable, Literal, Sequence

from nnbench.reporter.base import BenchmarkReporter
from nnbench.types import BenchmarkRecord


Expand All @@ -19,8 +20,8 @@ class FileDriverOptions:

_Options = dict[str, Any]
SerDe = tuple[
Callable[[BenchmarkRecord, IO, FileDriverOptions], None],
Callable[[IO, FileDriverOptions], BenchmarkRecord],
Callable[[Sequence[BenchmarkRecord], IO, FileDriverOptions], None],
Callable[[IO, FileDriverOptions], list[BenchmarkRecord]],
]


Expand All @@ -30,51 +31,67 @@ class FileDriverOptions:
_compression_lock = threading.Lock()


def yaml_save(record: BenchmarkRecord, fp: IO, fdoptions: FileDriverOptions) -> None:
def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
try:
import yaml
except ImportError:
raise ModuleNotFoundError("`pyyaml` is not installed")

bms = record.compact(mode=fdoptions.ctxmode)
bms = []
for r in records:
bms += r.compact(mode=fdoptions.ctxmode)
yaml.safe_dump(bms, fp, **fdoptions.options)


def yaml_load(fp: IO, fdoptions: FileDriverOptions) -> BenchmarkRecord:
def yaml_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
try:
import yaml
except ImportError:
raise ModuleNotFoundError("`pyyaml` is not installed")

# TODO: Use expandmany()
bms = yaml.safe_load(fp)
return BenchmarkRecord.expand(bms)
return [BenchmarkRecord.expand(bms)]


def json_save(record: BenchmarkRecord, fp: IO, fdoptions: FileDriverOptions) -> None:
def json_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
import json

benchmarks = record.compact(mode=fdoptions.ctxmode)
json.dump(benchmarks, fp, **fdoptions.options)
newline: bool = fdoptions.options.pop("newline", False)
bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
if newline:
fp.write("\n".join([json.dumps(b) for b in bm]))
else:
json.dump(bm, fp, **fdoptions.options)


def json_load(fp: IO, fdoptions: FileDriverOptions) -> BenchmarkRecord:
def json_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
import json

benchmarks: list[dict[str, Any]] = json.load(fp, **fdoptions.options)
return BenchmarkRecord.expand(benchmarks)
newline: bool = fdoptions.options.pop("newline", False)
benchmarks: list[dict[str, Any]]
if newline:
benchmarks = [json.loads(line, **fdoptions.options) for line in fp]
else:
benchmarks = json.load(fp, **fdoptions.options)
return [BenchmarkRecord.expand(benchmarks)]


def csv_save(record: BenchmarkRecord, fp: IO, fdoptions: FileDriverOptions) -> None:
def csv_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
import csv

benchmarks = record.compact(mode=fdoptions.ctxmode)
writer = csv.DictWriter(fp, fieldnames=benchmarks[0].keys(), **fdoptions.options)
bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **fdoptions.options)

for bm in benchmarks:
writer.writerow(bm)
for b in bm:
writer.writerow(b)


def csv_load(fp: IO, fdoptions: FileDriverOptions) -> BenchmarkRecord:
def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
import csv

reader = csv.DictReader(fp, **fdoptions.options)
Expand All @@ -86,7 +103,7 @@ def csv_load(fp: IO, fdoptions: FileDriverOptions) -> BenchmarkRecord:
for bm in reader:
benchmarks.append(bm)

return BenchmarkRecord.expand(benchmarks)
return [BenchmarkRecord.expand(benchmarks)]


with _file_driver_lock:
Expand All @@ -102,37 +119,48 @@ def __init__(
drivers: dict[str, SerDe] | None = None,
compressions: dict[str, Callable] | None = None,
):
super().__init__()
self.drivers = drivers or _file_drivers
self.compressions = compressions or _compressions

def read(
self,
file: str | os.PathLike[str],
driver: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> BenchmarkRecord:
"""
Writes a benchmark record to the given file path.
Greedy version of ``FileIO.read_batched()``, returning the first read record.
When reading a multi-record file, this uses as much memory as the batched version.
"""
records = self.read_batched(file=file, driver=driver, options=options)
return records[0]

def read_batched(
self,
file: str | os.PathLike[str],
driver: str | None = None,
options: dict[str, Any] | None = None,
) -> list[BenchmarkRecord]:
"""
Reads a set of benchmark records from the given file path.
The file driver is chosen based on the extension found on the ``file`` path.
Parameters
----------
file: str | os.PathLike[str]
The file name to write to.
The file name to read from.
driver: str | None
File driver implementation to use. If None, the file driver inferred from the
given file path's extension will be used.
ctxmode: Literal["flatten", "inline", "omit"]
How to handle the benchmark context when writing the record data.
options: dict[str, Any] | None
Options to pass to the respective file driver implementation.
Returns
-------
BenchmarkRecord
The benchmark record contained in the file.
list[BenchmarkRecord]
The benchmark records contained in the file.
Raises
------
Expand All @@ -144,27 +172,28 @@ def read(
try:
_, de = self.drivers[driver]
except KeyError:
raise ValueError(f"unsupported file format {driver!r}")
raise KeyError(f"unsupported file format {driver!r}") from None

fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {})
# dummy value, since the context mode is unused in read ops.
fdoptions = FileDriverOptions(ctxmode="omit", options=options or {})

with open(file, "w") as fp:
with open(file, "r") as fp:
return de(fp, fdoptions)

def read_batched(
def write(
self,
records: Sequence[BenchmarkRecord],
record: BenchmarkRecord,
file: str | os.PathLike[str],
driver: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> None:
"""A batched version of ``FileIO.read()``."""
pass
"""Greedy version of ``FileIO.write_batched()``"""
self.write_batched([record], file=file, driver=driver, ctxmode=ctxmode, options=options)

def write(
def write_batched(
self,
record: BenchmarkRecord,
records: Sequence[BenchmarkRecord],
file: str | os.PathLike[str],
driver: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
Expand All @@ -177,7 +206,7 @@ def write(
Parameters
----------
record: BenchmarkRecord
records: Sequence[BenchmarkRecord]
The record to write to the database.
file: str | os.PathLike[str]
The file name to write to.
Expand All @@ -203,25 +232,8 @@ def write(

fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {})
with open(file, "w") as fp:
ser(record, fp, fdoptions)
ser(records, fp, fdoptions)

def write_batched(
self,
records: Sequence[BenchmarkRecord],
file: str | os.PathLike[str],
driver: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> None:
"""A batched version of ``FileIO.write()``."""
driver = driver or Path(file).suffix.removeprefix(".")

try:
ser, _ = self.drivers[driver]
except KeyError:
raise KeyError(f"unsupported file format {driver!r}") from None

fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {})
with open(file, "a+") as fp:
for record in records:
ser(record, fp, fdoptions)
class FileReporter(FileIO, BenchmarkReporter):
pass

0 comments on commit e2890e0

Please sign in to comment.