Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check that warps are only accessing the subpartition of TMem that it can access #4016

Open
wants to merge 5 commits into
base: tmem-pm
Choose a base branch
from

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Mar 6, 2025

Warp i can only access to subpartition i % 4.

This PR is stacked on #4015

Copy link

github-actions bot commented Mar 6, 2025

Review updated until commit a31d9c6

Description

  • Validate warps access correct TMem sub-partition

  • Improve error messages in tests

  • Handle empty composition in proveLinearAndGetStrideAfterPropagation

  • Remove redundant error checks in cancelCommonFactors and trimRedundant


Changes walkthrough 📝

Relevant files
Enhancement
tensor_memory.cpp
Validate TMem sub-partition access                                             

csrc/device_lower/analysis/tensor_memory.cpp

  • Added validation to ensure warps access the correct TMem sub-partition
  • Improved error message for incorrect sub-partition access
  • +22/-1   
    utils.cpp
    Improve utils functions                                                                   

    csrc/device_lower/utils.cpp

  • Changed initial value in extent function to
    FusionGuard::getCurFusion()->oneVal()
  • Added handling for empty composition in
    proveLinearAndGetStrideAfterPropagation
  • Removed redundant error checks in cancelCommonFactors and
    trimRedundant
  • +4/-5     
    tmem.md
    Update test error messages                                                             

    doc/dev/tmem.md

  • Updated error messages in tests to specify incorrect sub-partition
    access
  • Removed NOT_IMPLEMENTED tags from tests
  • +4/-4     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Impact

    The new validation logic may introduce performance overhead. Ensure that the performance impact is minimal and does not degrade the overall performance of the system.

      AbstractTensorWithInfo<Contiguity> t = pdims;
      t.split(-1, 32);
      t.split(-2, 4);
      Val* warp_group_stride = lower_utils::proveLinearAndGetStride(
          id_graph,
          t[-2].as<ValGroupAndItsGraph>().group,
          lane_allocation_valgroups);
      NVF_ERROR(
          warp_group_stride != nullptr,
          "Invalid data access pattern in TMem load/store: ",
          "Warps are not accessing the correct sub-partition.");
      // The stride must be either 0 or 32, 32 is the most common case.
      // 0 is a special value indicating that there is only one warp.
      GpuLower::current()->validate(
          SimplifyingIrBuilder::logicalOrExpr(
              SimplifyingIrBuilder::eqExpr(
                  warp_group_stride, IrBuilder::create<Val>(32)),
              SimplifyingIrBuilder::eqExpr(
                  warp_group_stride, IrBuilder::create<Val>(0))),
          "Invalid data access pattern in TMem load/store: ",
          "Warps are not accessing the correct sub-partition.");
    }
    Code Clarity

    The changes in extent and propagate functions should be reviewed for clarity and correctness. Ensure that the logic is correct and the changes do not introduce any unintended behavior.

    Val* extent(const Composition<Projection>& comp) {
      return std::accumulate(
          comp.begin(),
          comp.end(),
          FusionGuard::getCurFusion()->oneVal(),
          [](Val* acc, const auto& g) {
            return SimplifyingIrBuilder::mulExpr(acc, extent(g));
          });
    }
    
    Val* extent(const std::monostate&) {
      NVF_THROW("Cannot get extent of std::monostate");
    }
    
    Val* extent(const Projection& proj) {
      return Projection::dispatch(
          [&](const auto& proj) { return extent(proj); }, proj);
    }
    
    // Simplify the abstract syntax tree so that it is easier to be pattern
    // matched. Defined below.
    Projection simplify(Projection proj);
    
    // Given an expression on the traversal path and its direction, get the from
    // and to groups.
    auto fromGroups(
        const ValGraph& id_graph,
        const ExprGroup& eg,
        Direction direction) {
      return direction == Direction::Backward ? id_graph.outputGroups(eg)
                                              : id_graph.inputGroups(eg);
    }
    
    auto toGroups(
        const ValGraph& id_graph,
        const ExprGroup& eg,
        Direction direction) {
      return direction == Direction::Backward ? id_graph.inputGroups(eg)
                                              : id_graph.outputGroups(eg);
    }
    
    // Do the propagation to project linear_g on domain through the given
    // expression, build out and simplify the abstract syntax tree on the fly by
    // substituting equivalent items. For example, if we have
    //   2   [2]  3
    //    \    \ /
    //     \    6
    //      \  /
    //       12   2
    //        \  /
    //         24
    //        /  \.
    //       4    6
    // and the linear_g is [2], when we propagate from [2] to 24, we will build out
    // the abstract syntax tree with the following steps:
    //
    // First, we will traverse the expression 6 = merge(2, 3). We will build out
    //   linear_g = PartOf{what=6, inner_extent=3, selected_extent=2}
    //
    // Second, we will traverse the expression 12 = merge(2, 6). From this
    // expression, we know that
    //   6 = PartOf{what=12, inner_extent=nullptr, selected_extent=6}
    // Substituting definition of 6, in the above definition of linear_g, we get
    //   linear_g = PartOf{
    //     what=PartOf{what=12, inner_extent=nullptr, selected_extent=6},
    //     inner_extent=3,
    //     selected_extent=2
    //   }
    //
    // Third, we will traverse the expression 24 = merge(12, 2). From this
    // expression, we know that
    //   12 = PartOf{what=24, inner_extent=2, selected_extent=12}
    // Substituting definition of 12, in the above definition of linear_g, we get
    //   linear_g = PartOf{
    //     what=PartOf{
    //        what=PartOf{what=24, inner_extent=2, selected_extent=12},
    //        inner_extent=nullptr,
    //        selected_extent=6
    //     },
    //     inner_extent=3,
    //     selected_extent=2
    //   }
    //
    // Finally, we will traverse the expression 4, 6 = split(24). From this
    // expression, we know that
    //   24 = Composition{4, 6}
    // Substituting definition of 24, in the above definition of linear_g, we get
    //   linear_g = PartOf{
    //     what=PartOf{
    //       what=PartOf{
    //         what=Composition{4, 6},
    //         inner_extent=2,
    //         selected_extent=12
    //       },
    //       inner_extent=nullptr,
    //       selected_extent=6
    //     },
    //     inner_extent=3,
    //     selected_extent=2
    //   }
    //
    // Note that the dynamic type Projection has limited expressiveness, we may
    // encounter cases where the projection can not be represented in the language
    // of the dynamic type Projection. For such cases, we will just use
    // std::monostate to denote "unknown".
    Projection propagate(
        const Projection& proj,
        const ValGraph& id_graph,
        const ExprGroup& eg,
        Direction direction);
    
    Projection propagate(
        const ValGroup& group,
        const ValGraph& id_graph,
        const ExprGroup& eg,
        Direction direction) {
      auto from = fromGroups(id_graph, eg, direction);
      auto to = toGroups(id_graph, eg, direction);
      if (from.size() == 1 && to.size() == 2) {
        // If we have
        //    group
        //    /   \.
        //   g1   g2
        // and the split is divisible, then build the following abstract syntax
        // tree:
        //   group = Composition{g1, g2}
        // If the split is not divisible, then build the following abstract syntax
        // tree:
        //   group = PartOf{what=Composition{g1, g2},
        //                  inner_extent=nullptr,
        //                  selected_extent=extent(group)}
        NVF_ERROR(eg->front()->isA<Split>() || eg->front()->isA<Merge>());
        if (from.front() != group) {
          return group;
        }
        auto comp = Composition<Projection>{to.front(), to.back()};
        bool may_be_indivisible_split = eg->front()->isA<Split>() &&
            !simplifyExpr(eg->front()->as<Split>()->isDivisible())->isTrue();
        if (may_be_indivisible_split) {
          return PartOf<Projection>{
              std::make_shared<Projection>(comp),
              /*inner_extent=*/nullptr,
              /*selected_extent=*/extent(group)};
        }
        return comp;
      } else if (from.size() == 2 && to.size() == 1) {
        // If we have
        //   group    g1
        //        \  /
        //         g2
        // then build the following abstract syntax tree
        //   group = PartOf{what=g2,
        //                  inner_extent=extent(g1),
        //                  selected_extent=extent(group)}
        //
        // If we have
        //   g1   group
        //     \  /
        //      g2
        // then build the following abstract syntax tree
        //   group = PartOf{what=g2,
        //                  inner_extent=nullptr,
        //                  selected_extent=extent(group)}
        NVF_ERROR(eg->front()->isA<Split>() || eg->front()->isA<Merge>());
        if (from.front() != group && from.back() != group) {
          return group;
        }
        return PartOf<Projection>{
            std::make_shared<Projection>(to.front()),
            /*inner_extent=*/from.front() == group ? extent(from.back()) : nullptr,
            /*selected_extent=*/
            simplifyExpr(extent(group))};
      }
      if (std::none_of(from.begin(), from.end(), [&](const auto& g) {
            return g == group;
          })) {
        return group;
      }
      // Not representable (or don't know how to represent) by the language of the
      // dynamic type Projection.
      return {};
    }
    
    Projection propagate(
        const PartOf<Projection>& part,
        const ValGraph& id_graph,
        const ExprGroup& eg,
        Direction direction) {
      // Just recursively propagate subtree.
      auto propagated = propagate(*part.what, id_graph, eg, direction);
      if (!propagated.hasValue()) {
        return {};
      }
    Documentation

    The updated test cases in the documentation should be reviewed to ensure that they accurately reflect the new validation logic and provide clear examples of expected behavior.

    to access a contiguous 32 or 16 lanes of data.<!-- */ //-->\
    ```cpp
    TEST_F(TMemTutorialC, WrongSubpartition) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeContigConcreteTensor({2, 2, 32, 2});
      fusion.addInput(tv0);
      auto tv1 = set(tv0);
      auto tv2 = set(tv1);
      auto tv3 = set(tv2);
      auto tv4 = set(tv3);
      fusion.addOutput(tv4);
      tv2->setMemoryType(MemoryType::Tensor);
      tv2->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::StTMem);
      tv3->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::LdTMem);
    
      tv4->axis(0)->parallelize(ParallelType::TIDy);
      tv4->axis(2)->parallelize(ParallelType::TIDx);
      scheduler_utils::parallelizeAllLike(tv4);
    
      tv2->setTMemDimSepPos(3);
    
      EXPECT_THAT(
          [&]() { KernelExecutor().compile(&fusion); },
          ::testing::ThrowsMessage<nvfError>(::testing::HasSubstr(
              "Invalid data access pattern in TMem load/store: "
              "Warps are not accessing the correct sub-partition.")));
    } /*

    The above example is invalid because the warp accesses the wrong subpartition of
    the tensor memory. In the above example, there are two warps, where warp 0
    accesses the subpartition 0 and 1, and warp 1 accesses the subpartition 2 and 3.
    However, warp 0 can only access subpartition 0, and warp 1 can only access
    subpartition 1.\

    TEST_F(TMemTutorialC, WrongSubpartition2) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto tv0 = makeContigConcreteTensor({32, 2});
      fusion.addInput(tv0);
      auto tv1 = set(tv0);
      auto tv2 = set(tv1);
      auto tv3 = set(tv2);
      auto tv4 = set(tv3);
      fusion.addOutput(tv4);
      tv2->setMemoryType(MemoryType::Tensor);
      tv2->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::StTMem);
      tv3->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::LdTMem);
    
      tv4->axis(0)->parallelize(ParallelType::TIDx);
      tv4->axis(1)->parallelize(ParallelType::TIDy);
      scheduler_utils::parallelizeAllLike(tv4);
    
      tv2->setTMemDimSepPos(1);
    
      EXPECT_THAT(
          [&]() { KernelExecutor().compile(&fusion); },
          ::testing::ThrowsMessage<nvfError>(::testing::HasSubstr(
              "Invalid data access pattern in TMem load/store: "
              "Warps are not accessing the correct sub-partition.")));
    } /*

    Copy link
    Collaborator Author

    @zasdfgbnm zasdfgbnm Mar 6, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This file is updated to handle an edge case where we have transformations like:

    split(1, 4) -> (1, 4)
    

    For this case, we still consider it linear because only 1 of the 4 is valid part, and by definition, 1 is always linear w.r.t. anything.

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant