From e1bc3630adae67517fcf62b6ee105dc0e57c8c57 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Sep 2024 11:10:22 +0200 Subject: [PATCH 1/2] Modified `gt_simplify()`. Before the pass was just calling the native DaCe version. Now it will not call the `PromoteScalarToSymbol` and `ConstantPropagation` pass. This is because the lowering sometimes has to change between a symbol and a scalar and back and back angain and so on. Furthermore, it looks like these passes have problems with that, so we excluded them. This is a temporary solution, at the end, it might be feasable or good to run the full simplify pass. --- .../transformations/auto_opt.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index e19dccc67f..bba6ef8947 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -8,7 +8,7 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any, Optional, Sequence +from typing import Any, Final, Optional, Sequence import dace from dace.transformation import dataflow as dace_dataflow @@ -21,6 +21,17 @@ ) +GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} +"""Set of simplify passes `gt_simplify()` skips by default. + +The following passes are included: +- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a + symbol or vice versa and at a later point to invert this again. However, this + pass has some problems with this pattern so for the time being it is disabled. +- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. +""" + + def gt_simplify( sdfg: dace.SDFG, validate: bool = True, @@ -32,19 +43,18 @@ def gt_simplify( Instead of calling `sdfg.simplify()` directly, you should use this function, as it is specially tuned for GridTool based SDFGs. + By default this function will run the normal DaCe simplify pass. However, if + `skip` is not set or `None` then the parts listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET` + will be skipped. + Args: sdfg: The SDFG to optimize. validate: Perform validation after the pass has run. validate_all: Perform extensive validation. skip: List of simplify passes that should not be applied. - - Note: - The reason for this function is that we can influence how simplify works. - Since some parts in simplify might break things in the SDFG. - However, currently nothing is customized yet, and the function just calls - the simplification pass directly. """ - + if skip is None: + skip = GT_SIMPLIFY_DEFAULT_SKIP_SET return dace_passes_simplify.SimplifyPass( validate=validate, validate_all=validate_all, From 3e45710146140a0c4e635a338ad1a7ef9acb615d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Sep 2024 11:42:17 +0200 Subject: [PATCH 2/2] `GT_SIMPLIFY_SKIP_SET` is now the default argument. Thus if `skip` is `None` or empty, then the full DaCe pass will be run. Futhermore, the argument now accepts an iterable, for correctness we have to turn it into a `set` though. --- .../dace_fieldview/transformations/auto_opt.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index bba6ef8947..3895f7f5e8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -8,7 +8,7 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any, Final, Optional, Sequence +from typing import Any, Final, Iterable, Optional, Sequence import dace from dace.transformation import dataflow as dace_dataflow @@ -36,30 +36,30 @@ def gt_simplify( sdfg: dace.SDFG, validate: bool = True, validate_all: bool = False, - skip: Optional[set[str]] = None, + skip: Optional[Iterable[str]] = GT_SIMPLIFY_DEFAULT_SKIP_SET, ) -> Any: """Performs simplifications on the SDFG in place. Instead of calling `sdfg.simplify()` directly, you should use this function, as it is specially tuned for GridTool based SDFGs. - By default this function will run the normal DaCe simplify pass. However, if - `skip` is not set or `None` then the parts listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET` - will be skipped. + By default this function will run the normal DaCe simplify pass, but skip + passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET`. If `skip` is passed it + will be forwarded to DaCe, i.e. `GT_SIMPLIFY_DEFAULT_SKIP_SET` are not + added automatically. Args: sdfg: The SDFG to optimize. validate: Perform validation after the pass has run. validate_all: Perform extensive validation. - skip: List of simplify passes that should not be applied. + skip: List of simplify passes that should not be applied, defaults + to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. """ - if skip is None: - skip = GT_SIMPLIFY_DEFAULT_SKIP_SET return dace_passes_simplify.SimplifyPass( validate=validate, validate_all=validate_all, verbose=False, - skip=skip, + skip=set(skip) if skip is not None else skip, ).apply_pass(sdfg, {})