Skip to content

Commit

Permalink
Remove temporary handling from domain inference
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Sep 3, 2024
1 parent fc4846f commit e8e679d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 97 deletions.
52 changes: 7 additions & 45 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from gt4py.next.common import Dimension
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union
from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, domain_union
from gt4py.next.iterator.transforms.trace_shifts import TraceShifts


Expand Down Expand Up @@ -209,31 +209,13 @@ def infer_expr(
raise ValueError(f"Unsupported expression: {expr}")


def _validate_temporary_usage(body: list[itir.Stmt], temporaries: list[str]):
assigned_targets = set()
for stmt in body:
assert isinstance(stmt, itir.SetAt) # TODO: extend for if-statements when they land
assert isinstance(
stmt.target, itir.SymRef
) # TODO: stmt.target can be an expr, e.g. make_tuple
if stmt.target.id in assigned_targets:
raise ValueError("Temporaries can only be used once within a program.")
if stmt.target.id in temporaries:
assigned_targets.add(stmt.target.id)


def infer_program(
program: itir.Program,
offset_provider: Dict[str, Dimension],
) -> itir.Program:
accessed_domains: dict[str, SymbolicDomain | None] = {}
transformed_set_ats: list[itir.SetAt] = []

temporaries: list[str] = [tmp.id for tmp in program.declarations]

# TODO(tehrengruber): disabled since it breaks with tuples
# _validate_temporary_usage(program.body, temporaries)

for set_at in reversed(program.body):
assert isinstance(set_at, itir.SetAt)
if isinstance(set_at.expr, itir.SymRef):
Expand All @@ -243,11 +225,8 @@ def infer_program(
assert isinstance(
set_at.target, itir.SymRef
) # TODO: stmt.target can be an expr, e.g. make_tuple
if set_at.target.id in temporaries:
# ignore temporaries as their domain is the `AUTO_DOMAIN` placeholder
assert set_at.domain == AUTO_DOMAIN
else:
accessed_domains[set_at.target.id] = SymbolicDomain.from_expr(set_at.domain)

accessed_domains[set_at.target.id] = SymbolicDomain.from_expr(set_at.domain)
transformed_call, current_accessed_domains = infer_expr(
set_at.expr, accessed_domains[set_at.target.id], offset_provider
)
Expand All @@ -264,31 +243,14 @@ def infer_program(

for field in current_accessed_domains:
if field in accessed_domains:
# multiple accesses to the same field -> compute union of accessed domains
if field in temporaries:
if accessed_domains[field] is None:
accessed_domains[field] = current_accessed_domains[field]
elif current_accessed_domains[field] is None:
accessed_domains[field] = accessed_domains[field]
else:
accessed_domains[field] = domain_union(
[accessed_domains[field], current_accessed_domains[field]] # type: ignore[list-item] # ensured by if condition
)
else:
# TODO(tehrengruber): if domain_ref is an external field the domain must
# already be larger. This should be checked, but would require additions
# to the IR.
pass
# TODO(tehrengruber): if domain_ref is an external field the domain must
# already be larger. This should be checked, but would require additions
# to the IR.
pass
else:
accessed_domains[field] = current_accessed_domains[field]

new_declarations = copy.deepcopy(program.declarations)
for temporary in new_declarations:
temporary.domain = (
SymbolicDomain.as_expr(accessed_domains[temporary.id]) # type: ignore[arg-type] # ensured by if condition
if accessed_domains[temporary.id] is not None
else None
)

return itir.Program(
id=program.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms.infer_domain import infer_program, infer_expr
from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, AUTO_DOMAIN
from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain
import pytest
from gt4py.eve.extended_typing import Dict
from gt4py.next.common import Dimension, DimensionKind
Expand Down Expand Up @@ -443,15 +443,15 @@ def test_program(offset_provider):
domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})
domain_tmp = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)})

params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]]
params = [im.sym(name) for name in ["in_field", "out_field"]]

testee = itir.Program(
id="forward_diff_with_tmp",
function_definitions=[],
params=params,
declarations=[itir.Temporary(id="tmp", domain=AUTO_DOMAIN, dtype=float_type)],
declarations=[itir.Temporary(id="tmp", domain=domain_tmp, dtype=float_type)],
body=[
itir.SetAt(expr=applied_as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")),
itir.SetAt(expr=applied_as_fieldop_tmp, domain=domain_tmp, target=im.ref("tmp")),
itir.SetAt(expr=applied_as_fieldop, domain=domain, target=im.ref("out_field")),
],
)
Expand Down Expand Up @@ -484,19 +484,19 @@ def test_program_two_tmps(offset_provider):
domain_tmp1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 13)})
domain_tmp2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)})

params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]]
params = [im.sym(name) for name in ["in_field", "out_field"]]

