Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a wrapper class for cache hash to prevent cache serialization #620

Merged
merged 8 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion changelog.d/611.change.rst

This file was deleted.

8 changes: 8 additions & 0 deletions changelog.d/620.change.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Fixed serialization behavior of non-slots classes with ``cache_hash=True``.
The hash cache will be cleared on operations which make "deep copies" of instances of classes with hash caching,
though the cache will not be cleared with shallow copies like those made by ``copy.copy()``.

Previously, ``copy.deepcopy()`` or serialization and deserialization with ``pickle`` would result in an un-initialized object.

This change also allows the creation of ``cache_hash=True`` classes with a custom ``__setstate__``,
which was previously forbidden (`#494 <https://github.com/python-attrs/attrs/issues/494>`_).
88 changes: 52 additions & 36 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,31 @@ def __repr__(self):
"""


class _CacheHashWrapper(int):
"""
An integer subclass that pickles / copies as None

This is used for non-slots classes with ``cache_hash=True``, to avoid
serializing a potentially (even likely) invalid hash value. Since ``None``
is the default value for uncalculated hashes, whenever this is copied,
the copy's value for the hash should automatically reset.

See GH #613 for more details.
"""

if PY2:
# For some reason `type(None)` isn't callable in Python 2, but we don't
# actually need a constructor for None objects, we just need any
# available function that returns None.
pganssle marked this conversation as resolved.
Show resolved Hide resolved
def __reduce__(self, _none_constructor=getattr, _args=(0, "", None)):
return _none_constructor, _args

else:

def __reduce__(self, _none_constructor=type(None), _args=()):
return _none_constructor, _args


def attrib(
default=NOTHING,
validator=None,
Expand Down Expand Up @@ -523,34 +548,6 @@ def _patch_original_class(self):
for name, value in self._cls_dict.items():
setattr(cls, name, value)

# Attach __setstate__. This is necessary to clear the hash code
# cache on deserialization. See issue
# https://github.com/python-attrs/attrs/issues/482 .
# Note that this code only handles setstate for dict classes.
# For slotted classes, see similar code in _create_slots_class .
if self._cache_hash:
existing_set_state_method = getattr(cls, "__setstate__", None)
if existing_set_state_method:
raise NotImplementedError(
"Currently you cannot use hash caching if "
"you specify your own __setstate__ method."
"See https://github.com/python-attrs/attrs/issues/494 ."
)

# Clears the cached hash state on serialization; for frozen
# classes we need to bypass the class's setattr method.
if self._frozen:

def cache_hash_set_state(chss_self, _):
object.__setattr__(chss_self, _hash_cache_field, None)

else:

def cache_hash_set_state(chss_self, _):
setattr(chss_self, _hash_cache_field, None)

cls.__setstate__ = cache_hash_set_state

return cls

def _create_slots_class(self):
Expand Down Expand Up @@ -612,11 +609,10 @@ def slots_setstate(self, state):
__bound_setattr = _obj_setattr.__get__(self, Attribute)
for name, value in zip(state_attr_names, state):
__bound_setattr(name, value)
# Clearing the hash code cache on deserialization is needed
# because hash codes can change from run to run. See issue
# https://github.com/python-attrs/attrs/issues/482 .
# Note that this code only handles setstate for slotted classes.
# For dict classes, see similar code in _patch_original_class .

# The hash code cache is not included when the object is
# serialized, but it still needs to be initialized to None to
# indicate that the first call to __hash__ should be a cache miss.
if hash_caching_enabled:
__bound_setattr(_hash_cache_field, None)

Expand Down Expand Up @@ -1103,22 +1099,42 @@ def _make_hash(cls, attrs, frozen, cache_hash):
unique_filename = _generate_unique_filename(cls, "hash")
type_hash = hash(unique_filename)

method_lines = ["def __hash__(self):"]
hash_def = "def __hash__(self"
hash_func = "hash(("
closing_braces = "))"
if not cache_hash:
hash_def += "):"
else:
if not PY2:
hash_def += ", *"

hash_def += (
", _cache_wrapper="
+ "__import__('attr._make')._make._CacheHashWrapper):"
)
hash_func = "_cache_wrapper(" + hash_func
closing_braces += ")"

method_lines = [hash_def]

def append_hash_computation_lines(prefix, indent):
"""
Generate the code for actually computing the hash code.
Below this will either be returned directly or used to compute
a value which is then cached, depending on the value of cache_hash
"""

method_lines.extend(
[indent + prefix + "hash((", indent + " %d," % (type_hash,)]
[
indent + prefix + hash_func,
indent + " %d," % (type_hash,),
]
)

for a in attrs:
method_lines.append(indent + " self.%s," % a.name)

method_lines.append(indent + " ))")
method_lines.append(indent + " " + closing_braces)

if cache_hash:
method_lines.append(tab + "if self.%s is None:" % _hash_cache_field)
Expand Down
159 changes: 94 additions & 65 deletions tests/test_dunders.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,33 @@ def test_str_no_repr(self):
) == e.value.args[0]


# these are for use in TestAddHash.test_cache_hash_serialization
# they need to be out here so they can be un-pickled
@attr.attrs(hash=True, cache_hash=False)
class HashCacheSerializationTestUncached(object):
foo_value = attr.ib()


@attr.attrs(hash=True, cache_hash=True)
class HashCacheSerializationTestCached(object):
foo_value = attr.ib()


@attr.attrs(slots=True, hash=True, cache_hash=True)
class HashCacheSerializationTestCachedSlots(object):
foo_value = attr.ib()


class IncrementingHasher(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice solution to the testing need below.

def __init__(self):
self.hash_value = 100

def __hash__(self):
rv = self.hash_value
self.hash_value += 1
return rv


class TestAddHash(object):
"""
Tests for `_add_hash`.
Expand Down Expand Up @@ -492,85 +519,87 @@ def __hash__(self):
assert 2 == uncached_instance.hash_counter.times_hash_called
assert 1 == cached_instance.hash_counter.times_hash_called

