Skip to content

Commit

Permalink
Towards async mocking
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk committed Oct 30, 2024
1 parent d712674 commit 76727d0
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 39 deletions.
2 changes: 1 addition & 1 deletion amaranth-stubs
42 changes: 26 additions & 16 deletions transactron/testing/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transactron.core.keys import TransactionManagerKey
from transactron.core import TransactionModule
from transactron.utils import ModuleConnector, HasElaborate, auto_debug_signals, HasDebugSignals
from transactron.testing.sugar import MethodMock


T = TypeVar("T")
Expand Down Expand Up @@ -222,34 +223,43 @@ class TestCaseWithSimulator:
dependency_manager: DependencyManager

@contextmanager
def configure_dependency_context(self):
def _configure_dependency_context(self):
self.dependency_manager = DependencyManager()
with DependencyContext(self.dependency_manager):
yield Tick()

def add_class_mocks(self, sim: PysimSimulator) -> None:
def _add_mock(self, sim: PysimSimulator, val: MethodMock | Callable[[], TestGen[None]]):
if isinstance(val, MethodMock):
sim.add_process(val.output_process)
if val.validate_arguments is not None:
sim.add_process(val.validate_arguments_process)
sim.add_testbench(val.effect_process)
else:
sim.add_process(val)

def _add_class_mocks(self, sim: PysimSimulator) -> None:
for key in dir(self):
val = getattr(self, key)
if hasattr(val, "_transactron_testing_process"):
sim.add_process(val)
self._add_mock(sim, val)

def add_local_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None:
def _add_local_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None:
for key, val in frame_locals.items():
if hasattr(val, "_transactron_testing_process"):
sim.add_process(val)
self._add_mock(sim, val)

def add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None:
self.add_class_mocks(sim)
self.add_local_mocks(sim, frame_locals)
def _add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None:
self._add_class_mocks(sim)
self._add_local_mocks(sim, frame_locals)

def configure_traces(self):
def _configure_traces(self):
traces_file = None
if "__TRANSACTRON_DUMP_TRACES" in os.environ:
traces_file = self._transactron_current_output_file_name
self._transactron_infrastructure_traces_file = traces_file

@contextmanager
def configure_profiles(self):
def _configure_profiles(self):
profile = None
if "__TRANSACTRON_PROFILE" in os.environ:

Expand All @@ -274,7 +284,7 @@ def f():
profile.encode(f"{profile_dir}/{profile_file}.json")

@contextmanager
def configure_logging(self):
def _configure_logging(self):
def on_error():
assert False, "Simulation finished due to an error"

Expand Down Expand Up @@ -302,10 +312,10 @@ def reinitialize_fixtures(self):
self._transactron_base_output_file_name + "_" + str(self._transactron_hypothesis_iter_counter)
)
self._transactron_sim_processes_to_add: list[Callable[[], Optional[Callable]]] = []
with self.configure_dependency_context():
self.configure_traces()
with self.configure_profiles():
with self.configure_logging():
with self._configure_dependency_context():
self._configure_traces()
with self._configure_profiles():
with self._configure_logging():
yield
self._transactron_hypothesis_iter_counter += 1

Expand Down Expand Up @@ -333,7 +343,7 @@ def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_tra
traces_file=self._transactron_infrastructure_traces_file,
clk_period=clk_period,
)
self.add_all_mocks(sim, sys._getframe(2).f_locals)
self._add_all_mocks(sim, sys._getframe(2).f_locals)

yield sim

Expand Down
132 changes: 131 additions & 1 deletion transactron/testing/sugar.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,58 @@
import functools
from typing import Callable, Any, Optional
from .testbenchio import TestbenchIO, TestGen

from amaranth_types import AnySimulatorContext

from transactron.lib.adapters import Adapter
from transactron.utils.transactron_helpers import mock_def_helper
from .testbenchio import AsyncTestbenchIO, TestbenchIO, TestGen
from transactron.utils._typing import RecordIntDict


