From c0bc17aa70e724e80d3be4370ef57e8aaaae0730 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Fri, 1 Nov 2024 15:58:15 +0100 Subject: [PATCH 1/2] Update dace/transformation/interstate/gpu_transform_sdfg.py Co-authored-by: Tal Ben-Nun --- dace/transformation/interstate/gpu_transform_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index b7aa3b708a..3f0248836e 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -170,7 +170,7 @@ def _get_marked_inputs_and_outputs(self, state, entry_node) -> list: def _output_or_input_is_marked_host(self, state, entry_node) -> bool: marked_accesses = self._get_marked_inputs_and_outputs(state, entry_node) - return (len(marked_accesses) > 0) + return len(marked_accesses) > 0 def apply(self, _, sdfg: sd.SDFG): From f1b81c40a3a5e237f0079e06fd8b97a5cff2ee79 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Fri, 1 Nov 2024 15:58:22 +0100 Subject: [PATCH 2/2] Update dace/transformation/interstate/gpu_transform_sdfg.py Co-authored-by: Tal Ben-Nun --- dace/transformation/interstate/gpu_transform_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 3f0248836e..823d599737 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -161,7 +161,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True def _get_marked_inputs_and_outputs(self, state, entry_node) -> list: - if (self.host_data is None or self.host_data == []) and (self.host_maps is None or self.host_maps == []): + if not self.host_data and not self.host_maps: return [] marked_sources = [state.memlet_tree(e).root().edge.src for e in state.in_edges(entry_node)] marked_destinations = [state.memlet_tree(e).root().edge.dst for e in state.in_edges(state.exit_node(entry_node))]