Skip to content

Commit

Permalink
Inject state class into setUp and tearDown tasks
Browse files Browse the repository at this point in the history
in the `runner.run()` method. This is necessary such that the setUp and
tearDown tasks know the benchmark states. Namely, how many bench-
marks are in the benchmark family and the index of the current.

In a follow up implementation of a cache we will use the index
and family lenght to compute a condition to empty the cache.
  • Loading branch information
maxmynter committed Mar 25, 2024
1 parent 65fc45b commit 7863925
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 12 deletions.
18 changes: 16 additions & 2 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,14 @@ def decorator(fn: Callable) -> list[Benchmark]:
)
names.add(name)

bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
bm = Benchmark(
fn,
name=name,
params=params,
setUp=setUp,
tearDown=tearDown,
tags=tags,
)
benchmarks.append(bm)
return benchmarks

Expand Down Expand Up @@ -236,7 +243,14 @@ def decorator(fn: Callable) -> list[Benchmark]:
)
names.add(name)

bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
bm = Benchmark(
fn,
name=name,
params=params,
setUp=setUp,
tearDown=tearDown,
tags=tags,
)
benchmarks.append(bm)
return benchmarks

Expand Down
20 changes: 18 additions & 2 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import collections
import contextlib
import inspect
import logging
Expand All @@ -16,6 +17,7 @@

from nnbench.context import Context, ContextProvider
from nnbench.types import Benchmark, BenchmarkRecord, Parameters
from nnbench.types.types import State
from nnbench.types.util import is_memo, is_memo_type
from nnbench.util import import_file_as_module, ismodule

Expand Down Expand Up @@ -247,6 +249,12 @@ def run(
if not self.benchmarks:
self.collect(path_or_module, tags)

family_sizes: dict[str, Any] = collections.defaultdict(int)
for bm in self.benchmarks:
family_sizes[bm.fn.__name__] += 1

family_indices: dict[str, Any] = collections.defaultdict(int)

if isinstance(context, Context):
ctx = context
else:
Expand Down Expand Up @@ -274,6 +282,14 @@ def _maybe_dememo(v, expected_type):
return v

for benchmark in self.benchmarks:
bm_family = benchmark.fn.__name__
bm_state = State(
name=benchmark.name or bm_family,
family=bm_family,
family_size=family_sizes[bm_family],
family_index=family_indices[bm_family],
)
family_indices[bm_family] += 1
bmtypes = dict(zip(benchmark.interface.names, benchmark.interface.types))
bmparams = dict(zip(benchmark.interface.names, benchmark.interface.defaults))
# TODO: Does this need a copy.deepcopy()?
Expand All @@ -291,14 +307,14 @@ def _maybe_dememo(v, expected_type):
"parameters": bmparams,
}
try:
benchmark.setUp(**bmparams)
benchmark.setUp(bm_state, **bmparams)
with timer(res):
res["value"] = benchmark.fn(**bmparams)
except Exception as e:
res["error_occurred"] = True
res["error_message"] = str(e)
finally:
benchmark.tearDown(**bmparams)
benchmark.tearDown(bm_state, **bmparams)
results.append(res)

return BenchmarkRecord(
Expand Down
2 changes: 1 addition & 1 deletion src/nnbench/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .types import Benchmark, BenchmarkRecord, Memo, Parameters
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State
34 changes: 27 additions & 7 deletions src/nnbench/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@
import functools
import inspect
from dataclasses import dataclass, field
from typing import (
Any,
Callable,
Generic,
Literal,
TypeVar,
)
from typing import Any, Callable, Generic, Literal, TypeVar

from nnbench.context import Context

Expand Down Expand Up @@ -101,6 +95,14 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
# context data.


@dataclass(frozen=True)
class State:
name: str
family: str
family_size: int
family_index: int


class Memo(Generic[T]):
@functools.cache
# TODO: Swap this out for a local type-wide memo cache.
Expand Down Expand Up @@ -170,6 +172,24 @@ def __post_init__(self):
super().__setattr__("name", self.fn.__name__)
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))

original_setUp = self.setUp

def wrapped_setUp(state: State, /, *args: Any, **kwargs: Any) -> None:
# TODO: setUp logic
print("SetUp: ", state)
original_setUp(*args, **kwargs)

super().__setattr__("setUp", wrapped_setUp)

original_tearDown = self.tearDown

def wrapped_tearDown(state: State, /, *args: Any, **kwargs: Any) -> None:
# TODO: tearDown logic
print("tearDown: ", state)
original_tearDown(*args, **kwargs)

super().__setattr__("tearDown", wrapped_tearDown)


@dataclass(frozen=True)
class Interface:
Expand Down

0 comments on commit 7863925

Please sign in to comment.