-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA][Backend] Add CoalesceAsyncCopy Pass for in-DotOpEnc Upcasting #5222
Merged
Merged
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
dc94f69
Add coalesce async copy pass
ggengnv 14858e4
Make logic more general
ggengnv f1af158
Document and format
ggengnv 2124a06
Fix random typo
ggengnv 5b0f4ad
Move memdesc to ttg in lit test
ggengnv d3a50e9
Address comments
ggengnv 82e9f63
Fix bug and add test
ggengnv 3d9be5a
Remove unused includes
ggengnv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include "triton/Analysis/Utility.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/Utility.h" | ||
#include "triton/Dialect/TritonGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h" | ||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h" | ||
|
||
#include <memory> | ||
|
||
namespace tt = mlir::triton; | ||
|
||
namespace mlir { | ||
namespace triton { | ||
namespace gpu { | ||
|
||
#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY | ||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" | ||
|
||
// This pass currently only applies if the following are all true... | ||
// 1) Operand A for WGMMA is to be loaded in registers | ||
// 2) We upcast operand A in registers before the WGMMA | ||
// (downcasting is not yet supported) | ||
// 3) Pipelining is enabled for loading A | ||
// | ||
// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding | ||
// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if | ||
// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread | ||
// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two | ||
// 8-byte-cp.async's for each contiguous 16B global data owned by each | ||
// thread. This breaks coalescing (i.e. results 2x the minimum required | ||
// transactions). | ||
// | ||
// This issue occurs for cp.async because it combines load and store into one | ||
// instruction. The fix is to clip each dim of sizePerThread by shared vec, so | ||
// that the vectorization of load and store are equal along the contiguous | ||
// dimension. In the above example, each thread will then only own 8B contiguous | ||
// global data. | ||
struct ClipAsyncCopySizePerThread | ||
: public OpRewritePattern<AsyncCopyGlobalToLocalOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp, | ||
PatternRewriter &rewriter) const override { | ||
Value src = copyOp.getSrc(); | ||
Value mask = copyOp.getMask(); | ||
Value other = copyOp.getOther(); | ||
auto srcTy = cast<RankedTensorType>(src.getType()); | ||
auto blockEnc = cast<BlockedEncodingAttr>(srcTy.getEncoding()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can't assume the copy will use blocked layout |
||
auto dstTy = cast<MemDescType>(copyOp.getResult().getType()); | ||
auto sharedEnc = cast<SharedEncodingAttr>(dstTy.getEncoding()); | ||
auto sharedVec = sharedEnc.getVec(); | ||
|
||
// obtain max contiguous copy size | ||
// Note this can be further optimized, as copyContigSize can be even | ||
// smaller when lowering, depending on contiguity and mask alignment | ||
// (see AsyncCopyGlobalToLocalOpConversion) | ||
auto elemBitWidth = dstTy.getElementTypeBitWidth(); | ||
auto regToSharedLayout = | ||
getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc, | ||
sharedEnc, elemBitWidth); | ||
auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut(); | ||
|
||
// obtain block sizePerThread along contig dim | ||
auto sizePerThread = blockEnc.getSizePerThread(); | ||
auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]]; | ||
|
||
if (blockContigSize <= copyContigSize) | ||
return rewriter.notifyMatchFailure( | ||
copyOp, | ||
"blocked sizePerThread along contiguous dim must be greater than the " | ||
"max contiguous copy size "); | ||
|
||
sizePerThread[blockEnc.getOrder()[0]] = copyContigSize; | ||
|
||
// obtain new blockedEnc based on clipped sizePerThread | ||
auto mod = copyOp->getParentOfType<ModuleOp>(); | ||
int numWarps = TritonGPUDialect::getNumWarps(mod); | ||
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); | ||
auto newBlockEnc = BlockedEncodingAttr::get( | ||
copyOp.getContext(), srcTy.getShape(), sizePerThread, | ||
blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout()); | ||
|
||
// insert cvt's after src, mask, and other | ||
auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) { | ||
auto ty = cast<TensorType>(src.getType()); | ||
auto newTy = | ||
RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); | ||
auto cvt = rewriter.create<ConvertLayoutOp>(copyOp->getLoc(), newTy, src); | ||
return cvt.getResult(); | ||
}; | ||
src = convertBlockLayout(src, newBlockEnc); | ||
if (mask) | ||
mask = convertBlockLayout(mask, newBlockEnc); | ||
if (other) | ||
other = convertBlockLayout(other, newBlockEnc); | ||
|
||
// replace the asyncCopy | ||
auto newCopyOp = rewriter.create<AsyncCopyGlobalToLocalOp>( | ||
copyOp.getLoc(), src, copyOp.getResult(), mask, other, | ||
copyOp.getCache(), copyOp.getEvict(), copyOp.getIsVolatile()); | ||
rewriter.replaceOp(copyOp, newCopyOp); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, you could do in place update |
||
|
||
return success(); | ||
} | ||
}; | ||
|
||
class CoalesceAsyncCopyPass | ||
: public impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> { | ||
public: | ||
void runOnOperation() override { | ||
ModuleOp m = getOperation(); | ||
MLIRContext *context = &getContext(); | ||
|
||
mlir::RewritePatternSet patterns(context); | ||
patterns.add<ClipAsyncCopySizePerThread>(context); | ||
|
||
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) | ||
signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace gpu | ||
} // namespace triton | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s | ||
|
||
// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]> | ||
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> | ||
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> | ||
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]> | ||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> | ||
|
||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { | ||
tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr<i8>, #blocked>, | ||
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>, | ||
%mask: tensor<64x16xi1, #blocked>, | ||
%other: tensor<64x16xi8, #blocked>) { | ||
%token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> | ||
tt.return | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is a bit of a layering violation, getRegToSharedLayout probably belongs to triton gpu dialect utils.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you remove the include now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right I forgot - just removed