Skip to content

Commit

Permalink
[NVIDIA][Backend] Add CoalesceAsyncCopy Pass for in-DotOpEnc Upcasting (
Browse files Browse the repository at this point in the history
#5222)

This is a follow-up to the dotOp hoisting optimization for WGMMA
(MMAv3). See
#5003 (comment)
 
In short, when upcasting operand A in registers prior to WGMMA and when
pipelining is enabled, `AsyncCopyGLobalToLocal`'s src gmem blocked
encoding will have `sizePerThread` > smem view's `vec` (along the
contiguous dimension). This will resulting in multiple `cp.async`
instructions being generated for a contiguous global data segment,
resulting in uncoalesced loads. This was previously confirmed in ncu.
See above comment for an example.
 
I've added a generalized fix in a new pass after the pipeliner. I've
reused the logic in the LLVM lowering for `AsyncCopyGlobalToLocal` to
calculate the max contiguous copy size. I compare that to the blockEnc's
`sizePerThread` along the inner (contiguous) dimension. If the former is
less than latter, I set the latter to former.

When A is k-major, can verify a small perf improvement and that ncu no
longer reports uncoalesced loads.
When A is m-major, this pass is a no-op because `copy size ==
sizePerThread == 16`

ptal, thanks @ThomasRaoux
  • Loading branch information
ggengnv authored Nov 27, 2024
1 parent 7b2beae commit b8a4b87
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 34 deletions.
12 changes: 12 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,16 @@ def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp">
"number of pipeline stages">
];
}

def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> {
let summary = "Improve coalescing for async global to local copies";

let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than "
"the blocked encoding's sizePerThread, this pass improves coalescing by clipping the "
"sizePerThread value";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}

#endif
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ enum class MMALoadType {
// pipelining
};
MMALoadType getMMALoadType(Operation *loadOp);

// Returns composed LinearLayout for register to shared copy
std::optional<triton::LinearLayout>
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc, int elemBitWidth);
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
46 changes: 12 additions & 34 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir {
Expand Down Expand Up @@ -174,41 +175,17 @@ bool emitTransferBetweenRegistersAndShared(
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");

std::optional<LinearLayout> regLayout =
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
std::optional<LinearLayout> sharedLayout = triton::gpu::toLinearLayout(
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
if (!regLayout.has_value() || !sharedLayout.has_value()) {
auto regToSharedLayout = getRegToSharedLayout(
ctx, shape, registerTy.getEncoding(), sharedTy.getEncoding(),
elemLlvmTy.getIntOrFloatBitWidth());
if (!regToSharedLayout.has_value())
return false;
}
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());

// sharedLayout's in-dims are currently (offset, block). Reshape to
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
// shmem strides. (The offsetX's appear in minor-to-major order.)
auto sharedLegacy =
cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
for (int i = 0; i < rank; i++) {
int dim = sharedOrder[i];
int64_t size = std::max(
int64_t{1},
shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]);
multiDimSharedSize.push_back(
{str_attr("offset" + std::to_string(dim)), size});
}
multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)});
sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize);

// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout);

