-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DID loop split on reshaped IDs #3875
Conversation
!test |
Review updated until commit de6a4b2 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
d3c602d
to
52a7a0c
Compare
858f9fc
to
90ab5ee
Compare
!test |
csrc/scheduler/utils.cpp
Outdated
|
||
// Reorder the reshape dimensions to the front of the domain |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@naoyam I'm quite confused by this pre-existing logic before I can understand this PR. Why is it necessary to move reshape dimensions to the front of the loop domain? It can cause conflict with the pre-existing assumption that DIDs have to be the front as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is related to the propagation done at line 2279. IIRC, it propagates the outermost N dimensions, where N is old2new.size()
in this case. Since here we just want to propagate the transformations related to the rfactor, this is how we limit the propagation.
We can probably just reorder tv
back after line 2279.
Co-authored-by: Jingyue Wu <[email protected]>
csrc/scheduler/utils.cpp
Outdated
auto find_it = std::find( | ||
tv->getLoopDomain().begin(), tv->getLoopDomain().end(), logical_id); | ||
|
||
// If not found directly and there is a sharded loop ID, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I see what the below part is trying to do and why, which seems to make sense, but can you expand the comment and elaborate a little more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bears several assumptions that will break in a foreseeable future.
With context parallelism, the sequence dimension s
will be split into [tp, iDIDy{cp}, s/tp/cp]
, so the code below won't be able to find tp
and s/tp/cp
. Similarly, with overlapping, the sequence dimension s
will be split into [sp, iDID{tp}, s/sp/tp]
where the sp
is the stream parallelization factor. See this test for the idea.
I understand this change does fix some narrow cases that we care about at this very moment, but I'll have to think more about how to fix the broader issue...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I still haven't given up on improving ID model.)
If we have to do graph traversal like this, we may want to do it in a place where the logic can be generalized and reused (and therefore ID model). At this moment, there are two use cases:
- splitting reshape: [h]=>[d, h/d] and [h]=>[d,a/d,h/a]
- merging reshape: [a,h/a]=>[d,a/d,h/a] and [a,h/a]=>[h]=>[d,h/d]
We want ID model to map thed
s in both cases so these reshapes won't be considered resharding.
How much harder is it to make ID model support these cases than working around using reshape transformation? I suspect the latter has a bigger blast radius because the former is local to ID model and the latter changes TensorViews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized the same limitation for DID loop split on slice:
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
const int d = communicator_->size();
const int64_t b = 2, s = 2, h = 4;
TensorView* in = makeContigConcreteTensor(
{b, s, 3 * d * h});
TensorView* out = slice(
in,
{0, 0, 0},
{b, s, d * h});
fusion->addInput(in);
fusion->addOutput(out);
auto mesh = DeviceMesh::createForNumDevices(d);
for (auto tv: {in, out}) {
tv->setDeviceMesh(mesh);
tv->split(-1, d, /*inner_split=*/false);
tv->axis(-2)->parallelize(ParallelType::DIDx);
tv->setAllocationDomain(tv->getLoopDomain(), true);
}
I was trying to manually handle the case of SliceOp
in hasDifferentShardings
but it would make certain assumptions about the parallelization patterns and can easily break.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this change does fix some narrow cases that we care about at this very moment, but I'll have to think more about how to fix the broader issue...
Yes, I agree. I wanted to add an example to demonstrate how reshapes can be loop split but it certainly does not cover all the cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still haven't given up on improving ID model.
(Sorry -- I wish I knew more about IdModel to be more constructive.)
Another use case to consider is manual sharding -- the user wants to manually shard a subset of TVs to improve perf when our sharding propagation is suboptimal.
They may well annotate [b,s,h]=>(reshape)=>[b,s,a,h/a]
as follows
in: [b,s,h] => [b,s,d,h/d]
out: [b,s,a,h/a] => [b,s,d,a/d,h/a]
and expect nvFuser to recognize this reshape is local. In this case, it's hard to replay the reshape on the input because h
there is already split by d
.
csrc/scheduler/utils.cpp
Outdated
auto split = dynamic_cast<Split*>( | ||
tv->getLoopDomain().at(sharded_axis)->definition()); | ||
if (split != nullptr && split->in() == logical_id) { | ||
find_it = std::find( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I understand your intention, I don't know the implications of doing this for TransformPropagator. E.g.
root=[b, s, h], logical=[b, s, a, h/a], loop=[b, s, d, a/d, h/a]
This will move a/d
and h/a
to the front so the new loop domain becomes [a/d, h/a, b, s, d]
and later ask TransformPropagator
to replay at replayed_pos_
2. What is TransformPropagator
supposed to do with that? The first two loop IDs (a/d
and h/a
) don't even form a split in this TV.
cc @naoyam
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. In the case of reshape with DID loop split, we have already propagated the reshape upwards, so the TransformPropagator only reorders the axis when called later. In the absence of the earlier reshape propagation before the loop split, the behavior could be erroneous since they don't form a split.
Although, since the reshape has already been propagated, and, as @naoyam mentioned above, the tv
is reordered back, maybe this propagation can be skipped altogether.
Let me think about it more and see what the schedulers expect from this propagateReshapeTransform
.
However, this may not work for the manual sharding case you mentioned in the above comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One caveat for the above to hold: The reshape transform has been applied to all tensorviews prior to the reshape in the given fusion segment.
In the transformer forward: For split-reshape: The order is linear -> slice -> reshape -> permute -> SDPA. Since linear will be its own segment, it leaves slice and reshape. I am uncertain if this will be a single segment since it can potentially depend on how the sharding on slice is represented.
For merge-reshape after SDPA: The order is SDPA-> permute ->reshape -> linear. Again, SDPA and linear will be their own segment.
More generally, the boundaries upto which the reshape transform is propagated upwards is important since we may different patterns appear in other models.
FWIW, in my tests, I found that TransformPropagator can propagate the split-reshapes upwards after DID loop split as well, but we lose the DID parallelization. This might be similar to the comment here.
An orthogonal issue I see here is resharding at the boundary upto which the reshape has been propagated upwards. At that boundary, we will go from [h] -> [a, h/a] and should hit the same resharding error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but we lose the DID parallelization
This is expected and the reason for parallelAllLike and functions like that
@wujingyue should we move forward with this PR? |
!test |
csrc/scheduler/utils.cpp
Outdated
auto split = dynamic_cast<Split*>( | ||
tv->getLoopDomain().at(sharded_axis)->definition()); | ||
if (split != nullptr && split->in() == logical_id) { | ||
// Move the DIDx dimension to the front |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it be more general to move all loop domains reachable from logical_id
to the front? Still, DIDx needs to be the very front. This way, you don't need to make assumption on DIDx has to be the immediate outer-split of logical_id and code can probably be made simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still, DIDx needs to be the very front.
Is this for the schedulers? In that case, we don't have to worry about DID being at the front here. reorderDIDToFront
is called after this function within the scheduler.
Will it be more general to move all loop domains reachable from logical_id to the front?
Yes. That should work. Are you aware of any direct utilities for this? Else, I should be able to use getExprsBetween
to find relevant transforms and find the loop IDs from their outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getExprsBetween
That's about right. I suspect getInputsTo would also work. I didn't try enough to understand their differences. Many of the graph traversal utilities seem to overlap and/or be redundant...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using getExprsBetween
as follows:
auto transforms = StmtSort::getExprsBetween({logical_id}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
For the new added test, the reshaped tensorview is: [i{b}, i{s}, i{a}, i{h/a}] (I am using TE notation, I use a different notation in the test to guarantee divisbility).
For the logical ID i{a}
, I also see the split h -> [a, h/a]
. I expected to only see the DID split. Similarly, for the logical ID i{h/a}
, I expected to see no transforms/exprs in between since it is directly found but I see both the h->[a, h/a] and a -> [d, a/d]
splits.
This is required since:
- I also have transforms that are not necessarily on the reshaped IDs (For example, in the test case
ViewWithSplit
, we will see the split creating DIDx whereas it is not on a reshaped ID) and hence should not be reordered or propagated. - It is difficult to tell if there is atleast a loop iterdomain reachable from a particular reshaped logical ID.
Logical ID: iS7{4}rf
Expr: Outer split: iS7{4}rf by factor 1 -> ideviceIdx.x13{1}, iS14{4}
Output: ideviceIdx.x13{1}
Output: iS14{4}
Expr: Outer split: iS6{48}rf by factor 4 -> iS7{4}rf, iS8{12}rf
Output: iS7{4}rf
Output: iS8{12}rf
Any suggestions on what I maybe missing in this function? I have not used it from a specific ID like above, only between entire domains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should be using DependencyCheck::getAllExprsBetween
!
That gives me the expected transforms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL!
!test |
!test. |
csrc/scheduler/utils.cpp
Outdated
} | ||
|
||
bool has_reachable_loop_id = false; | ||
for (auto id : reachable_ids) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know reachable_ids will be ordered -- it really depends on the implementation of getAllExprsBetween. Therefore, instead, I'd loop over the loop domain and try to find a match in reachable_ids (which probably should be a set instead of a vector). It's roughly the same logic but more deterministic and more aligned with the existing order in the loop domain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall
@wujingyue CI failures seem like script failures, will re-run. The PR is ready for another review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with comments
!test |
!test |
@@ -2257,7 +2257,8 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { | |||
} | |||
|
|||
bool has_reachable_loop_id = false; | |||
for (auto loop_idx : c10::irange(static_cast<int64_t>(tv->getLoopDomain().size()))) { | |||
for (auto loop_idx : | |||
c10::irange(static_cast<int64_t>(tv->getLoopDomain().size()))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c10::irange(static_cast<int64_t>(tv->getLoopDomain().size()))) { | |
c10::irange(std::ssize(tv->getLoopDomain()))) { |
FYI. Don't bother changing this if you are about to submit the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will make this change in the other PR!
This PR updates
propagateReshapeTransform
to support DID loop split.When the loop split is on the iterdomains being reshaped, the logical reshaped iterdomain is no longer present in the loop domain since it is split. In this case, we check if there is a sharded loop ID and compare the logical reshape iterdomain to the producer of this DID split.