Skip to content

Commit

Permalink
refactor[cartesian]: readability improvements in gtir -> oir conversi…
Browse files Browse the repository at this point in the history
…on and other cleanups (#1630)

Being new to the codebase and in preparation of the [GT4Py-DaCe bridge
refactor](GEOS-ESM/NDSL#53) I was reading a
lot of code. Some things that caught my eye are summarized in this PR.
Bigger changes include

1. Moving `daceir.py` into the `dace/` backend folder. This follows the
convention of all other backends who have their IR inside the backend
folder. (last commit)
2. Readability improvements in `gtir_to_oir.py` (the "Cleanup visit_*"
commits): In my opinion, there's no point in shortening "statement" to
"stmt". I wouldn't go as far as renaming the classes. For the renamed
local variables, I think this adds a lot to readability.
3. `visit_While` in `gtir_to_oir.py` was having an extra mask statement
around the `while` loop (in case a mask was defined. I inlined the
potential `mask` with the condition of `while` loop.
4. In `gtir_to_oir.py` don't translate every body statement into a
single `MaskStmt`. Instead, create one (or two incase of `else`)
`MaskStmt` with multiple statements in the `body` (which is already
typed as a list of statements).

Tested locally by running the test suite. All tests passing.

Co-authored-by: Roman Cattaneo <>
  • Loading branch information
romanc authored Sep 11, 2024
1 parent fe8349e commit e0fb2a2
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 99 deletions.
7 changes: 1 addition & 6 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import dace
import dace.data
from dace.sdfg.utils import inline_sdfgs
from dace.serialize import dumps

from gt4py import storage as gt_storage
from gt4py.cartesian import config as gt_config
Expand Down Expand Up @@ -56,10 +55,6 @@
from gt4py.cartesian.stencil_object import StencilObject


def _serialize_sdfg(sdfg: dace.SDFG):
return dumps(sdfg)


def _specialize_transient_strides(sdfg: dace.SDFG, layout_map):
repldict = replace_strides(
[array for array in sdfg.arrays.values() if array.transient], layout_map
Expand Down Expand Up @@ -125,7 +120,7 @@ def _set_expansion_orders(sdfg: dace.SDFG):


def _set_tile_sizes(sdfg: dace.SDFG):
import gt4py.cartesian.gtc.daceir as dcir # avoid circular import
import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import

for node, _ in filter(
lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive()
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/frontend/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@
"""

from __future__ import annotations

import enum
import operator
import sys
Expand Down Expand Up @@ -704,7 +706,7 @@ def is_single_index(self) -> bool:

return self.start.level == self.end.level and self.start.offset == self.end.offset - 1

def disjoint_from(self, other: "AxisInterval") -> bool:
def disjoint_from(self, other: AxisInterval) -> bool:
def get_offset(bound: AxisBound) -> int:
return (
0 + bound.offset if bound.level == LevelMarker.START else sys.maxsize + bound.offset
Expand Down
File renamed without changes.
34 changes: 14 additions & 20 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion_specification import Loop, Map, Sections, Stages
from gt4py.cartesian.gtc.dace.utils import (
compute_dcir_access_infos,
Expand Down Expand Up @@ -68,7 +69,13 @@ def _iterator():
).unique(key=lambda x: x[2])


def _get_tasklet_inout_memlets(node: oir.HorizontalExecution, *, get_outputs, global_ctx, **kwargs):
def _get_tasklet_inout_memlets(
node: oir.HorizontalExecution,
*,
get_outputs: bool,
global_ctx: DaCeIRBuilder.GlobalContext,
**kwargs,
):
access_infos = compute_dcir_access_infos(
node,
block_extents=global_ctx.library_node.get_extents,
Expand Down Expand Up @@ -190,12 +197,7 @@ def _get_dcir_decl(
@dataclass
class IterationContext:
grid_subset: dcir.GridSubset
parent: Optional[DaCeIRBuilder.IterationContext]

@classmethod
def init(cls, *args, **kwargs):
res = cls(*args, parent=None, **kwargs)
return res
parent: Optional[DaCeIRBuilder.IterationContext] = None

def push_axes_extents(self, axes_extents) -> DaCeIRBuilder.IterationContext:
res = self.grid_subset
Expand Down Expand Up @@ -611,7 +613,7 @@ def _process_map_item(
scope_nodes,
item: Map,
*,
global_ctx,
global_ctx: DaCeIRBuilder.GlobalContext,
iteration_ctx: DaCeIRBuilder.IterationContext,
symbol_collector: DaCeIRBuilder.SymbolCollector,
**kwargs,
Expand Down Expand Up @@ -787,19 +789,13 @@ def _process_iteration_item(self, scope, item, **kwargs):
def visit_VerticalLoop(
self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs
):
start, end = (node.sections[0].interval.start, node.sections[0].interval.end)

overall_interval = dcir.DomainInterval(
start=dcir.AxisBound(axis=dcir.Axis.K, level=start.level, offset=start.offset),
end=dcir.AxisBound(axis=dcir.Axis.K, level=end.level, offset=end.offset),
)
overall_extent = Extent.zeros(2)
for he in node.walk_values().if_isinstance(oir.HorizontalExecution):
overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he))

iteration_ctx = DaCeIRBuilder.IterationContext.init(
iteration_ctx = DaCeIRBuilder.IterationContext(
grid_subset=dcir.GridSubset.from_gt4py_extent(overall_extent).set_interval(
axis=dcir.Axis.K, interval=overall_interval
axis=dcir.Axis.K, interval=node.sections[0].interval
)
)

Expand Down Expand Up @@ -849,13 +845,11 @@ def visit_VerticalLoop(

read_fields = set(memlet.field for memlet in read_memlets)
write_fields = set(memlet.field for memlet in write_memlets)
res = dcir.NestedSDFG(
return dcir.NestedSDFG(
label=global_ctx.library_node.label,
states=self.to_state(computations, grid_subset=iteration_ctx.grid_subset),
field_decls=field_decls,
read_memlets=[memlet for memlet in field_memlets if memlet.field in read_fields],
write_memlets=[memlet for memlet in field_memlets if memlet.field in write_fields],
symbol_decls=list(symbol_collector.symbol_decls.values()),
)

return res
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dace.subsets
import sympy

from gt4py.cartesian.gtc import daceir as dcir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder
from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import daceir as dcir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen
from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import gt4py.cartesian.gtc.common as common
from gt4py import eve
from gt4py.cartesian.gtc import daceir as dcir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.symbol_utils import get_axis_bound_str
from gt4py.cartesian.gtc.dace.utils import make_dace_subset
from gt4py.eve.codegen import FormatTemplate as as_fmt
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.definitions import Extent


Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import dace

from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.definitions import Extent


Expand Down
15 changes: 7 additions & 8 deletions src/gt4py/cartesian/gtc/dace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import numpy as np
from dace import library

from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection
Expand Down Expand Up @@ -215,10 +216,8 @@ def has_splittable_regions(self):
def tile_strides(self):
if self.tile_sizes_interpretation == "strides":
return self.tile_sizes
else:
overall_extent: Extent = next(iter(self.extents.values()))
for extent in self.extents.values():
overall_extent |= extent
return {
key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items()
}

overall_extent: Extent = next(iter(self.extents.values()))
for extent in self.extents.values():
overall_extent |= extent
return {key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items()}
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import gt4py.cartesian.gtc.oir as oir
from gt4py import eve
from gt4py.cartesian.gtc import daceir as dcir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/symbol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


if TYPE_CHECKING:
import gt4py.cartesian.gtc.daceir as dcir
import gt4py.cartesian.gtc.dace.daceir as dcir


def data_type_to_dace_typeclass(data_type):
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import numpy as np

from gt4py import eve
from gt4py.cartesian.gtc import common, daceir as dcir, oir
from gt4py.cartesian.gtc.common import CartesianOffset
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.common import CartesianOffset, VariableKOffset
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents


Expand Down Expand Up @@ -56,7 +57,9 @@ def replace_strides(arrays, get_layout_map):
return symbol_mapping


def get_tasklet_symbol(name, offset, is_target):
def get_tasklet_symbol(
name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool
):
if is_target:
return f"__{name}"

Expand Down
Loading

0 comments on commit e0fb2a2

Please sign in to comment.