diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 572cbe028..1f5e1cab3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,7 @@ jobs: python-version: '3.x' - name: "Main Script" run: | + export EXTRA_INSTALL="git+https://github.com/inducer/arraycontext" curl -L -O https://tiker.net/ci-support-v0 . ci-support-v0 build_py_project_in_conda_env @@ -54,6 +55,7 @@ jobs: python-version: '3.x' - name: "Main Script" run: | + export EXTRA_INSTALL="git+https://github.com/inducer/arraycontext" curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 build_py_project_in_conda_env diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 988c11a16..a274a335a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,6 +29,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method @@ -565,4 +566,31 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int: # }}} + +# {{{ PytatoKeyBuilder + +class PytatoKeyBuilder(LoopyKeyBuilder): + """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass + for objects within :mod:`pytato`. + """ + # The types below aren't immutable in general, but in the context of + # pytato, they are used as such. + + def update_for_ndarray(self, key_hash: Any, key: Any) -> None: + import numpy as np + assert isinstance(key, np.ndarray) + self.rec(key_hash, key.data.tobytes()) + + def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray + assert isinstance(key, TaggableCLArray) + self.rec(key_hash, key.get()) + + def update_for_Array(self, key_hash: Any, key: Any) -> None: + from pyopencl.array import Array + assert isinstance(key, Array) + self.rec(key_hash, key.get()) + +# }}} + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 4a29ae13c..38e2fda5e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1364,6 +1364,99 @@ def test_dot_visualizers(): # }}} +# {{{ Test PytatoKeyBuilder + +def run_test_with_new_python_invocation(f, *args, extra_env_vars=None) -> None: + import os + if extra_env_vars is None: + extra_env_vars = {} + + from base64 import b64encode + from pickle import dumps + from subprocess import check_call + + env_vars = { + "INVOCATION_INFO": b64encode(dumps((f, args))).decode(), + } + env_vars.update(extra_env_vars) + + my_env = os.environ.copy() + my_env.update(env_vars) + + check_call([sys.executable, __file__], env=my_env) + + +def run_test_with_new_python_invocation_inner() -> None: + import os + from base64 import b64decode + from pickle import loads + + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) + + f(*args) + + +def test_persistent_hashing_and_persistent_dict() -> None: + import shutil + import tempfile + + from pytools.persistent_dict import ReadOnlyEntryError, WriteOncePersistentDict + + from pytato.analysis import PytatoKeyBuilder + + try: + tmpdir = tempfile.mkdtemp() + + pkb = PytatoKeyBuilder() + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir, + safe_sync=False) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=5, use_numpy=True) + + dag = make_random_dag(rdagc) + + # Make sure the PytatoKeyBuilder can handle 'dag' + pd[dag] = 42 + + # Make sure that the key stays the same within the same Python invocation + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + + # Make sure that the key stays the same across Python invocations + run_test_with_new_python_invocation( + _test_persistent_hashing_and_persistent_dict_stage2, tmpdir) + finally: + shutil.rmtree(tmpdir) + + +def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None: + from pytools.persistent_dict import ReadOnlyEntryError, WriteOncePersistentDict + + from pytato.analysis import PytatoKeyBuilder + pkb = PytatoKeyBuilder() + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir, + safe_sync=False) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=5, use_numpy=True) + + dag = make_random_dag(rdagc) + + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + +# }}} + + def test_numpy_type_promotion_with_pytato_arrays(): class NotReallyAnArray: @property @@ -1427,7 +1520,10 @@ def test_pickling_hash(): if __name__ == "__main__": - if len(sys.argv) > 1: + import os + if "INVOCATION_INFO" in os.environ: + run_test_with_new_python_invocation_inner() + elif len(sys.argv) > 1: exec(sys.argv[1]) else: from pytest import main