Skip to content

Commit

Permalink
feat[next][dace]: Build clean nestedSDFG without unused data connecto…
Browse files Browse the repository at this point in the history
…rs (#1628)

The main objective of this PR is to reduce the GT4Py dependency on DaCe
(especially on DaCe transformations) to produce clean SDFGs. Instead of
relying on the `PruneConnectors` pass to remove unused data connectors
in nested SDFGs, these connectors can be avoided in the first place,
when lowering let-lambdas to nested SDFGs.
  • Loading branch information
edopao authored Sep 4, 2024
1 parent 07e1b1d commit feda23d
Showing 1 changed file with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any, Dict, List, Optional, Protocol, Sequence, Set, Tuple, Union

import dace
import dace.transformation.dataflow as dace_dataflow

from gt4py import eve
from gt4py.eve import concepts
Expand Down Expand Up @@ -395,13 +394,21 @@ def visit_Lambda(
lambda_symbols = self.global_symbols | {
pname: type_ for pname, (_, type_) in lambda_args_mapping.items()
}
# obtain the set of symbols that are used in the lambda node and all its child nodes
used_symbols = {str(sym.id) for sym in eve.walk_values(node).if_isinstance(gtir.SymRef)}

nsdfg = dace.SDFG(f"{sdfg.label}_nested")
nstate = nsdfg.add_state("lambda")

# add sdfg storage for the symbols that need to be passed as input parameters,
# that is only the symbols that are used in the context of the lambda node
self._add_sdfg_params(
nsdfg,
[gtir.Sym(id=p_name, type=p_type) for p_name, p_type in lambda_symbols.items()],
[
gtir.Sym(id=p_name, type=p_type)
for p_name, p_type in lambda_symbols.items()
if p_name in used_symbols
],
)

lambda_nodes = GTIRToSDFG(self.offset_provider, lambda_symbols.copy()).visit(
Expand Down Expand Up @@ -527,9 +534,5 @@ def build_sdfg_from_gtir(
sdfg = sdfg_genenerator.visit(program)
assert isinstance(sdfg, dace.SDFG)

# nested-SDFGs for let-lambda may contain unused symbols, in which case
# we can remove unnecesssary data connectors (not done by dace simplify pass)
sdfg.apply_transformations_repeated(dace_dataflow.PruneConnectors)

sdfg.simplify()
return sdfg

0 comments on commit feda23d

Please sign in to comment.