Skip to content

Commit

Permalink
Migrate to provider-supplied contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
pschanely committed Mar 6, 2024
1 parent b303a81 commit 494d042
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 51 deletions.
50 changes: 32 additions & 18 deletions hypothesis_crosshair_provider/crosshair_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,29 @@
from hypothesis.internal.conjecture.data import PrimitiveProvider
from hypothesis.internal.intervalsets import IntervalSet

_PREVIOUS_REALIZED_DRAWS = None


class CrossHairPrimitiveProvider(PrimitiveProvider):
"""An implementation of PrimitiveProvider based on CrossHair."""

def __init__(self, *_a, **_kw) -> None:
self.name_id = 0
self.iteration_number = 0
self.current_exit_stack: Optional[ExitStack] = None
self.search_root = RootNode()
if len(os.environ.get("DEBUG_CROSSHAIR", "")) > 1:
set_debug(os.environ["DEBUG_CROSSHAIR"].lower() not in ("0", "false"))
elif "-vv" in sys.argv:
set_debug(True)
self._previous_realized_draws = None

@contextmanager
def per_test_case_context_manager(self):
self.iteration_number += 1
if self.search_root.child.is_exhausted():
debug("Resetting search root")
# might be nice to signal that we're done somehow.
# But for now, just start over!
self.search_root = RootNode()
global _PREVIOUS_REALIZED_DRAWS
_PREVIOUS_REALIZED_DRAWS = None
self._previous_realized_draws = None
iter_start = monotonic()
options = DEFAULT_OPTIONS.overlay(AnalysisOptionSet(analysis_kind=[]))
per_path_timeout = options.get_per_path_timeout() # TODO: how to set this?
Expand All @@ -52,6 +51,10 @@ def per_test_case_context_manager(self):
search_root=self.search_root,
)
space._hypothesis_draws = [] # keep a log of drawn values
space._hypothesis_next_name_id = (
0 # something to uniqu-ify names for drawn values
)

try:
with (
condition_parser([]),
Expand All @@ -60,34 +63,34 @@ def per_test_case_context_manager(self):
COMPOSITE_TRACER,
):
try:
debug("start iter")
debug("starting iteration", self.iteration_number)
try:
yield
finally:
any_choices_made = bool(space.choices_made)
if any_choices_made:
space.detach_path()
_PREVIOUS_REALIZED_DRAWS = {
self._previous_realized_draws = {
id(symbolic): deep_realize(symbolic)
for symbolic in space._hypothesis_draws
}
else:
# TODO: I can't detach_path here because it will conflict with the
# top node of a prior "real" execution.
# Should I just generate a dummy concrete value for each of the draws?
_PREVIOUS_REALIZED_DRAWS = {}
debug("end iter (normal)")
self._previous_realized_draws = {}
debug("ended iteration (normal completion)")
except Exception as exc:
try:
exc.args = deep_realize(exc.args)
debug(
f"end iter (exception: {type(exc).__name__}: {exc})",
f"ended iteration (exception: {type(exc).__name__}: {exc})",
test_stack(exc.__traceback__),
)
except Exception:
exc.args = ()
debug(
f"end iter ({type(exc)} exception)",
f"ended iteration ({type(exc)} exception)",
test_stack(exc.__traceback__),
)
raise exc
Expand All @@ -101,14 +104,22 @@ def per_test_case_context_manager(self):
)
else:
debug("no decisions made; ignoring this iteration")

def _next_name(self, prefix: str) -> str:
self.name_id += 1
return f"{prefix}_{self.name_id:02d}"
space = context_statespace()
space._hypothesis_next_name_id += 1
return f"{prefix}_{space._hypothesis_next_name_id:02d}"

def _remember_draw(self, symbolic):
context_statespace()._hypothesis_draws.append(symbolic)

def draw_boolean(self, p: float = 0.5, *, forced: Optional[bool] = None) -> bool:
def draw_boolean(
self,
p: float = 0.5,
*,
forced: Optional[bool] = None,
fake_forced: bool = False,
) -> bool:
if forced is not None:
return forced

