Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GTIR domain inference #1568

Merged
merged 41 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
717ea3f
Extend constant folding for literals and add respective test cases
SF-N Jun 28, 2024
7480402
Add GTIR domain inference for nested as_fieldops and respective test …
SF-N Jul 3, 2024
a1ef6d8
Simplify im.cartesian_domain, update test_domain_inference and run pr…
SF-N Jul 4, 2024
0eb0032
Update infer_domain wrt review comments and run pre-commit
SF-N Jul 4, 2024
087b19a
Add offset_provider fixture and pass it to infer_as_fieldop
SF-N Jul 4, 2024
bd0915c
Merge branch 'main' into GTIR_domain-inference
SF-N Jul 8, 2024
e71379e
Merge branch 'main' of github.com:GridTools/gt4py into GTIR_domain-in…
SF-N Jul 8, 2024
4305051
Merge branch 'GTIR_domain-inference' of github.com:SF-N/gt4py into GT…
SF-N Jul 8, 2024
ab367e2
Further refactor wrt review comments
SF-N Jul 8, 2024
8e495c1
Update SymbolicDomain.as_expr to make use of im.cartesian_domain and …
SF-N Jul 9, 2024
718be39
Remove comment
SF-N Jul 9, 2024
3fe3b5d
Add support for unstructured shifts
tehrengruber Jul 11, 2024
ed4fc32
Address review comments
SF-N Jul 22, 2024
f09494f
Start working on domain-inference for let-statements
SF-N Jul 24, 2024
7f8ad2a
Cleanup with Sara
tehrengruber Jul 24, 2024
af12b21
Format
tehrengruber Jul 24, 2024
c18e61f
Cleanup and remove work on let domain inference
SF-N Jul 25, 2024
709b5aa
Move _translate to SymbolicDomain.translate
SF-N Jul 25, 2024
cfcd807
Remove ConstantFolding from domain_union and apply it in the tests in…
SF-N Jul 25, 2024
2097226
Refactor test_domain_inerence
SF-N Jul 25, 2024
67f1c2b
MyPy and TODO cleanup with Till
SF-N Jul 26, 2024
a007bf2
Merge main
SF-N Jul 26, 2024
685929f
Some further cleanup and refactoring
SF-N Jul 26, 2024
4447221
Merge branch 'main' into GTIR_domain-inference
SF-N Aug 7, 2024
e10f1a5
Add support for let-statements with expr=ir.SymRef(...) and passing L…
SF-N Aug 8, 2024
cc940e1
Merge branch 'GTIR_domain-inference' of github.com:SF-N/gt4py into GT…
SF-N Aug 8, 2024
1e6cafc
Merge branch 'main' into GTIR_domain-inference
SF-N Aug 9, 2024
6dff994
Add domain inference for cond
SF-N Aug 9, 2024
d07a94d
Merge branch 'main' into GTIR_domain-inference
SF-N Aug 13, 2024
4955266
Change licence of new files
SF-N Aug 13, 2024
2ccc36f
Remove comd and rename input_domain -> target_domain
SF-N Aug 14, 2024
af5ea76
Merge branch 'main' into GTIR_domain-inference
SF-N Aug 20, 2024
e007846
Add support for unused input in as_fieldop inference and correspondin…
SF-N Aug 20, 2024
7b99b99
Cleanup tests
tehrengruber Aug 21, 2024
03f42ed
Small fix
tehrengruber Aug 21, 2024
19af61b
Small fix
tehrengruber Aug 21, 2024
587f883
Small fix
tehrengruber Aug 21, 2024
52f4628
Small fix
tehrengruber Aug 21, 2024
bb8b00e
Small cleanup
tehrengruber Aug 21, 2024
9b82c6b
Fix TraceShift for dynamic shifts
tehrengruber Aug 21, 2024
d75f715
Retrigger CI
tehrengruber Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import Callable, Iterable, Union

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Any, Dict, Tuple
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, SymbolicRange
from gt4py.next.type_system import type_specifications as ts, type_translation


Expand Down Expand Up @@ -389,3 +391,24 @@ def _impl(*its: itir.Expr) -> itir.FunCall:
def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))


def cartesian_domain(ranges: Dict[str, Tuple[Any, Any]]) -> SymbolicDomain:
"""
>>> pformat(cartesian_domain({"IDim": (0, 10), "JDim": (0, 20)}))
SF-N marked this conversation as resolved.
Show resolved Hide resolved
'c⟨ IDim: [0, 10), JDim: [0, 20) ⟩'

"""

