Skip to content

Commit

Permalink
Merge pull request #4247 from tybug/tcs-cache
Browse files Browse the repository at this point in the history
Fully move to the typed choice sequence cache
  • Loading branch information
tybug authored Jan 21, 2025
2 parents d3baf4e + 02067bb commit 7a880c9
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 239 deletions.
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

Improves our internal caching logic for test cases.
6 changes: 2 additions & 4 deletions hypothesis-python/src/hypothesis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def execute_explicit_examples(state, wrapped_test, arguments, kwargs, original_s

with local_settings(state.settings):
fragments_reported = []
empty_data = ConjectureData.for_buffer(b"")
empty_data = ConjectureData.for_choices([])
try:
execute_example = partial(
state.execute_once,
Expand Down Expand Up @@ -1334,9 +1334,7 @@ def run_engine(self):
info = falsifying_example.extra_information
fragments = []

ran_example = runner.new_conjecture_data_for_buffer(
falsifying_example.buffer
)
ran_example = runner.new_conjecture_data_ir(falsifying_example.choices)
ran_example.slice_comments = falsifying_example.slice_comments
tb = None
origin = None
Expand Down
37 changes: 17 additions & 20 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,19 +1737,12 @@ def _draw(self, ir_type, kwargs, *, observe, forced, fake_forced):
debug_report(f"overrun because hit {self.max_length_ir=}")
self.mark_overrun()

if self.ir_prefix is not None and observe:
if self.index_ir < len(self.ir_prefix):
choice = self._pop_choice(ir_type, kwargs, forced=forced)
else:
try:
choice = (
forced
if forced is not None
else draw_choice(ir_type, kwargs, random=self.__random)
)
except StopTest:
debug_report("overrun because draw_choice overran")
self.mark_overrun()
if (
observe
and self.ir_prefix is not None
and self.index_ir < len(self.ir_prefix)
):
choice = self._pop_choice(ir_type, kwargs, forced=forced)

if forced is None:
forced = choice
Expand All @@ -1764,7 +1757,7 @@ def _draw(self, ir_type, kwargs, *, observe, forced, fake_forced):
getattr(self.observer, f"draw_{ir_type}")(
value, kwargs=kwargs, was_forced=was_forced
)
size = ir_size([value])
size = 0 if self.provider.avoid_realization else ir_size([value])
if self.length_ir + size > self.max_length_ir:
debug_report(
f"overrun because {self.length_ir=} + {size=} > {self.max_length_ir=}"
Expand Down Expand Up @@ -1971,14 +1964,18 @@ def _pop_choice(
# node if the alternative is not "the entire data is an overrun".
assert self.index_ir == len(self.ir_prefix) - 1
if node.type == "simplest":
try:
choice: ChoiceT = choice_from_index(0, ir_type, kwargs)
except ChoiceTooLarge:
self.mark_overrun()
if isinstance(self.provider, HypothesisProvider):
try:
choice: ChoiceT = choice_from_index(0, ir_type, kwargs)
except ChoiceTooLarge:
self.mark_overrun()
else:
# give alternative backends control over these draws
choice = getattr(self.provider, f"draw_{ir_type}")(**kwargs)
else:
raise NotImplementedError

node.size -= ir_size([choice])
node.size -= 0 if self.provider.avoid_realization else ir_size([choice])
if node.size < 0:
self.mark_overrun()
return choice
Expand Down Expand Up @@ -2261,7 +2258,7 @@ def draw_bits(
elif self._bytes_drawn < len(self.__prefix):
index = self._bytes_drawn
buf = self.__prefix[index : index + n_bytes]
if len(buf) < n_bytes:
if len(buf) < n_bytes: # pragma: no cover # removing soon
assert self.__random is not None
buf += uniform(self.__random, n_bytes - len(buf))
else:
Expand Down
186 changes: 32 additions & 154 deletions hypothesis-python/src/hypothesis/internal/conjecture/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,7 @@
from datetime import timedelta
from enum import Enum
from random import Random, getrandbits
from typing import (
Callable,
Final,
List,
Literal,
NoReturn,
Optional,
Union,
cast,
overload,
)
from typing import Callable, Final, List, Literal, NoReturn, Optional, Union, cast

import attr

Expand Down Expand Up @@ -283,7 +273,6 @@ def __init__(
# shrinking where we need to know about the structure of the
# executed test case.
self.__data_cache = LRUReusedCache(CACHE_SIZE)
self.__data_cache_ir = LRUReusedCache(CACHE_SIZE)

self.reused_previously_shrunk_test_case = False

Expand Down Expand Up @@ -359,26 +348,8 @@ def _cache_key(self, choices: Sequence[ChoiceT]) -> tuple[ChoiceKeyT, ...]:

def _cache(self, data: ConjectureData) -> None:
result = data.as_result()
self.__data_cache[data.buffer] = result

# interesting buffer-based data can mislead the shrinker if we cache them.
#
# @given(st.integers())
# def f(n):
# assert n < 100
#
# may generate two counterexamples, n=101 and n=m > 101, in that order,
# where the buffer corresponding to n is large due to eg failed probes.
# We shrink m and eventually try n=101, but it is cached to a large buffer
# and so the best we can do is n=102, a non-ideal shrink.
#
# We can cache ir-based buffers fine, which always correspond to the
# smallest buffer via forced=. The overhead here is small because almost
# all interesting data are ir-based via the shrinker (and that overhead
# will tend towards zero as we move generation to the ir).
if data.ir_prefix is not None or data.status < Status.INTERESTING:
key = self._cache_key(data.choices)
self.__data_cache_ir[key] = result
key = self._cache_key(data.choices)
self.__data_cache[key] = result

def cached_test_function_ir(
self,
Expand All @@ -387,6 +358,14 @@ def cached_test_function_ir(
error_on_discard: bool = False,
extend: int = 0,
) -> Union[ConjectureResult, _Overrun]:
"""
If ``error_on_discard`` is set to True this will raise ``ContainsDiscard``
in preference to running the actual test function. This is to allow us
to skip test cases we expect to be redundant in some cases. Note that
it may be the case that we don't raise ``ContainsDiscard`` even if the
result has discards if we cannot determine from previous runs whether
it will have a discard.
"""
# node templates represent a not-yet-filled hole and therefore cannot
# be cached or retrieved from the cache.
if not any(isinstance(choice, NodeTemplate) for choice in choices):
Expand All @@ -395,7 +374,7 @@ def cached_test_function_ir(
choices = cast(Sequence[ChoiceT], choices)
key = self._cache_key(choices)
try:
cached = self.__data_cache_ir[key]
cached = self.__data_cache[key]
# if we have a cached overrun for this key, but we're allowing extensions
# of the nodes, it could in fact run to a valid data if we try.
if extend == 0 or cached.status is not Status.OVERRUN:
Expand Down Expand Up @@ -429,14 +408,19 @@ def kill_branch(self) -> NoReturn:
else:
trial_data.freeze()
key = self._cache_key(trial_data.choices)
if trial_data.status is Status.OVERRUN:
if trial_data.status > Status.OVERRUN:
try:
return self.__data_cache[key]
except KeyError:
pass
else:
# if we simulated to an overrun, then we our result is certainly
# an overrun; no need to consult the cache. (and we store this result
# for simulation-less lookup later).
self.__data_cache_ir[key] = Overrun
self.__data_cache[key] = Overrun
return Overrun
try:
return self.__data_cache_ir[key]
return self.__data_cache[key]
except KeyError:
pass

Expand Down Expand Up @@ -567,8 +551,8 @@ def test_function(self, data: ConjectureData) -> None:

if data.status == Status.INTERESTING:
if not self.using_hypothesis_backend:
# drive the ir tree through the test function to convert it
# to a buffer
# replay this failure on the hypothesis backend to ensure it still
# finds a failure. otherwise, it is flaky.
initial_origin = data.interesting_origin
initial_traceback = getattr(
data.extra_information, "_expected_traceback", None
Expand Down Expand Up @@ -611,13 +595,13 @@ def test_function(self, data: ConjectureData) -> None:
if sort_key_ir(data.ir_nodes) < sort_key_ir(existing.ir_nodes):
self.shrinks += 1
self.downgrade_buffer(ir_to_bytes(existing.choices))
self.__data_cache.unpin(existing.buffer)
self.__data_cache.unpin(self._cache_key(existing.choices))
changed = True

if changed:
self.save_choices(data.choices)
self.interesting_examples[key] = data.as_result() # type: ignore
self.__data_cache.pin(data.buffer, data.as_result())
self.__data_cache.pin(self._cache_key(data.choices), data.as_result())
self.shrunk_examples.discard(key)

if self.shrinks >= MAX_SHRINKS:
Expand Down Expand Up @@ -969,11 +953,13 @@ def generate_new_examples(self) -> None:
self.debug("Generating new examples")

assert self.should_generate_more()
zero_data = self.cached_test_function(bytes(BUFFER_SIZE))
zero_data = self.cached_test_function_ir(
(NodeTemplate("simplest", size=BUFFER_SIZE),)
)
if zero_data.status > Status.OVERRUN:
assert isinstance(zero_data, ConjectureResult)
self.__data_cache.pin(
zero_data.buffer, zero_data.as_result()
self._cache_key(zero_data.choices), zero_data.as_result()
) # Pin forever

if zero_data.status == Status.OVERRUN or (
Expand Down Expand Up @@ -1048,7 +1034,7 @@ def generate_new_examples(self) -> None:
# not whatever is specified by the backend. We can improve this
# once more things are on the ir.
if not self.using_hypothesis_backend:
data = self.new_conjecture_data(prefix=b"", max_length=BUFFER_SIZE)
data = self.new_conjecture_data_ir([], max_length=BUFFER_SIZE)
with suppress(BackendCannotProceed):
self.test_function(data)
continue
Expand Down Expand Up @@ -1228,7 +1214,7 @@ def generate_mutations_from(
assert isinstance(new_data, ConjectureResult)
if (
new_data.status >= data.status
and data.buffer != new_data.buffer
and choices_key(data.choices) != choices_key(new_data.choices)
and all(
k in new_data.target_observations
and new_data.target_observations[k] >= v
Expand Down Expand Up @@ -1332,32 +1318,6 @@ def new_conjecture_data_ir(
random=self.random,
)

def new_conjecture_data(
self,
prefix: Union[bytes, bytearray],
max_length: int = BUFFER_SIZE,
observer: Optional[DataObserver] = None,
) -> ConjectureData:
provider = (
HypothesisProvider if self._switch_to_hypothesis_provider else self.provider
)
observer = observer or self.tree.new_observer()
if not self.using_hypothesis_backend:
observer = DataObserver()

return ConjectureData(
prefix=prefix,
max_length=max_length,
random=self.random,
observer=observer,
provider=provider,
)

def new_conjecture_data_for_buffer(
self, buffer: Union[bytes, bytearray]
) -> ConjectureData:
return self.new_conjecture_data(buffer, max_length=len(buffer))

def shrink_interesting_examples(self) -> None:
"""If we've found interesting examples, try to replace each of them
with a minimal interesting example with the same interesting_origin.
Expand Down Expand Up @@ -1468,88 +1428,6 @@ def new_shrinker(
in_target_phase=self._current_phase == "target",
)

def cached_test_function(
self,
buffer: Union[bytes, bytearray],
*,
extend: int = 0,
) -> Union[ConjectureResult, _Overrun]: # pragma: no cover # removing function soon
"""Checks the tree to see if we've tested this buffer, and returns the
previous result if we have.
Otherwise we call through to ``test_function``, and return a
fresh result.
If ``error_on_discard`` is set to True this will raise ``ContainsDiscard``
in preference to running the actual test function. This is to allow us
to skip test cases we expect to be redundant in some cases. Note that
it may be the case that we don't raise ``ContainsDiscard`` even if the
result has discards if we cannot determine from previous runs whether
it will have a discard.
"""
buffer = bytes(buffer)[:BUFFER_SIZE]

max_length = min(BUFFER_SIZE, len(buffer) + extend)

@overload
def check_result(result: _Overrun) -> _Overrun: ...
@overload
def check_result(result: ConjectureResult) -> ConjectureResult: ...
def check_result(
result: Union[_Overrun, ConjectureResult],
) -> Union[_Overrun, ConjectureResult]:
assert result is Overrun or (
isinstance(result, ConjectureResult) and result.status != Status.OVERRUN
)
return result

try:
cached = check_result(self.__data_cache[buffer])
if cached.status > Status.OVERRUN or extend == 0:
return cached
except KeyError:
pass

observer = DataObserver()
dummy_data = self.new_conjecture_data(
prefix=buffer, max_length=max_length, observer=observer
)

if self.using_hypothesis_backend:
try:
self.tree.simulate_test_function(dummy_data)
except PreviouslyUnseenBehaviour:
pass
else:
if dummy_data.status > Status.OVERRUN:
dummy_data.freeze()
try:
return self.__data_cache[dummy_data.buffer]
except KeyError:
pass
else:
self.__data_cache[buffer] = Overrun
return Overrun

# We didn't find a match in the tree, so we need to run the test
# function normally. Note that test_function will automatically
# add this to the tree so we don't need to update the cache.

result = None

data = self.new_conjecture_data(
prefix=max((buffer, dummy_data.buffer), key=len), max_length=max_length
)
self.test_function(data)
result = check_result(data.as_result())
if extend == 0 or (
result is not Overrun
and not isinstance(result, _Overrun)
and len(result.buffer) <= len(buffer)
):
self.__data_cache[buffer] = result
return result

def passing_choice_sequences(
self, prefix: Sequence[IRNode] = ()
) -> frozenset[bytes]:
Expand All @@ -1558,8 +1436,8 @@ def passing_choice_sequences(
"""
return frozenset(
result.ir_nodes
for key in self.__data_cache_ir
if (result := self.__data_cache_ir[key]).status is Status.VALID
for key in self.__data_cache
if (result := self.__data_cache[key]).status is Status.VALID
and startswith(result.ir_nodes, prefix)
)

Expand Down
Loading

0 comments on commit 7a880c9

Please sign in to comment.