Skip to content

Commit

Permalink
GTIR domain inference (#1568)
Browse files Browse the repository at this point in the history
Uses TraceShifts to infer the minimal domain of (nested) `as_fieldop`s and SetAt:

Domain inference of let-statements is not implemented yet and will be done in PR #1591

Changes:
- New file gt4py.next.iterator.transforms.infer_domain with function
`infer_as_fieldop` and `infer_program`
- Several new tests in test_infer_domain.py to test functionality
- New function SymbolicDomain.translate
- New helper-function `domain` and `as_fieldop` in ir.makers
- Uses constant folding of `itir.Literal`s to simplify resulting domains

---------

Co-authored-by: Sara Faghih-Naini <[email protected]>
Co-authored-by: Till Ehrengruber <[email protected]>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent 6eba654 commit 196cc5f
Show file tree
Hide file tree
Showing 5 changed files with 850 additions and 11 deletions.
39 changes: 39 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Callable, Iterable, Optional, Union

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Dict, Tuple
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand Down Expand Up @@ -397,9 +399,46 @@ def map_(op):
return call(call("map_")(op))


def domain(
grid_type: Union[common.GridType, str],
ranges: Dict[Union[common.Dimension, str], Tuple[itir.Expr, itir.Expr]],
) -> itir.FunCall:
"""
>>> str(
... domain(
... common.GridType.CARTESIAN,
... {
... common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL): (0, 10),
... common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL): (0, 20),
... },
... )
... )
'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩'
>>> str(domain(common.GridType.CARTESIAN, {"IDim": (0, 10), "JDim": (0, 20)}))
'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩'
>>> str(domain(common.GridType.UNSTRUCTURED, {"IDim": (0, 10), "JDim": (0, 20)}))
'u⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩'
"""
if isinstance(grid_type, common.GridType):
grid_type = f"{grid_type!s}_domain"
return call(grid_type)(
*[
call("named_range")(
itir.AxisLiteral(value=d.value, kind=d.kind)
if isinstance(d, common.Dimension)
else itir.AxisLiteral(value=d),
r[0],
r[1],
)
for d, r in ranges.items()
]
)


def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call:
"""
Create an `as_fieldop` call.
Examples
--------
>>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2"))
Expand Down
58 changes: 52 additions & 6 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import gt4py.next as gtx
from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.eve.extended_typing import Dict, Tuple
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
Expand Down Expand Up @@ -446,12 +447,56 @@ def from_expr(cls, node: ir.Node) -> SymbolicDomain:
return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above

def as_expr(self) -> ir.FunCall:
return im.call(self.grid_type)(
*[
im.call("named_range")(ir.AxisLiteral(value=d.value, kind=d.kind), r.start, r.stop)
for d, r in self.ranges.items()
]
)
converted_ranges: dict[common.Dimension | str, tuple[ir.Expr, ir.Expr]] = {
key: (value.start, value.stop) for key, value in self.ranges.items()
}
return im.domain(self.grid_type, converted_ranges)

def translate(
self: SymbolicDomain,
shift: Tuple[ir.OffsetLiteral, ...],
offset_provider: Dict[str, common.Dimension],
) -> SymbolicDomain:
dims = list(self.ranges.keys())
new_ranges = {dim: self.ranges[dim] for dim in dims}
if len(shift) == 0:
return self
if len(shift) == 2:
off, val = shift
assert isinstance(off.value, str) and isinstance(val.value, int)
nbt_provider = offset_provider[off.value]
if isinstance(nbt_provider, common.Dimension):
current_dim = nbt_provider
# cartesian offset
new_ranges[current_dim] = SymbolicRange.translate(
self.ranges[current_dim], val.value
)
elif isinstance(nbt_provider, common.Connectivity):
# unstructured shift
# note: ugly but cheap re-computation, but should disappear
horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider)

old_dim = nbt_provider.origin_axis
new_dim = nbt_provider.neighbor_axis

assert new_dim not in new_ranges or old_dim == new_dim

# TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON?
new_range = SymbolicRange(
im.literal("0", ir.INTEGER_INDEX_BUILTIN),
im.literal(str(horizontal_sizes[new_dim.value]), ir.INTEGER_INDEX_BUILTIN),
)
new_ranges = dict(
(dim, range_) if dim != old_dim else (new_dim, new_range)
for dim, range_ in new_ranges.items()
)
else:
raise AssertionError()
return SymbolicDomain(self.grid_type, new_ranges)
elif len(shift) > 2:
return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider)
else:
raise AssertionError("Number of shifts must be a multiple of 2.")


def domain_union(domains: list[SymbolicDomain]) -> SymbolicDomain:
Expand All @@ -469,6 +514,7 @@ def domain_union(domains: list[SymbolicDomain]) -> SymbolicDomain:
[domain.ranges[dim].stop for domain in domains],
)
new_domain_ranges[dim] = SymbolicRange(start, stop)

return SymbolicDomain(domains[0].grid_type, new_domain_ranges)


Expand Down
205 changes: 205 additions & 0 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Dict, Tuple
from gt4py.next.common import Dimension
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union
from gt4py.next.iterator.transforms.trace_shifts import TraceShifts


def _merge_domains(
original_domains: Dict[str, SymbolicDomain], additional_domains: Dict[str, SymbolicDomain]
) -> Dict[str, SymbolicDomain]:
new_domains = {**original_domains}
for key, value in additional_domains.items():
if key in original_domains:
new_domains[key] = domain_union([original_domains[key], value])
else:
new_domains[key] = value

return new_domains


# FIXME[#1582](tehrengruber): Use new TraceShift API when #1592 is merged.
def trace_shifts(
stencil: itir.Expr, input_ids: list[str], domain: itir.Expr
) -> dict[str, set[tuple[itir.OffsetLiteral, ...]]]:
node = itir.StencilClosure(
stencil=stencil,
inputs=[im.ref(id_) for id_ in input_ids],
output=im.ref("__dummy"),
domain=domain,
)
return TraceShifts.apply(node, inputs_only=True) # type: ignore[return-value] # ensured by inputs_only=True


def extract_shifts_and_translate_domains(
stencil: itir.Expr,
input_ids: list[str],
target_domain: SymbolicDomain,
offset_provider: Dict[str, Dimension],
accessed_domains: Dict[str, SymbolicDomain],
):
shifts_results = trace_shifts(stencil, input_ids, SymbolicDomain.as_expr(target_domain))

for in_field_id in input_ids:
shifts_list = shifts_results[in_field_id]

new_domains = [
SymbolicDomain.translate(target_domain, shift, offset_provider) for shift in shifts_list
]
if new_domains:
accessed_domains[in_field_id] = domain_union(new_domains)


def infer_as_fieldop(
applied_fieldop: itir.FunCall,
target_domain: SymbolicDomain | itir.FunCall,
offset_provider: Dict[str, Dimension],
) -> Tuple[itir.FunCall, Dict[str, SymbolicDomain]]:
assert isinstance(applied_fieldop, itir.FunCall)
assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop")

# `as_fieldop(stencil)(inputs...)`
stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args

# ensure stencil has as many params as arguments
assert not isinstance(stencil, itir.Lambda) or len(stencil.params) == len(applied_fieldop.args)

input_ids: list[str] = []
accessed_domains: Dict[str, SymbolicDomain] = {}

# Assign ids for all inputs to `as_fieldop`. `SymRef`s stay as is, nested `as_fieldop` get a
# temporary id.
tmp_uid_gen = eve_utils.UIDGenerator(prefix="__dom_inf")
for in_field in inputs:
if isinstance(in_field, itir.FunCall) or isinstance(in_field, itir.Literal):
id_ = tmp_uid_gen.sequential_id()
elif isinstance(in_field, itir.SymRef):
id_ = in_field.id
else:
raise ValueError(f"Unsupported type {type(in_field)}")
input_ids.append(id_)

if isinstance(target_domain, itir.FunCall):
target_domain = SymbolicDomain.from_expr(target_domain)

extract_shifts_and_translate_domains(
stencil, input_ids, target_domain, offset_provider, accessed_domains
)

# Recursively infer domain of inputs and update domain arg of nested `as_fieldops`
transformed_inputs: list[itir.Expr] = []
for in_field_id, in_field in zip(input_ids, inputs):
if isinstance(in_field, itir.FunCall):
transformed_input, accessed_domains_tmp = infer_as_fieldop(
in_field, accessed_domains[in_field_id], offset_provider
)
transformed_inputs.append(transformed_input)

# Merge accessed_domains and accessed_domains_tmp
accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp)
elif isinstance(in_field, itir.SymRef) or isinstance(in_field, itir.Literal):
transformed_inputs.append(in_field)
else:
raise ValueError(f"Unsupported type {type(in_field)}")

transformed_call = im.as_fieldop(stencil, SymbolicDomain.as_expr(target_domain))(
*transformed_inputs
)

accessed_domains_without_tmp = {
k: v
for k, v in accessed_domains.items()
if not k.startswith(tmp_uid_gen.prefix) # type: ignore[arg-type] # prefix is always str
}

return transformed_call, accessed_domains_without_tmp


def _validate_temporary_usage(body: list[itir.Stmt], temporaries: list[str]):
assigned_targets = set()
for stmt in body:
assert isinstance(stmt, itir.SetAt) # TODO: extend for if-statements when they land
assert isinstance(
stmt.target, itir.SymRef
) # TODO: stmt.target can be an expr, e.g. make_tuple
if stmt.target.id in assigned_targets:
raise ValueError("Temporaries can only be used once within a program.")
if stmt.target.id in temporaries:
assigned_targets.add(stmt.target.id)


def infer_program(
program: itir.Program,
offset_provider: Dict[str, Dimension],
) -> itir.Program:
accessed_domains: dict[str, SymbolicDomain] = {}
transformed_set_ats: list[itir.SetAt] = []

temporaries: list[str] = [tmp.id for tmp in program.declarations]

_validate_temporary_usage(program.body, temporaries)

for set_at in reversed(program.body):
assert isinstance(set_at, itir.SetAt)
if isinstance(set_at.expr, itir.SymRef):
transformed_set_ats.insert(0, set_at)
continue
assert isinstance(set_at.expr, itir.FunCall)
assert cpm.is_call_to(set_at.expr.fun, "as_fieldop")
assert isinstance(
set_at.target, itir.SymRef
) # TODO: stmt.target can be an expr, e.g. make_tuple
if set_at.target.id in temporaries:
# ignore temporaries as their domain is the `AUTO_DOMAIN` placeholder
assert set_at.domain == AUTO_DOMAIN
else:
accessed_domains[set_at.target.id] = SymbolicDomain.from_expr(set_at.domain)

transformed_as_fieldop, current_accessed_domains = infer_as_fieldop(
set_at.expr, accessed_domains[set_at.target.id], offset_provider
)
transformed_set_ats.insert(
0,
itir.SetAt(
expr=transformed_as_fieldop,
domain=SymbolicDomain.as_expr(accessed_domains[set_at.target.id]),
target=set_at.target,
),
)

for field in current_accessed_domains:
if field in accessed_domains:
# multiple accesses to the same field -> compute union of accessed domains
if field in temporaries:
accessed_domains[field] = domain_union(
[accessed_domains[field], current_accessed_domains[field]]
)
else:
# TODO(tehrengruber): if domain_ref is an external field the domain must
# already be larger. This should be checked, but would require additions
# to the IR.
pass
else:
accessed_domains[field] = current_accessed_domains[field]

new_declarations = program.declarations
for temporary in new_declarations:
temporary.domain = SymbolicDomain.as_expr(accessed_domains[temporary.id])

return itir.Program(
id=program.id,
function_definitions=program.function_definitions,
params=program.params,
declarations=new_declarations,
body=transformed_set_ats,
)
14 changes: 9 additions & 5 deletions src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
import enum
import sys
from collections.abc import Callable
from typing import Any, Final, Iterable, Literal
Expand Down Expand Up @@ -45,11 +44,11 @@ def copy_recorded_shifts(from_: ir.Node, to: ir.Node) -> None:
to.annex.recorded_shifts = from_.annex.recorded_shifts


class Sentinel(enum.Enum):
VALUE = enum.auto()
TYPE = enum.auto()
class Sentinel(eve.StrEnum):
VALUE = "VALUE"
TYPE = "TYPE"

ALL_NEIGHBORS = enum.auto()
ALL_NEIGHBORS = "ALL_NEIGHBORS"


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -137,6 +136,11 @@ def _can_deref(x):


def _shift(*offsets):
assert all(
isinstance(offset, ir.OffsetLiteral) or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE]
for offset in offsets
)

def apply(arg):
assert isinstance(arg, IteratorTracer)
return arg.shift(offsets)
Expand Down
Loading

0 comments on commit 196cc5f

Please sign in to comment.