Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DID loop split on reshaped IDs #3875

Merged
merged 14 commits into from
Mar 6, 2025
Merged

DID loop split on reshaped IDs #3875

merged 14 commits into from
Mar 6, 2025

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Feb 11, 2025

This PR updates propagateReshapeTransform to support DID loop split.

When the loop split is on the iterdomains being reshaped, the logical reshaped iterdomain is no longer present in the loop domain since it is split. In this case, we check if there is a sharded loop ID and compare the logical reshape iterdomain to the producer of this DID split.

@Priya2698
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 11, 2025

Review updated until commit de6a4b2

Description

  • Updated propagateReshapeTransform to handle DID loop splits.

  • Added test for sharded split reshape IDs.

  • Improved reordering of reshape dimensions in loop domain.


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Improve reshape dimension reordering                                         

csrc/scheduler/utils.cpp

  • Added logic to find all reachable IDs between logical reshape IDs and
    loop domain.
  • Reordered reshape dimensions to the front of the domain.
  • Improved error handling for missing logical IDs in loop domain.
  • +38/-10 
    Tests
    test_multidevice_sharding.cpp
    Add test for sharded split reshape                                             

    tests/cpp/test_multidevice_sharding.cpp

  • Added test case for sharded split reshape IDs.
  • Demonstrated propagation of DID loop splits.
  • +62/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Complexity

    The new logic for finding reachable IDs and reordering them might introduce additional computational overhead. Consider profiling this part to ensure it does not negatively impact performance.

    // Find all reachable ids between the logical id and the loop domain.
    // If the ids are in the loop domain, reorder them to the front.
    auto transforms = DependencyCheck::getAllExprsBetween(
        {logical_id},
        {tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
    std::unordered_set<IterDomain*> reachable_ids;
    // Add the logical id for the case where it is directly in the loop
    // domain.
    reachable_ids.insert(logical_id);
    
    for (auto expr : transforms) {
      auto outputs = ir_utils::filterByType<IterDomain>(expr->outputs());
      reachable_ids.insert(outputs.begin(), outputs.end());
    }
    
    bool has_reachable_loop_id = false;
    for (auto loop_idx :
         c10::irange(static_cast<int64_t>(tv->getLoopDomain().size()))) {
      if (reachable_ids.count(tv->axis(loop_idx)) == 0) {
        continue;
      }
      has_reachable_loop_id = true;
      // Reorder the reshape dimensions to the front of the domain
      old2new[loop_idx] = (int64_t)old2new.size();
    }
    Error Handling

    The error message in NVF_ERROR could be more descriptive to help diagnose issues when has_reachable_loop_id is false.

          has_reachable_loop_id,
          "Require ",
          logical_id,
          " is in the active domain of ",
          tv->toString(),
          " for view propagation.");
    }
    Test Coverage

    Ensure that the new test case covers all edge cases, including scenarios where the reshaped dimensions are not sharded or where the loop split does not align with the reshaped dimensions.

    TEST_F(MultiDeviceTest, ShardedSplitReshapeIds) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int d = communicator_->size();
      const int64_t b = 2, s = 2, h = 4, e = 3;
    
      TensorView* tv0 = makeContigConcreteTensor(
          {b, s, d * h * e}); // in: loop domain: {b, s, d*h*e}
      TensorView* tv1 = reshape(
          tv0,
          {b, s, d * h * e},
          {b, s, d * h, e}); // out: loop domain: {b, s, d*h, e}
    
      fusion->addInput(tv0);
      fusion->addOutput(tv1);
    
      auto mesh = DeviceMesh::createForNumDevices(d);
    
      // Propagate transform from reshaped output to input.
      // Without this propagation, the two DID axes on `in` and `out` will not be
      // mapped in together in ID model. This causes scheduling to fail due to
      // resharding.
      TransformPropagator propagator_c2p(tv1);
      MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator_c2p);
      // in: loop domain: {b, s, d*h, e} after transform propagation
    
      // Loop split and parallelize input
      tv0->setDeviceMesh(mesh);
      tv0->split(-2, d, /*inner_split=*/false);
      tv0->axis(-3)->parallelize(ParallelType::DIDx);
      // in: loop domain: {b, s, DIDx{d}, h, e}
    
      // Propagate DID loop split to output
      TransformPropagator propagator_p2c(tv0);
      MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator_p2c);
      // out: loop domain: {b, s, d, h, e} after transform propagation
    
      // Parallelize output
      scheduler_utils::parallelizeAllLike(
          tv0,
          /*pos=*/-1,
          /*selected_tv=*/{tv1});
      // out: loop domain: {b, s, DIDx{d}, h, e} after parallelization
    
      tv0->setAllocationDomain(tv0->getLoopDomain(), true);
      tv1->setAllocationDomain(tv1->getLoopDomain(), true);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options);
      at::Tensor sharded_inp = shardTensor(inp, tv0);
      at::Tensor nvf_out =
          executor_cache.runFusionWithInputs({sharded_inp})[0].as<at::Tensor>();
      testValidate(
          executor_cache.fusion(),
          {nvf_out},
          {sharded_inp},
          {sharded_inp.view({b, s, h, e})},
          __LINE__,
          __FILE__);
    }
    
    } // namespace nvfuser

    @Priya2698 Priya2698 marked this pull request as ready for review February 12, 2025 22:34
    @Priya2698 Priya2698 force-pushed the pm/reshape_propagate branch from d3c602d to 52a7a0c Compare February 12, 2025 22:57
    @Priya2698 Priya2698 requested a review from wujingyue February 12, 2025 22:58
    @Priya2698 Priya2698 force-pushed the pm/reshape_propagate branch from 858f9fc to 90ab5ee Compare February 13, 2025 00:57
    @wujingyue wujingyue requested a review from naoyam February 13, 2025 04:01
    @Priya2698
    Copy link
    Collaborator Author

    !test


    // Reorder the reshape dimensions to the front of the domain
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @naoyam I'm quite confused by this pre-existing logic before I can understand this PR. Why is it necessary to move reshape dimensions to the front of the loop domain? It can cause conflict with the pre-existing assumption that DIDs have to be the front as well.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I think this is related to the propagation done at line 2279. IIRC, it propagates the outermost N dimensions, where N is old2new.size() in this case. Since here we just want to propagate the transformations related to the rfactor, this is how we limit the propagation.

    We can probably just reorder tv back after line 2279.

    auto find_it = std::find(
    tv->getLoopDomain().begin(), tv->getLoopDomain().end(), logical_id);

    // If not found directly and there is a sharded loop ID,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I think I see what the below part is trying to do and why, which seems to make sense, but can you expand the comment and elaborate a little more?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This bears several assumptions that will break in a foreseeable future.

    With context parallelism, the sequence dimension s will be split into [tp, iDIDy{cp}, s/tp/cp], so the code below won't be able to find tp and s/tp/cp. Similarly, with overlapping, the sequence dimension s will be split into [sp, iDID{tp}, s/sp/tp] where the sp is the stream parallelization factor. See this test for the idea.

    I understand this change does fix some narrow cases that we care about at this very moment, but I'll have to think more about how to fix the broader issue...

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    (I still haven't given up on improving ID model.)

    If we have to do graph traversal like this, we may want to do it in a place where the logic can be generalized and reused (and therefore ID model). At this moment, there are two use cases:

    1. splitting reshape: [h]=>[d, h/d] and [h]=>[d,a/d,h/a]
    2. merging reshape: [a,h/a]=>[d,a/d,h/a] and [a,h/a]=>[h]=>[d,h/d]
      We want ID model to map the ds in both cases so these reshapes won't be considered resharding.

    How much harder is it to make ID model support these cases than working around using reshape transformation? I suspect the latter has a bigger blast radius because the former is local to ID model and the latter changes TensorViews.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I realized the same limitation for DID loop split on slice:

      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int d = communicator_->size();
      const int64_t b = 2, s = 2, h = 4;
    
      TensorView* in = makeContigConcreteTensor(
          {b, s, 3 * d * h});
      TensorView* out = slice(
          in,
          {0, 0, 0},
          {b, s, d * h});
    
      fusion->addInput(in);
      fusion->addOutput(out);
    
      auto mesh = DeviceMesh::createForNumDevices(d);
      for (auto tv: {in, out}) {
        tv->setDeviceMesh(mesh);
        tv->split(-1, d, /*inner_split=*/false);
        tv->axis(-2)->parallelize(ParallelType::DIDx);
        tv->setAllocationDomain(tv->getLoopDomain(), true);
      }
    

    I was trying to manually handle the case of SliceOp in hasDifferentShardings but it would make certain assumptions about the parallelization patterns and can easily break.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I understand this change does fix some narrow cases that we care about at this very moment, but I'll have to think more about how to fix the broader issue...

    Yes, I agree. I wanted to add an example to demonstrate how reshapes can be loop split but it certainly does not cover all the cases.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I still haven't given up on improving ID model.

    (Sorry -- I wish I knew more about IdModel to be more constructive.)

    Another use case to consider is manual sharding -- the user wants to manually shard a subset of TVs to improve perf when our sharding propagation is suboptimal.

    They may well annotate [b,s,h]=>(reshape)=>[b,s,a,h/a] as follows

    in: [b,s,h] => [b,s,d,h/d]
    out: [b,s,a,h/a] => [b,s,d,a/d,h/a]
    

    and expect nvFuser to recognize this reshape is local. In this case, it's hard to replay the reshape on the input because h there is already split by d.

    auto split = dynamic_cast<Split*>(
    tv->getLoopDomain().at(sharded_axis)->definition());
    if (split != nullptr && split->in() == logical_id) {
    find_it = std::find(
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    While I understand your intention, I don't know the implications of doing this for TransformPropagator. E.g.

    root=[b, s, h], logical=[b, s, a, h/a], loop=[b, s, d, a/d, h/a]
    

    This will move a/d and h/a to the front so the new loop domain becomes [a/d, h/a, b, s, d] and later ask TransformPropagator to replay at replayed_pos_ 2. What is TransformPropagator supposed to do with that? The first two loop IDs (a/d and h/a) don't even form a split in this TV.

    cc @naoyam

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I see your point. In the case of reshape with DID loop split, we have already propagated the reshape upwards, so the TransformPropagator only reorders the axis when called later. In the absence of the earlier reshape propagation before the loop split, the behavior could be erroneous since they don't form a split.

    Although, since the reshape has already been propagated, and, as @naoyam mentioned above, the tv is reordered back, maybe this propagation can be skipped altogether.
    Let me think about it more and see what the schedulers expect from this propagateReshapeTransform.

    However, this may not work for the manual sharding case you mentioned in the above comment.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    One caveat for the above to hold: The reshape transform has been applied to all tensorviews prior to the reshape in the given fusion segment.

    In the transformer forward: For split-reshape: The order is linear -> slice -> reshape -> permute -> SDPA. Since linear will be its own segment, it leaves slice and reshape. I am uncertain if this will be a single segment since it can potentially depend on how the sharding on slice is represented.

    For merge-reshape after SDPA: The order is SDPA-> permute ->reshape -> linear. Again, SDPA and linear will be their own segment.

    More generally, the boundaries upto which the reshape transform is propagated upwards is important since we may different patterns appear in other models.

    FWIW, in my tests, I found that TransformPropagator can propagate the split-reshapes upwards after DID loop split as well, but we lose the DID parallelization. This might be similar to the comment here.

    An orthogonal issue I see here is resharding at the boundary upto which the reshape has been propagated upwards. At that boundary, we will go from [h] -> [a, h/a] and should hit the same resharding error.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    but we lose the DID parallelization

    This is expected and the reason for parallelAllLike and functions like that

    @Priya2698
    Copy link
    Collaborator Author

    @wujingyue should we move forward with this PR?
    It fixes the case where we are using transform propagator for reshape before the DID loop split.
    PR #3953 currently uses hardcoding and we may want to merge it after PR #3482.
    Wdyt?

    @Priya2698
    Copy link
    Collaborator Author

    !test

    auto split = dynamic_cast<Split*>(
    tv->getLoopDomain().at(sharded_axis)->definition());
    if (split != nullptr && split->in() == logical_id) {
    // Move the DIDx dimension to the front
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Will it be more general to move all loop domains reachable from logical_id to the front? Still, DIDx needs to be the very front. This way, you don't need to make assumption on DIDx has to be the immediate outer-split of logical_id and code can probably be made simpler.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Still, DIDx needs to be the very front.

    Is this for the schedulers? In that case, we don't have to worry about DID being at the front here. reorderDIDToFront is called after this function within the scheduler.

    Will it be more general to move all loop domains reachable from logical_id to the front?

    Yes. That should work. Are you aware of any direct utilities for this? Else, I should be able to use getExprsBetween to find relevant transforms and find the loop IDs from their outputs

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    getExprsBetween

    That's about right. I suspect getInputsTo would also work. I didn't try enough to understand their differences. Many of the graph traversal utilities seem to overlap and/or be redundant...

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I am using getExprsBetween as follows:

     auto transforms = StmtSort::getExprsBetween({logical_id}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
    
    

    For the new added test, the reshaped tensorview is: [i{b}, i{s}, i{a}, i{h/a}] (I am using TE notation, I use a different notation in the test to guarantee divisbility).
    For the logical ID i{a}, I also see the split h -> [a, h/a]. I expected to only see the DID split. Similarly, for the logical ID i{h/a}, I expected to see no transforms/exprs in between since it is directly found but I see both the h->[a, h/a] and a -> [d, a/d] splits.
    This is required since:

    1. I also have transforms that are not necessarily on the reshaped IDs (For example, in the test case ViewWithSplit, we will see the split creating DIDx whereas it is not on a reshaped ID) and hence should not be reordered or propagated.
    2. It is difficult to tell if there is atleast a loop iterdomain reachable from a particular reshaped logical ID.
    Logical ID: iS7{4}rf
    
    Expr: Outer split: iS7{4}rf by factor 1 -> ideviceIdx.x13{1}, iS14{4}
    Output: ideviceIdx.x13{1}
    Output: iS14{4}
    
    Expr: Outer split: iS6{48}rf by factor 4 -> iS7{4}rf, iS8{12}rf
    Output: iS7{4}rf
    Output: iS8{12}rf
    

    Any suggestions on what I maybe missing in this function? I have not used it from a specific ID like above, only between entire domains.

    Copy link
    Collaborator Author

    @Priya2698 Priya2698 Feb 28, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I should be using DependencyCheck::getAllExprsBetween!
    That gives me the expected transforms.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    TIL!

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test.

    }

    bool has_reachable_loop_id = false;
    for (auto id : reachable_ids) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I don't know reachable_ids will be ordered -- it really depends on the implementation of getAllExprsBetween. Therefore, instead, I'd loop over the loop domain and try to find a match in reachable_ids (which probably should be a set instead of a vector). It's roughly the same logic but more deterministic and more aligned with the existing order in the loop domain.

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM overall

    @Priya2698 Priya2698 requested a review from wujingyue March 4, 2025 23:47
    @Priya2698
    Copy link
    Collaborator Author

    @wujingyue CI failures seem like script failures, will re-run. The PR is ready for another review.

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM with comments

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @@ -2257,7 +2257,8 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) {
    }

    bool has_reachable_loop_id = false;
    for (auto loop_idx : c10::irange(static_cast<int64_t>(tv->getLoopDomain().size()))) {
    for (auto loop_idx :
    c10::irange(static_cast<int64_t>(tv->getLoopDomain().size()))) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    c10::irange(static_cast<int64_t>(tv->getLoopDomain().size()))) {
    c10::irange(std::ssize(tv->getLoopDomain()))) {

    FYI. Don't bother changing this if you are about to submit the PR.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Will make this change in the other PR!

    @Priya2698 Priya2698 merged commit 1bbc745 into main Mar 6, 2025
    48 of 49 checks passed
    @Priya2698 Priya2698 deleted the pm/reshape_propagate branch March 6, 2025 03:25
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    3 participants