Expand All @@ -125,6 +136,7 @@ def draw_integer(
weights: Optional[Sequence[float]] = None,
shrink_towards: int = 0,
forced: Optional[int] = None,
fake_forced: bool = False,
) -> int:
if forced is not None:
return forced
Expand All @@ -151,6 +163,7 @@ def draw_float(
# width: Literal[16, 32, 64] = 64,
# exclude_min and exclude_max handled higher up
forced: Optional[float] = None,
fake_forced: bool = False,
) -> float:
# TODO: all of this is a bit of a ruse - at present, CrossHair approximates
# floats as real numbers. (though it will attempt +/-inf & nan)
Expand Down Expand Up @@ -187,6 +200,7 @@ def draw_string(
min_size: int = 0,
max_size: Optional[int] = None,
forced: Optional[str] = None,
fake_forced: bool = False,
) -> str:
with NoTracing():
if forced is not None:
Expand All @@ -202,6 +216,7 @@ def draw_bytes(
self,
size: int,
forced: Optional[bytes] = None,
fake_forced: bool = False,
) -> bytes:
if forced is not None:
return forced
Expand All @@ -215,11 +230,10 @@ def export_value(self, value):
if is_tracing():
return deep_realize(value)
else:
global _PREVIOUS_REALIZED_DRAWS
if _PREVIOUS_REALIZED_DRAWS is None:
if self._previous_realized_draws is None:
debug("WARNING: export_value() requested at wrong time", test_stack())
return None
return _PREVIOUS_REALIZED_DRAWS.get(id(value))
return value
return self._previous_realized_draws.get(id(value), value)

def post_test_case_hook(self, val):
return self.export_value(val)
67 changes: 34 additions & 33 deletions hypothesis_crosshair_provider/crosshair_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from hypothesis.internal.intervalsets import IntervalSet

from hypothesis_crosshair_provider.crosshair_provider import (
CrossHairPrimitiveProvider,
)
from hypothesis_crosshair_provider.crosshair_provider import \
CrossHairPrimitiveProvider


class TargetException(Exception):
pass


def _example_user_code(s_bool, s_int, s_float, s_str, s_bytes):
Expand All @@ -13,37 +16,35 @@ def _example_user_code(s_bool, s_int, s_float, s_str, s_bytes):
if s_float < 2.0:
if s_str == "foo":
if s_bytes == b"b":
raise Exception("uh oh")
raise TargetException


def test_end_to_end():
provider = CrossHairPrimitiveProvider()
with provider.per_test_case_context_manager() as per_run_mgr:
found_ct = 0
for _ in range(30):
try:
with per_run_mgr():
s_bool = provider.draw_boolean()
s_int = provider.draw_integer()
s_float = provider.draw_float()
s_str = provider.draw_string(
IntervalSet.from_string("abcdefghijklmnopqrstuvwxyz")
)
s_bytes = provider.draw_bytes(1)
assert type(s_bool) == bool
assert type(s_int) == int
assert type(s_float) == float
assert type(s_str) == str
assert type(s_bytes) == bytes
_example_user_code(s_bool, s_int, s_float, s_str, s_bytes)
assert type(provider.export_value(s_bool)) == bool
assert type(provider.export_value(s_int)) == int
assert type(provider.export_value(s_float)) == float
assert type(provider.export_value(s_str)) == str
# NOTE: draw_bytes can raise IgnoreAttempt, which will leave the bytes
# symbolic without a concrete value:
assert type(provider.export_value(s_bytes)) in (bytes, types.NoneType)
except Exception as exc:
assert str(exc) == "uh oh"
found_ct += 1
assert found_ct > 0, "CrossHair could not find the exception"
found_ct = 0
for _ in range(30):
try:
with provider.per_test_case_context_manager():
s_bool = provider.draw_boolean()
s_int = provider.draw_integer()
s_float = provider.draw_float()
s_str = provider.draw_string(
IntervalSet.from_string("abcdefghijklmnopqrstuvwxyz")
)
s_bytes = provider.draw_bytes(1)
assert type(s_bool) == bool
assert type(s_int) == int
assert type(s_float) == float
assert type(s_str) == str
assert type(s_bytes) == bytes
_example_user_code(s_bool, s_int, s_float, s_str, s_bytes)
assert type(provider.export_value(s_bool)) == bool
assert type(provider.export_value(s_int)) == int
assert type(provider.export_value(s_float)) == float
assert type(provider.export_value(s_str)) == str
# NOTE: draw_bytes can raise IgnoreAttempt, which will leave the bytes
# symbolic without a concrete value:
assert type(provider.export_value(s_bytes)) in (bytes, types.NoneType)
except TargetException:
found_ct += 1
assert found_ct > 0, "CrossHair could not find the exception"

0 comments on commit 494d042

Please sign in to comment.