Skip to content

Commit

Permalink
Updated the changes for the optimizer a bit.
Browse files Browse the repository at this point in the history
The main important changes are:
- I added some instance checkes (just to make sure that we do the right thing) that the type checker can never verify.
- I also removed most of the accesses to `_pipeline_results` with regards to `StateReachability` as I realized that it was not working anyway.
	Instead I turned them into TODOs that we must address at some point.
	However, I think they are not so pressing because our state machines are usually quite small.
  • Loading branch information
philip-paul-mueller committed Feb 5, 2025
1 parent 8774462 commit 1a513be
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def gt_auto_optimize(
# For compatibility with DaCe (and until we found out why) the GT4Py
# auto optimizer will emulate this behaviour.
for state in sdfg.states():
assert isinstance(state, dace.SDFGState)
for edge in state.edges():
edge.data.wcr_nonatomic = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def gt_create_local_double_buffering(
it is not needed that the whole data is stored, but only the working set
of a single thread.
"""

processed_maps = 0
for nsdfg in sdfg.all_sdfgs_recursive():
processed_maps += _create_local_double_buffering_non_recursive(nsdfg)
Expand All @@ -60,6 +59,7 @@ def _create_local_double_buffering_non_recursive(

processed_maps = 0
for state in sdfg.states():
assert isinstance(state, dace.SDFGState)
scope_dict = state.scope_dict()
for node in state.nodes():
if not isinstance(node, dace_nodes.MapEntry):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import collections
import copy
import uuid
import warnings
from typing import Any, Final, Iterable, Optional, TypeAlias

import dace
Expand Down Expand Up @@ -210,12 +209,15 @@ def gt_substitute_compiletime_symbols(
repl: Maps the name of the symbol to the value it should be replaced with.
validate: Perform validation at the end of the function.
validate_all: Perform validation also on intermediate steps.
Todo: This function needs improvement.
"""

# We will use the `replace` function of the top SDFG, however, lower levels
# are handled using ConstantPropagation.
sdfg.replace_dict(repl)

# TODO(phimuell): Get rid of the `ConstantPropagation`
const_prop = dace_passes.ConstantPropagation()
const_prop.recursive = True
const_prop.progress = False
Expand Down Expand Up @@ -280,9 +282,6 @@ def __init__(
def expressions(cls) -> Any:
return [dace.sdfg.utils.node_path_graph(cls.node_read_g, cls.node_tmp, cls.node_write_g)]

def depends_on(self) -> set[type[dace_transformation.Pass]]:
return {dace_transformation.passes.StateReachability}

def can_be_applied(
self,
graph: dace.SDFGState | dace.SDFG,
Expand Down Expand Up @@ -344,18 +343,12 @@ def _is_read_downstream(
write_g: dace_nodes.AccessNode = self.node_write_g
tmp_node: dace_nodes.AccessNode = self.node_tmp

reachable_states: Optional[dict[dace.SDFGState, set[dace.SDFGState]]] = None
if self._pipeline_results and "StateReachability" in self._pipeline_results:
warnings.warn(
"The 'StateReachability' analysis pass was not part of the pipeline results.",
stacklevel=0,
)
reachable_states = self._pipeline_results["StateReachability"]

# TODO(phimuell): Run the `StateReachability` pass in a pipeline and use
# the `_pipeline_results` member to access the data.
return gtx_transformations.utils.is_accessed_downstream(
start_state=start_state,
sdfg=sdfg,
reachable_states=reachable_states,
reachable_states=None,
data_to_look=data_to_look,
nodes_to_ignore={read_g, write_g, tmp_node},
)
Expand Down Expand Up @@ -454,17 +447,20 @@ def depends_on(self) -> set[type[dace_transformation.Pass]]:
def apply_pass(
self, sdfg: dace.SDFG, pipeline_results: dict[str, Any]
) -> Optional[dict[dace.SDFGState, set[str]]]:
reachable: dict[dace.SDFGState, set[dace.SDFGState]] = pipeline_results[
"StateReachability"
][sdfg.cfg_id]

# NOTE: We can not use `AccessSets` because this pass operates on
# `ControlFlowBlock`s, which might consists of multiple states. Thus we are
# using `FindAccessStates` which has this `SDFGState` granularity. However,
# we have to determine if it is a write or not.
# using `FindAccessStates` which has this `SDFGState` granularity. The downside
# is, however, that we have to determine if the access in that state is a
# write or not, which means we have to find it first.
access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][
sdfg.cfg_id
]

# For speeding up the `is_accessed_downstream()` calls.
reachable: dict[dace.SDFGState, set[dace.SDFGState]] = pipeline_results[
"StateReachability"
][sdfg.cfg_id]

result: dict[dace.SDFGState, set[str]] = collections.defaultdict(set)

to_relocate = self._find_candidates(sdfg, reachable, access_states)
Expand Down Expand Up @@ -834,9 +830,6 @@ def __init__(
def expressions(cls) -> Any:
return [dace.sdfg.utils.node_path_graph(cls.tasklet, cls.access_node, cls.map_entry)]

def depends_on(self) -> set[type[dace_transformation.Pass]]:
return {dace_transformation.passes.StateReachability}

def can_be_applied(
self,
graph: dace.SDFGState | dace.SDFG,
Expand Down Expand Up @@ -955,18 +948,11 @@ def apply(
# The data is no longer referenced in this state, so we can potentially
# remove
if graph.out_degree(access_node) == 0:
reachable_states: Optional[dict[dace.SDFGState, set[dace.SDFGState]]] = None
if self._pipeline_results and "StateReachability" in self._pipeline_results:
warnings.warn(
"The 'StateReachability' analysis pass was not part of the pipeline results.",
stacklevel=0,
)
reachable_states = self._pipeline_results["StateReachability"]

# TODO(phimuell): Use the pipeline to run `StateReachability` once.
if not gtx_transformations.utils.is_accessed_downstream(
start_state=graph,
sdfg=sdfg,
reachable_states=reachable_states,
reachable_states=None,
data_to_look=access_node.data,
nodes_to_ignore={access_node},
):
Expand Down Expand Up @@ -1027,6 +1013,7 @@ class GT4PyMapBufferElimination(dace_transformation.SingleStateTransformation):
Todo:
- Implement a real pointwise test.
- Run this inside a pipeline.
"""

map_exit = dace_transformation.PatternNode(dace_nodes.MapExit)
Expand Down Expand Up @@ -1054,10 +1041,7 @@ def expressions(cls) -> Any:
return [dace.sdfg.utils.node_path_graph(cls.map_exit, cls.tmp_ac, cls.glob_ac)]

def depends_on(self) -> set[type[dace_transformation.Pass]]:
return {
dace_transformation.passes.ConsolidateEdges,
dace_transformation.passes.analysis.StateReachability,
}
return {dace_transformation.passes.ConsolidateEdges}

def can_be_applied(
self,
Expand All @@ -1066,14 +1050,6 @@ def can_be_applied(
sdfg: dace.SDFG,
permissive: bool = False,
) -> bool:
reachable_states: Optional[dict[dace.SDFGState, set[dace.SDFGState]]] = None
if self._pipeline_results and "StateReachability" in self._pipeline_results:
warnings.warn(
"The 'StateReachability' analysis pass was not part of the pipeline results.",
stacklevel=0,
)
reachable_states = self._pipeline_results["StateReachability"]

tmp_ac: dace_nodes.AccessNode = self.tmp_ac
glob_ac: dace_nodes.AccessNode = self.glob_ac
tmp_desc: dace_data.Data = tmp_ac.desc(sdfg)
Expand Down Expand Up @@ -1101,10 +1077,12 @@ def can_be_applied(
# Test if `tmp` is only anywhere else, this is important for removing it.
if graph.out_degree(tmp_ac) != 1:
return False
# TODO(phimuell): Use the pipeline system to run the `StateReachability` pass
# only once. Taking care of DaCe issue 1911.
if gtx_transformations.utils.is_accessed_downstream(
start_state=graph,
sdfg=sdfg,
reachable_states=reachable_states,
reachable_states=None,
data_to_look=tmp_ac.data,
nodes_to_ignore={tmp_ac},
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def gt_propagate_strides_from_access_node(
processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored.
Only specify when you know what your are doing.
"""
assert isinstance(state, dace.SDFGState)

if processed_nsdfgs is None:
# For preventing the case that nested SDFGs are handled multiple time.
processed_nsdfgs = set()
Expand Down Expand Up @@ -631,6 +633,7 @@ def _gt_find_toplevel_data_accesses(
not_top_level_data: set[str] = set()

for state in sdfg.states():
assert isinstance(state, dace.SDFGState)
scope_dict = state.scope_dict()
for dnode in state.data_nodes():
data: str = dnode.data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,18 @@ def is_accessed_downstream(
`reachable_states` argument.
- Fix the behaviour for `states_to_ignore`.
"""
# After DaCe 1 switched to a hierarchical version of the state machine. Thus
# it is no longer possible in a simple way to traverse the SDFG. As a temporary
# solution we use the `StateReachability` pass. However, this has some issues,
# see the note about `states_to_ignore`.
if reachable_states is None:
state_reachability_pass = dace_analysis.StateReachability()
reachable_states = state_reachability_pass.apply_pass(sdfg, None)[sdfg.cfg_id]
else:
# Ensures that the externally generated result was passed properly.
assert all(
isinstance(state, dace.SDFGState) and state.sdfg is sdfg for state in reachable_states
)

ign_dnodes: set[dace_nodes.AccessNode] = nodes_to_ignore or set()
ign_states: set[dace.SDFGState] = states_to_ignore or set()
Expand Down

0 comments on commit 1a513be

Please sign in to comment.