-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NVIDIA][Backend] Add CoalesceAsyncCopy Pass for in-DotOpEnc Upcasting (
#5222) This is a follow-up to the dotOp hoisting optimization for WGMMA (MMAv3). See #5003 (comment) In short, when upcasting operand A in registers prior to WGMMA and when pipelining is enabled, `AsyncCopyGLobalToLocal`'s src gmem blocked encoding will have `sizePerThread` > smem view's `vec` (along the contiguous dimension). This will resulting in multiple `cp.async` instructions being generated for a contiguous global data segment, resulting in uncoalesced loads. This was previously confirmed in ncu. See above comment for an example. I've added a generalized fix in a new pass after the pipeliner. I've reused the logic in the LLVM lowering for `AsyncCopyGlobalToLocal` to calculate the max contiguous copy size. I compare that to the blockEnc's `sizePerThread` along the inner (contiguous) dimension. If the former is less than latter, I set the latter to former. When A is k-major, can verify a small perf improvement and that ncu no longer reports uncoalesced loads. When A is m-major, this pass is a no-op because `copy size == sizePerThread == 16` ptal, thanks @ThomasRaoux
- Loading branch information
Showing
9 changed files
with
228 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include "triton/Analysis/Utility.h" | ||
#include "triton/Dialect/TritonGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h" | ||
|
||
namespace mlir { | ||
namespace triton { | ||
namespace gpu { | ||
|
||
#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY | ||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" | ||
|
||
// This pass currently only applies if the following are all true... | ||
// 1) Operand A for WGMMA is to be loaded in registers | ||
// 2) We upcast operand A in registers before the WGMMA | ||
// (downcasting is not yet supported) | ||
// 3) Pipelining is enabled for loading A | ||
// | ||
// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding | ||
// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if | ||
// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread | ||
// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two | ||
// 8-byte-cp.async's for each contiguous 16B global data owned by each | ||
// thread. This breaks coalescing (i.e. results 2x the minimum required | ||
// transactions). | ||
// | ||
// This issue occurs for cp.async because it combines load and store into one | ||
// instruction. The fix is to clip each dim of sizePerThread by shared vec, so | ||
// that the vectorization of load and store are equal along the contiguous | ||
// dimension. In the above example, each thread will then only own 8B contiguous | ||
// global data. | ||
struct ClipAsyncCopySizePerThread | ||
: public OpRewritePattern<AsyncCopyGlobalToLocalOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp, | ||
PatternRewriter &rewriter) const override { | ||
Value src = copyOp.getSrc(); | ||
Value mask = copyOp.getMask(); | ||
Value other = copyOp.getOther(); | ||
auto srcTy = cast<RankedTensorType>(src.getType()); | ||
auto dstTy = cast<MemDescType>(copyOp.getResult().getType()); | ||
auto blockEnc = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding()); | ||
if (!blockEnc) | ||
return rewriter.notifyMatchFailure(copyOp, | ||
"src must be of blocked encoding"); | ||
auto sharedEnc = cast<SharedEncodingAttr>(dstTy.getEncoding()); | ||
auto sharedVec = sharedEnc.getVec(); | ||
|
||
// obtain max contiguous copy size | ||
// Note this can be further optimized, as copyContigSize can be even | ||
// smaller when lowering, depending on contiguity and mask alignment | ||
// (see AsyncCopyGlobalToLocalOpConversion) | ||
auto elemBitWidth = dstTy.getElementTypeBitWidth(); | ||
auto regToSharedLayout = | ||
getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc, | ||
sharedEnc, elemBitWidth); | ||
auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut(); | ||
|
||
// obtain block sizePerThread along contig dim | ||
auto sizePerThread = blockEnc.getSizePerThread(); | ||
auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]]; | ||
|
||
if (blockContigSize <= copyContigSize) | ||
return rewriter.notifyMatchFailure( | ||
copyOp, | ||
"blocked sizePerThread along contiguous dim must be greater than the " | ||
"max contiguous copy size "); | ||
|
||
sizePerThread[blockEnc.getOrder()[0]] = copyContigSize; | ||
|
||
// obtain new blockedEnc based on clipped sizePerThread | ||
auto mod = copyOp->getParentOfType<ModuleOp>(); | ||
int numWarps = TritonGPUDialect::getNumWarps(mod); | ||
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); | ||
auto newBlockEnc = BlockedEncodingAttr::get( | ||
copyOp.getContext(), srcTy.getShape(), sizePerThread, | ||
blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout()); | ||
|
||
// insert cvt's after src, mask, and other | ||
auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) { | ||
auto ty = cast<TensorType>(src.getType()); | ||
auto newTy = | ||
RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); | ||
auto cvt = rewriter.create<ConvertLayoutOp>(copyOp->getLoc(), newTy, src); | ||
return cvt.getResult(); | ||
}; | ||
src = convertBlockLayout(src, newBlockEnc); | ||
if (mask) | ||
mask = convertBlockLayout(mask, newBlockEnc); | ||
if (other) | ||
other = convertBlockLayout(other, newBlockEnc); | ||
|
||
rewriter.modifyOpInPlace(copyOp, [&]() { | ||
copyOp.getSrcMutable().assign(src); | ||
if (mask) | ||
copyOp.getMaskMutable().assign(mask); | ||
if (other) | ||
copyOp.getOtherMutable().assign(other); | ||
}); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
class CoalesceAsyncCopyPass | ||
: public impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> { | ||
public: | ||
void runOnOperation() override { | ||
ModuleOp m = getOperation(); | ||
MLIRContext *context = &getContext(); | ||
|
||
mlir::RewritePatternSet patterns(context); | ||
patterns.add<ClipAsyncCopySizePerThread>(context); | ||
|
||
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) | ||
signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace gpu | ||
} // namespace triton | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s | ||
|
||
// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]> | ||
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> | ||
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> | ||
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]> | ||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> | ||
|
||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { | ||
tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr<i8>, #blocked>, | ||
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>, | ||
%mask: tensor<64x16xi1, #blocked>, | ||
%other: tensor<64x16xi8, #blocked>) { | ||
%token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> | ||
tt.return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]> | ||
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]> | ||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> | ||
|
||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { | ||
tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr<i8>, #blocked>, | ||
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>) { | ||
%token = triton_gpu.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> | ||
tt.return | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters