Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next][dace]: make if_ always execute branch exclusively #1846

Merged
merged 6 commits into from
Feb 7, 2025

Conversation

edopao
Copy link
Contributor

@edopao edopao commented Feb 5, 2025

This PR reverts the change previously made in #1824. Lowering if_ expressions to tasklet is semantically wrong, from a dataflow perspective. It causes segmentation faults in several stencils, that rely on exclusive branch execution.

The source problem was that full array shape was passed into the nested SDFG scope, which prevented map fusion in most cases. This PR extends the lowering with the detection of simple iterator dereferencing, without shifts: for this type of data access, only the local element is moved into the nested SDFG. However, when shift is applied on the iterator input (which typically happens in iterator view), the full array shape is still passed.

This approach increases the optimization opportunities by enabling more map fusion. At the same time, it keeps the if_ semantics of exclusive branch execution.

@edopao edopao changed the title fix[next][dace]: make if_ always execute branch exclusivly fix[next][dace]: make if_ always execute branch exclusively Feb 5, 2025
Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good, there are some small issues that have to be clarified.

Comment on lines 697 to 704
for shift_node in eve.walk_values(expr).filter(lambda x: cpm.is_applied_shift(x)):
shifted_iterators |= (
eve.walk_values(shift_node)
.if_isinstance(gtir.SymRef)
.map(lambda x: str(x.id))
.filter(lambda x: isinstance(self.symbol_map.get(x, None), IteratorExpr))
.to_set()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not fully remembering, but we discussed something with Till and there you explained some heuristic and (as far as I can remember) his assessment of it was that it is most likely fine but not in general.
So now, is this a new heuristic?
I am aware that the problem is now different, as we always generate exclusive ifs and to me it looks fine, I just want to ask how the following cases are handled:

  • Reduction, especially with skip values?
  • dymref?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cases you ask about are not directly addressed (and I don't even recognize dymref). The PR deals with how to pass iterator expressions into a nested SDFG containing a ConditionalBlock. Note that an iterator needs to be dereferenced, in order to provide access to a grid element. Therefore, the iterator deref is present inside the true and/or false branch expression. In baseline, the lowering would do a 1:1 translation to SDFG, that would imply to pass the full array shape to the nested SDFG and implement the deref memlet inside. For simple deref (by simple I mean direct access, without shifting the iterator) this PR is moving the element access memlet outside the nested SDFG.

Now, about the heuristic question. Till's concern was about implementing if_ as a tasklet, that is as a non-exclusive if_ operator. Now that is out of scope. The heuristic I implement here is how to detect indirect iterator deref, that is iterators that are shifted before being dereferenced. I would not even call it a heuristic. I search for all shift expressions inside the true/false branch and what iterator they apply to. This search provides shifted_iterators: a list of iterators that are shifted, in one or both the true/false branches, so they have to be passed with full array shape.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay thanks for the explanation.
With dymref I meant the runtime deref (you create tasklets with the name tlet_{NUMBER}_runtime_deref) are they detected or are they just a shift combined with a deref?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are just shifts. I will change the name of those tasklets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I should have written this is just a deref tasklet. When the indices are all dace symbolic expressions, I lower the deref as a memlet, where the index is the memlet subset. When any of the indices is a runtime value (either a dynamic cartesian offset or a connectivity offset), the deref is lowered to a tasklet.

Comment on lines 697 to 704
for shift_node in eve.walk_values(expr).filter(lambda x: cpm.is_applied_shift(x)):
shifted_iterators |= (
eve.walk_values(shift_node)
.if_isinstance(gtir.SymRef)
.map(lambda x: str(x.id))
.filter(lambda x: isinstance(self.symbol_map.get(x, None), IteratorExpr))
.to_set()
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cases you ask about are not directly addressed (and I don't even recognize dymref). The PR deals with how to pass iterator expressions into a nested SDFG containing a ConditionalBlock. Note that an iterator needs to be dereferenced, in order to provide access to a grid element. Therefore, the iterator deref is present inside the true and/or false branch expression. In baseline, the lowering would do a 1:1 translation to SDFG, that would imply to pass the full array shape to the nested SDFG and implement the deref memlet inside. For simple deref (by simple I mean direct access, without shifting the iterator) this PR is moving the element access memlet outside the nested SDFG.

Now, about the heuristic question. Till's concern was about implementing if_ as a tasklet, that is as a non-exclusive if_ operator. Now that is out of scope. The heuristic I implement here is how to detect indirect iterator deref, that is iterators that are shifted before being dereferenced. I would not even call it a heuristic. I search for all shift expressions inside the true/false branch and what iterator they apply to. This search provides shifted_iterators: a list of iterators that are shifted, in one or both the true/false branches, so they have to be passed with full array shape.

Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It generally looks good to me, there is however style of question of the runtime deref.
I see that there is now one access, however, it does not fuse yet.
I will look into that to see where (either here or in MapFusion) we have to change something.

@@ -613,35 +614,56 @@ def _visit_if_branch_arg(
if_branch_state: The state inside the nested SDFG where the if branch is lowered.
param_name: The parameter name of the input argument.
arg: The input argument expression.
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet.
deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. This means that the values are copied into a temporary storage which is passed into the nested SDFG.

I think the description of this option should have a bit more information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change you propose is not correct. We are not allocating temporary storage, we are just narrowing the memlet subset.

Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

The issue I observed is somehow related to substitute_compiletime_symbols(), however, I have no idea how.

Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too fast.

@edopao edopao merged commit 5c3393f into GridTools:main Feb 7, 2025
23 checks passed
@edopao edopao deleted the gtir-dace-exclusive_if branch February 7, 2025 10:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants