Skip to content

Commit

Permalink
add PytatoKeyBuilder, persistent_dict test (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Dec 4, 2024
1 parent d57cb0c commit 006510c
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
python-version: '3.x'
- name: "Main Script"
run: |
export EXTRA_INSTALL="git+https://github.com/inducer/arraycontext"
curl -L -O https://tiker.net/ci-support-v0
. ci-support-v0
build_py_project_in_conda_env
Expand All @@ -54,6 +55,7 @@ jobs:
python-version: '3.x'
- name: "Main Script"
run: |
export EXTRA_INSTALL="git+https://github.com/inducer/arraycontext"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_conda_env
Expand Down
28 changes: 28 additions & 0 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

from loopy.tools import LoopyKeyBuilder
from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

Expand Down Expand Up @@ -565,4 +566,31 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int:

# }}}


# {{{ PytatoKeyBuilder

class PytatoKeyBuilder(LoopyKeyBuilder):
"""A custom :class:`pytools.persistent_dict.KeyBuilder` subclass
for objects within :mod:`pytato`.
"""
# The types below aren't immutable in general, but in the context of
# pytato, they are used as such.

def update_for_ndarray(self, key_hash: Any, key: Any) -> None:
import numpy as np
assert isinstance(key, np.ndarray)
self.rec(key_hash, key.data.tobytes())

def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None:
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
assert isinstance(key, TaggableCLArray)
self.rec(key_hash, key.get())

def update_for_Array(self, key_hash: Any, key: Any) -> None:
from pyopencl.array import Array
assert isinstance(key, Array)
self.rec(key_hash, key.get())

# }}}

# vim: fdm=marker
98 changes: 97 additions & 1 deletion test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,99 @@ def test_dot_visualizers():
# }}}


# {{{ Test PytatoKeyBuilder

def run_test_with_new_python_invocation(f, *args, extra_env_vars=None) -> None:
import os
if extra_env_vars is None:
extra_env_vars = {}

from base64 import b64encode
from pickle import dumps
from subprocess import check_call

env_vars = {
"INVOCATION_INFO": b64encode(dumps((f, args))).decode(),
}
env_vars.update(extra_env_vars)

my_env = os.environ.copy()
my_env.update(env_vars)

check_call([sys.executable, __file__], env=my_env)


def run_test_with_new_python_invocation_inner() -> None:
import os
from base64 import b64decode
from pickle import loads

f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode()))

f(*args)


def test_persistent_hashing_and_persistent_dict() -> None:
import shutil
import tempfile

from pytools.persistent_dict import ReadOnlyEntryError, WriteOncePersistentDict

from pytato.analysis import PytatoKeyBuilder

try:
tmpdir = tempfile.mkdtemp()

pkb = PytatoKeyBuilder()

pd = WriteOncePersistentDict("test_persistent_dict",
key_builder=pkb,
container_dir=tmpdir,
safe_sync=False)

for i in range(100):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=5, use_numpy=True)

dag = make_random_dag(rdagc)

# Make sure the PytatoKeyBuilder can handle 'dag'
pd[dag] = 42

# Make sure that the key stays the same within the same Python invocation
with pytest.raises(ReadOnlyEntryError):
pd[dag] = 42

# Make sure that the key stays the same across Python invocations
run_test_with_new_python_invocation(
_test_persistent_hashing_and_persistent_dict_stage2, tmpdir)
finally:
shutil.rmtree(tmpdir)


def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None:
from pytools.persistent_dict import ReadOnlyEntryError, WriteOncePersistentDict

from pytato.analysis import PytatoKeyBuilder
pkb = PytatoKeyBuilder()

pd = WriteOncePersistentDict("test_persistent_dict",
key_builder=pkb,
container_dir=tmpdir,
safe_sync=False)

for i in range(100):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=5, use_numpy=True)

dag = make_random_dag(rdagc)

with pytest.raises(ReadOnlyEntryError):
pd[dag] = 42

# }}}


def test_numpy_type_promotion_with_pytato_arrays():
class NotReallyAnArray:
@property
Expand Down Expand Up @@ -1427,7 +1520,10 @@ def test_pickling_hash():


if __name__ == "__main__":
if len(sys.argv) > 1:
import os
if "INVOCATION_INFO" in os.environ:
run_test_with_new_python_invocation_inner()
elif len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
Expand Down

0 comments on commit 006510c

Please sign in to comment.