Skip to content

Commit

Permalink
wip - State injection into setup and teardown
Browse files Browse the repository at this point in the history
  • Loading branch information
maxmynter committed Mar 22, 2024
1 parent 65fc45b commit 670a857
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 5 deletions.
26 changes: 22 additions & 4 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload

from nnbench.types import Benchmark
from nnbench.types import Benchmark, State
from nnbench.types.util import is_memo, is_memo_type


Expand Down Expand Up @@ -178,7 +178,9 @@ def decorator(fn: Callable) -> list[Benchmark]:
)
names.add(name)

bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
bm = Benchmark(
fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags
)
benchmarks.append(bm)
return benchmarks

Expand Down Expand Up @@ -224,7 +226,8 @@ def decorator(fn: Callable) -> list[Benchmark]:
benchmarks = []
names = set()
varnames = iterables.keys()
for values in itertools.product(*iterables.values()):
cartesian_product = itertools.product(*iterables.values())
for idx, values in enumerate(cartesian_product):
params = dict(zip(varnames, values))
_check_against_interface(params, fn)

Expand All @@ -235,8 +238,23 @@ def decorator(fn: Callable) -> list[Benchmark]:
f"Perhaps you specified a parameter configuration twice?"
)
names.add(name)
state = State(
name=name,
function=fn,
family=fn.__name__,
family_size=len(list(cartesian_product)),
family_index=idx,
)

bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
bm = Benchmark(
fn,
name=name,
params=params,
setUp=setUp,
tearDown=tearDown,
tags=tags,
state=state,
)
benchmarks.append(bm)
return benchmarks

Expand Down
2 changes: 1 addition & 1 deletion src/nnbench/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .types import Benchmark, BenchmarkRecord, Memo, Parameters
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State
41 changes: 41 additions & 0 deletions src/nnbench/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from nnbench.context import Context
from nnbench import __version__

T = TypeVar("T")
Variable = tuple[str, type, Any]
Expand Down Expand Up @@ -101,6 +102,16 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
# context data.


@dataclass(frozen=True)
class State:
name: str
function: Callable
family: str
family_size: int
family_index: int
nnbench_version: str = __version__


class Memo(Generic[T]):
@functools.cache
# TODO: Swap this out for a local type-wide memo cache.
Expand Down Expand Up @@ -164,11 +175,41 @@ class Benchmark:
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
tags: tuple[str, ...] = field(repr=False, default=())
interface: Interface = field(init=False, repr=False)
state: State | None = field(default=None)

def __post_init__(self):
if not self.name:
super().__setattr__("name", self.fn.__name__)
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))
if not self.state:
super().__setattr__(
"state",
State(
name=self.name or "",
function=self.fn,
family=self.fn.__name__,
family_size=1,
family_index=0,
),
)

original_setUp = self.setUp

def wrapped_setUp(*args, **kwargs):
state = self.state
# TODO: setUp and Teardown logic
original_setUp(*args, **kwargs)

super().__setattr__("setUp", wrapped_setUp)

original_tearDown = self.tearDown

def wrapped_tearDown(*args, **kwargs):
state = self.state
# TODO: tearDown logic
original_tearDown(*args, **kwargs)

super().__setattr__("tearDown", wrapped_tearDown)


@dataclass(frozen=True)
Expand Down

0 comments on commit 670a857

Please sign in to comment.