Skip to content

Commit

Permalink
test push indirections
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Mar 17, 2024
1 parent f02ba1c commit 9ae9474
Showing 1 changed file with 142 additions and 1 deletion.
143 changes: 142 additions & 1 deletion test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa

import pytato as pt
from testlib import assert_allclose_to_numpy, get_random_pt_dag
from testlib import assert_allclose_to_numpy, get_random_pt_dag, auto_test_vs_ref
import pymbolic.primitives as p


Expand Down Expand Up @@ -2002,6 +2002,147 @@ def call_bar(tracer, x, y):
np.testing.assert_allclose(result_out[k], expect_out[k])


def _evaluator_for_indirection_folding(cl_ctx, dictofarys):
from immutables import Map
cq = cl.CommandQueue(cl_ctx)
_, out_dict = pt.generate_loopy(dictofarys)(cq)
return Map({k: pt.make_data_wrapper(v) for k, v in out_dict.items()})


@pytest.mark.parametrize("fold_constant_idxs", (False, True))
def test_push_indirections_0(ctx_factory, fold_constant_idxs):
from testlib import (are_all_indexees_materialized_nodes,
are_all_indexer_arrays_datawrappers)

cl_ctx = cl.create_some_context()
rng = np.random.default_rng(0)
x_np = rng.random((10, 4))
map1_np = rng.integers(0, 10, size=17)
map2_np = rng.integers(0, 17, size=29)

x = pt.make_data_wrapper(x_np)
map1 = pt.make_data_wrapper(map1_np)
map2 = pt.make_data_wrapper(map2_np)

y = 3.14 * ((42*((2*x)[map1]))[map2])
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
)

if fold_constant_idxs:
assert not are_all_indexer_arrays_datawrappers(y_transformed)
y_transformed = pt.fold_constant_indirections(
y_transformed,
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
doa)
)
assert are_all_indexer_arrays_datawrappers(y_transformed)

auto_test_vs_ref(cl_ctx, y, y_transformed)
assert are_all_indexees_materialized_nodes(y_transformed)


@pytest.mark.parametrize("fold_constant_idxs", (False, True))
def test_push_indirections_1(ctx_factory, fold_constant_idxs):
from testlib import (are_all_indexees_materialized_nodes,
are_all_indexer_arrays_datawrappers)

cl_ctx = cl.create_some_context()
rng = np.random.default_rng(0)
x_np = rng.random((100, 4))
map1_np = rng.integers(0, 20, size=17)

x = pt.make_data_wrapper(x_np)
map1 = pt.make_data_wrapper(map1_np)

y = 3.14 * ((42*((2*x)[2:92:3, :3]))[map1])
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
)

if fold_constant_idxs:
assert not are_all_indexer_arrays_datawrappers(y_transformed)
y_transformed = pt.fold_constant_indirections(
y_transformed,
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
doa)
)
assert are_all_indexer_arrays_datawrappers(y_transformed)

auto_test_vs_ref(cl_ctx, y, y_transformed)
assert are_all_indexees_materialized_nodes(y_transformed)


@pytest.mark.parametrize("fold_constant_idxs", (False, True))
def test_push_indirections_2(ctx_factory, fold_constant_idxs):
from testlib import (are_all_indexees_materialized_nodes,
are_all_indexer_arrays_datawrappers)

cl_ctx = cl.create_some_context()
rng = np.random.default_rng(0)
x_np = rng.random((100, 10))
map1_np = rng.integers(0, 20, size=17)
map2_np = rng.integers(0, 4, size=29)

x = pt.make_data_wrapper(x_np)
map1 = pt.make_data_wrapper(map1_np)
map2 = pt.make_data_wrapper(map2_np)

y = (1729*((3.14*((42*((2*x)[2:92:3, ::2]))[map1]))[map2]))[1:-3:2, 1:-2:7]
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
)

if fold_constant_idxs:
assert not are_all_indexer_arrays_datawrappers(y_transformed)
y_transformed = pt.fold_constant_indirections(
y_transformed,
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
doa)
)
assert are_all_indexer_arrays_datawrappers(y_transformed)

auto_test_vs_ref(cl_ctx, y, y_transformed)
assert are_all_indexees_materialized_nodes(y_transformed)


@pytest.mark.parametrize("fold_constant_idxs", (False, True))
def test_push_indirections_3(ctx_factory, fold_constant_idxs):
from testlib import (are_all_indexees_materialized_nodes,
are_all_indexer_arrays_datawrappers)

cl_ctx = cl.create_some_context()
rng = np.random.default_rng(0)
x_np = rng.random((10, 4))
map1_np = rng.integers(0, 10, size=17)
map2_np = rng.integers(0, 17, size=29)
map3_np = rng.integers(0, 4, size=60)
map4_np = rng.integers(0, 60, size=22)

x = pt.make_data_wrapper(x_np)
map1 = pt.make_data_wrapper(map1_np)
map2 = pt.make_data_wrapper(map2_np)
map3 = pt.make_data_wrapper(map3_np)
map4 = pt.make_data_wrapper(map4_np)

y = 3.14 * ((42*((2*x)[map1.reshape(-1, 1), map3]))[map2.reshape(-1, 1), map4])
y_transformed = pt.push_axis_indirections_towards_materialized_nodes(
pt.decouple_multi_axis_indirections_into_single_axis_indirections(y)
)

if fold_constant_idxs:
assert not are_all_indexer_arrays_datawrappers(y_transformed)
y_transformed = pt.fold_constant_indirections(
y_transformed,
lambda doa: _evaluator_for_indirection_folding(cl_ctx,
doa)
)
assert are_all_indexer_arrays_datawrappers(y_transformed)

auto_test_vs_ref(cl_ctx, y, y_transformed)
assert are_all_indexees_materialized_nodes(y_transformed)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit 9ae9474

Please sign in to comment.