From 1b882761a58d0a58d4456dfa950e8a3c68b9b114 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 16 Jan 2025 11:14:07 +0100 Subject: [PATCH] fix[dace][next]: Fix for DistributedBufferRelocator (#1799) This PR fixes an error that was reported by Edoardo (@edopao). The bug was because the `DistributedBufferRelocator` transformation did not check if its insertion would create a read-write conflict. This commit adds such a check, that is, however, not very sophisticated and needs some improvements. However, the example /`model/atmosphere/dycore/tests/dycore_stencil_tests/test_compute_exner_from_rhotheta.py`) where it surfaced, does hold more challenges. The main purpose of this PR is to unblock further development in ICON4Py. Link to ICON4Py PR: https://github.com/C2SM/icon4py/pull/638 --- .../transformations/simplify.py | 254 ++++++++++++++---- .../test_distributed_buffer_relocator.py | 217 ++++++++++++++- 2 files changed, 406 insertions(+), 65 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 4339a761fa..bb95244aef 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -374,7 +374,7 @@ def apply( raise -AccessLocation: TypeAlias = tuple[dace.SDFGState, dace_nodes.AccessNode] +AccessLocation: TypeAlias = tuple[dace_nodes.AccessNode, dace.SDFGState] """Describes an access node and the state in which it is located. """ @@ -387,29 +387,38 @@ class DistributedBufferRelocator(dace_transformation.Pass): in each branch and then in the join state written back. Thus there is some additional storage needed. The transformation will look for the following situation: - - A transient data container, called `src_cont`, is written into another - container, called `dst_cont`, which is not transient. - - The access node of `src_cont` has an in degree of zero and an out degree of one. - - The access node of `dst_cont` has an in degree of of one and an + - A transient data container, called `temp_storage`, is written into another + container, called `dest_storage`, which is not transient. + - The access node of `temp_storage` has an in degree of zero and an out degree of one. + - The access node of `dest_storage` has an in degree of of one and an out degree of zero (this might be lifted). - - `src_cont` is not used afterwards. - - `dst_cont` is only used to implement the buffering. + - `temp_storage` is not used afterwards. + - `dest_storage` is only used to implement the buffering. - The function will relocate the writing of `dst_cont` to where `src_cont` is + The function will relocate the writing of `dest_storage` to where `temp_storage` is written, which might be multiple locations. It will also remove the writing back. It is advised that after this transformation simplify is run again. + The relocation will not take place if it might create data race. A necessary + but not sufficient condition for a data race is if `dest_storage` is present + in the state where `temp_storage` is defined. In addition at least one of the + following conditions has to be met: + - There are accesses to `dest_storage` that are not predecessor to the node where + the data is stored inside `temp_storage`. This check will ignore empty Memlets. + - There is a `dest_storage` access node, that has an output degree larger + than one. + Note: - Essentially this transformation removes the double buffering of `dst_cont`. - Because we ensure that that `dst_cont` is non transient this is okay, as our - rule guarantees this. + - Essentially this transformation removes the double buffering of + `dest_storage`. Because we ensure that that `dest_storage` is non + transient this is okay, as our rule guarantees this. Todo: - - Allow that `dst_cont` can also be transient. - - Allow that `dst_cont` does not need to be a sink node, this is most + - Allow that `dest_storage` can also be transient. + - Allow that `dest_storage` does not need to be a sink node, this is most likely most relevant if it is transient. - - Check if `dst_cont` is used between where we want to place it and + - Check if `dest_storage` is used between where we want to place it and where it is currently used. """ @@ -489,10 +498,10 @@ def _find_candidates( where the temporary is defined. """ # All nodes that are used as distributed buffers. - candidate_src_cont: list[AccessLocation] = [] + candidate_temp_storage: list[AccessLocation] = [] - # Which `src_cont` access node is written back to which global memory. - src_cont_to_global: dict[dace_nodes.AccessNode, str] = {} + # Which `temp_storage` access node is written back to which global memory. + temp_storage_to_global: dict[dace_nodes.AccessNode, str] = {} for state in sdfg.states(): # These are the possible targets we want to write into. @@ -508,26 +517,26 @@ def _find_candidates( if len(candidate_dst_nodes) == 0: continue - for src_cont in state.source_nodes(): - if not isinstance(src_cont, dace_nodes.AccessNode): + for temp_storage in state.source_nodes(): + if not isinstance(temp_storage, dace_nodes.AccessNode): continue - if not src_cont.desc(sdfg).transient: + if not temp_storage.desc(sdfg).transient: continue - if state.out_degree(src_cont) != 1: + if state.out_degree(temp_storage) != 1: continue dst_candidate: dace_nodes.AccessNode = next( - iter(edge.dst for edge in state.out_edges(src_cont)) + iter(edge.dst for edge in state.out_edges(temp_storage)) ) if dst_candidate not in candidate_dst_nodes: continue - candidate_src_cont.append((src_cont, state)) - src_cont_to_global[src_cont] = dst_candidate.data + candidate_temp_storage.append((temp_storage, state)) + temp_storage_to_global[temp_storage] = dst_candidate.data - if len(candidate_src_cont) == 0: + if len(candidate_temp_storage) == 0: return [] # Now we have to find the places where the temporary sources are defined. - # I.e. This is also the location where the original value is defined. + # I.e. This is also the location where the temporary source was initialized. result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = [] def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: @@ -537,72 +546,199 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: if dst_state in reachable[src_state] and dst_state is not src_state } - for src_cont in candidate_src_cont: + for temp_storage in candidate_temp_storage: + temp_storage_node, temp_storage_state = temp_storage def_locations: list[AccessLocation] = [] - for upstream_state in find_upstream_states(src_cont[1]): - if src_cont[0].data in access_sets[upstream_state][1]: + for upstream_state in find_upstream_states(temp_storage_state): + if temp_storage_node.data in access_sets[upstream_state][1]: def_locations.extend( (data_node, upstream_state) for data_node in upstream_state.data_nodes() - if data_node.data == src_cont[0].data + if data_node.data == temp_storage_node.data ) if len(def_locations) != 0: - result_candidates.append((src_cont, def_locations)) + result_candidates.append((temp_storage, def_locations)) - # This transformation removes `src_cont` by writing its content directly - # to `dst_cont`, at the point where it is defined. + # This transformation removes `temp_storage` by writing its content directly + # to `dest_storage`, at the point where it is defined. # For this transformation to be valid the following conditions have to be met: - # - Between the definition of `src_cont` and the write back to `dst_cont`, - # `dst_cont` can not be accessed. - # - Between the definitions of `src_cont` and the point where it is written - # back, `src_cont` can only be accessed in the range that is written back. - # - After the write back point, `src_cont` shall not be accessed. This + # - Between the definition of `temp_storage` and the write back to `dest_storage`, + # `dest_storage` can not be accessed. + # - Between the definitions of `temp_storage` and the point where it is written + # back, `temp_storage` can only be accessed in the range that is written back. + # - After the write back point, `temp_storage` shall not be accessed. This # restriction could be lifted. # # To keep the implementation simple, we use the conditions: - # - `src_cont` is only accessed were it is defined and at the write back + # - `temp_storage` is only accessed were it is defined and at the write back # point. - # - Between the definitions of `src_cont` and the write back point, - # `dst_cont` is not used. + # - Between the definitions of `temp_storage` and the write back point, + # `dest_storage` is not used. result: list[tuple[AccessLocation, list[AccessLocation]]] = [] - for wb_localation, def_locations in result_candidates: + for wb_location, def_locations in result_candidates: + # Get the state and the location where the temporary is written back + # into the global data container. + wb_node, wb_state = wb_location + for def_node, def_state in def_locations: - # Test if `src_cont` is only accessed where it is defined and + # Test if `temp_storage` is only accessed where it is defined and # where it is written back. if gtx_transformations.util.is_accessed_downstream( start_state=def_state, sdfg=sdfg, - data_to_look=wb_localation[0].data, - nodes_to_ignore={def_node, wb_localation[0]}, + data_to_look=wb_node.data, + nodes_to_ignore={def_node, wb_node}, ): break # check if the global data is not used between the definition of - # `dst_cont` and where its written back. We allow one exception, - # if the global data is used in the state the distributed temporary - # is defined is used only for reading then it is ignored. This is - # allowed because of rule 3 of ADR0018. - glob_nodes_in_def_state = { - dnode - for dnode in def_state.data_nodes() - if dnode.data == src_cont_to_global[wb_localation[0]] + # `dest_storage` and where its written back. However, we ignore + # the state were `temp_storage` is defined. The checks if these + # checks are performed by the `_check_read_write_dependency()` + # function. + global_data_name = temp_storage_to_global[wb_node] + global_nodes_in_def_state = { + dnode for dnode in def_state.data_nodes() if dnode.data == global_data_name } - if any(def_state.in_degree(gdnode) != 0 for gdnode in glob_nodes_in_def_state): - break if gtx_transformations.util.is_accessed_downstream( start_state=def_state, sdfg=sdfg, - data_to_look=src_cont_to_global[wb_localation[0]], - nodes_to_ignore=glob_nodes_in_def_state, - states_to_ignore={wb_localation[1]}, + data_to_look=global_data_name, + nodes_to_ignore=global_nodes_in_def_state, + states_to_ignore={wb_state}, ): break + if self._check_read_write_dependency(sdfg, wb_location, def_locations): + break else: - result.append((wb_localation, def_locations)) + result.append((wb_location, def_locations)) return result + def _check_read_write_dependency( + self, + sdfg: dace.SDFG, + write_back_location: AccessLocation, + target_locations: list[AccessLocation], + ) -> bool: + """Tests if read-write conflicts would be created. + + This function ensures that the substitution of `write_back_location` into + `target_locations` will not create a read-write conflict. + The rules that are used for this are outlined in the class description. + + Args: + sdfg: The SDFG on which we operate. + write_back_location: Where currently the write back occurs. + target_locations: List of the locations where we would like to perform + the write back instead. + + Returns: + If a read-write dependency is detected then the function will return + `True` and if none was detected `False` will be returned. + """ + for target_location in target_locations: + if self._check_read_write_dependency_impl(sdfg, write_back_location, target_location): + return True + return False + + def _check_read_write_dependency_impl( + self, + sdfg: dace.SDFG, + write_back_location: AccessLocation, + target_location: AccessLocation, + ) -> bool: + """Tests if read-write conflict would be created for a single location. + + Args: + sdfg: The SDFG on which we operate. + write_back_location: Where currently the write back occurs. + target_locations: Location where the new write back should be performed. + + Todo: + Refine these checks later. + + Returns: + If a read-write dependency is detected then the function will return + `True` and if none was detected `False` will be returned. + """ + assert write_back_location[0].data == target_location[0].data + + # Get the state and the location where the temporary is written back + # into the global data container. Because `write_back_node` refers to + # the temporary we must query the graph to find the global node. + write_back_node, write_back_state = write_back_location + write_back_edge = next(iter(write_back_state.out_edges(write_back_node))) + global_data_name = write_back_edge.dst.data + assert not sdfg.arrays[global_data_name].transient + assert write_back_state.out_degree(write_back_node) == 1 + assert write_back_state.in_degree(write_back_node) == 0 + + # Get the location and the state where the temporary is originally defined. + def_location_of_intermediate, state_to_inspect = target_location + assert state_to_inspect.out_degree(def_location_of_intermediate) == 0 + + # These are all access nodes that refers to the global data, that we want + # to move into the state `state_to_inspect`. We need them to do the + # second test. + accesses_to_global_data: set[dace_nodes.AccessNode] = set() + + # In the first check we look for an access node, to the global data, that + # has an output degree larger than one. However, for this we ignore all + # empty Memlets. This is done because such Memlets are used to induce a + # schedule or order in the dataflow graph. + # As a byproduct, for the second test, we also collect all of these nodes. + for dnode in state_to_inspect.data_nodes(): + if dnode.data != global_data_name: + continue + dnode_degree = sum( + (1 for oedge in state_to_inspect.out_edges(dnode) if not oedge.data.is_empty()) + ) + if dnode_degree > 1: + return True + # TODO(phimuell): Maybe AccessNodes with zero input degree should be ignored. + accesses_to_global_data.add(dnode) + + # There is no reference to the global data, so no need to do more tests. + if len(accesses_to_global_data) == 0: + return False + + # For the second test we will explore the dataflow graph, in reverse order, + # starting from the definition of the temporary node. If we find an access + # to the global data we remove it from the `accesses_to_global_data` list. + # If the list has not become empty, then we know that there is some sind + # branch (or concurrent dataflow) in this state that accesses the global + # data and we will have read-write conflicts. + # It is however, important to realize that passing this check does not + # imply that there are no read-write. We assume here that all accesses to + # the global data that was made before the write back were constructed in + # a correct way. + to_process: list[dace_nodes.Node] = [def_location_of_intermediate] + seen: set[dace_nodes.Node] = set() + while len(to_process) != 0: + node = to_process.pop() + seen.add(node) + + if isinstance(node, dace_nodes.AccessNode): + if node.data == global_data_name: + accesses_to_global_data.discard(node) + if len(accesses_to_global_data) == 0: + return False + + # Note that we only explore the ingoing edges, thus we will not necessarily + # explore the whole graph. However, this is fine, because we will see the + # relevant parts. To see that assume that we would also have to check the + # outgoing edges, this would mean that there was some branching point, + # which is a serialization point, so the dataflow would have been invalid + # before. + to_process.extend( + iedge.src for iedge in state_to_inspect.in_edges(node) if iedge.src not in seen + ) + + assert len(accesses_to_global_data) > 0 + return True + @dace_properties.make_properties class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index 1543a048ad..d61b8a2d42 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -13,7 +13,7 @@ transformations as gtx_transformations, ) -# from . import util +from . import util # dace = pytest.importorskip("dace") @@ -21,8 +21,8 @@ import dace -def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG("NAME") # util.unique_name("distributed_buffer_sdfg")) +def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_sdfg")) for name in ["a", "b", "tmp"]: sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) @@ -66,19 +66,224 @@ def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: sdfg.validate() assert sdfg.number_of_nodes() == 3 - return sdfg, state1 + return sdfg, state1, state3 def test_distributed_buffer_remover(): - sdfg, state1 = _mk_distributed_buffer_sdfg() + sdfg, state1, state3 = _mk_distributed_buffer_sdfg() assert state1.number_of_nodes() == 5 assert not any(dnode.data == "b" for dnode in state1.data_nodes()) res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) - assert res is not None + assert res[sdfg]["DistributedBufferRelocator"][state3] == {"tmp"} # Because the final state has now become empty assert sdfg.number_of_nodes() == 3 assert state1.number_of_nodes() == 6 assert any(dnode.data == "b" for dnode in state1.data_nodes()) assert any(dnode.data == "tmp" for dnode in state1.data_nodes()) + + +def _make_distributed_buffer_global_memory_data_race_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race")) + arr_names = ["a", "b", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + state1.add_nedge(a_state1, state1.add_access("b"), dace.Memlet("a[0:10, 0:10]")) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_race(): + """Tests if the transformation realized that it would create a data race. + + If the transformation would apply, then `a` is read twice, once from two + different branches, whose order of execution is indeterminate. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_race_sdfg() + assert state2.number_of_nodes() == 2 + + sdfg.simplify() + assert sdfg.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert "DistributedBufferRelocator" not in res[sdfg] + assert state2.number_of_nodes() == 2 + + +def _make_distributed_buffer_global_memory_data_race_sdfg2() -> ( + tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] +): + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race2_sdfg")) + arr_names = ["a", "b", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in - 10", + outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + external_edges=True, + ) + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state1, state2 + + +def test_distributed_buffer_global_memory_data_race2(): + """Tests if the transformation realized that it would create a data race. + + Similar situation but now there are two different subgraphs. This is needed + because it is another branch that checks it. + """ + sdfg, state1, state2 = _make_distributed_buffer_global_memory_data_race_sdfg2() + assert state1.number_of_nodes() == 10 + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert "DistributedBufferRelocator" not in res[sdfg] + assert state1.number_of_nodes() == 10 + assert state2.number_of_nodes() == 2 + + +def _make_distributed_buffer_global_memory_data_no_rance() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance_sdfg")) + arr_names = ["a", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_no_rance(): + """Transformation applies if there is no data race. + + According to ADR18, pointwise dependencies are fine. This tests checks if the + checks for the read-write conflicts are not too strong. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_no_rance() + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} + assert state2.number_of_nodes() == 0 + + +def _make_distributed_buffer_global_memory_data_no_rance2() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance2_sdfg")) + arr_names = ["a", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("a[__i0, __i1]")}, + output_nodes={a_state1}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "computation2", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_no_rance2(): + """Transformation applies if there is no data race. + + These dependency is fine, because the access nodes are in a clear serial order. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_no_rance2() + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} + assert state2.number_of_nodes() == 0