Skip to content

Commit

Permalink
Organize nnbench.types.types into different files
Browse files Browse the repository at this point in the history
such that the file is less long. The files aim to group related
types into files with fitting names.

Also refactored imports around the projects and in tests accordingly.
  • Loading branch information
maxmynter committed Mar 26, 2024
1 parent 636d18b commit 9f75085
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 175 deletions.
10 changes: 5 additions & 5 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload

from nnbench.types import Benchmark
from nnbench.types.types import NoOp
from nnbench.types.benchmark import NoOp
from nnbench.types.util import is_memo, is_memo_type


Expand Down Expand Up @@ -61,7 +61,7 @@ def _default_namegen(fn: Callable, **kwargs: Any) -> str:
@overload
def benchmark(
func: None = None,
name: str | None = None,
name: str = "",
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
Expand All @@ -75,7 +75,7 @@ def benchmark(
@overload
def benchmark(
func: Callable[..., Any],
name: str | None = None,
name: str = "",
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
Expand All @@ -84,7 +84,7 @@ def benchmark(

def benchmark(
func: Callable[..., Any] | None = None,
name: str | None = None,
name: str = "",
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
Expand All @@ -101,7 +101,7 @@ def benchmark(
func: Callable[..., Any] | None
The function to benchmark. This slot only exists to allow application of the decorator
without parentheses, you should never fill it explicitly.
name: str | None
name: str
A display name to give to the benchmark. Useful in summaries and reports.
setUp: Callable[..., None]
A setup hook to run before the benchmark.
Expand Down
3 changes: 2 additions & 1 deletion src/nnbench/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State
from .benchmark import Benchmark, BenchmarkRecord, Parameters, State
from .memo import Memo
186 changes: 19 additions & 167 deletions src/nnbench/types/types.py → src/nnbench/types/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,26 @@
"""Useful type interfaces to override/subclass in benchmarking workflows."""
"""Type interfaces for benchmarks and benchmark collections."""

from __future__ import annotations

import copy
import functools
import inspect
import logging
import threading
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Callable, Generic, Literal, Mapping, Protocol, TypeVar
from typing import Any, Callable, Literal, Mapping

from nnbench.context import Context

T = TypeVar("T")
Variable = tuple[str, type, Any]

_memo_cache: dict[int, Any] = {}
_cache_lock = threading.Lock()

logger = logging.getLogger(__name__)


def memo_cache_size() -> int:
"""
Get the current size of the memo cache.
Returns
-------
int
The number of items currently stored in the memo cache.
"""
return len(_memo_cache)


def clear_memo_cache() -> None:
"""
Clear all items from memo cache in a thread_safe manner.
"""
with _cache_lock:
_memo_cache.clear()


def evict_memo(_id: int) -> Any:
"""
Pop cached item with key `_id` from the memo cache.
Parameters
----------
_id : int
The unique identifier (usually the id assigned by the Python interpreter) of the item to be evicted.
Returns
-------
Any
The value that was associated with the removed cache entry. If no item is found with the given `_id`, a KeyError is raised.
"""
with _cache_lock:
return _memo_cache.pop(_id)


def cached_memo(fn: Callable) -> Callable:
"""
Decorator that caches the result of a method call based on the instance ID.
Parameters
----------
fn: Callable
The method to memoize.
Returns
-------
Callable
A wrapped version of the method that caches its result.
"""

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
_tid = id(self)
with _cache_lock:
if _tid in _memo_cache:
logger.debug(f"Returning memoized value from cache with ID {_tid}")
return _memo_cache[_tid]
logger.debug(f"Computing value on memo with ID {_tid} (cache miss)")
value = fn(self, *args, **kwargs)
with _cache_lock:
_memo_cache[_tid] = value
return value

return wrapper
from nnbench.types.interface import Interface


def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None:
pass


class CallbackProtocol(Protocol):
def __call__(self, state: State, params: Mapping[str, Any]) -> None: ...
@dataclass(frozen=True)
class State:
name: str
family: str
family_size: int
family_index: int


@dataclass(frozen=True)
Expand Down Expand Up @@ -176,47 +100,6 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
# context data.


@dataclass(frozen=True)
class State:
name: str
family: str
family_size: int
family_index: int


class Memo(Generic[T]):
"""Abstract base class for memoized values in benchmark runs."""

# TODO: Make this better than the decorator application
# -> _Cached metaclass like in fsspec's AbstractFileSystem (maybe vendor with license)

@cached_memo
def __call__(self) -> T:
"""Placeholder to override when subclassing. The call should return the to be cached object."""
raise NotImplementedError

def __del__(self) -> None:
"""Delete the cached object and clear it from the cache."""
with _cache_lock:
sid = id(self)
if sid in _memo_cache:
logger.debug(f"Deleting cached value for memo with ID {sid}")
del _memo_cache[sid]


@dataclass(init=False, frozen=True)
class Parameters:
"""
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
"""

pass


@dataclass(frozen=True)
class Benchmark:
"""
Expand Down Expand Up @@ -246,7 +129,9 @@ class Benchmark:
name: str = field(default="")
params: dict[str, Any] = field(default_factory=dict)
setUp: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
tearDown: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
tearDown: Callable[[State, Mapping[str, Any]], None] = field(
repr=False, default=NoOp
)
tags: tuple[str, ...] = field(repr=False, default=())
interface: Interface = field(init=False, repr=False)

Expand All @@ -256,47 +141,14 @@ def __post_init__(self):
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))


@dataclass(frozen=True)
class Interface:
@dataclass(init=False, frozen=True)
class Parameters:
"""
Data model representing a function's interface. An instance of this class
is created using the `from_callable` class method.
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.
Parameters:
----------
names : tuple[str, ...]
Names of the function parameters.
types : tuple[type, ...]
Types of the function parameters.
defaults : tuple
A tuple of the function parameters' default values.
variables : tuple[Variable, ...]
A tuple of tuples, where each inner tuple contains the parameter name and type.
returntype: type
The function's return type annotation, or NoneType if left untyped.
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
"""

names: tuple[str, ...]
types: tuple[type, ...]
defaults: tuple
variables: tuple[Variable, ...]
returntype: type

@classmethod
def from_callable(cls, fn: Callable, defaults: dict[str, Any]) -> Interface:
"""
Creates an interface instance from the given callable.
"""
# Set `follow_wrapped=False` to get the partially filled interfaces.
# Otherwise we get missing value errors for parameters supplied in benchmark decorators.
sig = inspect.signature(fn, follow_wrapped=False)
ret = sig.return_annotation
_defaults = {k: defaults.get(k, v.default) for k, v in sig.parameters.items()}
# defaults are the signature parameters, then the partial parametrization.
return cls(
tuple(sig.parameters.keys()),
tuple(p.annotation for p in sig.parameters.values()),
tuple(_defaults.values()),
tuple((k, v.annotation, _defaults[k]) for k, v in sig.parameters.items()),
type(ret) if ret is None else ret,
)
pass
56 changes: 56 additions & 0 deletions src/nnbench/types/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Type interface for the function interface"""

from __future__ import annotations

import inspect
from dataclasses import dataclass
from typing import Any, Callable, TypeVar

T = TypeVar("T")
Variable = tuple[str, type, Any]


@dataclass(frozen=True)
class Interface:
"""
Data model representing a function's interface. An instance of this class
is created using the `from_callable` class method.
Parameters:
----------
names : tuple[str, ...]
Names of the function parameters.
types : tuple[type, ...]
Types of the function parameters.
defaults : tuple
A tuple of the function parameters' default values.
variables : tuple[Variable, ...]
A tuple of tuples, where each inner tuple contains the parameter name and type.
returntype: type
The function's return type annotation, or NoneType if left untyped.
"""

names: tuple[str, ...]
types: tuple[type, ...]
defaults: tuple
variables: tuple[Variable, ...]
returntype: type

@classmethod
def from_callable(cls, fn: Callable, defaults: dict[str, Any]) -> Interface:
"""
Creates an interface instance from the given callable.
"""
# Set `follow_wrapped=False` to get the partially filled interfaces.
# Otherwise we get missing value errors for parameters supplied in benchmark decorators.
sig = inspect.signature(fn, follow_wrapped=False)
ret = sig.return_annotation
_defaults = {k: defaults.get(k, v.default) for k, v in sig.parameters.items()}
# defaults are the signature parameters, then the partial parametrization.
return cls(
tuple(sig.parameters.keys()),
tuple(p.annotation for p in sig.parameters.values()),
tuple(_defaults.values()),
tuple((k, v.annotation, _defaults[k]) for k, v in sig.parameters.items()),
type(ret) if ret is None else ret,
)
Loading

0 comments on commit 9f75085

Please sign in to comment.