Skip to content

Commit

Permalink
Merge pull request #82 from aai-institute/reporter-refactor
Browse files Browse the repository at this point in the history
Improve file drivers, add dict roundtrip methods to `BenchmarkRecord`
  • Loading branch information
nicholasjng authored Feb 19, 2024
2 parents cf86afb + 5495757 commit 7eee103
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 120 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
types_or: [ python, pyi ]
args: [--ignore-missing-imports, --scripts-are-modules]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
rev: v0.2.2
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
Expand All @@ -29,6 +29,6 @@ repos:
args: [-c, pyproject.toml]
additional_dependencies: ["bandit[toml]"]
- repo: https://github.com/jsh9/pydoclint
rev: 0.4.0
rev: 0.4.1
hooks:
- id: pydoclint
20 changes: 2 additions & 18 deletions src/nnbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A framework for organizing and running benchmark workloads on machine learning models."""

from importlib.metadata import PackageNotFoundError, entry_points, version
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("nnbench")
Expand All @@ -9,22 +9,6 @@
pass

from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter, register_reporter
from .reporter import BenchmarkReporter
from .runner import BenchmarkRunner
from .types import Parameters


def add_reporters():
eps = entry_points()

if hasattr(eps, "select"): # Python 3.10+ / importlib.metadata >= 3.9.0
reporters = eps.select(group="nnbench.reporters")
else:
reporters = eps.get("nnbench.reporters", []) # type: ignore

for rep in reporters:
key, clsname = rep.name.split("=", 1)
register_reporter(key, clsname)


add_reporters()
18 changes: 11 additions & 7 deletions src/nnbench/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,15 @@ def __call__(self) -> dict[str, Any]:

class Context:
def __init__(self, data: dict[str, Any] | None = None) -> None:
self._ctx_dict: dict[str, Any] = data or {}
self._data: dict[str, Any] = data or {}

def __contains__(self, key: str) -> bool:
return key in self.keys()

@property
def data(self):
return self._data

@staticmethod
def _ctx_items(d: dict[str, Any], prefix: str, sep: str) -> Iterator[tuple[str, Any]]:
"""
Expand Down Expand Up @@ -234,7 +238,7 @@ def keys(self, sep: str = ".") -> Iterator[str]:
str
Iterator over the context dictionary keys.
"""
for k, v in self._ctx_items(d=self._ctx_dict, prefix="", sep=sep):
for k, v in self._ctx_items(d=self._data, prefix="", sep=sep):
yield k

def values(self) -> Iterator[Any]:
Expand All @@ -246,7 +250,7 @@ def values(self) -> Iterator[Any]:
Any
Iterator over all values in the context dictionary.
"""
for k, v in self._ctx_items(d=self._ctx_dict, prefix="", sep=""):
for k, v in self._ctx_items(d=self._data, prefix="", sep=""):
yield v

def items(self, sep: str = ".") -> Iterator[tuple[str, Any]]:
Expand All @@ -263,7 +267,7 @@ def items(self, sep: str = ".") -> Iterator[tuple[str, Any]]:
tuple[str, Any]
Iterator over the items of the context dictionary.
"""
yield from self._ctx_items(d=self._ctx_dict, prefix="", sep=sep)
yield from self._ctx_items(d=self._data, prefix="", sep=sep)

def update(self, other: ContextProvider | dict[str, Any] | "Context") -> None:
"""
Expand All @@ -278,8 +282,8 @@ def update(self, other: ContextProvider | dict[str, Any] | "Context") -> None:
if callable(other):
other = other()
elif isinstance(other, Context):
other = other._ctx_dict
self._ctx_dict.update(other)
other = other._data
self._data.update(other)

@staticmethod
def _flatten_dict(d: dict[str, Any], prefix: str = "", sep: str = ".") -> dict[str, Any]:
Expand Down Expand Up @@ -325,7 +329,7 @@ def flatten(self, sep: str = ".") -> dict[str, Any]:
The flattened context values as a Python dictionary.
"""

return self._flatten_dict(self._ctx_dict, prefix="", sep=sep)
return self._flatten_dict(self._data, prefix="", sep=sep)

@staticmethod
def unflatten(d: dict[str, Any], sep: str = ".") -> dict[str, Any]:
Expand Down
37 changes: 1 addition & 36 deletions src/nnbench/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,5 @@
"""
from __future__ import annotations

import importlib
import types

from .base import BenchmarkReporter

# internal, mutable
_reporter_registry: dict[str, type[BenchmarkReporter]] = {}

# external, immutable
reporter_registry: types.MappingProxyType[str, type[BenchmarkReporter]] = types.MappingProxyType(
_reporter_registry
)


def register_reporter(key: str, cls_or_name: str | type[BenchmarkReporter]) -> None:
"""
Register a reporter class by its fully qualified module path.
Parameters
----------
key: str
The key to register the reporter under. Subsequently, this key can be used in place
of reporter classes in code.
cls_or_name: str | type[BenchmarkReporter]
Name of or full module path to the reporter class. For example, when registering a class
``MyReporter`` located in ``my_module``, ``name`` should be ``my_module.MyReporter``.
"""

if isinstance(cls_or_name, str):
name = cls_or_name
modname, clsname = name.rsplit(".", 1)
mod = importlib.import_module(modname)
cls = getattr(mod, clsname)
_reporter_registry[key] = cls
else:
# name = cls_or_name.__module__ + "." + cls_or_name.__qualname__
_reporter_registry[key] = cls_or_name
from .duckdb_sql import DuckDBReporter
5 changes: 2 additions & 3 deletions src/nnbench/reporter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def display(
A mapping of column names to custom formatters, i.e. functions formatting input
values for display in the console.
"""
ctx, benchmarks = record["context"], record["benchmarks"]

benchmarks = record.benchmarks
# This assumes a stable schema across benchmarks.
if include is None:
includes = set(benchmarks[0].keys())
Expand All @@ -108,7 +107,7 @@ def display(
continue
filteredctx = {
k: v
for k, v in ctx.flatten().items()
for k, v in record.context.items()
if any(k.startswith(i) for i in include_context)
}
filteredbm = {k: v for k, v in bm.items() if k in cols}
Expand Down
7 changes: 3 additions & 4 deletions src/nnbench/reporter/duckdb_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
except ImportError:
DUCKDB_INSTALLED = False

from nnbench.reporter.file import FileReporter
from nnbench.reporter.file import FileIO
from nnbench.types import BenchmarkRecord


class DuckDBReporter(FileReporter):
class DuckDBReporter(FileIO):
"""
A reporter for streaming benchmark results to duckdb.
Expand Down Expand Up @@ -85,7 +85,6 @@ def directory(self) -> os.PathLike[str]:

def initialize(self):
self.conn = duckdb.connect(self.dbname, read_only=self.read_only)
super().initialize()

def finalize(self):
if self.conn:
Expand All @@ -94,7 +93,7 @@ def finalize(self):
if self.delete:
shutil.rmtree(self._directory, ignore_errors=True)

def read(
def read_sql(
self,
file: str | os.PathLike[str],
driver: str | None = None,
Expand Down
Loading

0 comments on commit 7eee103

Please sign in to comment.