Skip to content

Commit

Permalink
Forward debug info from gt4py to dace
Browse files Browse the repository at this point in the history
Forward debug info from gt4py to dace if we know it. If we don't know,
just don't specify instead of setting `DebugInfo(0)`.

Moved `get_dace_debuginfo()` one folder higher from expansion utils into
"general" dace utils because its not only used in expansion.
  • Loading branch information
romanc committed Feb 1, 2025
1 parent 0d121e8 commit 0b75c61
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 28 deletions.
10 changes: 4 additions & 6 deletions src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from gt4py import eve
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen
from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import make_dace_subset
from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset


class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait):
Expand Down Expand Up @@ -268,13 +267,13 @@ def visit_ComputationState(
for memlet in computation.read_memlets:
if memlet.field not in read_acc_and_conn:
read_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
sdfg_ctx.state.add_access(memlet.field),
None,
)
for memlet in computation.write_memlets:
if memlet.field not in write_acc_and_conn:
write_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
sdfg_ctx.state.add_access(memlet.field),
None,
)
node_ctx = StencilComputationSDFGBuilder.NodeContext(
Expand All @@ -298,7 +297,7 @@ def visit_FieldDecl(
dtype=data_type_to_dace_typeclass(node.dtype),
storage=node.storage.to_dace_storage(),
transient=node.name not in non_transients,
debuginfo=dace.DebugInfo(0),
debuginfo=get_dace_debuginfo(node),
)

def visit_SymbolDecl(
Expand Down Expand Up @@ -343,7 +342,6 @@ def visit_NestedSDFG(
inputs=node.input_connectors,
outputs=node.output_connectors,
symbol_mapping=symbol_mapping,
debuginfo=dace.DebugInfo(0),
)
self.visit(
node.read_memlets,
Expand Down
14 changes: 0 additions & 14 deletions src/gt4py/cartesian/gtc/dace/expansion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@

from typing import TYPE_CHECKING, List

import dace
import dace.data
import dace.library
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
Expand All @@ -25,15 +20,6 @@
from gt4py.cartesian.gtc.dace.nodes import StencilComputation


def get_dace_debuginfo(node: common.LocNode):
if node.loc is not None:
return dace.dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)
else:
return dace.dtypes.DebugInfo(0)


class HorizontalIntervalRemover(eve.NodeTranslator):
def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis):
mask_attrs = dict(i=node.i, j=node.j)
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/gtc/dace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion
from gt4py.cartesian.gtc.dace.expansion.utils import HorizontalExecutionSplitter
from gt4py.cartesian.gtc.dace.expansion_specification import ExpansionItem, make_expansion_order
from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection

from .expansion.utils import HorizontalExecutionSplitter, get_dace_debuginfo
from .expansion_specification import ExpansionItem, make_expansion_order


def _set_expansion_order(
node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]]
Expand Down
14 changes: 9 additions & 5 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset
from gt4py.cartesian.gtc.dace.utils import (
compute_dcir_access_infos,
get_dace_debuginfo,
make_dace_subset,
)
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import (
AccessCollector,
Expand Down Expand Up @@ -115,15 +119,15 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFG
access_collection = AccessCollector.apply(node)

for field in access_collection.read_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
library_node.add_in_connector("__in_" + field)
subset = ctx.make_input_dace_subset(node, field)
state.add_edge(
access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset)
)

for field in access_collection.write_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
library_node.add_out_connector("__out_" + field)
subset = ctx.make_output_dace_subset(node, field)
state.add_edge(
Expand All @@ -146,7 +150,7 @@ def visit_Stencil(self, node: oir.Stencil):
],
dtype=data_type_to_dace_typeclass(param.dtype),
transient=False,
debuginfo=dace.DebugInfo(0),
debuginfo=get_dace_debuginfo(param),
)
else:
ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype))
Expand All @@ -164,7 +168,7 @@ def visit_Stencil(self, node: oir.Stencil):
],
dtype=data_type_to_dace_typeclass(decl.dtype),
transient=True,
debuginfo=dace.DebugInfo(0),
debuginfo=get_dace_debuginfo(decl),
)
self.generic_visit(node, ctx=ctx)
ctx.sdfg.validate()
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents


def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo:
if node.loc is None:
return dace.dtypes.DebugInfo(0)

return dace.dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)


def array_dimensions(array: dace.data.Array):
dims = [
any(
Expand Down

0 comments on commit 0b75c61

Please sign in to comment.