Skip to content

Commit

Permalink
[NIT][BACKEND] Clean up Allocation.cpp (#5021)
Browse files Browse the repository at this point in the history
1. Remove unnecessary header files
2. Remove unused `getCvtOrder` since dot operand now has its order
defined
3. Remove unnecessary forward declarations
  • Loading branch information
Jokeren authored Oct 31, 2024
1 parent 534aacb commit 9293f0a
Showing 1 changed file with 25 additions and 61 deletions.
86 changes: 25 additions & 61 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
#include <limits>
#include <numeric>

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Alias.h"
Expand All @@ -15,19 +13,6 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"

using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getUniqueContigPerThread;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;

namespace mlir {

//===----------------------------------------------------------------------===//
Expand All @@ -38,27 +23,6 @@ namespace triton {
// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;

static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
auto srcMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout);
auto srcDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(srcLayout);
auto dstMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(dstLayout);
auto dstDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(dstLayout);

assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() &&
!srcMmaLayout.isHopper()) &&
"mma -> mma layout conversion is only supported on Ampere");

// mma or dot layout does not have an order, so the order depends on the
// layout of the other operand.
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
: getOrder(srcLayout);
auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
: getOrder(dstLayout);

return {inOrd, outOrd};
}

static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
RankedTensorType dstTy) {
Attribute srcLayout = srcTy.getEncoding();
Expand All @@ -70,15 +34,17 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,

if (shouldUseDistSmem(srcLayout, dstLayout)) {
// TODO: padding to avoid bank conflicts
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
return convertType<unsigned, int64_t>(gpu::getShapePerCTA(srcTy));
}

assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()");

auto srcShapePerCTA = getShapePerCTA(srcTy);
auto dstShapePerCTA = getShapePerCTA(dstTy);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
auto srcShapePerCTATile =
gpu::getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile =
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());

unsigned rank = dstTy.getRank();
SmallVector<unsigned> repShape(rank);
Expand Down Expand Up @@ -124,9 +90,9 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
scratchConfig.order = outOrd;

unsigned srcContigPerThread =
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
gpu::getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
unsigned dstContigPerThread =
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
gpu::getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
unsigned innerDim = rank - 1;
Expand All @@ -135,12 +101,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
: srcContigPerThread;
scratchConfig.outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread;

if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout)) {
if (auto mma = mlir::dyn_cast<gpu::NvidiaMmaEncodingAttr>(srcLayout)) {
if (mma.getVersionMajor() == 1) {
// For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the
// codegen.
scratchConfig.inVec = srcContigPerThread;
} else if (mlir::isa<BlockedEncodingAttr>(dstLayout)) {
} else if (mlir::isa<gpu::BlockedEncodingAttr>(dstLayout)) {
// when storing from mma layout and loading in blocked layout vectorizing
// the load back gives better performance even if there is a
// transposition.
Expand Down Expand Up @@ -186,12 +152,12 @@ class AllocationAnalysis {
/// Initializes explicitly defined shared memory values for a given operation.
void getExplicitValueSize(Operation *op) {
for (Value result : op->getResults()) {
auto alloc = result.getDefiningOp<triton::gpu::LocalAllocOp>();
auto alloc = result.getDefiningOp<gpu::LocalAllocOp>();
if (alloc && alloc.isSharedMemoryAlloc()) {
// Bytes could be a different value once we support padding or other
// allocation policies.
auto allocType = alloc.getType();
auto shapePerCTA = triton::gpu::getShapePerCTA(allocType);
auto shapePerCTA = gpu::getShapePerCTA(allocType);
auto bytes = product<int64_t>(shapePerCTA) *
allocType.getElementTypeBitWidth() / 8;

Expand All @@ -218,31 +184,31 @@ class AllocationAnalysis {
/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
const size_t scratchAlignment = 128;
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
} else if (auto scanOp = dyn_cast<ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto histogram = dyn_cast<triton::HistogramOp>(op)) {
} else if (auto histogram = dyn_cast<HistogramOp>(op)) {
auto dstTy = histogram.getType();
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
auto bytes = std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
} else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType();
auto dstTy = cvtLayout.getType();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (mlir::isa<SharedEncodingAttr>(srcEncoding) ||
mlir::isa<SharedEncodingAttr>(dstEncoding)) {
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
// Conversions from/to shared memory do not need scratch memory.
return;
}
Expand All @@ -253,12 +219,12 @@ class AllocationAnalysis {
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
auto bytes =
isa<triton::PointerType>(srcTy.getElementType())
isa<PointerType>(srcTy.getElementType())
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(op)) {
} else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
Expand All @@ -267,12 +233,10 @@ class AllocationAnalysis {
} else {
auto smemShape = getRepShapeForAtomic(op->getResult(0));
auto elems = getNumScratchElements(smemShape);
auto elemTy =
cast<triton::PointerType>(value.getType()).getPointeeType();
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
auto bytes =
isa<triton::PointerType>(elemTy)
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
Expand Down

0 comments on commit 9293f0a

Please sign in to comment.