Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA][Backend] Add CoalesceAsyncCopy Pass for in-DotOpEnc Upcasting #5222

Merged
merged 8 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,11 @@ SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset);

// Returns composed LinearLayout for register to shared copy
std::optional<LinearLayout>
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc, int elemBitWidth);

// Emits IR to load data from shared memory into registers, or to store data
// from registers into shared memory.
//
Expand Down
12 changes: 12 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,16 @@ def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp">
"number of pipeline stages">
];
}

def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> {
let summary = "Improve coalescing for async global to local copies";

let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than "
"the blocked encoding's sizePerThread, this pass improves coalescing by clipping the "
"sizePerThread value";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}

#endif
69 changes: 41 additions & 28 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,36 +158,25 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
return ret;
}

bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

auto shape = registerTy.getShape();
int rank = shape.size();

std::optional<LinearLayout>
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc, int elemBitWidth) {
StringAttr kBlock = str_attr("block");
StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
int rank = shape.size();

std::optional<LinearLayout> regLayout =
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
std::optional<LinearLayout> sharedLayout = triton::gpu::toLinearLayout(
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
triton::gpu::toLinearLayout(shape, srcEnc);
std::optional<LinearLayout> sharedLayout =
triton::gpu::toLinearLayout(shape, dstEnc, elemBitWidth);
if (!regLayout.has_value() || !sharedLayout.has_value()) {
return false;
return std::nullopt;
}
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
auto sharedOrder = triton::gpu::getOrder(dstEnc);

// sharedLayout's in-dims are currently (offset, block). Reshape to
// (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
// shmem strides. (The offsetX's appear in minor-to-major order.)
auto sharedLegacy =
cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
auto sharedLegacy = cast<triton::gpu::SharedEncodingAttr>(dstEnc);
SmallVector<std::pair<StringAttr, int32_t>> multiDimSharedSize;
for (int i = 0; i < rank; i++) {
int dim = sharedOrder[i];
Expand All @@ -202,13 +191,36 @@ bool emitTransferBetweenRegistersAndShared(

// regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
// ..., offsetXN, block), where the offsetX's are in minor-to-major order.
LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout);
return regLayout->invertAndCompose(*sharedLayout);
}

bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
MLIRContext *ctx = rewriter.getContext();

auto shape = registerTy.getShape();
int rank = shape.size();

StringAttr kBlock = str_attr("block");
StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");

auto regToSharedLayout = getRegToSharedLayout(
ctx, shape, registerTy.getEncoding(), sharedTy.getEncoding(),
elemLlvmTy.getIntOrFloatBitWidth());
if (!regToSharedLayout.has_value())
return false;

// TODO(jlebar): We don't currently support loading from shared memory in a
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock);
for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock);
inBlock *= 2) {
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply(
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply(
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}})));
// offsetX1, ..., offsetXN must all be 0.
if (!llvm::all_of(ArrayRef(idx).drop_back(1),
Expand All @@ -234,15 +246,15 @@ bool emitTransferBetweenRegistersAndShared(
// which have known strides. This would allow us to vectorize across multiple
// shmem out dimensions where possible.
const int vecElems =
std::min(regToSharedLayout.getNumConsecutiveInOut(),
std::min(regToSharedLayout->getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane));
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);

int numElems = regToSharedLayout.getInDimSize(kRegister);
int numElems = regToSharedLayout->getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
auto ptrTy = shmemBase.getType();
Value zero = i32_val(0);
Expand All @@ -253,14 +265,15 @@ bool emitTransferBetweenRegistersAndShared(
// we drop_end to drop block, which we know from above will be 0.
auto multiDimShmemOffset =
llvm::to_vector(llvm::drop_end(llvm::make_second_range(
applyLinearLayout(loc, rewriter, regToSharedLayout,
applyLinearLayout(loc, rewriter, *regToSharedLayout,
{{kRegister, i32_val(i * vecElems)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, zero}}))));

// Reorder strides according to `order`. This way they match the
// multi-dimensional offsets in regToSharedLayout.
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset,
applyPermutation(shmemStrides, sharedOrder));
auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_triton_library(TritonGPUTransforms
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
CoalesceAsyncCopy.cpp
Utility.cpp

DEPENDS
Expand Down
125 changes: 125 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a bit of a layering violation, getRegToSharedLayout probably belongs to triton gpu dialect utils.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you remove the include now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right I forgot - just removed

#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

#include <memory>

namespace tt = mlir::triton;

namespace mlir {
namespace triton {
namespace gpu {

#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

// This pass currently only applies if the following are all true...
// 1) Operand A for WGMMA is to be loaded in registers
// 2) We upcast operand A in registers before the WGMMA
// (downcasting is not yet supported)
// 3) Pipelining is enabled for loading A
//
// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding
// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if
// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread
// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two
// 8-byte-cp.async's for each contiguous 16B global data owned by each
// thread. This breaks coalescing (i.e. results 2x the minimum required
// transactions).
//
// This issue occurs for cp.async because it combines load and store into one
// instruction. The fix is to clip each dim of sizePerThread by shared vec, so
// that the vectorization of load and store are equal along the contiguous
// dimension. In the above example, each thread will then only own 8B contiguous
// global data.
struct ClipAsyncCopySizePerThread
: public OpRewritePattern<AsyncCopyGlobalToLocalOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp,
PatternRewriter &rewriter) const override {
Value src = copyOp.getSrc();
Value mask = copyOp.getMask();
Value other = copyOp.getOther();
auto srcTy = cast<RankedTensorType>(src.getType());
auto blockEnc = cast<BlockedEncodingAttr>(srcTy.getEncoding());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can't assume the copy will use blocked layout

auto dstTy = cast<MemDescType>(copyOp.getResult().getType());
auto sharedEnc = cast<SharedEncodingAttr>(dstTy.getEncoding());
auto sharedVec = sharedEnc.getVec();

// obtain max contiguous copy size
// Note this can be further optimized, as copyContigSize can be even
// smaller when lowering, depending on contiguity and mask alignment
// (see AsyncCopyGlobalToLocalOpConversion)
auto elemBitWidth = dstTy.getElementTypeBitWidth();
auto regToSharedLayout =
getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc,
sharedEnc, elemBitWidth);
auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut();

// obtain block sizePerThread along contig dim
auto sizePerThread = blockEnc.getSizePerThread();
auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]];

if (blockContigSize <= copyContigSize)
return rewriter.notifyMatchFailure(
copyOp,
"blocked sizePerThread along contiguous dim must be greater than the "
"max contiguous copy size ");

sizePerThread[blockEnc.getOrder()[0]] = copyContigSize;

// obtain new blockedEnc based on clipped sizePerThread
auto mod = copyOp->getParentOfType<ModuleOp>();
int numWarps = TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
auto newBlockEnc = BlockedEncodingAttr::get(
copyOp.getContext(), srcTy.getShape(), sizePerThread,
blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout());

// insert cvt's after src, mask, and other
auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) {
auto ty = cast<TensorType>(src.getType());
auto newTy =
RankedTensorType::get(ty.getShape(), ty.getElementType(), enc);
auto cvt = rewriter.create<ConvertLayoutOp>(copyOp->getLoc(), newTy, src);
return cvt.getResult();
};
src = convertBlockLayout(src, newBlockEnc);
if (mask)
mask = convertBlockLayout(mask, newBlockEnc);
if (other)
other = convertBlockLayout(other, newBlockEnc);

// replace the asyncCopy
auto newCopyOp = rewriter.create<AsyncCopyGlobalToLocalOp>(
copyOp.getLoc(), src, copyOp.getResult(), mask, other,
copyOp.getCache(), copyOp.getEvict(), copyOp.getIsVolatile());
rewriter.replaceOp(copyOp, newCopyOp);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, you could do in place update


return success();
}
};

class CoalesceAsyncCopyPass
: public impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> {
public:
void runOnOperation() override {
ModuleOp m = getOperation();
MLIRContext *context = &getContext();

mlir::RewritePatternSet patterns(context);
patterns.add<ClipAsyncCopySizePerThread>(context);

if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
signalPassFailure();
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createTritonGPUOptimizeAccumulatorInit);
ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling",
createTritonGPULoopScheduling, int);
ADD_PASS_WRAPPER_0("add_coalesce_async_copy",
createTritonGPUCoalesceAsyncCopy);
}

void init_triton_passes_convert(py::module &&m) {
Expand Down
19 changes: 19 additions & 0 deletions test/TritonGPU/coalesce-async-copy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s

// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]>
// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr<i8>, #[[NEW_BLOCKED]]>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr<i8>, #blocked>,
%view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>,
%mask: tensor<64x16xi1, #blocked>,
%other: tensor<64x16xi8, #blocked>) {
%token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr<i8>, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable>
tt.return
}
}
1 change: 1 addition & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def make_ttgir(mod, metadata, opt, capability):
passes.ttgpuir.add_pipeline(pm, opt.num_stages)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_coalesce_async_copy(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
Expand Down
Loading