Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid ublk tma out bound access #3917

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
176631d
extend circular buffer tests and fix index
liqiangxl Feb 9, 2025
898d4a6
format
liqiangxl Feb 9, 2025
7134d2d
clean
liqiangxl Feb 9, 2025
b0fc704
simple fix
liqiangxl Feb 10, 2025
c3e035f
use idmodel index for 1d tma load
liqiangxl Feb 10, 2025
f87d23a
Merge branch 'main' into llu/fix_index_1dtma_warpspecialization
liqiangxl Feb 10, 2025
e8dfcf1
Merge branch 'main' into llu/fix_index_1dtma_warpspecialization
liqiangxl Feb 12, 2025
37ddfd1
skip tests causing src address overflow for 1d tma
liqiangxl Feb 13, 2025
d99a60f
Merge branch 'main' into llu/fix_index_1dtma_warpspecialization
liqiangxl Feb 13, 2025
7423bb4
Merge remote-tracking branch 'origin/llu/fix_index_1dtma_warpspeciali…
liqiangxl Feb 14, 2025
6786b38
use tid.x==0
liqiangxl Feb 14, 2025
4a1153a
revert not related changes
liqiangxl Feb 14, 2025
ed772be
fix typo
liqiangxl Feb 14, 2025
a3bf733
clang
liqiangxl Feb 14, 2025
55772d5
Merge branch 'main' into llu/tma_predicate
liqiangxl Feb 14, 2025
d65c183
tidy
liqiangxl Feb 14, 2025
d0515c6
Merge branch 'llu/tma_predicate' of https://github.com/nvidia/fuser i…
liqiangxl Feb 14, 2025
aa4a20b
clangtidy
liqiangxl Feb 14, 2025
a5de65a
needs clean
liqiangxl Feb 17, 2025
67fb333
clean
liqiangxl Feb 17, 2025
5b2a40d
2 stages pass
liqiangxl Feb 17, 2025
0a24930
wip
liqiangxl Feb 18, 2025
878d9cb
Merge branch 'main' into llu/tma_predicate
liqiangxl Feb 18, 2025
70880c0
wip
liqiangxl Feb 18, 2025
49bee3f
wip
liqiangxl Feb 18, 2025
4f183da
use modulo to avoid out of bound access
liqiangxl Feb 18, 2025
932ac08
fix test
liqiangxl Feb 18, 2025
baf46f8
Merge branch 'main' into llu/tma_out_bound_access
liqiangxl Feb 18, 2025
68fcd73
only revise for gmem tv
liqiangxl Feb 18, 2025
48bedc1
Merge branch 'main' into llu/tma_out_bound_access
liqiangxl Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,24 @@ inline CpAsyncBulkMode getCpAsyncBulkMode(const Expr* expr) {

} // namespace

bool isCpAsyncUblk(const Expr* expr) {
if (auto ldst = dynamic_cast<const LoadStoreOp*>(expr)) {
auto op_type = ldst->opType();
if (op_type == LoadStoreOpType::CpAsyncBulk) {
auto in_mem = getTv(ldst->in())->getMemoryType();
auto out_mem = getTv(ldst->out())->getMemoryType();
if ((in_mem == MemoryType::Global && out_mem == MemoryType::Shared) ||
(in_mem == MemoryType::Shared && out_mem == MemoryType::Global)) {
return true;
} else {
NVF_THROW("Invalid memory types for CpAsyncBulk");
}
}
return false;
}
return false;
}

bool isCpAsyncBulk(const Expr* expr) {
return getCpAsyncBulkMode(expr) != CpAsyncBulkMode::NotACpAsyncBulk;
}
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ bool isCpAsyncOp(const Expr* expr);
bool isCpAsyncBulkLoad(const Expr* expr);
bool isCpAsyncBulkStore(const Expr* expr);
bool isCpAsyncBulk(const Expr* expr);
bool isCpAsyncUblk(const Expr* expr);

//! Short-cut for detecting initialization for cpAsync op.
bool isCpAsyncInit(const Expr* expr);
Expand Down
22 changes: 22 additions & 0 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,28 @@ Val* TensorIndexer::getLinearIndex(
SimplifyingIrBuilder::addExpr(linear_index, circular_buffer_offset);
}