testee = itir.Program(
id="forward_diff_with_two_tmps",
function_definitions=[],
params=params,
declarations=[
itir.Temporary(id="tmp1", domain=AUTO_DOMAIN, dtype=float_type),
itir.Temporary(id="tmp2", domain=AUTO_DOMAIN, dtype=float_type),
itir.Temporary(id="tmp1", domain=domain_tmp1, dtype=float_type),
itir.Temporary(id="tmp2", domain=domain_tmp2, dtype=float_type),
],
body=[
itir.SetAt(expr=as_fieldop_tmp1, domain=AUTO_DOMAIN, target=im.ref("tmp1")),
itir.SetAt(expr=as_fieldop_tmp2, domain=AUTO_DOMAIN, target=im.ref("tmp2")),
itir.SetAt(expr=as_fieldop_tmp1, domain=domain_tmp1, target=im.ref("tmp1")),
itir.SetAt(expr=as_fieldop_tmp2, domain=domain_tmp2, target=im.ref("tmp2")),
itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")),
],
)
Expand All @@ -523,36 +523,6 @@ def test_program_two_tmps(offset_provider):
run_test_program(testee, expected, offset_provider)


@pytest.mark.xfail(
reason="this currently fails since _validate_temporary_usage is not called"
) # TODO
def test_program_ValueError(offset_provider):
with pytest.raises(ValueError, match=r"Temporaries can only be used once within a program."):
stencil = im.lambda_("arg0")(im.deref("arg0"))

as_fieldop_tmp = im.as_fieldop(stencil)(im.ref("in_field"))
as_fieldop = im.as_fieldop(stencil)(im.ref("tmp"))

domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})

params = [im.sym(name) for name in ["in_field", "out_field", "_gtmp_auto_domain"]]

infer_program(
itir.Program(
id="forward_diff_with_tmp",
function_definitions=[],
params=params,
declarations=[itir.Temporary(id="tmp", domain=AUTO_DOMAIN, dtype=float_type)],
body=[
itir.SetAt(expr=as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")),
itir.SetAt(expr=as_fieldop_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")),
itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")),
],
),
offset_provider,
)


def test_program_tree_tmps_two_inputs(offset_provider):
stencil = im.lambda_("arg0", "arg1")(
im.minus(im.deref(im.shift("Ioff", 1)("arg0")), im.deref("arg1"))
Expand All @@ -574,25 +544,22 @@ def test_program_tree_tmps_two_inputs(offset_provider):
domain_tmp2 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 12)})
domain_tmp3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})
domain_out = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})
params = [
im.sym(name)
for name in ["in_field1", "in_field2", "out_field1", "out_field2", "_gtmp_auto_domain"]
]
params = [im.sym(name) for name in ["in_field1", "in_field2", "out_field1", "out_field2"]]

testee = itir.Program(
id="differences_three_tmps_two_inputs",
function_definitions=[],
params=params,
declarations=[
itir.Temporary(id="tmp1", domain=AUTO_DOMAIN, dtype=float_type),
itir.Temporary(id="tmp2", domain=AUTO_DOMAIN, dtype=float_type),
itir.Temporary(id="tmp3", domain=AUTO_DOMAIN, dtype=float_type),
itir.Temporary(id="tmp1", domain=domain_tmp1, dtype=float_type),
itir.Temporary(id="tmp2", domain=domain_tmp2, dtype=float_type),
itir.Temporary(id="tmp3", domain=domain_tmp3, dtype=float_type),
],
body=[
itir.SetAt(expr=as_fieldop_tmp1, domain=AUTO_DOMAIN, target=im.ref("tmp1")),
itir.SetAt(expr=as_fieldop_tmp2, domain=AUTO_DOMAIN, target=im.ref("tmp2")),
itir.SetAt(expr=as_fieldop_tmp1, domain=domain_tmp1, target=im.ref("tmp1")),
itir.SetAt(expr=as_fieldop_tmp2, domain=domain_tmp2, target=im.ref("tmp2")),
itir.SetAt(expr=as_fieldop_out1, domain=domain_out, target=im.ref("out_field1")),
itir.SetAt(expr=as_fieldop_tmp3, domain=AUTO_DOMAIN, target=im.ref("tmp3")),
itir.SetAt(expr=as_fieldop_tmp3, domain=domain_tmp3, target=im.ref("tmp3")),
itir.SetAt(expr=as_fieldop_out2, domain=domain_out, target=im.ref("out_field2")),
],
)
Expand Down Expand Up @@ -912,15 +879,15 @@ def test_program_let(offset_provider):
domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})
domain_lm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 11)})

params = [im.sym(name) for name in ["in_field", "out_field", "outer", "_gtmp_auto_domain"]]
params = [im.sym(name) for name in ["in_field", "out_field", "outer"]]

testee = itir.Program(
id="forward_diff_with_tmp",
function_definitions=[],
params=params,
declarations=[itir.Temporary(id="tmp", domain=AUTO_DOMAIN, dtype=float_type)],
declarations=[itir.Temporary(id="tmp", domain=domain_lm1, dtype=float_type)],
body=[
itir.SetAt(expr=let_tmp, domain=AUTO_DOMAIN, target=im.ref("tmp")),
itir.SetAt(expr=let_tmp, domain=domain_lm1, target=im.ref("tmp")),
itir.SetAt(expr=as_fieldop, domain=domain, target=im.ref("out_field")),
],
)
Expand Down

0 comments on commit e8e679d

Please sign in to comment.