class MethodMock:
def __init__(
self,
adapter: Adapter,
function: Callable[..., Optional[RecordIntDict]],
*,
validate_arguments: Optional[Callable[..., bool]] = None,
enable: Callable[[], bool] = lambda: True,
):
self.adapter = adapter
self.function = function
self.validate_arguments = validate_arguments
self.enable = enable
self._effects: list[Callable[[], None]] = []

def effect(self, effect: Callable[[], None]):
self._effects.append(effect)

async def output_process(
self,
sim: AnySimulatorContext,
) -> None:
async for *_, done, arg in sim.changed(self.adapter.done, self.adapter.data_out):
if not done:
continue
self._effects = []
sim.set(self.adapter.data_in, mock_def_helper(self, self.function, arg))

async def validate_arguments_process(self, sim: AnySimulatorContext) -> None:
assert self.validate_arguments is not None
async for args in sim.changed(*(a for a, _ in self.adapter.validators)):
assert len(args) == len(self.adapter.validators) # TODO: remove later
for arg, r in zip(args, (r for _, r in self.adapter.validators)):
sim.set(r, mock_def_helper(self, self.validate_arguments, arg))

async def effect_process(self, sim: AnySimulatorContext) -> None:
async for *_, done in sim.tick().sample(self.adapter.done):
with sim.critical():
if done:
for eff in self._effects:
eff()
sim.set(self.adapter.en, self.enable())


def def_method_mock(
tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], sched_prio: int = 0, **kwargs
) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[], TestGen[None]]]:
Expand Down Expand Up @@ -80,3 +129,84 @@ def mock(func_self=None, /) -> TestGen[None]:
return mock

return decorator


def async_def_method_mock(
tb_getter: Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO], **kwargs
) -> Callable[[Callable[..., Optional[RecordIntDict]]], Callable[[AnySimulatorContext], MethodMock]]:
"""
TODO: better description!
Decorator function to create method mock handlers. It should be applied on
a function which describes functionality which we want to invoke on method call.
Such function will be wrapped by `method_handle_loop` and called on each
method invocation.
Function `f` should take only one argument `arg` - data used in function
invocation - and should return data to be sent as response to the method call.
Function `f` can also be a method and take two arguments `self` and `arg`,
the data to be passed on to invoke a method. It should return data to be sent
as response to the method call.
Instead of the `arg` argument, the data can be split into keyword arguments.
Make sure to defer accessing state, since decorators are evaluated eagerly
during function declaration.
Parameters
----------
tb_getter : Callable[[], TestbenchIO] | Callable[[Any], TestbenchIO]
Function to get the TestbenchIO providing appropriate `method_handle_loop`.
**kwargs
Arguments passed to `method_handle_loop`.
Example
-------
```
m = TestCircuit()
def target_process(k: int):
@def_method_mock(lambda: m.target[k])
def process(arg):
return {"data": arg["data"] + k}
return process
```
or equivalently
```
m = TestCircuit()
def target_process(k: int):
@def_method_mock(lambda: m.target[k], settle=1, enable=False)
def process(data):
return {"data": data + k}
return process
```
or for class methods
```
@def_method_mock(lambda self: self.target[k], settle=1, enable=False)
def process(self, data):
return {"data": data + k}
```
"""

def decorator(func: Callable[..., Optional[RecordIntDict]]) -> Callable[[AnySimulatorContext], MethodMock]:
@functools.wraps(func)
def mock(func_self=None, /) -> MethodMock:
f = func
getter: Any = tb_getter
kw = kwargs
if func_self is not None:
getter = getter.__get__(func_self)
f = f.__get__(func_self)
kw = {}
for k, v in kwargs.items():
bind = getattr(v, "__get__", None)
kw[k] = bind(func_self) if bind else v
tb = getter()
assert isinstance(tb, AsyncTestbenchIO)
assert isinstance(tb.adapter, Adapter)
return MethodMock(tb.adapter, f, **kw)