// Modulo the linear index if the expr is a UBLK copy and the index is for
// gmem tv, this is to avoid out of bound access. For example, in test
// `UblkPredicate`, tensor [I0,I1] is split as: [sm_count, I0/stages/sm_count,
// stages, I1] and parallelized as: [BIDx, Serial, Serial, Bulk]. The TMA load
// is nested within two for-loops, one for [I0/stages/sm_count] and the other
// for [stages], since predicate is not generated for TMA load, out of bound
// access may happen if any of the split is not disvisible. The modulo
// operation is added to avoid this issue at the cost of several useless loads
// in the last iteration.
if (ir_utils::isCpAsyncUblk(expr)) {
auto gmem_tv = expr->input(0)->as<TensorView>();
if (gmem_tv == tv) {
auto logical_size = gmem_tv->fusion()->oneVal();
const auto& logical_domain = gmem_tv->getLogicalDomain();
for (const auto i : c10::irange(logical_domain.size())) {
logical_size = SimplifyingIrBuilder::mulExpr(
logical_size, logical_domain.at(i)->extent());
}
linear_index = SimplifyingIrBuilder::modExpr(linear_index, logical_size);
}
}

return linear_index;
}

Expand Down
95 changes: 95 additions & 0 deletions tests/cpp/test_circular_buffering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2158,4 +2158,99 @@ INSTANTIATE_TEST_SUITE_P(
tmaCircularBufferingParams(),
tmaName);

// Similar to TmaCircularBufferingTest, but only test 1D TMA (UBLK) with one
// tensor size. Outer dim is a prime number to test predicate due to
// non-divisble split.
class TmaCircularBufferingTestUblk : public TmaCircularBufferingTest {};
TEST_P(TmaCircularBufferingTestUblk, Predicate) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);

if (testEnablesRegisterSharing()) {
GTEST_SKIP();
return;
}
constexpr at::ScalarType dtype = at::ScalarType::Float;
CompileParams index32bit{DataType::Int32, 255, false};
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto tv0 = makeContigTensor(2, aten_to_data_type(dtype));
fusion->addInput(tv0);
auto tv1 = add(tv0, tv0);
fusion->addOutput(tv1);

auto tv0a = tv0->cacheAfter(tma_load_type);
auto tv1c = tv1->cacheBefore();
tv0a->setMemoryType(MemoryType::Shared);

// tensor_outer_dim is a prime number, not divisible by number_of_stages or
// number_of_cta when stages is 1. When stages > 1, increase number_of_cta to
// make sure the 2nd split i also not divisible by number_of_cta.
int64_t number_of_cta =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
if (number_of_stages > 1) {
int64_t after_stages =
(tensor_outer_dim + number_of_stages - 1) / number_of_stages;
while (after_stages % number_of_cta == 0) {
number_of_cta++;
}
}
tv1->split(0, number_of_stages);
tv1->split(0, number_of_cta, false);
TransformPropagator propagator(tv1);
MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator);

tv1->axis(0)->parallelize(ParallelType::BIDx);
scheduler_utils::parallelizeAllLike(tv1);

/// TIDx for computation, Bulk for load
tv1->axis(-1)->parallelize(ParallelType::TIDx);
tv1c->axis(-1)->parallelize(ParallelType::TIDx);
tv0a->axis(-1)->parallelize(ParallelType::Bulk);
inlineMost();

if (number_of_stages > 1) {
tv0a->circularBuffer(
number_of_stages, prefetch_distance, circular_buffer_type);
}

auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
at::Tensor at_tv0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);

KernelExecutor ke;
ke.compile(fusion.get(), {at_tv0}, {}, index32bit);
auto outputs = ke.run({at_tv0});
auto at_output = at_tv0 + at_tv0;
testValidate(
fusion.get(), outputs, {at_tv0}, {at_output}, __LINE__, __FILE__);
}
auto tmaUblkPredicateParams() {
// When using register sharing with warp-specialized circular buffering, the
// circular buffer loop must be the outer-most for-loop
const std::vector<CircularBufferType> all_types{
Pipelined(false),
Pipelined(true),
WarpSpecialized(ParallelType::TIDx),
WarpSpecialized(ParallelType::TIDy),
WarpSpecialized(ParallelType::TIDz)};
int64_t dim0 = 8191, dim1 = 256;
const std::vector<LoadStoreOpType> tma_types{LoadStoreOpType::CpAsyncBulk};
std::vector<TmaCircularBufferingParams> values;
for (int64_t i : {2, 4}) {
for (int64_t j : c10::irange(-i, i)) {
for (auto circular_buffer_type : all_types) {
for (auto tma_load_type : tma_types) {
values.emplace_back(
i, j, dim0, dim1, circular_buffer_type, tma_load_type);
}
}
}
}
return testing::ValuesIn(values);
}
INSTANTIATE_TEST_SUITE_P(
UblkTma,
TmaCircularBufferingTestUblk,
tmaUblkPredicateParams(),
tmaName);

} // namespace nvfuser