diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 0639df714c60..dec9271f3515 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -192,4 +192,16 @@ def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> "number of pipeline stages"> ]; } + +def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::ModuleOp"> { + let summary = "Improve coalescing for async global to local copies"; + + let description = "For AsyncCopyGlobalToLocal ops where the shared encoding's vec is less than " + "the blocked encoding's sizePerThread, this pass improves coalescing by clipping the " + "sizePerThread value"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index f1e361f64d66..0f6bd57afaf1 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -202,6 +202,11 @@ enum class MMALoadType { // pipelining }; MMALoadType getMMALoadType(Operation *loadOp); + +// Returns composed LinearLayout for register to shared copy +std::optional +getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, + Attribute srcEnc, Attribute dstEnc, int elemBitWidth); } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index c681cd344ce8..49f05a758e42 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -4,6 +4,7 @@ #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/STLExtras.h" namespace mlir { @@ -174,41 +175,17 @@ bool emitTransferBetweenRegistersAndShared( StringAttr kLane = str_attr("lane"); StringAttr kWarp = str_attr("warp"); - std::optional regLayout = - triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); - std::optional sharedLayout = triton::gpu::toLinearLayout( - shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth()); - if (!regLayout.has_value() || !sharedLayout.has_value()) { + auto regToSharedLayout = getRegToSharedLayout( + ctx, shape, registerTy.getEncoding(), sharedTy.getEncoding(), + elemLlvmTy.getIntOrFloatBitWidth()); + if (!regToSharedLayout.has_value()) return false; - } - auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); - - // sharedLayout's in-dims are currently (offset, block). Reshape to - // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional - // shmem strides. (The offsetX's appear in minor-to-major order.) - auto sharedLegacy = - cast(sharedTy.getEncoding()); - SmallVector> multiDimSharedSize; - for (int i = 0; i < rank; i++) { - int dim = sharedOrder[i]; - int64_t size = std::max( - int64_t{1}, - shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]); - multiDimSharedSize.push_back( - {str_attr("offset" + std::to_string(dim)), size}); - } - multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)}); - sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize); - - // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, - // ..., offsetXN, block), where the offsetX's are in minor-to-major order. - LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout); // TODO(jlebar): We don't currently support loading from shared memory in a // different CTA. We'd need to emit `mapa.shared::cluster` instructions. - for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock); + for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock); inBlock *= 2) { - auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply( + auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply( {{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}}))); // offsetX1, ..., offsetXN must all be 0. if (!llvm::all_of(ArrayRef(idx).drop_back(1), @@ -234,15 +211,15 @@ bool emitTransferBetweenRegistersAndShared( // which have known strides. This would allow us to vectorize across multiple // shmem out dimensions where possible. const int vecElems = - std::min(regToSharedLayout.getNumConsecutiveInOut(), + std::min(regToSharedLayout->getNumConsecutiveInOut(), maxVecElems.value_or(std::numeric_limits::max())); Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane)); + Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane)); Value laneId = urem(threadId, threadsPerWarp); Value warpId = udiv(threadId, threadsPerWarp); - int numElems = regToSharedLayout.getInDimSize(kRegister); + int numElems = regToSharedLayout->getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); auto ptrTy = shmemBase.getType(); Value zero = i32_val(0); @@ -253,7 +230,7 @@ bool emitTransferBetweenRegistersAndShared( // we drop_end to drop block, which we know from above will be 0. auto multiDimShmemOffset = llvm::to_vector(llvm::drop_end(llvm::make_second_range( - applyLinearLayout(loc, rewriter, regToSharedLayout, + applyLinearLayout(loc, rewriter, *regToSharedLayout, {{kRegister, i32_val(i * vecElems)}, {kLane, laneId}, {kWarp, warpId}, @@ -261,6 +238,7 @@ bool emitTransferBetweenRegistersAndShared( // Reorder strides according to `order`. This way they match the // multi-dimensional offsets in regToSharedLayout. + auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset, applyPermutation(shmemStrides, sharedOrder)); auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset); diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index ef4cec328f86..740014b77948 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -18,6 +18,7 @@ add_triton_library(TritonGPUTransforms Prefetch.cpp RemoveLayoutConversions.cpp ReorderInstructions.cpp + CoalesceAsyncCopy.cpp Utility.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp b/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp new file mode 100644 index 000000000000..2d634fc6fa7b --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp @@ -0,0 +1,124 @@ +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This pass currently only applies if the following are all true... +// 1) Operand A for WGMMA is to be loaded in registers +// 2) We upcast operand A in registers before the WGMMA +// (downcasting is not yet supported) +// 3) Pipelining is enabled for loading A +// +// ...then for the AsyncCopyGlobalToLocal op, the SharedEncoding +// vec will be less than BlockedEncoding's sizePerThread for k-dim. E.g. if +// we're upcasting from int8 to bf16, then shared vec is 8 and sizePerThread +// for k is 16. In this case, AsyncCopyGlobalToLocal will generate two +// 8-byte-cp.async's for each contiguous 16B global data owned by each +// thread. This breaks coalescing (i.e. results 2x the minimum required +// transactions). +// +// This issue occurs for cp.async because it combines load and store into one +// instruction. The fix is to clip each dim of sizePerThread by shared vec, so +// that the vectorization of load and store are equal along the contiguous +// dimension. In the above example, each thread will then only own 8B contiguous +// global data. +struct ClipAsyncCopySizePerThread + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp, + PatternRewriter &rewriter) const override { + Value src = copyOp.getSrc(); + Value mask = copyOp.getMask(); + Value other = copyOp.getOther(); + auto srcTy = cast(src.getType()); + auto dstTy = cast(copyOp.getResult().getType()); + auto blockEnc = dyn_cast(srcTy.getEncoding()); + if (!blockEnc) + return rewriter.notifyMatchFailure(copyOp, + "src must be of blocked encoding"); + auto sharedEnc = cast(dstTy.getEncoding()); + auto sharedVec = sharedEnc.getVec(); + + // obtain max contiguous copy size + // Note this can be further optimized, as copyContigSize can be even + // smaller when lowering, depending on contiguity and mask alignment + // (see AsyncCopyGlobalToLocalOpConversion) + auto elemBitWidth = dstTy.getElementTypeBitWidth(); + auto regToSharedLayout = + getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc, + sharedEnc, elemBitWidth); + auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut(); + + // obtain block sizePerThread along contig dim + auto sizePerThread = blockEnc.getSizePerThread(); + auto blockContigSize = sizePerThread[blockEnc.getOrder()[0]]; + + if (blockContigSize <= copyContigSize) + return rewriter.notifyMatchFailure( + copyOp, + "blocked sizePerThread along contiguous dim must be greater than the " + "max contiguous copy size "); + + sizePerThread[blockEnc.getOrder()[0]] = copyContigSize; + + // obtain new blockedEnc based on clipped sizePerThread + auto mod = copyOp->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + auto newBlockEnc = BlockedEncodingAttr::get( + copyOp.getContext(), srcTy.getShape(), sizePerThread, + blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout()); + + // insert cvt's after src, mask, and other + auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) { + auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); + auto cvt = rewriter.create(copyOp->getLoc(), newTy, src); + return cvt.getResult(); + }; + src = convertBlockLayout(src, newBlockEnc); + if (mask) + mask = convertBlockLayout(mask, newBlockEnc); + if (other) + other = convertBlockLayout(other, newBlockEnc); + + rewriter.modifyOpInPlace(copyOp, [&]() { + copyOp.getSrcMutable().assign(src); + if (mask) + copyOp.getMaskMutable().assign(mask); + if (other) + copyOp.getOtherMutable().assign(other); + }); + + return success(); + } +}; + +class CoalesceAsyncCopyPass + : public impl::TritonGPUCoalesceAsyncCopyBase { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + MLIRContext *context = &getContext(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index b8f3abfcaca8..7effc18825aa 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1153,4 +1153,40 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +std::optional +getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, + Attribute srcEnc, Attribute dstEnc, int elemBitWidth) { + StringAttr kBlock = StringAttr::get(ctx, ("block")); + int rank = shape.size(); + + std::optional regLayout = + triton::gpu::toLinearLayout(shape, srcEnc); + std::optional sharedLayout = + triton::gpu::toLinearLayout(shape, dstEnc, elemBitWidth); + if (!regLayout.has_value() || !sharedLayout.has_value()) { + return std::nullopt; + } + auto sharedOrder = triton::gpu::getOrder(dstEnc); + + // sharedLayout's in-dims are currently (offset, block). Reshape to + // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional + // shmem strides. (The offsetX's appear in minor-to-major order.) + auto sharedLegacy = cast(dstEnc); + SmallVector> multiDimSharedSize; + for (int i = 0; i < rank; i++) { + int dim = sharedOrder[i]; + int64_t size = std::max( + int64_t{1}, + shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]); + multiDimSharedSize.push_back( + {StringAttr::get(ctx, ("offset" + std::to_string(dim))), size}); + } + multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)}); + sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize); + + // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, + // ..., offsetXN, block), where the offsetX's are in minor-to-major order. + return regLayout->invertAndCompose(*sharedLayout); +} + } // namespace mlir diff --git a/python/src/passes.cc b/python/src/passes.cc index d6612387b286..235eba4465cb 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -72,6 +72,8 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUOptimizeAccumulatorInit); ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling", createTritonGPULoopScheduling, int); + ADD_PASS_WRAPPER_0("add_coalesce_async_copy", + createTritonGPUCoalesceAsyncCopy); } void init_triton_passes_convert(py::module &&m) { diff --git a/test/TritonGPU/coalesce-async-copy.mlir b/test/TritonGPU/coalesce-async-copy.mlir new file mode 100644 index 000000000000..4707ddaca9cb --- /dev/null +++ b/test/TritonGPU/coalesce-async-copy.mlir @@ -0,0 +1,35 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s + +// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr, #blocked>, + %view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>, + %mask: tensor<64x16xi1, #blocked>, + %other: tensor<64x16xi8, #blocked>) { + %token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> + tt.return +} +} + +// ----- + +// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr, #blocked>, + %view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>) { + %token = triton_gpu.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> + tt.return +} +} diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6d6d70fc87e3..233c11938fda 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -234,6 +234,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_pipeline(pm, opt.num_stages) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.ttgpuir.add_coalesce_async_copy(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) passes.ttgpuir.add_reorder_instructions(pm)