Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add get_hash, make DAGs deterministic #457

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
from pytato.loopy import LoopyCall
from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method
from pytools.persistent_dict import KeyBuilder

from pymbolic.mapper.persistent_hash import PersistentHashWalkMapper

from immutables import Map

if TYPE_CHECKING:
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
Expand All @@ -51,6 +56,8 @@

.. autofunction:: get_num_call_sites

.. autofunction:: get_hash

.. autoclass:: DirectPredecessorsGetter
"""

Expand Down Expand Up @@ -453,4 +460,80 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int:

# }}}


# {{{ get_hash

class PytatoKeyBuilder(KeyBuilder):
"""A custom :class:`pytools.persistent_dict.KeyBuilder` subclass
for objects within :mod:`pytato`.
"""

update_for_list = KeyBuilder.update_for_tuple
update_for_set = KeyBuilder.update_for_frozenset

def update_for_dict(self, key_hash: Any, key: Any) -> None:
from pytools import unordered_hash
unordered_hash(
key_hash,
(self.rec(self.new_hash(), # type: ignore[misc]
(k, v)).digest() # type: ignore[no-untyped-call]
for k, v in key.items()))

update_for_defaultdict = update_for_dict

def update_for_ndarray(self, key_hash: Any, key: Any) -> None:
self.rec(key_hash, hash(key.data.tobytes())) # type: ignore[no-untyped-call]

def update_for_frozenset(self, key_hash: Any, key: Any) -> None:
for set_key in sorted(key,
key=lambda obj: type(obj).__name__ + str(obj)):
self.rec(key_hash, set_key) # type: ignore[no-untyped-call]

def update_for_BasicSet(self, key_hash: Any, key: Any) -> None:
from islpy import Printer
prn = Printer.to_str(key.get_ctx())
getattr(prn, "print_"+key._base_name)(key)
key_hash.update(prn.get_str().encode("utf8"))

def update_for_Map(self, key_hash: Any, key: Any) -> None:
import islpy as isl
if isinstance(key, Map):
self.update_for_dict(key_hash, key)
elif isinstance(key, isl.Map):
self.update_for_BasicSet(key_hash, key)
else:
raise AssertionError()

def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None:
if key is None:
self.update_for_NoneType(key_hash, key) # type: ignore[no-untyped-call]
else:
PersistentHashWalkMapper(key_hash)(key)

def update_for_Reduce(self, key_hash: Any, key: Any) -> None:
self.rec(key_hash, hash(key)) # type: ignore[no-untyped-call]

update_for_Product = update_for_pymbolic_expression # noqa: N815
update_for_Sum = update_for_pymbolic_expression # noqa: N815
update_for_If = update_for_pymbolic_expression # noqa: N815
update_for_LogicalOr = update_for_pymbolic_expression # noqa: N815
update_for_Call = update_for_pymbolic_expression # noqa: N815
update_for_Comparison = update_for_pymbolic_expression # noqa: N815
update_for_Quotient = update_for_pymbolic_expression # noqa: N815
update_for_Power = update_for_pymbolic_expression # noqa: N815
update_for_PMap = update_for_dict # noqa: N815


def get_hash(outputs: Union[Array, DictOfNamedArrays]) -> str:
"""Returns a hash of the DAG *outputs*."""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

hm = PytatoKeyBuilder()

return hm(outputs) # type: ignore[no-any-return]

# }}}

# vim: fdm=marker
16 changes: 13 additions & 3 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def copy(self: ArrayT, **kwargs: Any) -> ArrayT:
def _with_new_tags(self: ArrayT, tags: FrozenSet[Tag]) -> ArrayT:
return attrs.evolve(self, tags=tags)

def update_persistent_hash(self, key_hash: Any, key_builder: Any) -> None:
key_builder.rec(key_hash, hash(self))

if TYPE_CHECKING:
@property
def shape(self) -> ShapeType:
Expand Down Expand Up @@ -1675,10 +1678,17 @@ def _is_eq_valid(self) -> bool:
# and valid by returning True
return True

@memoize_method
def __hash__(self) -> int:
# It would be better to hash the data, but we have no way of getting to
# it.
return id(self)
import hashlib

if hasattr(self.data, "get"):
d = self.data.get()
else:
d = self.data

return hash((hashlib.sha256(d).hexdigest(), self._shape,
self.axes, Taggable.__hash__(self)))

@property
def shape(self) -> ShapeType:
Expand Down
28 changes: 28 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,34 @@ def test_dot_visualizers():
# }}}


def test_get_hash():
from pytato.analysis import get_hash

axis_len = 5

seen_hashes = set()

for i in range(100):
rdagc1 = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=False)
rdagc2 = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=False)
rdagc3 = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=False)

dag1 = make_random_dag(rdagc1)
dag2 = make_random_dag(rdagc2)
dag3 = make_random_dag(rdagc3)

h1 = get_hash(dag1)
h2 = get_hash(dag2)
h3 = get_hash(dag3)

assert h1 == h2 == h3
assert h1 not in seen_hashes
seen_hashes.add(h1)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down