Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend proveLinearAndGetStride to support missing dependencies #3984

Merged
merged 12 commits into from
Mar 5, 2025

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Feb 27, 2025

Suppose we have

[16, 8, 2, 4] -> merge -> [1024] -> split -> [32o, 32i]

Then 32i is linear w.r.t. [8, 2, 4]. Although 16 is a dependency of 32i, the result of whether 32i is linear or not is irrelevant to 16, so proveLinearAndGetStride should not require it to exist.

Copy link

github-actions bot commented Feb 27, 2025

Review updated until commit 291e17e

Description

  • Extend proveLinearAndGetStride to support missing dependencies

  • Add early stopping in propagation for efficiency

  • Update tests to cover new functionality


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Update proveLinearAndGetStride for missing dependencies and early
stopping

csrc/device_lower/utils.cpp

  • Update direction handling in fromGroups and toGroups
  • Modify proveLinearAndGetStride to use ValGraphPermissiveBFS and add
    early stopping
  • Add comments explaining the changes and the rationale
  • +46/-16 
    Tests
    test_utils.cpp
    Add tests for `proveLinearAndGetStride` enhancements         

    tests/cpp/test_utils.cpp

  • Include for random number generation
  • Add test ProveLinearAndGetStrideWithMissingDependency to verify
    functionality
  • Add test ProveLinearAndGetStrideEarlyStopping to verify early stopping
  • +85/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Early Stopping Logic

    Ensure that the early stopping logic correctly identifies when to stop propagation without missing valid linear dependencies.

    // not contain full dependency of domain.
    Projection frontier = linear_g;
    auto path =
        ValGraphPermissiveBFS::getExprGroupsBetween(
            id_graph, {linear_g}, domain, /*require_all_to_visited=*/false)
            .first;
    // Propagate along the path from linear_g to domain. Note that we do not
    // always propagate all the way through the path. Instead, early stopping
    // is necessary to be functionally correct. For example, if we have the
    // following ValGroups:
    //   4   2
    //    \ /
    //     8
    //    / \.
    //   4'  2'
    // and we are asking: is 2' linear in [4, 2']? The answer is trivially
    // yes by eyeballing, because 2' is the inner of [4, 2']. However, we must be
    // careful in propagation to algorithmically get it right. Although we can
    // directly tell the answer for this example without any progagation, because
    // ValGraphPermissiveBFS has no information about the underlying problem we
    // are solving, it always generate a path that visits `domain` as much as
    // possible, regardless of whether the underlying problem want it or not.
    // For this case, although the 4 in `domain` is unrelated to the answer,
    // ValGraphPermissiveBFS will still visit it. Therefore, it will generate a
    // path that include the merge of 4 and 2, and the split of 8. If we
    // mindlessly propagate along this path without early stopping, we will
    // propagate linear_g into frontier = 2, which leads to a conclusion that
    // "linear_g is the 2, and domain is [4, 2'], linear_g is not in domain, so I
    // can not prove linearity", which is not the answer we want. Note that
    // patterns like this can appear anywhere in the path, so we need to check for
    // early stopping at each step of the propagation.
    Val* stride = proveLinearAndGetStrideAfterPropagation(frontier, domain);
    if (stride != nullptr) {
      return stride;
    }
    for (const auto& [eg, direction] : path) {
      frontier = propagate(frontier, id_graph, eg, direction);
      if (!frontier.hasValue()) {
        // Not representable (or don't know how to represent) by the language of
        // the dynamic type Projection.
        return nullptr;
      }
      // Check for early stopping.
      Val* stride = proveLinearAndGetStrideAfterPropagation(frontier, domain);
      if (stride != nullptr) {
        return stride;
      }
    }
    Randomness in Tests

    The test ProveLinearAndGetStrideWithMissingDependency uses std::rand(), which can lead to non-deterministic behavior. Consider using a fixed seed for reproducibility.

    (void)_;
    // [16, 8, 2, 4]
    Test Coverage

    Ensure that the new tests cover all edge cases and provide comprehensive coverage for the changes made to proveLinearAndGetStride.

      Val* v4_6_in_v4 = lower_utils::proveLinearAndGetStride(g, v4[6], v4);
      EXPECT_EQ(simplifyExpr(v4_6_in_v4)->value(), 64);
    
      Val* v4_7_in_v4 = lower_utils::proveLinearAndGetStride(g, v4[7], v4);
      EXPECT_EQ(simplifyExpr(v4_7_in_v4)->value(), 1);
    }
    
    // Test that lower_utils::proveLinearAndGetStride still works even if some
    // dependency are missing, as long as the missing dependency is irrelevant to
    // result.
    TEST_F(NVFuserTest, ProveLinearAndGetStrideWithMissingDependency) {
      Fusion fusion;
      FusionGuard fg(&fusion);
      for (auto _ : c10::irange(100)) {
        (void)_;
        // [16, 8, 2, 4]
        auto id16 =
            IterDomainBuilder(
                fusion.zeroVal(), IrBuilder::create<Val>(16, DataType::Index))
                .build();
        auto id8 = IterDomainBuilder(
                       fusion.zeroVal(), IrBuilder::create<Val>(8, DataType::Index))
                       .build();
        auto id2 = IterDomainBuilder(
                       fusion.zeroVal(), IrBuilder::create<Val>(2, DataType::Index))
                       .build();
        auto id4 = IterDomainBuilder(
                       fusion.zeroVal(), IrBuilder::create<Val>(4, DataType::Index))
                       .build();
    
        ValGraph g;
        g.initializeVal(id16);
        g.initializeVal(id8);
        g.initializeVal(id2);
        g.initializeVal(id4);
        ValGroup g16{g.toGroup(id16)};
        ValGroup g8{g.toGroup(id8)};
        ValGroup g2{g.toGroup(id2)};
        ValGroup g4{g.toGroup(id4)};
        ValGroupAndItsGraph gg16{g16, &g};
        ValGroupAndItsGraph gg8{g8, &g};
        ValGroupAndItsGraph gg2{g2, &g};
        ValGroupAndItsGraph gg4{g4, &g};
    
        AbstractTensor v({gg16, gg8, gg2, gg4});
        // Merge all dims in random order
        while (v.size() > 1) {
          v.merge(std::rand() % (v.size() - 1));
        }
        v.split(0, 32);
    
        ValGroup linear_g = v[1].as<ValGroupAndItsGraph>().group;
        // Although linear_g depend on g16, whether it is linear w.r.t. [8, 2, 4] is
        // not relevant to g16. So we should not require g16 to exist in order to
        // prove linearity.
        Val* stride =
            lower_utils::proveLinearAndGetStride(g, linear_g, {g8, g2, g4});
        ASSERT_NE(stride, nullptr);
        EXPECT_EQ(simplifyExpr(stride)->value(), 1);
      }
    }
    
    TEST_F(NVFuserTest, ProveLinearAndGetStrideEarlyStopping) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      // [4, 2]
      auto id4 = IterDomainBuilder(
                     fusion.zeroVal(), IrBuilder::create<Val>(4, DataType::Index))
                     .build();
      auto id2 = IterDomainBuilder(
                     fusion.zeroVal(), IrBuilder::create<Val>(2, DataType::Index))
                     .build();
    
      ValGraph g;
      g.initializeVal(id4);
      g.initializeVal(id2);
      ValGroup g4{g.toGroup(id4)};
      ValGroup g2{g.toGroup(id2)};
      ValGroupAndItsGraph gg4{g4, &g};
      ValGroupAndItsGraph gg2{g2, &g};
      AbstractTensor v({gg4, gg2});
      v.merge(0);
      v.split(0, 2);
      ValGroup g4_ = v[0].as<ValGroupAndItsGraph>().group;
      ValGroup g2_ = v[1].as<ValGroupAndItsGraph>().group;
      Val* stride = lower_utils::proveLinearAndGetStride(g, g2_, {g4, g2_});
      ASSERT_NE(stride, nullptr);
      EXPECT_EQ(simplifyExpr(stride)->value(), 1);
    }
    
    using TestCpp23BackPort = NVFuserTest;

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm zasdfgbnm requested review from naoyam and Copilot February 27, 2025 08:01

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

    @zasdfgbnm
    Copy link
    Collaborator Author

    CI failure is real, changing to draft

    @zasdfgbnm zasdfgbnm marked this pull request as draft February 27, 2025 21:35
    @zasdfgbnm zasdfgbnm removed the request for review from naoyam February 28, 2025 00:48
    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator

    !test

    @zasdfgbnm zasdfgbnm marked this pull request as ready for review March 2, 2025 18:50
    @zasdfgbnm zasdfgbnm requested a review from Copilot March 2, 2025 18:51

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

    @zasdfgbnm zasdfgbnm requested a review from naoyam March 2, 2025 18:51
    // partial information on how to reach to a state that is easiest for our
    // proof. It is possible that the easiest state is not the final state of
    // the propagation. So we need to try the proof each step of the
    // propagation.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Could you give a few examples here? Specifically, one for a case where no propagation is done, another that involves a partial traversal of path, and also a case where complete traversal of path is required.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I completely rewrite the comment here, and added a new test ProveLinearAndGetStrideEarlyStopping demonstrating why early stopping is important. I believe the new comment and the single example included in it is sufficiently to clearly explain the reason for early stopping. Let me know if you believe more examples are needed.

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm zasdfgbnm requested review from naoyam and Copilot March 5, 2025 00:45

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    @zasdfgbnm zasdfgbnm merged commit f246e8e into main Mar 5, 2025
    48 of 49 checks passed
    @zasdfgbnm zasdfgbnm deleted the proveLinearAndGetStride branch March 5, 2025 20:45
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    3 participants