Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Organize nnbench.types.types into different files #135

Merged
merged 4 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload

from nnbench.types import Benchmark
from nnbench.types.types import NoOp
from nnbench.types.util import is_memo, is_memo_type
from nnbench.types.benchmark import NoOp
from nnbench.types.memo import is_memo, is_memo_type


def _check_against_interface(params: dict[str, Any], fun: Callable) -> None:
Expand Down Expand Up @@ -61,7 +61,7 @@ def _default_namegen(fn: Callable, **kwargs: Any) -> str:
@overload
def benchmark(
func: None = None,
name: str | None = None,
name: str = "",
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
Expand All @@ -75,7 +75,7 @@ def benchmark(
@overload
def benchmark(
func: Callable[..., Any],
name: str | None = None,
name: str = "",
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
Expand All @@ -84,7 +84,7 @@ def benchmark(

def benchmark(
func: Callable[..., Any] | None = None,
name: str | None = None,
name: str = "",
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
Expand All @@ -101,7 +101,7 @@ def benchmark(
func: Callable[..., Any] | None
The function to benchmark. This slot only exists to allow application of the decorator
without parentheses, you should never fill it explicitly.
name: str | None
name: str
A display name to give to the benchmark. Useful in summaries and reports.
setUp: Callable[..., None]
A setup hook to run before the benchmark.
Expand Down
2 changes: 1 addition & 1 deletion src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from nnbench.context import Context, ContextProvider
from nnbench.types import Benchmark, BenchmarkRecord, Parameters, State
from nnbench.types.util import is_memo, is_memo_type
from nnbench.types.memo import is_memo, is_memo_type
from nnbench.util import import_file_as_module, ismodule

logger = logging.getLogger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion src/nnbench/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State
from .benchmark import Benchmark, BenchmarkRecord, Parameters, State
from .memo import Memo, cached_memo
184 changes: 17 additions & 167 deletions src/nnbench/types/types.py → src/nnbench/types/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,26 @@
"""Useful type interfaces to override/subclass in benchmarking workflows."""
"""Type interfaces for benchmarks and benchmark collections."""

from __future__ import annotations

import copy
import functools
import inspect
import logging
import threading
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Callable, Generic, Literal, Mapping, Protocol, TypeVar
from typing import Any, Callable, Literal, Mapping

from nnbench.context import Context

T = TypeVar("T")
Variable = tuple[str, type, Any]

_memo_cache: dict[int, Any] = {}
_cache_lock = threading.Lock()

logger = logging.getLogger(__name__)


def memo_cache_size() -> int:
"""
Get the current size of the memo cache.

Returns
-------
int
The number of items currently stored in the memo cache.
"""
return len(_memo_cache)


def clear_memo_cache() -> None:
"""
Clear all items from memo cache in a thread_safe manner.
"""
with _cache_lock:
_memo_cache.clear()


def evict_memo(_id: int) -> Any:
"""
Pop cached item with key `_id` from the memo cache.

Parameters
----------
_id : int
The unique identifier (usually the id assigned by the Python interpreter) of the item to be evicted.

Returns
-------
Any
The value that was associated with the removed cache entry. If no item is found with the given `_id`, a KeyError is raised.
"""
with _cache_lock:
return _memo_cache.pop(_id)


def cached_memo(fn: Callable) -> Callable:
"""
Decorator that caches the result of a method call based on the instance ID.

Parameters
----------
fn: Callable
The method to memoize.

Returns
-------
Callable
A wrapped version of the method that caches its result.
"""

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
_tid = id(self)
with _cache_lock:
if _tid in _memo_cache:
logger.debug(f"Returning memoized value from cache with ID {_tid}")
return _memo_cache[_tid]
logger.debug(f"Computing value on memo with ID {_tid} (cache miss)")
value = fn(self, *args, **kwargs)
with _cache_lock:
_memo_cache[_tid] = value
return value

return wrapper
from nnbench.types.interface import Interface


def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None:
pass


class CallbackProtocol(Protocol):
def __call__(self, state: State, params: Mapping[str, Any]) -> None: ...
@dataclass(frozen=True)
class State:
name: str
family: str
family_size: int
family_index: int


@dataclass(frozen=True)
Expand Down Expand Up @@ -176,47 +100,6 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
# context data.


@dataclass(frozen=True)
class State:
name: str
family: str
family_size: int
family_index: int


class Memo(Generic[T]):
"""Abstract base class for memoized values in benchmark runs."""

# TODO: Make this better than the decorator application
# -> _Cached metaclass like in fsspec's AbstractFileSystem (maybe vendor with license)

@cached_memo
def __call__(self) -> T:
"""Placeholder to override when subclassing. The call should return the to be cached object."""
raise NotImplementedError

def __del__(self) -> None:
"""Delete the cached object and clear it from the cache."""
with _cache_lock:
sid = id(self)
if sid in _memo_cache:
logger.debug(f"Deleting cached value for memo with ID {sid}")
del _memo_cache[sid]


@dataclass(init=False, frozen=True)
class Parameters:
"""
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.

The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
"""

pass


@dataclass(frozen=True)
class Benchmark:
"""
Expand All @@ -241,7 +124,7 @@ class Benchmark:
"""

fn: Callable[..., Any]
name: str = field(default="")
name: str = ""
params: dict[str, Any] = field(default_factory=dict)
setUp: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
tearDown: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
Expand All @@ -254,47 +137,14 @@ def __post_init__(self):
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))


@dataclass(frozen=True)
class Interface:
@dataclass(init=False, frozen=True)
class Parameters:
"""
Data model representing a function's interface. An instance of this class
is created using the `from_callable` class method.
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.

Parameters:
----------
names : tuple[str, ...]
Names of the function parameters.
types : tuple[type, ...]
Types of the function parameters.
defaults : tuple
A tuple of the function parameters' default values.
variables : tuple[Variable, ...]
A tuple of tuples, where each inner tuple contains the parameter name and type.
returntype: type
The function's return type annotation, or NoneType if left untyped.
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
"""

names: tuple[str, ...]
types: tuple[type, ...]
defaults: tuple
variables: tuple[Variable, ...]
returntype: type

@classmethod
def from_callable(cls, fn: Callable, defaults: dict[str, Any]) -> Interface:
"""
Creates an interface instance from the given callable.
"""
# Set `follow_wrapped=False` to get the partially filled interfaces.
# Otherwise we get missing value errors for parameters supplied in benchmark decorators.
sig = inspect.signature(fn, follow_wrapped=False)
ret = sig.return_annotation
_defaults = {k: defaults.get(k, v.default) for k, v in sig.parameters.items()}
# defaults are the signature parameters, then the partial parametrization.
return cls(
tuple(sig.parameters.keys()),
tuple(p.annotation for p in sig.parameters.values()),
tuple(_defaults.values()),
tuple((k, v.annotation, _defaults[k]) for k, v in sig.parameters.items()),
type(ret) if ret is None else ret,
)
pass
56 changes: 56 additions & 0 deletions src/nnbench/types/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Type interface for the function interface"""

from __future__ import annotations

import inspect
from dataclasses import dataclass
from typing import Any, Callable, TypeVar

T = TypeVar("T")
Variable = tuple[str, type, Any]


@dataclass(frozen=True)
class Interface:
"""
Data model representing a function's interface. An instance of this class
is created using the `from_callable` class method.

Parameters:
----------
names : tuple[str, ...]
Names of the function parameters.
types : tuple[type, ...]
Types of the function parameters.
defaults : tuple
A tuple of the function parameters' default values.
variables : tuple[Variable, ...]
A tuple of tuples, where each inner tuple contains the parameter name and type.
returntype: type
The function's return type annotation, or NoneType if left untyped.
"""

names: tuple[str, ...]
types: tuple[type, ...]
defaults: tuple
variables: tuple[Variable, ...]
returntype: type

@classmethod
def from_callable(cls, fn: Callable, defaults: dict[str, Any]) -> Interface:
"""
Creates an interface instance from the given callable.
"""
# Set `follow_wrapped=False` to get the partially filled interfaces.
# Otherwise we get missing value errors for parameters supplied in benchmark decorators.
sig = inspect.signature(fn, follow_wrapped=False)
ret = sig.return_annotation
_defaults = {k: defaults.get(k, v.default) for k, v in sig.parameters.items()}
# defaults are the signature parameters, then the partial parametrization.
return cls(
tuple(sig.parameters.keys()),
tuple(p.annotation for p in sig.parameters.values()),
tuple(_defaults.values()),
tuple((k, v.annotation, _defaults[k]) for k, v in sig.parameters.items()),
type(ret) if ret is None else ret,
)
Loading