// TODO(jlebar): We don't currently support loading from shared memory in a
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock);
for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock);
inBlock *= 2) {
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply(
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply(
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}})));
// offsetX1, ..., offsetXN must all be 0.
if (!llvm::all_of(ArrayRef(idx).drop_back(1),
Expand All @@ -234,15 +211,15 @@ bool emitTransferBetweenRegistersAndShared(
// which have known strides. This would allow us to vectorize across multiple
// shmem out dimensions where possible.
const int vecElems =
std::min(regToSharedLayout.getNumConsecutiveInOut(),
std::min(regToSharedLayout->getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane));
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);

int numElems = regToSharedLayout.getInDimSize(kRegister);
int numElems = regToSharedLayout->getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
auto ptrTy = shmemBase.getType();
Value zero = i32_val(0);
Expand All @@ -253,14 +230,15 @@ bool emitTransferBetweenRegistersAndShared(
// we drop_end to drop block, which we know from above will be 0.
auto multiDimShmemOffset =
llvm::to_vector(llvm::drop_end(llvm::make_second_range(
applyLinearLayout(loc, rewriter, regToSharedLayout,
applyLinearLayout(loc, rewriter, *regToSharedLayout,
{{kRegister, i32_val(i * vecElems)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, zero}}))));

// Reorder strides according to `order`. This way they match the
// multi-dimensional offsets in regToSharedLayout.
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset,
applyPermutation(shmemStrides, sharedOrder));
auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_triton_library(TritonGPUTransforms
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
CoalesceAsyncCopy.cpp
Utility.cpp

DEPENDS
Expand Down
124 changes: 124 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

namespace mlir {
namespace triton {
namespace gpu {

#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

// This pass currently only applies if the following are all true...
// 1) Operand A for WGMMA is to be loaded in registers
// 2) We upcast operand A in registers before the WGMMA
// (downcasting is not yet supported)
// 3) Pipelining is enabled for loading A
//
// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding
// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if
// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread
// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two
// 8-byte-cp.async's for each contiguous 16B global data owned by each
// thread. This breaks coalescing (i.e. results 2x the minimum required
// transactions).
//
// This issue occurs for cp.async because it combines load and store into one
// instruction. The fix is to clip each dim of sizePerThread by shared vec, so
// that the vectorization of load and store are equal along the contiguous
// dimension. In the above example, each thread will then only own 8B contiguous
// global data.
struct ClipAsyncCopySizePerThread
: public OpRewritePattern<AsyncCopyGlobalToLocalOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp,
PatternRewriter &rewriter) const override {
Value src = copyOp.getSrc();
Value mask = copyOp.getMask();
Value other = copyOp.getOther();
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(copyOp.getResult().getType());
auto blockEnc = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
if (!blockEnc)
return rewriter.notifyMatchFailure(copyOp,
"src must be of blocked encoding");
auto sharedEnc = cast<SharedEncodingAttr>(dstTy.getEncoding());
auto sharedVec = sharedEnc.getVec();

// obtain max contiguous copy size
// Note this can be further optimized, as copyContigSize can be even
// smaller when lowering, depending on contiguity and mask alignment
// (see AsyncCopyGlobalToLocalOpConversion)
auto elemBitWidth = dstTy.getElementTypeBitWidth();
auto regToSharedLayout =
getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc,
sharedEnc, elemBitWidth);
auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut();

// obtain block sizePerThread along contig dim
auto sizePerThread = blockEnc.getSizePerThread();
auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]];

if (blockContigSize <= copyContigSize)
return rewriter.notifyMatchFailure(
copyOp,
"blocked sizePerThread along contiguous dim must be greater than the "
"max contiguous copy size ");

sizePerThread[blockEnc.getOrder()[0]] = copyContigSize;

// obtain new blockedEnc based on clipped sizePerThread
auto mod = copyOp->getParentOfType<ModuleOp>();
int numWarps = TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
auto newBlockEnc = BlockedEncodingAttr::get(
copyOp.getContext(), srcTy.getShape(), sizePerThread,
blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout());

// insert cvt's after src, mask, and other
auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) {
auto ty = cast<TensorType>(src.getType());
auto newTy =
RankedTensorType::get(ty.getShape(), ty.getElementType(), enc);
auto cvt = rewriter.create<ConvertLayoutOp>(copyOp->getLoc(), newTy, src);
return cvt.getResult();
};
src = convertBlockLayout(src, newBlockEnc);
if (mask)
mask = convertBlockLayout(mask, newBlockEnc);
if (other)
other = convertBlockLayout(other, newBlockEnc);

rewriter.modifyOpInPlace(copyOp, [&]() {
copyOp.getSrcMutable().assign(src);
if (mask)
copyOp.getMaskMutable().assign(mask);
if (other)
copyOp.getOtherMutable().assign(other);
});

return success();
}
};

class CoalesceAsyncCopyPass
: public impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> {
public:
void runOnOperation() override {
ModuleOp m = getOperation();
MLIRContext *context = &getContext();

mlir::RewritePatternSet patterns(context);
patterns.add<ClipAsyncCopySizePerThread>(context);

if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
signalPassFailure();
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir
36 changes: 36 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1154,4 +1154,40 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) {
patterns.add<ForOpDeadArgElimination>(patterns.getContext());
}

std::optional<LinearLayout>
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc, int elemBitWidth) {
StringAttr kBlock = StringAttr::get(ctx, ("block"));
int rank = shape.size();

std::optional<LinearLayout> regLayout =
triton::gpu::toLinearLayout(shape, srcEnc);
std::optional<LinearLayout> sharedLayout =
triton::gpu::toLinearLayout(shape, dstEnc, elemBitWidth);
if (!regLayout.has_value() || !sharedLayout.has_value()) {
return std::nullopt;
}
auto sharedOrder = triton::gpu::getOrder(dstEnc);

// sharedLayout's in-dims are currently (offset, block). Reshape to
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
// shmem strides. (The offsetX's appear in minor-to-major order.)
auto sharedLegacy = cast<triton::gpu::SharedEncodingAttr>(dstEnc);
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
for (int i = 0; i < rank; i++) {
int dim = sharedOrder[i];
int64_t size = std::max(
int64_t{1},
shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]);
multiDimSharedSize.push_back(
{StringAttr::get(ctx, ("offset" + std::to_string(dim))), size});
}
multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)});
sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize);

// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
return regLayout->invertAndCompose(*sharedLayout);
}

} // namespace mlir
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createTritonGPUOptimizeAccumulatorInit);
ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling",
createTritonGPULoopScheduling, int);
ADD_PASS_WRAPPER_0("add_coalesce_async_copy",
createTritonGPUCoalesceAsyncCopy);
}

void init_triton_passes_convert(py::module &&m) {
Expand Down
35 changes: 35 additions & 0 deletions test/TritonGPU/coalesce-async-copy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s

// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr<i8>, #blocked>,
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>,
%mask: tensor<64x16xi1, #blocked>,
%other: tensor<64x16xi8, #blocked>) {
%token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable>
tt.return
}
}

// -----

// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr<i8>, #blocked>,
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>) {
%token = triton_gpu.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable>
tt.return
}
}
1 change: 1 addition & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def make_ttgir(mod, metadata, opt, capability):
passes.ttgpuir.add_pipeline(pm, opt.num_stages)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_coalesce_async_copy(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
Expand Down

0 comments on commit b8a4b87

Please sign in to comment.