def test_cache_hash_serialization(self):
@pytest.mark.parametrize("cache_hash", [True, False])
@pytest.mark.parametrize("frozen", [True, False])
@pytest.mark.parametrize("slots", [True, False])
def test_copy_hash_cleared(self, cache_hash, frozen, slots):
"""
Tests that the hash cache is cleared on deserialization to fix
https://github.com/python-attrs/attrs/issues/482 .
Test that the default hash is recalculated after a copy operation.
"""

# First, check that our fix didn't break serialization without
# hash caching.
# We don't care about the result of this; we just want to make sure we
# can do it without exceptions.
hash(pickle.loads(pickle.dumps(HashCacheSerializationTestUncached)))

def assert_hash_code_not_cached_across_serialization(original):
# Now check our fix for #482 for when hash caching is enabled.
original_hash = hash(original)
round_tripped = pickle.loads(pickle.dumps(original))
# What we want to guard against is having a stale hash code
# when a field's hash code differs in a new interpreter after
# deserialization. This is tricky to test because we are,
# of course, still running in the same interpreter. So
# after deserialization we reach in and change the value of
# a field to simulate the field changing its hash code. We then
# check that the object's hash code changes, indicating that we
# don't have a stale hash code.
# This could fail in two ways: (1) pickle.loads could get the hash
# code of the deserialized value (triggering it to cache) before
# we alter the field value. This doesn't happen in our tested
# Python versions. (2) "foo" and "something different" could
# have a hash collision on this interpreter run. But this is
# extremely improbable and would just result in one buggy test run.
round_tripped.foo_string = "something different"
assert original_hash != hash(round_tripped)

# Slotted and dict classes implement __setstate__ differently,
# so we need to test both cases.
assert_hash_code_not_cached_across_serialization(
HashCacheSerializationTestCached()
)
assert_hash_code_not_cached_across_serialization(
HashCacheSerializationTestCachedSlots()
)
kwargs = dict(frozen=frozen, slots=slots, cache_hash=cache_hash,)

# Give it an explicit hash if we don't have an implicit one
if not frozen:
kwargs["hash"] = True

@attr.s(**kwargs)
class C(object):
x = attr.ib()

a = C(IncrementingHasher())
# Ensure that any hash cache would be calculated before copy
orig_hash = hash(a)
b = copy.deepcopy(a)

def test_caching_and_custom_setstate(self):
if kwargs["cache_hash"]:
# For cache_hash classes, this call is cached
assert orig_hash == hash(a)

assert orig_hash != hash(b)

