Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check that warps are only accessing the subpartition of TMem that it can access #4016

Open
wants to merge 6 commits into
base: tmem-pm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion csrc/device_lower/analysis/tensor_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,28 @@ computeTMemLdStDataPath(Fusion* fusion, const TMemAlllocationInfo& allocation) {
}
NVF_THROW(error.str());
}
// TODO: Validate that we are accessing the correct sub-partition
// Validate that warps are accessing the correct sub-partition
AbstractTensorWithInfo<Contiguity> t = pdims;
t.split(-1, 32);
t.split(-2, 4);
Val* warp_group_stride = lower_utils::proveLinearAndGetStride(
id_graph,
t[-2].as<ValGroupAndItsGraph>().group,
lane_allocation_valgroups);
NVF_ERROR(
warp_group_stride != nullptr,
"Invalid data access pattern in TMem load/store: ",
"Warps are not accessing the correct sub-partition.");
// The stride must be either 0 or 32, 32 is the most common case.
// 0 is a special value indicating that there is only one warp.
GpuLower::current()->validate(
SimplifyingIrBuilder::logicalOrExpr(
SimplifyingIrBuilder::eqExpr(
warp_group_stride, IrBuilder::create<Val>(32)),
SimplifyingIrBuilder::eqExpr(
warp_group_stride, IrBuilder::create<Val>(0))),
"Invalid data access pattern in TMem load/store: ",
"Warps are not accessing the correct sub-partition.");
}
return {std::move(load_data_path), std::move(store_data_path)};
}
Expand Down
9 changes: 4 additions & 5 deletions csrc/device_lower/utils.cpp
Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is updated to handle an edge case where we have transformations like:

split(1, 4) -> (1, 4)

For this case, we still consider it linear because only 1 of the 4 is valid part, and by definition, 1 is always linear w.r.t. anything.

Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ Val* extent(const Composition<Projection>& comp) {
return std::accumulate(
comp.begin(),
comp.end(),
static_cast<Val*>(nullptr),
FusionGuard::getCurFusion()->oneVal(),
[](Val* acc, const auto& g) {
return SimplifyingIrBuilder::mulExpr(acc, extent(g));
});
Expand Down Expand Up @@ -1483,8 +1483,6 @@ Projection propagate(
const ExprGroup& eg,
Direction direction) {
// Just recursively propagate subtree.
auto from = fromGroups(id_graph, eg, direction);
auto to = toGroups(id_graph, eg, direction);
auto propagated = propagate(*part.what, id_graph, eg, direction);
if (!propagated.hasValue()) {
return {};
Expand Down Expand Up @@ -1618,6 +1616,9 @@ Val* proveLinearAndGetStrideAfterPropagation(
Val* proveLinearAndGetStrideAfterPropagation(
const Composition<Projection>& comp,
const ValGroups& domain) {
if (comp.empty()) {
return FusionGuard::getCurFusion()->zeroVal();
}
auto it = search(domain, comp);
if (it == domain.end()) {
return nullptr;
Expand Down Expand Up @@ -1717,7 +1718,6 @@ PartOf<Projection> cancelCommonFactors(const PartOf<Projection>& part) {
if (new_inner_extent->isOne()) {
new_inner_extent = nullptr;
}
NVF_ERROR(!dq.empty());
if (dq.size() == 1) {
return PartOf<Projection>{
std::make_shared<Projection>(dq.front()),
Expand Down Expand Up @@ -1806,7 +1806,6 @@ PartOf<Projection> trimRedundant(const PartOf<Projection>& part) {
while (count < (int64_t)dq.size()) {
dq.pop_front();
}
NVF_ERROR(!dq.empty());
if (dq.size() == 1) {
return PartOf<Projection>{
std::make_shared<Projection>(dq.front()),
Expand Down
8 changes: 4 additions & 4 deletions doc/dev/tmem.md
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,6 @@ columns of the tensor memory, while all the specified patterns requires the warp
to access a contiguous 32 or 16 lanes of data.<!-- */ //-->\
```cpp
TEST_F(TMemTutorialC, WrongSubpartition) {
NOT_IMPLEMENTED
Fusion fusion;
FusionGuard fg(&fusion);

Expand All @@ -610,7 +609,8 @@ TEST_F(TMemTutorialC, WrongSubpartition) {
EXPECT_THAT(
[&]() { KernelExecutor().compile(&fusion); },
::testing::ThrowsMessage<nvfError>(::testing::HasSubstr(
"Invalid data access pattern in TMem load/store.")));
"Invalid data access pattern in TMem load/store: "
"Warps are not accessing the correct sub-partition.")));
} /*
```

Expand All @@ -621,7 +621,6 @@ However, warp 0 can only access subpartition 0, and warp 1 can only access
subpartition 1.<!-- */ //-->\
```cpp
TEST_F(TMemTutorialC, WrongSubpartition2) {
NOT_IMPLEMENTED
Fusion fusion;
FusionGuard fg(&fusion);

Expand All @@ -645,7 +644,8 @@ TEST_F(TMemTutorialC, WrongSubpartition2) {
EXPECT_THAT(
[&]() { KernelExecutor().compile(&fusion); },
::testing::ThrowsMessage<nvfError>(::testing::HasSubstr(
"Invalid data access pattern in TMem load/store.")));
"Invalid data access pattern in TMem load/store: "
"Warps are not accessing the correct sub-partition.")));
} /*
```

Expand Down
Loading