mock._transactron_testing_process = 1 # type: ignore
return mock

return decorator
35 changes: 14 additions & 21 deletions transactron/testing/testbenchio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from amaranth import *
from amaranth.sim import Settle, Passive, Tick
from amaranth.sim._async import ProcessContext
from typing import Optional, Callable

from amaranth_types import AnySimulatorContext
from transactron.lib import AdapterBase
from transactron.lib.adapters import Adapter
from transactron.utils import ValueLike, SignalBundle, mock_def_helper, assign
Expand All @@ -20,13 +21,13 @@ def elaborate(self, platform):

# Low-level operations

def set_enable(self, sim: ProcessContext, en):
def set_enable(self, sim: AnySimulatorContext, en):
sim.set(self.adapter.en, 1 if en else 0)

def enable(self, sim: ProcessContext):
def enable(self, sim: AnySimulatorContext):
self.set_enable(sim, True)

def disable(self, sim: ProcessContext):
def disable(self, sim: AnySimulatorContext):
self.set_enable(sim, False)

@property
Expand All @@ -37,40 +38,40 @@ def done(self):
def outputs(self):
return self.adapter.data_out

def set_inputs(self, sim: ProcessContext, data):
def set_inputs(self, sim: AnySimulatorContext, data):
sim.set(self.adapter.data_in, data)

def sample_outputs(self, sim: ProcessContext):
def sample_outputs(self, sim: AnySimulatorContext):
return sim.tick().sample(self.adapter.data_out)

def sample_outputs_until_done(self, sim: ProcessContext):
def sample_outputs_until_done(self, sim: AnySimulatorContext):
return self.sample_outputs(sim).until(self.adapter.done)

def sample_outputs_done(self, sim: ProcessContext):
def sample_outputs_done(self, sim: AnySimulatorContext):
return sim.tick().sample(self.adapter.data_out, self.adapter.done)

# Operations for AdapterTrans

def call_init(self, sim: ProcessContext, data={}, /, **kwdata):
def call_init(self, sim: AnySimulatorContext, data={}, /, **kwdata):
if data and kwdata:
raise TypeError("call_init() takes either a single dict or keyword arguments")
if not data:
data = kwdata
self.enable(sim)
self.set_inputs(sim, data)

async def call_result(self, sim: ProcessContext):
async def call_result(self, sim: AnySimulatorContext):
*_, data, done = await self.sample_outputs_done(sim)
if done:
return data
return None

async def call_do(self, sim: ProcessContext):
async def call_do(self, sim: AnySimulatorContext):
*_, outputs = await self.sample_outputs_until_done(sim)
self.disable(sim)
return outputs

async def call_try(self, sim: ProcessContext, data={}, /, **kwdata):
async def call_try(self, sim: AnySimulatorContext, data={}, /, **kwdata):
if data and kwdata:
raise TypeError("call_try() takes either a single dict or keyword arguments")
if not data:
Expand All @@ -80,22 +81,14 @@ async def call_try(self, sim: ProcessContext, data={}, /, **kwdata):
self.disable(sim)
return outputs

async def call(self, sim: ProcessContext, data={}, /, **kwdata):
async def call(self, sim: AnySimulatorContext, data={}, /, **kwdata):
if data and kwdata:
raise TypeError("call() takes either a single dict or keyword arguments")
if not data:
data = kwdata
self.call_init(sim, data)
return await self.call_do(sim)

# Operations for Adapter

def method_argument(self, sim: ProcessContext):
return self.call_result(sim)

def method_return(self, sim: ProcessContext, data):
self.set_inputs(sim, data)


class TestbenchIO(Elaboratable):
def __init__(self, adapter: AdapterBase):
Expand Down

0 comments on commit 76727d0

Please sign in to comment.