axis_order = ["IDim", "JDim", "KDim"]
SF-N marked this conversation as resolved.
Show resolved Hide resolved

symbolic_ranges = {}

for axis, (start, stop) in ranges.items():
if axis not in axis_order:
raise ValueError("The ranges need to contain either IDim, JDim or KDim.")
symbolic_ranges[axis] = SymbolicRange(start, stop)

domain = SymbolicDomain(grid_type="cartesian_domain", ranges=symbolic_ranges)
SF-N marked this conversation as resolved.
Show resolved Hide resolved

return SymbolicDomain.as_expr(domain)
12 changes: 8 additions & 4 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ def visit_FunCall(self, node: ir.FunCall):
and len(new_node.args) > 0
and all(isinstance(arg, ir.Literal) for arg in new_node.args)
): # `1 + 1` -> `2`
if new_node.fun.id in ir.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(new_node.fun.id))
arg_values = [getattr(embedded, str(arg.type))(arg.value) for arg in new_node.args] # type: ignore[attr-defined] # arg type already established in if condition
new_node = im.literal_from_value(fun(*arg_values))
try:
if new_node.fun.id in ir.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(new_node.fun.id))
arg_values = [getattr(embedded, str(arg.type))(arg.value) for arg in
new_node.args] # type: ignore[attr-defined] # arg type already established in if condition
new_node = im.literal_from_value(fun(*arg_values))
except ValueError:
pass # happens for inf and neginf

return new_node
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs
from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs

from gt4py.next.iterator.transforms.constant_folding import ConstantFolding


