-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix[next][dace]: make if_ always execute branch exclusively #1846
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good, there are some small issues that have to be clarified.
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
for shift_node in eve.walk_values(expr).filter(lambda x: cpm.is_applied_shift(x)): | ||
shifted_iterators |= ( | ||
eve.walk_values(shift_node) | ||
.if_isinstance(gtir.SymRef) | ||
.map(lambda x: str(x.id)) | ||
.filter(lambda x: isinstance(self.symbol_map.get(x, None), IteratorExpr)) | ||
.to_set() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not fully remembering, but we discussed something with Till and there you explained some heuristic and (as far as I can remember) his assessment of it was that it is most likely fine but not in general.
So now, is this a new heuristic?
I am aware that the problem is now different, as we always generate exclusive if
s and to me it looks fine, I just want to ask how the following cases are handled:
- Reduction, especially with skip values?
- dymref?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cases you ask about are not directly addressed (and I don't even recognize dymref). The PR deals with how to pass iterator expressions into a nested SDFG containing a ConditionalBlock. Note that an iterator needs to be dereferenced, in order to provide access to a grid element. Therefore, the iterator deref
is present inside the true and/or false branch expression. In baseline, the lowering would do a 1:1 translation to SDFG, that would imply to pass the full array shape to the nested SDFG and implement the deref
memlet inside. For simple deref
(by simple I mean direct access, without shifting the iterator) this PR is moving the element access memlet outside the nested SDFG.
Now, about the heuristic question. Till's concern was about implementing if_
as a tasklet, that is as a non-exclusive if_
operator. Now that is out of scope. The heuristic I implement here is how to detect indirect iterator deref
, that is iterators that are shifted before being dereferenced. I would not even call it a heuristic. I search for all shift expressions inside the true/false branch and what iterator they apply to. This search provides shifted_iterators
: a list of iterators that are shifted, in one or both the true/false branches, so they have to be passed with full array shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay thanks for the explanation.
With dymref
I meant the runtime deref (you create tasklets with the name tlet_{NUMBER}_runtime_deref
) are they detected or are they just a shift combined with a deref?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are just shifts. I will change the name of those tasklets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I should have written this is just a deref
tasklet. When the indices are all dace symbolic expressions, I lower the deref
as a memlet, where the index is the memlet subset. When any of the indices is a runtime value (either a dynamic cartesian offset or a connectivity offset), the deref
is lowered to a tasklet.
for shift_node in eve.walk_values(expr).filter(lambda x: cpm.is_applied_shift(x)): | ||
shifted_iterators |= ( | ||
eve.walk_values(shift_node) | ||
.if_isinstance(gtir.SymRef) | ||
.map(lambda x: str(x.id)) | ||
.filter(lambda x: isinstance(self.symbol_map.get(x, None), IteratorExpr)) | ||
.to_set() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cases you ask about are not directly addressed (and I don't even recognize dymref). The PR deals with how to pass iterator expressions into a nested SDFG containing a ConditionalBlock. Note that an iterator needs to be dereferenced, in order to provide access to a grid element. Therefore, the iterator deref
is present inside the true and/or false branch expression. In baseline, the lowering would do a 1:1 translation to SDFG, that would imply to pass the full array shape to the nested SDFG and implement the deref
memlet inside. For simple deref
(by simple I mean direct access, without shifting the iterator) this PR is moving the element access memlet outside the nested SDFG.
Now, about the heuristic question. Till's concern was about implementing if_
as a tasklet, that is as a non-exclusive if_
operator. Now that is out of scope. The heuristic I implement here is how to detect indirect iterator deref
, that is iterators that are shifted before being dereferenced. I would not even call it a heuristic. I search for all shift expressions inside the true/false branch and what iterator they apply to. This search provides shifted_iterators
: a list of iterators that are shifted, in one or both the true/false branches, so they have to be passed with full array shape.
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It generally looks good to me, there is however style of question of the runtime deref.
I see that there is now one access, however, it does not fuse yet.
I will look into that to see where (either here or in MapFusion) we have to change something.
@@ -613,35 +614,56 @@ def _visit_if_branch_arg( | |||
if_branch_state: The state inside the nested SDFG where the if branch is lowered. | |||
param_name: The parameter name of the input argument. | |||
arg: The input argument expression. | |||
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. | |
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. This means that the values are copied into a temporary storage which is passed into the nested SDFG. |
I think the description of this option should have a bit more information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change you propose is not correct. We are not allocating temporary storage, we are just narrowing the memlet subset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The issue I observed is somehow related to substitute_compiletime_symbols()
, however, I have no idea how.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too fast.
This PR reverts the change previously made in #1824. Lowering
if_
expressions to tasklet is semantically wrong, from a dataflow perspective. It causes segmentation faults in several stencils, that rely on exclusive branch execution.The source problem was that full array shape was passed into the nested SDFG scope, which prevented map fusion in most cases. This PR extends the lowering with the detection of simple iterator dereferencing, without shifts: for this type of data access, only the local element is moved into the nested SDFG. However, when shift is applied on the iterator input (which typically happens in iterator view), the full array shape is still passed.
This approach increases the optimization opportunities by enabling more map fusion. At the same time, it keeps the
if_
semantics of exclusive branch execution.