From df96738458597f48f4d7b4daf802dceba5049149 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 18 Sep 2023 15:57:01 -0500 Subject: [PATCH 1/7] add get_hash --- pytato/analysis/__init__.py | 34 ++++++++++++++++++++++++++++++++++ test/test_pytato.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d0fd6ef1e..bf498d74e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -51,6 +51,8 @@ .. autofunction:: get_num_call_sites +.. autofunction:: get_hash + .. autoclass:: DirectPredecessorsGetter """ @@ -453,4 +455,36 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} + +# {{{ get_hash + +class HashMapper(CachedWalkMapper): + """ + A mapper that generates a hash for a given DAG. + """ + def __init__(self) -> None: + super().__init__() + import hashlib + self.hash = hashlib.sha256() + + def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: + return expr + + def post_visit(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None: + self.hash.update(str(hash(expr)).encode("ascii")) + + +def get_hash(outputs: Union[Array, DictOfNamedArrays]) -> str: + """Returns a hash of the DAG *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + hm = HashMapper() + hm(outputs) + + return hm.hash.hexdigest() + +# }}} + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 98393b95b..dfd1bc307 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1116,6 +1116,34 @@ def test_dot_visualizers(): # }}} +def test_get_hash(): + from pytato.analysis import get_hash + + axis_len = 5 + + seen_hashes = set() + + for i in range(100): + rdagc1 = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=False) + rdagc2 = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=False) + rdagc3 = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=False) + + dag1 = make_random_dag(rdagc1) + dag2 = make_random_dag(rdagc2) + dag3 = make_random_dag(rdagc3) + + h1 = get_hash(dag1) + h2 = get_hash(dag2) + h3 = get_hash(dag3) + + assert h1 == h2 == h3 + assert h1 not in seen_hashes + seen_hashes.add(h1) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 8b8e66f123dcfa131b79216c1148eb3306bf2474 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 18 Sep 2023 16:21:56 -0500 Subject: [PATCH 2/7] make DataWrapper hash stable --- pytato/array.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 1271631fe..990b5f8ac 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1675,10 +1675,17 @@ def _is_eq_valid(self) -> bool: # and valid by returning True return True + @memoize_method def __hash__(self) -> int: - # It would be better to hash the data, but we have no way of getting to - # it. - return id(self) + import hashlib + + if hasattr(self.data, "get"): + d = self.data.get() + else: + d = self.data + + return hash((hashlib.sha256(d).hexdigest(), self._shape, + self.axes, Taggable.__hash__(self))) @property def shape(self) -> ShapeType: From ef56356c8a8a0e4bbb675d62a55c62fdcea79166 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 18 Sep 2023 16:55:43 -0500 Subject: [PATCH 3/7] add simple cross-invocation hash stability test --- test/test_pytato.py | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index dfd1bc307..a84e971a3 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1144,8 +1144,47 @@ def test_get_hash(): seen_hashes.add(h1) +def run_test_with_hashseed(f, *args): + from pickle import dumps + from base64 import b64encode + + from subprocess import check_call + import os + + os.environ["RUN_WITH_SET_PYTHONHASHSEED"] = "1" + os.environ["PYTHONHASHSEED"] = "1" + os.environ["INVOCATION_INFO"] = b64encode(dumps((f, args))).decode() + + check_call([sys.executable, __file__]) + + +def run_test_with_hashseed_inner(): + from pickle import loads + from base64 import b64decode + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) + + f(*args) + + +def test_get_hash_stable_with_hashseed(): + run_test_with_hashseed(_do_test_get_hash_stable_with_hashseed) + + +def _do_test_get_hash_stable_with_hashseed(): + rdagc = RandomDAGContext(np.random.default_rng(seed=42), + axis_len=5, use_numpy=False) + + dag = make_random_dag(rdagc) + assert hash(dag) == 6313690382525417492 + +# }}} + + if __name__ == "__main__": - if len(sys.argv) > 1: + import os + if "RUN_WITH_SET_PYTHONHASHSEED" in os.environ: + run_test_with_hashseed_inner() + elif len(sys.argv) > 1: exec(sys.argv[1]) else: from pytest import main From 1bfc5a260c310c8cb967c8899aea632de1703620 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 21 Sep 2023 15:12:24 -0500 Subject: [PATCH 4/7] create PytatoKeyBuilder --- pytato/analysis/__init__.py | 84 ++++++++++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index bf498d74e..5e9fed715 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -36,6 +36,11 @@ from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method +from pytools.persistent_dict import KeyBuilder + +from pymbolic.mapper.persistent_hash import PersistentHashWalkMapper + +from immutables import Map if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder @@ -458,20 +463,72 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # {{{ get_hash -class HashMapper(CachedWalkMapper): - """ - A mapper that generates a hash for a given DAG. +class PytatoKeyBuilder(KeyBuilder): + """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass + for objects within :mod:`pytato`. """ - def __init__(self) -> None: - super().__init__() - import hashlib - self.hash = hashlib.sha256() - def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: - return expr + update_for_list = KeyBuilder.update_for_tuple + update_for_set = KeyBuilder.update_for_frozenset + + def update_for_dict(self, key_hash, key): + from pytools import unordered_hash + unordered_hash( + key_hash, + (self.rec(self.new_hash(), (k, v)).digest() + for k, v in key.items())) + + update_for_defaultdict = update_for_dict + + def update_for_ndarray(self, key_hash, key): + self.rec(key_hash, hash(key.data.tobytes())) + + def update_for_frozenset(self, key_hash, key): + for set_key in sorted(key, + key=lambda obj: type(obj).__name__ + str(obj)): + self.rec(key_hash, set_key) + + def update_for_BasicSet(self, key_hash, key): # noqa + from islpy import Printer + prn = Printer.to_str(key.get_ctx()) + getattr(prn, "print_"+key._base_name)(key) + key_hash.update(prn.get_str().encode("utf8")) + + def update_for_Map(self, key_hash, key): # noqa + import islpy as isl + if isinstance(key, Map): + self.update_for_dict(key_hash, key) + elif isinstance(key, isl.Map): + self.update_for_BasicSet(key_hash, key) + else: + raise AssertionError() + + def update_for_pymbolic_expression(self, key_hash, key): + if key is None: + self.update_for_NoneType(key_hash, key) + else: + PersistentHashWalkMapper(key_hash)(key) + + def update_for_Product(self, key_hash, key): + PersistentHashWalkMapper(key_hash)(key) + + def update_for_Array(self, key_hash, key): + self.update_for_ndarray(key_hash, key.get()) + + def update_for_int64(self, key_hash, key): + self.rec(key_hash, int(key)) + + def update_for_Reduce(self, key_hash, key): + self.rec(key_hash, hash(key)) - def post_visit(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None: - self.hash.update(str(hash(expr)).encode("ascii")) + update_for_Sum = update_for_Product + update_for_If = update_for_Product + update_for_LogicalOr = update_for_Product + update_for_Call = update_for_Product + update_for_Comparison = update_for_Product + update_for_Quotient = update_for_Product + update_for_Power = update_for_Product + update_for_PMap = update_for_dict # noqa: N815 def get_hash(outputs: Union[Array, DictOfNamedArrays]) -> str: @@ -480,10 +537,9 @@ def get_hash(outputs: Union[Array, DictOfNamedArrays]) -> str: from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - hm = HashMapper() - hm(outputs) + hm = PytatoKeyBuilder() - return hm.hash.hexdigest() + return hm(outputs) # }}} From 97966c5f98f4170dc9603c0990e9354865f66708 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 21 Sep 2023 16:35:44 -0500 Subject: [PATCH 5/7] Revert "add simple cross-invocation hash stability test" This reverts commit ef56356c8a8a0e4bbb675d62a55c62fdcea79166. --- test/test_pytato.py | 41 +---------------------------------------- 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index a84e971a3..dfd1bc307 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1144,47 +1144,8 @@ def test_get_hash(): seen_hashes.add(h1) -def run_test_with_hashseed(f, *args): - from pickle import dumps - from base64 import b64encode - - from subprocess import check_call - import os - - os.environ["RUN_WITH_SET_PYTHONHASHSEED"] = "1" - os.environ["PYTHONHASHSEED"] = "1" - os.environ["INVOCATION_INFO"] = b64encode(dumps((f, args))).decode() - - check_call([sys.executable, __file__]) - - -def run_test_with_hashseed_inner(): - from pickle import loads - from base64 import b64decode - f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) - - f(*args) - - -def test_get_hash_stable_with_hashseed(): - run_test_with_hashseed(_do_test_get_hash_stable_with_hashseed) - - -def _do_test_get_hash_stable_with_hashseed(): - rdagc = RandomDAGContext(np.random.default_rng(seed=42), - axis_len=5, use_numpy=False) - - dag = make_random_dag(rdagc) - assert hash(dag) == 6313690382525417492 - -# }}} - - if __name__ == "__main__": - import os - if "RUN_WITH_SET_PYTHONHASHSEED" in os.environ: - run_test_with_hashseed_inner() - elif len(sys.argv) > 1: + if len(sys.argv) > 1: exec(sys.argv[1]) else: from pytest import main From ff4d0410ebe653e2d56d87d79f3415190af6276e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 21 Sep 2023 17:14:02 -0500 Subject: [PATCH 6/7] cleanups for KeyBuilder --- pytato/analysis/__init__.py | 24 ++++++++---------------- pytato/array.py | 3 +++ 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5e9fed715..35e3c3f31 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -509,25 +509,17 @@ def update_for_pymbolic_expression(self, key_hash, key): else: PersistentHashWalkMapper(key_hash)(key) - def update_for_Product(self, key_hash, key): - PersistentHashWalkMapper(key_hash)(key) - - def update_for_Array(self, key_hash, key): - self.update_for_ndarray(key_hash, key.get()) - - def update_for_int64(self, key_hash, key): - self.rec(key_hash, int(key)) - def update_for_Reduce(self, key_hash, key): self.rec(key_hash, hash(key)) - update_for_Sum = update_for_Product - update_for_If = update_for_Product - update_for_LogicalOr = update_for_Product - update_for_Call = update_for_Product - update_for_Comparison = update_for_Product - update_for_Quotient = update_for_Product - update_for_Power = update_for_Product + update_for_Product = update_for_pymbolic_expression + update_for_Sum = update_for_pymbolic_expression + update_for_If = update_for_pymbolic_expression + update_for_LogicalOr = update_for_pymbolic_expression + update_for_Call = update_for_pymbolic_expression + update_for_Comparison = update_for_pymbolic_expression + update_for_Quotient = update_for_pymbolic_expression + update_for_Power = update_for_pymbolic_expression update_for_PMap = update_for_dict # noqa: N815 diff --git a/pytato/array.py b/pytato/array.py index 990b5f8ac..550900e61 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -459,6 +459,9 @@ def copy(self: ArrayT, **kwargs: Any) -> ArrayT: def _with_new_tags(self: ArrayT, tags: FrozenSet[Tag]) -> ArrayT: return attrs.evolve(self, tags=tags) + def update_persistent_hash(self, key_hash: Any, key_builder: Any) -> None: + key_builder.rec(key_hash, hash(self)) + if TYPE_CHECKING: @property def shape(self) -> ShapeType: From 0081ce4c9f29dae62d6a972effd6ec2f97bc6921 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 21 Sep 2023 17:35:35 -0500 Subject: [PATCH 7/7] lint --- pytato/analysis/__init__.py | 45 +++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 35e3c3f31..76bcbd360 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -471,30 +471,31 @@ class PytatoKeyBuilder(KeyBuilder): update_for_list = KeyBuilder.update_for_tuple update_for_set = KeyBuilder.update_for_frozenset - def update_for_dict(self, key_hash, key): + def update_for_dict(self, key_hash: Any, key: Any) -> None: from pytools import unordered_hash unordered_hash( key_hash, - (self.rec(self.new_hash(), (k, v)).digest() + (self.rec(self.new_hash(), # type: ignore[misc] + (k, v)).digest() # type: ignore[no-untyped-call] for k, v in key.items())) update_for_defaultdict = update_for_dict - def update_for_ndarray(self, key_hash, key): - self.rec(key_hash, hash(key.data.tobytes())) + def update_for_ndarray(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, hash(key.data.tobytes())) # type: ignore[no-untyped-call] - def update_for_frozenset(self, key_hash, key): + def update_for_frozenset(self, key_hash: Any, key: Any) -> None: for set_key in sorted(key, key=lambda obj: type(obj).__name__ + str(obj)): - self.rec(key_hash, set_key) + self.rec(key_hash, set_key) # type: ignore[no-untyped-call] - def update_for_BasicSet(self, key_hash, key): # noqa + def update_for_BasicSet(self, key_hash: Any, key: Any) -> None: from islpy import Printer prn = Printer.to_str(key.get_ctx()) getattr(prn, "print_"+key._base_name)(key) key_hash.update(prn.get_str().encode("utf8")) - def update_for_Map(self, key_hash, key): # noqa + def update_for_Map(self, key_hash: Any, key: Any) -> None: import islpy as isl if isinstance(key, Map): self.update_for_dict(key_hash, key) @@ -503,23 +504,23 @@ def update_for_Map(self, key_hash, key): # noqa else: raise AssertionError() - def update_for_pymbolic_expression(self, key_hash, key): + def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: if key is None: - self.update_for_NoneType(key_hash, key) + self.update_for_NoneType(key_hash, key) # type: ignore[no-untyped-call] else: PersistentHashWalkMapper(key_hash)(key) - def update_for_Reduce(self, key_hash, key): - self.rec(key_hash, hash(key)) - - update_for_Product = update_for_pymbolic_expression - update_for_Sum = update_for_pymbolic_expression - update_for_If = update_for_pymbolic_expression - update_for_LogicalOr = update_for_pymbolic_expression - update_for_Call = update_for_pymbolic_expression - update_for_Comparison = update_for_pymbolic_expression - update_for_Quotient = update_for_pymbolic_expression - update_for_Power = update_for_pymbolic_expression + def update_for_Reduce(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, hash(key)) # type: ignore[no-untyped-call] + + update_for_Product = update_for_pymbolic_expression # noqa: N815 + update_for_Sum = update_for_pymbolic_expression # noqa: N815 + update_for_If = update_for_pymbolic_expression # noqa: N815 + update_for_LogicalOr = update_for_pymbolic_expression # noqa: N815 + update_for_Call = update_for_pymbolic_expression # noqa: N815 + update_for_Comparison = update_for_pymbolic_expression # noqa: N815 + update_for_Quotient = update_for_pymbolic_expression # noqa: N815 + update_for_Power = update_for_pymbolic_expression # noqa: N815 update_for_PMap = update_for_dict # noqa: N815 @@ -531,7 +532,7 @@ def get_hash(outputs: Union[Array, DictOfNamedArrays]) -> str: hm = PytatoKeyBuilder() - return hm(outputs) + return hm(outputs) # type: ignore[no-any-return] # }}}