"""Iterator IR extension for global temporaries.

Expand Down Expand Up @@ -453,7 +455,7 @@ def domain_union(domains: list[SymbolicDomain]) -> SymbolicDomain:
lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr),
[domain.ranges[dim].stop for domain in domains],
)
new_domain_ranges[dim] = SymbolicRange(start, stop)
new_domain_ranges[dim] = SymbolicRange(ConstantFolding.apply(start), ConstantFolding.apply(stop))
return SymbolicDomain(domains[0].grid_type, new_domain_ranges)


Expand Down
169 changes: 169 additions & 0 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
import dataclasses

from gt4py.eve.extended_typing import Dict, List, Tuple, Union
from gt4py.next.common import Dimension, DimensionKind
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain, SymbolicRange, domain_union
from gt4py.next.iterator.transforms.trace_shifts import TraceShifts


# Define a mapping for offset values to dimension names and kinds
OFFSET_TO_DIMENSION = {
itir.SymbolRef("Ioff"): ("IDim", DimensionKind.HORIZONTAL),
itir.SymbolRef("Joff"): ("JDim", DimensionKind.HORIZONTAL),
itir.SymbolRef("Koff"): ("KDim", DimensionKind.VERTICAL),
}


@dataclasses.dataclass(frozen=True)
class InferDomain:
SF-N marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def _infer_dimension_from_offset(offset: itir.OffsetLiteral) -> Dimension:
if offset.value in OFFSET_TO_DIMENSION:
name, kind = OFFSET_TO_DIMENSION[offset.value]
return Dimension(name, kind)
else:
raise ValueError("offset must be either Ioff, Joff, or Koff")

@staticmethod
def _extract_axis_dims(domain_expr: SymbolicDomain | itir.FunCall) -> List[str]:
axis_dims = []
if isinstance(domain_expr, SymbolicDomain) and domain_expr.grid_type == "cartesian_domain":
axis_dims.extend(domain_expr.ranges.keys())
elif isinstance(domain_expr, itir.FunCall) and domain_expr.fun == im.ref(
"cartesian_domain"
):
for named_range in domain_expr.args:
if isinstance(named_range, itir.FunCall) and named_range.fun == im.ref(
"named_range"
):
axis_literal = named_range.args[0]
if isinstance(axis_literal, itir.AxisLiteral):
axis_dims.append(axis_literal.value)
return axis_dims

@staticmethod
def _get_symbolic_domain(domain: Union[SymbolicDomain, itir.FunCall]) -> SymbolicDomain:
if isinstance(domain, SymbolicDomain):
return domain
if isinstance(domain, itir.FunCall) and domain.fun == im.ref("cartesian_domain"):
return SymbolicDomain.from_expr(domain)
raise TypeError("domain must either be a FunCall or a SymbolicDomain.")

@staticmethod
def _merge_domains(original_domains, new_domains):
for key, value in new_domains.items():
if key in original_domains:
original_domains[key] = domain_union([original_domains[key], value])
else:
original_domains[key] = value

return {
key: domain_union(value) if isinstance(value, list) else value
for key, value in original_domains.items()
}

@classmethod
def _translate_domain(
cls, symbolic_domain: SymbolicDomain, shift: Tuple[itir.OffsetLiteral, int], dims: List[str]
) -> SymbolicDomain:
new_ranges = {dim: symbolic_domain.ranges[dim] for dim in dims}
if shift:
off, val = shift
current_dim = cls._infer_dimension_from_offset(off)
SF-N marked this conversation as resolved.
Show resolved Hide resolved
new_ranges[current_dim.value] = SymbolicRange.translate(
symbolic_domain.ranges[current_dim.value], val.value
)
return SymbolicDomain("cartesian_domain", new_ranges)

@classmethod
def infer_as_fieldop(
cls, applied_fieldop: itir.FunCall, input_domain: SymbolicDomain
) -> Tuple[itir.FunCall, Dict[str, SymbolicDomain]]: # todo: test scan operator
assert isinstance(applied_fieldop, itir.FunCall) and isinstance(
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved
applied_fieldop.fun, itir.FunCall
)
assert applied_fieldop.fun.fun == im.ref("as_fieldop")

stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args

inputs_node = []
accessed_domains = {}

# Set inputs for StencilClosure node by replacing FunCalls with temporary SymRefs
tmp_counter = 0
for in_field in inputs:
if isinstance(in_field, itir.FunCall):
in_field.id = im.ref(f"__dom_inf_{tmp_counter}")
inputs_node.append(im.ref(in_field.id))
accessed_domains[str(in_field.id)] = []
tmp_counter += 1
else:
inputs_node.append(in_field)
accessed_domains[str(in_field.id)] = []

out_field_name = "tmp" # todo: can this be derived from somewhere?

symbolic_domain = cls._get_symbolic_domain(input_domain)

# TODO: until TraceShifts directly supporty stencils we just wrap our expression into a dummy closure in this helper function.
def trace_shifts(stencil: itir.Expr, inputs: list[itir.Expr], domain: itir.Expr):
SF-N marked this conversation as resolved.
Show resolved Hide resolved
node = itir.StencilClosure(
stencil=stencil,
inputs=inputs,
output=im.ref(out_field_name),
domain=domain,
)
return TraceShifts.apply(node)

# Extract the shifts and translate the domains accordingly
shifts_results = trace_shifts(stencil, inputs_node, SymbolicDomain.as_expr(symbolic_domain))
dims = cls._extract_axis_dims(SymbolicDomain.as_expr(symbolic_domain))
SF-N marked this conversation as resolved.
Show resolved Hide resolved

for in_field in inputs:
in_field_id = str(in_field.id)
shifts_list = shifts_results[in_field_id]

new_domains = [
cls._translate_domain(symbolic_domain, shift, dims) for shift in shifts_list
]

accessed_domains[in_field_id] = domain_union(new_domains)

inputs_new = []
for in_field in inputs:
# Recursively traverse inputs
if isinstance(in_field, itir.FunCall):
transformed_calls_tmp, accessed_domains_tmp = cls.infer_as_fieldop(
in_field, accessed_domains[str(in_field.id)]
)
inputs_new.append(transformed_calls_tmp)

# Merge accessed_domains and accessed_domains_tmp
accessed_domains = cls._merge_domains(accessed_domains, accessed_domains_tmp)
else:
inputs_new.append(in_field)

transformed_call = im.call(
im.call("as_fieldop")(stencil, SymbolicDomain.as_expr(symbolic_domain))
)(*inputs_new)

accessed_domains_without_tmp = {
k: v for k, v in accessed_domains.items() if not k.startswith("__dom_inf_")
}

return transformed_call, accessed_domains_without_tmp
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,17 @@ def test_constant_folding_minimum():
expected = im.ref("a")
actual = ConstantFolding.apply(testee)
assert actual == expected


def test_constant_folding_literal():
testee = im.plus(im.literal_from_value(1), im.literal_from_value(2))
expected = im.literal_from_value(3)
actual = ConstantFolding.apply(testee)
assert actual == expected


def test_constant_folding_literal_maximum():
testee = im.call("maximum")(im.literal_from_value(1), im.literal_from_value(2))
expected = im.literal_from_value(2)
actual = ConstantFolding.apply(testee)
assert actual == expected
Loading
Loading