@pytest.mark.parametrize(
"klass,cached",
[
(HashCacheSerializationTestUncached, False),
(HashCacheSerializationTestCached, True),
(HashCacheSerializationTestCachedSlots, True),
],
)
def test_cache_hash_serialization_hash_cleared(self, klass, cached):
"""
The combination of a custom __setstate__ and cache_hash=True is caught
with a helpful message.
Tests that the hash cache is cleared on deserialization to fix
https://github.com/python-attrs/attrs/issues/482 .

This is needed because we handle clearing the cache after
deserialization with a custom __setstate__. It is possible to make both
work, but it requires some thought about how to go about it, so it has
not yet been implemented.
This test is intended to guard against a stale hash code surviving
across serialization (which may cause problems when the hash value
is different in different interpreters).
"""
with pytest.raises(
NotImplementedError,
match="Currently you cannot use hash caching if you "
"specify your own __setstate__ method.",
):

@attr.attrs(hash=True, cache_hash=True)
class NoCacheHashAndCustomSetState(object):
def __setstate__(self, state):
pass
obj = klass(IncrementingHasher())
original_hash = hash(obj)
obj_rt = self._roundtrip_pickle(obj)

if cached:
assert original_hash == hash(obj)

# these are for use in TestAddHash.test_cache_hash_serialization
# they need to be out here so they can be un-pickled
@attr.attrs(hash=True, cache_hash=False)
class HashCacheSerializationTestUncached(object):
foo_string = attr.ib(default="foo")
assert original_hash != hash(obj_rt)

@pytest.mark.parametrize("frozen", [True, False])
def test_copy_two_arg_reduce(self, frozen):
"""
If __getstate__ returns None, the tuple returned by object.__reduce__
won't contain the state dictionary; this test ensures that the custom
__reduce__ generated when cache_hash=True works in that case.
"""

@attr.attrs(hash=True, cache_hash=True)
class HashCacheSerializationTestCached(object):
foo_string = attr.ib(default="foo")
@attr.s(frozen=frozen, cache_hash=True, hash=True)
class C(object):
x = attr.ib()

def __getstate__(self):
return None

@attr.attrs(slots=True, hash=True, cache_hash=True)
class HashCacheSerializationTestCachedSlots(object):
foo_string = attr.ib(default="foo")
# By the nature of this test it doesn't really create an object that's
# in a valid state - it basically does the equivalent of
# `object.__new__(C)`, so it doesn't make much sense to assert anything
# about the result of the copy. This test will just check that it
# doesn't raise an *error*.
copy.deepcopy(C(1))

def _roundtrip_pickle(self, obj):
pickle_str = pickle.dumps(obj)
return pickle.loads(pickle_str)


class TestAddInit(object):
Expand Down
60 changes: 55 additions & 5 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,16 +1466,66 @@ class C2(C):

assert [C2] == C.__subclasses__()

def test_cache_hash_with_frozen_serializes(self):
def _get_copy_kwargs(include_slots=True):
"""
Frozen classes with cache_hash should be serializable.
Generate a list of compatible attr.s arguments for the `copy` tests.
"""
options = ["frozen", "hash", "cache_hash"]

@attr.s(cache_hash=True, frozen=True)
if include_slots:
options.extend(["slots", "weakref_slot"])

out_kwargs = []
for args in itertools.product([True, False], repeat=len(options)):
kwargs = dict(zip(options, args))

kwargs["hash"] = kwargs["hash"] or None

if kwargs["cache_hash"] and not (
kwargs["frozen"] or kwargs["hash"]
):
continue

out_kwargs.append(kwargs)

return out_kwargs

@pytest.mark.parametrize("kwargs", _get_copy_kwargs())
def test_copy(self, kwargs):
"""
Ensure that an attrs class can be copied successfully.
"""

@attr.s(eq=True, **kwargs)
class C(object):
pass
x = attr.ib()

a = C(1)
b = copy.deepcopy(a)

assert a == b

@pytest.mark.parametrize("kwargs", _get_copy_kwargs(include_slots=False))
def test_copy_custom_setstate(self, kwargs):
"""
Ensure that non-slots classes respect a custom __setstate__.
"""

@attr.s(eq=True, **kwargs)
class C(object):
x = attr.ib()

def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
state["x"] *= 5
self.__dict__.update(state)

expected = C(25)
actual = copy.copy(C(5))

copy.deepcopy(C())
assert actual == expected


class TestMakeOrder:
Expand Down