Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[dace]: Modified gt_simplify() #1647

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

"""Fast access to the auto optimization on DaCe."""

from typing import Any, Optional, Sequence
from typing import Any, Final, Iterable, Optional, Sequence

import dace
from dace.transformation import dataflow as dace_dataflow
Expand All @@ -21,35 +21,45 @@
)


GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"}
"""Set of simplify passes `gt_simplify()` skips by default.

The following passes are included:
- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a
symbol or vice versa and at a later point to invert this again. However, this
pass has some problems with this pattern so for the time being it is disabled.
- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`.
"""


def gt_simplify(
sdfg: dace.SDFG,
validate: bool = True,
validate_all: bool = False,
skip: Optional[set[str]] = None,
skip: Optional[Iterable[str]] = GT_SIMPLIFY_DEFAULT_SKIP_SET,
) -> Any:
"""Performs simplifications on the SDFG in place.

Instead of calling `sdfg.simplify()` directly, you should use this function,
as it is specially tuned for GridTool based SDFGs.

By default this function will run the normal DaCe simplify pass, but skip
passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET`. If `skip` is passed it
will be forwarded to DaCe, i.e. `GT_SIMPLIFY_DEFAULT_SKIP_SET` are not
added automatically.

Args:
sdfg: The SDFG to optimize.
validate: Perform validation after the pass has run.
validate_all: Perform extensive validation.
skip: List of simplify passes that should not be applied.

Note:
The reason for this function is that we can influence how simplify works.
Since some parts in simplify might break things in the SDFG.
However, currently nothing is customized yet, and the function just calls
the simplification pass directly.
skip: List of simplify passes that should not be applied, defaults
to `GT_SIMPLIFY_DEFAULT_SKIP_SET`.
"""

return dace_passes_simplify.SimplifyPass(
validate=validate,
validate_all=validate_all,
verbose=False,
skip=skip,
skip=set(skip) if skip is not None else skip,
).apply_pass(sdfg, {})


Expand Down
Loading