diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 1281072b7178..f1e361f64d66 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -195,7 +195,13 @@ int getNVIDIAComputeCapability(Operation *module); std::optional getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible); -bool loadIsMMAv3(Operation *loadOp); +enum class MMALoadType { + SharedV3, + Registers, // may be v2 or v3 + DoNotPipeline, // could be a valid shared/registers MMA operand, but skip + // pipelining +}; +MMALoadType getMMALoadType(Operation *loadOp); } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp index 818b0c15ed6e..9d6d903f4d2c 100644 --- a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp +++ b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp @@ -145,11 +145,18 @@ filterPipelinedLoad(llvm::SmallVector> bool hasSharedEncoding = false; if (use->hasTrait()) { - if (loadIsMMAv3(op)) { + auto mmaLoadType = getMMALoadType(op); + auto dot = dyn_cast(use); + auto warpGroupDot = dyn_cast(use); + bool isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; + bool isMMAv3Registers = + (mmaLoadType == MMALoadType::Registers) && warpGroupDot; + + if (isMMAv3Shared) { hasSharedEncoding = true; } else if (isa(op)) { hasSharedEncoding = true; - } else if (auto dot = dyn_cast(use)) { + } else if (isMMAv3Registers || dot) { // FIXME: if we have a better solution in handling incompatible shared // encoding, we can simplify the logic here by checking if all users are // dot encoding. Fow now, getSharedEncIfAllUsersAreDotEnc will be used diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 4695984acfd3..fdb9bafa0b6b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -1,9 +1,11 @@ +#include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -15,6 +17,125 @@ namespace gpu { namespace { +// Helpers + +// Returns whether we can hoist DotOp Encoding through `op`. +// Roughly, whether op is elementwise and thus threads don't need +// to exchange elements. But some ops are not currently supported even though +// they meet that criterion. +bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) { + // Only consider custom conversions or arith ops. + // TODO(jlebar): Is this too restrictive? + if (!isa(op) && !isPureUnaryInlineAsm(op) && + !isa(op->getDialect())) + return false; + + // Quick handling to fix loading issues when computing the original + // bitwidth is unable to realize that there is a mixed-precision dot + // (hence kWidth = 1) but wants to hoist through the type conversion. + if (isa(op) && dotOpEnc.getKWidth() == 1) + return false; + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(op)) + return false; + + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(op)) { + Type opType = getElementTypeOrSelf(op->getOperand(0)); + if (opType.isInteger(1)) + return false; + } + + return true; +} + +// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A +// is in registers). +bool canHoistDotOpEncV3(Operation *op) { + // Must have exactly one result and at least one operand + if (op->getNumOperands() == 0 || op->getNumResults() != 1) + return false; + + auto isBlockedOrDotOpRankedTensor = [](Type ty) { + auto tensorTy = dyn_cast(ty); + if (!tensorTy) + return false; + return isa( + tensorTy.getEncoding()); + }; + + // Operands and results must be of RankedTensorType and Blocked or DotOp + if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) && + all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor))) + return false; + + // Only consider custom conversions or arith ops. + if (!isa(op) && !isPureUnaryInlineAsm(op) && + !isa(op->getDialect())) + return false; + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(op)) + return false; + + // Downcasting not currently supported; it will likely require minor + // adjustments in sharedToDotOperandMMv2 + auto oprType = getElementTypeOrSelf(op->getOperand(0)); + auto resType = getElementTypeOrSelf(op->getResult(0)); + if (oprType.getIntOrFloatBitWidth() > resType.getIntOrFloatBitWidth()) + return false; + + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(op) && oprType.isInteger(1)) + return false; + + return true; +} + +// Helper to perform a "deep" clone of the given slice (i.e., set of ops), +// returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice, +// and sliceMap the IRMapping that maps the ops and result values of the +// original slice to those in the cloned slice. +auto cloneSlice(PatternRewriter &rewriter, + const SetVector &slice) { + IRMapping sliceMap; + SetVector newSlice; + + // First pass: clone ops; the result values are cloned as well, but the + // operands still refer to the original result values + for (Operation *op : slice) { + rewriter.setInsertionPoint(op); + auto newOp = rewriter.clone(*op); + newSlice.insert(newOp); + sliceMap.map(op, newOp); + for (auto [result, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + assert(result != newResult); + sliceMap.map(result, newResult); + } + } + + // Second pass: replace operand references in cloned ops to point to cloned + // values + for (auto [op, newOp] : sliceMap.getOperationMap()) + for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) { + auto defOp = operand.getDefiningOp(); + if (!slice.contains(defOp)) + continue; + + newOp->setOperand(oprIdx, sliceMap.lookup(operand)); + } + + return std::make_tuple(newSlice, sliceMap); +} + // Given // convert(trans(src)) #dot_operand -> // convert(local_load(trans(alloc(src)))) @@ -111,7 +232,8 @@ class HoistLayoutConversion : public OpRewritePattern { PatternRewriter &rewriter) const override { // Only consider conversions to dot operand. auto cvtTy = cast(cvt.getType()); - if (!isa(cvtTy.getEncoding())) + auto dotOpEnc = dyn_cast(cvtTy.getEncoding()); + if (!dotOpEnc) return failure(); auto src = cvt.getSrc().getDefiningOp(); @@ -126,16 +248,7 @@ class HoistLayoutConversion : public OpRewritePattern { [](Type ty) { return isa(ty); })) return failure(); - // Only consider custom conversions or arith ops. - // TODO(jlebar): Is this too restrictive? - if (!isa(src) && !isPureUnaryInlineAsm(src) && - src->getDialect()->getTypeID() != TypeID::get()) - return failure(); - - // Currently, these instructions are not supported during lowering of - // shared -> dot_operand layout. Not all types and type conversions are - // supported. - if (isa(src)) + if (!canHoistDotOpEncV2(src, dotOpEnc)) return failure(); // Check that the conversion is transitively dependent on a load, and all @@ -165,12 +278,7 @@ class HoistLayoutConversion : public OpRewritePattern { if (isa(currOp)) { foundLoad = true; } else if (foundLoad) { - // Bail out if there exists an op after Load that is not FpToFp, - // Bitcast, or Arith. - if (!isa(currOp) && - !isPureUnaryInlineAsm(currOp) && - currOp->getDialect()->getTypeID() != - TypeID::get()) + if (!canHoistDotOpEncV2(currOp, dotOpEnc)) return failure(); } } @@ -301,6 +409,150 @@ struct MMAV3UseRegOperand } }; +// MMAV3's analog of HoistLayoutConversion, for operand A only; will make +// WarpGroupDot accept operand A in registers instead of shmem. +// +// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot +// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; +// warp_group_dot +// +// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a +// time and requires multiple passes, this pattern will directly hoist the +// convert to the right place in one pass. +// +// Or, to be more precise, this pattern deletes the local_alloc op and inserts a +// convert_layout op after each load that warp_group_dot uses; so this is not +// simply hoisting a convert_layout op up as in V2, but can be considered as +// first changing local_alloc to convert_layout and then hoisting, which results +// in WGMMA now accepting operand A in DotOp layout rather than Shared. +struct MMAV3HoistLayoutConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp, + PatternRewriter &rewriter) const override { + // Can only hoist operand 0 + auto alloc = dotOp.getOperand(0).getDefiningOp(); + if (!alloc || !alloc.getSrc()) + return rewriter.notifyMatchFailure( + dotOp, "operand A must be produced by local_alloc"); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + if (!isa(getEncoding(dotOp.getOperand(0)))) + return rewriter.notifyMatchFailure( + dotOp, "requires Shared encoding for operand A"); + + // Step 1: Performs checks for early stop + auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); + if (!srcEnc) + return rewriter.notifyMatchFailure( + alloc, "requires src to have Blocked encoding"); + + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstEnc || dstEnc.getVersionMajor() != 3) + return rewriter.notifyMatchFailure( + dotOp, "requires result in NvidiaMma encoding"); + + // Step 2: Obtain slice of ops between load/constant and local_alloc + SetVector slice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = [&](Operation *op) { + // Stop before Load, ConstantOp, or LocalLoad + return (op->getParentRegion() == alloc->getParentRegion()) && + !isa(op) && + (op->getNumOperands() != 0); + }; + getBackwardSlice(alloc.getOperation(), &slice, opt); + + // Step 3: Verify slice can be hoisted through + if (slice.empty()) + return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through"); + + // We define frontierOp as an op outside this slice whose result is used by + // an op in this slice. We must eventually convert the result of all + // frontierOps to DotOperandEncoding. This is done via the insertion of + // ConvertLayout after each frontierOp. We currently support frontierOp to + // be load or constant. + for (Operation *currOp : slice) { + if (!canHoistDotOpEncV3(currOp)) + return rewriter.notifyMatchFailure(currOp, "cannot hoist through"); + + // We previously ensured that all ops in slice have at least one operand + for (auto operand : currOp->getOperands()) { + auto defOp = operand.getDefiningOp(); + if (!slice.contains(defOp)) { + // ensure frontierOp is load or constant + if (!isa(defOp)) + return rewriter.notifyMatchFailure(defOp, + "must be load or constant"); + } + } + } + + // Step 4: Clone slice + auto [newSlice, sliceMap] = cloneSlice(rewriter, slice); + + // Step 5: Modify the cloned slice to have dotOp encoding. + // Before: load #blocked; (elementwise #blocked)+; local_alloc; + // warp_group_dot After: load #blocked; convert_layout #dot_op; + // (elementwise #dot_op)+; warp_group_dot + // + // Specifically, this step will change all value types from #blocked to + // #dot_op encoding in the cloned slice, and for those values produced by + // frontierOps (i.e., outside the slice), we will insert convert_layout's + // after the frontierOp. + auto srcTy = cast(alloc.getSrc().getType()); + Type inputEltTy = srcTy.getElementType(); + auto dotOperandEnc = DotOperandEncodingAttr::get( + dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy); + + for (auto op : newSlice) { + // Step 5a: If any operand is defined by a frontierOp, we must insert a + // convert_layout(#dot_op) after the frontierOp and before currOp + for (auto [oprIdx, operand] : llvm::enumerate(op->getOperands())) { + + auto defOp = operand.getDefiningOp(); + + // defOp is not frontier (i.e. it's within slice); no need to convert + // the layout of its result + if (newSlice.contains(defOp)) + continue; + + // We checked earlier that all operands are ranked tensors + auto operandTy = cast(operand.getType()); + auto operandEltTy = operandTy.getElementType(); + + Type cvtTy = RankedTensorType::get( + operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); + rewriter.setInsertionPoint(op); + auto cvt = + rewriter.create(defOp->getLoc(), cvtTy, operand); + + op->setOperand(oprIdx, cvt); + } + + // Step 5b: Change the result to have DotOp rather than Blocked encoding + auto resTy = cast(op->getResult(0).getType()); + op->getResult(0).setType(RankedTensorType::get( + resTy.getShape(), resTy.getElementType(), dotOperandEnc)); + } + + // Step 6: replace LHS operand with alloc's parent in the cloned slice + // This changes the warpGroupDot to accept a DotOp tensor as operand A + // instead of a Shared memdesc. + auto newDotOperand = sliceMap.lookup(alloc.getSrc()); + rewriter.modifyOpInPlace(dotOp, + [&]() { dotOp.setOperand(0, newDotOperand); }); + + return success(); + } +}; + } // namespace #define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS @@ -322,6 +574,7 @@ class TritonGPUOptimizeDotOperandsPass auto ret = pm.run(m); mlir::RewritePatternSet patterns(context); + patterns.add(context); patterns.add(context); if (this->hoistLayoutConversion.getValue()) patterns.add(context); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 9f5bec98503c..f3b956870de5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -44,7 +44,8 @@ struct LoadInfo { ttg::SharedEncodingAttr sharedEncoding = nullptr; // Blocked encoding is used for loads not used by the dot. ttg::BlockedEncodingAttr blockedEncoding = nullptr; - bool loadIsMMAV3 = false; + bool isMMAv3Shared = false; + bool isMMAv3Registers = false; int distToUse = 0; bool usedByDot = false; }; @@ -120,26 +121,27 @@ static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Value src = loadOp.getPtr(); Value mask = loadOp.getMask(); Value other = loadOp.getOther(); + tt::MemDescType allocTy = cast(alloc.getType()); + + auto convertBlockLayout = [&](Value src, ttg::BlockedEncodingAttr enc) { + auto ty = cast(src.getType()); + auto newTy = RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); + auto cvt = builder.createWithStage( + loadOp->getLoc(), stage, clusterId, newTy, src); + return cvt.getResult(); + }; + if (!isExpensiveLoadOrStore(loadOp) && loadToInfo[loadOp].blockedEncoding) { // For inexpensive loads that do not directly feed into dot ops // we want to use optimal layout for the data. ttg::BlockedEncodingAttr encoding = loadToInfo[loadOp].blockedEncoding; - auto convertBlockLayout = [&](Value src) { - auto ty = cast(src.getType()); - auto newTy = - RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); - auto cvt = builder.createWithStage( - loadOp->getLoc(), stage, clusterId, newTy, src); - return cvt.getResult(); - }; - src = convertBlockLayout(src); + src = convertBlockLayout(src, encoding); if (mask) - mask = convertBlockLayout(mask); + mask = convertBlockLayout(mask, encoding); if (other) - other = convertBlockLayout(other); + other = convertBlockLayout(other, encoding); } - tt::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); copyOffsets[0] = insertIdx; Attribute sharedMemorySpace = @@ -157,14 +159,14 @@ static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Operation *wait = builder.createWithStage( loc, stageForFirstUse, clusterForFirstUse, commmit->getResult(0), 0); - bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; + auto loadIsMMAv3Shared = loadToInfo[loadOp].isMMAv3Shared; // Extract part. SmallVector loadOffsets(allocTy.getRank(), zero); loadOffsets[0] = extractIdx; auto viewLoad = builder.createWithStage( loc, stageForFirstUse, clusterForFirstUse, subviewTy, alloc, loadOffsets); - if (isMMV3Load) { + if (loadIsMMAv3Shared) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); alloc.erase(); @@ -245,7 +247,7 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Operation *copy = builder.createWithStage( loc, stage, clusterId, tmaPtr, loadOp.getIndices(), barrier, view, pred); - bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; + auto loadIsMMAv3Shared = loadToInfo[loadOp].isMMAv3Shared; builder.setInsertionPointAfter(waitOp); // Extract part. @@ -253,7 +255,7 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, loadOffsets[0] = extractIdx; auto viewLoad = builder.createWithStage( loc, stageForFirstUse, clusterForFirstUse, subviewTy, alloc, loadOffsets); - if (isMMV3Load) { + if (loadIsMMAv3Shared) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); alloc.erase(); @@ -298,7 +300,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { } static std::optional -getSharedEncoding(Operation *loadOp, bool isMMAV3) { +getSharedEncoding(Operation *loadOp, bool isMMAV3Shared) { auto ty = cast(loadOp->getResultTypes()[0]); auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); auto blockedOrder = ttg::getOrder(ty.getEncoding()); @@ -313,7 +315,7 @@ getSharedEncoding(Operation *loadOp, bool isMMAV3) { } else { order = blockedOrder; } - if (isMMAV3) { + if (isMMAV3Shared) { return ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); } @@ -471,15 +473,22 @@ assignMemoryLayouts(scf::ForOp &forOp, for (auto use : users) { if (use->hasTrait()) { LDBG("set shared encoding with dot user: " << *use); + auto mmaLoadType = getMMALoadType(&op); + auto dot = dyn_cast(use); + auto warpGroupDot = dyn_cast(use); + loadInfo.usedByDot = true; - if (loadIsMMAv3(&op)) { - loadInfo.loadIsMMAV3 = true; + loadInfo.isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; + loadInfo.isMMAv3Registers = + (mmaLoadType == MMALoadType::Registers) && warpGroupDot; + + if (loadInfo.isMMAv3Shared) { loadInfo.sharedEncoding = getSharedEncoding(&op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (isa(op)) { loadInfo.sharedEncoding = getSharedEncoding(&op, /*loadIsMMAv3=*/true).value_or(nullptr); - } else if (auto dot = dyn_cast(use)) { + } else if (loadInfo.isMMAv3Registers || dot) { bool incompatible = false; loadInfo.sharedEncoding = getSharedEncIfAllUsersAreDotEnc(op.getResult(0), incompatible) @@ -492,7 +501,7 @@ assignMemoryLayouts(scf::ForOp &forOp, if (!loadInfo.sharedEncoding && !isa(use)) { LDBG("try generic shared encoding"); loadInfo.sharedEncoding = - getSharedEncoding(&op, /*isMMAV3=*/loadInfo.loadIsMMAV3) + getSharedEncoding(&op, /*isMMAV3=*/loadInfo.isMMAv3Shared) .value_or(nullptr); if (auto loadOp = dyn_cast(op)) loadInfo.blockedEncoding = @@ -593,7 +602,7 @@ static void createTMABarrierAndWait( if (it != loadToInfo.end()) { // Special case for MMAv3 loads, we can ignore the alloc and only // consider uses of the alloc op since it will be removed. - if (it->second.loadIsMMAV3) { + if (it->second.isMMAv3Shared) { auto alloc = cast( (*loadInfo->loadOp->getUsers().begin())); if (alloc->getBlock() == loadBlock) { @@ -713,8 +722,9 @@ createAsyncOps(scf::ForOp &forOp, auto &rhs) { return lhs.distToUse < rhs.distToUse; })->distToUse; - bool hasMMAV3 = - llvm::any_of(loadToInfo, [](auto &kv) { return kv.second.loadIsMMAV3; }); + bool hasMMAV3 = llvm::any_of(loadToInfo, [](auto &kv) { + return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers; + }); if (hasMMAV3) { // For MMAv3, we need an extra buffer as this is assumed in the wgmma // pipelining post-processing. @@ -1123,6 +1133,15 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, // // 1. All operands that touch shared memory are multi-buffered, i.e. can't read // an incomplete value while it's being written asynchronously by a load. +// 1a. If operand A is in registers, these registers cannot be updated +// inside +// the loop. +// **Exception** if the operand is produced by a preceding WGMMA, +// then this op can be properly async. Either the f16 shortcut is +// possible and the WGMMA's can run back-to-back (see rule 3 below), or +// elementwise truncate is needed, in which case the preceding WGMMA is +// not async and a WarpGroupDotWait is inserted right after, which +// guarantees exclusive access to the operand registers. // // 2. If the dot is used by any op in the loop, it must be used under an `if`, // and will be synced with a `wait 0` at the beginning of the `if` block. @@ -1158,7 +1177,15 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, auto checkOperand = [&](Value operand) { if (!isa( cast(operand.getType()).getEncoding())) { - return true; + // Rule 1a: Register operands must not be modified within the loop. + // First, check for chained WGMMA as an exception. + if (auto cvt = dyn_cast(operand.getDefiningOp())) { + return isa( + cvt.getSrc().getType().getEncoding()); + } + // And then, do a stricter-than-necessary check for now, that the operand + // is defined outside the loop. + return forOp.isDefinedOutsideOfLoop(operand); } // If it's a shmem operand, it must either be defined outside the loop, or diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 820db63022d7..bb60c1821ad7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -976,26 +976,43 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { return attr; } -bool loadIsMMAv3(Operation *loadOp) { +MMALoadType getMMALoadType(Operation *loadOp) { if (!loadOp->hasOneUse()) - return false; - auto alloc = dyn_cast(*loadOp->getUsers().begin()); - if (!alloc) - return false; - auto sharedEnc = cast(alloc.getType().getEncoding()); - if (!sharedEnc.getHasLeadingOffset()) - return false; + return MMALoadType::DoNotPipeline; + + if (auto alloc = dyn_cast(*loadOp->getUsers().begin())) { + auto sharedEnc = + cast(alloc.getType().getEncoding()); + + if (!sharedEnc.getHasLeadingOffset()) + return MMALoadType::DoNotPipeline; + + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); + auto ty = cast(loadOp->getResultTypes()[0]); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + + // The operand of MMAv3 is in SharedEncoding and its order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + return oldOrder == newOrder ? MMALoadType::SharedV3 + : MMALoadType::DoNotPipeline; + } else if (auto cvt = + dyn_cast(*loadOp->getUsers().begin())) { + auto resTy = dyn_cast(cvt->getResultTypes()[0]); + if (!resTy) { + return MMALoadType::DoNotPipeline; + } - // MMA V3 case. - auto newOrder = sharedEnc.getOrder(); - auto ty = cast(loadOp->getResultTypes()[0]); - auto oldOrder = ttg::getOrder(ty.getEncoding()); + if (isa(resTy.getEncoding())) { + return MMALoadType::Registers; + } - // The operand of MMAv3 is in SharedEncoding and its order should not - // be changed after FuseTranspositions Pass. So we only pipeline the - // load if the order of the loaded BlockedEncoding is the same as the - // order of the SharedEncoding it is converted to. - return oldOrder == newOrder; + return MMALoadType::DoNotPipeline; + } else { + return MMALoadType::DoNotPipeline; + } } namespace { diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 272ff25d7e4f..2d562b958776 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -208,3 +208,51 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : tt.return %td : tensor<128x128xf32, #mma> } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = triton_gpu.convert_layout %[[A_BLOCK]] : tensor<128x64xbf16, #blocked> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_DOTOP]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !tt.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a = tt.fp_to_fp %a_bf16 : tensor<128x64xbf16, #blocked> -> tensor<128x64xf16, #blocked> + %dota = triton_gpu.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared1> + %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise_chained +// CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = triton_gpu.convert_layout %[[A_BLOCK]] : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !tt.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %a_i8 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a_f16 = arith.sitofp %a_i8 : tensor<128x64xi8, #blocked> to tensor<128x64xf16, #blocked> + %a_scaled = arith.mulf %a_f16, %cst : tensor<128x64xf16, #blocked> + %a_negated = arith.negf %a_scaled : tensor<128x64xf16, #blocked> + %dota = triton_gpu.local_alloc %a_negated: (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared1> + %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 1f0ecaee439c..0b30ccb4191b 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -934,3 +934,66 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %17#0 : tensor<128x16xf32, #mma1> } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: dot_lhs_registers + tt.func @dot_lhs_registers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_3 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %cst_4 = arith.constant dense<2.0> : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: scf.for + // CHECK: triton_gpu.local_load + // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_commit_group + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_commit_group + // CHECK: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %8, %arg6 = %16) -> (tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, + tensor<64x16x!tt.ptr, #blocked>) : i32 { + %a_block = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked1> + %b_block = tt.load %arg6 : tensor<64x16x!tt.ptr, #blocked> + %a_dotop = triton_gpu.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %b_smem = triton_gpu.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %21 = triton_nvidia_gpu.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma> + %25 = tt.addptr %arg5, %cst_3 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %26 = tt.addptr %arg6, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %21, %25, %26 : tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x16x!tt.ptr, #blocked> + } + tt.return %17#0 : tensor<128x16xf32, #mma> + } +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 6f2b99dfa13c..945ac1cce02b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -402,8 +402,9 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, int canonWidth = (8 * elemBytes * inc) / canonBits; Type canonInt = int_ty(canonBits); std::array retElems; - // don't pack to 32b for Hopper - int vecSize = isHopper ? 1 : 32 / canonBits; + // Hopper may not contain 32b contiguously along k-dimension + int kBits = isHopper ? (8 * elemBytes * kWidth) : 32; + int vecSize = kBits / canonBits; retElems.fill(undef(vec_ty(canonInt, vecSize))); for (int r = 0; r < 2; ++r) { for (int em = 0; em < 2 * vecWidth; em += inc) { @@ -424,7 +425,7 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, if (isActualTrans) std::swap(retElems[1], retElems[2]); - auto iTy = isHopper ? int_ty(8 * elemBytes * inc) : i32_ty; + auto iTy = isHopper ? int_ty(kBits) : i32_ty; return {bitcast(retElems[0], iTy), bitcast(retElems[1], iTy), bitcast(retElems[2], iTy), bitcast(retElems[3], iTy)}; @@ -529,9 +530,8 @@ Value composeValuesToDotOperandLayoutStruct( bool isA) { auto bitwidth = eltTy.getIntOrFloatBitWidth(); assert(32 >= bitwidth && "only support 32-bit or less"); - auto numElemsPerVec = 32 / bitwidth; + auto numElemsPerVec = isHopper ? kWidth : 32 / bitwidth; auto vecTy = vec_ty(eltTy, numElemsPerVec); - // FIXME: Fix the hopper path // FIXME: [DOT LL] // `kWidth` specifies the number of contiguous elements each thread will load. // Loaded elements are packed into a vector of int32, which will then be diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 2b9b4f159bf4..9b1667db7083 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -264,6 +264,28 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, // Return a vector of Value of the accumulator start at startIndex and pack the // values into 32bits in case the accumulator is fp16. +// +// `elements` contains all loaded register values for operand A. +// This consists of operand A for possibly multiple wgmma instructions. +// For each wgmma, each warp in a warp group feeds a single "warp matrix" +// Each warp matrix consists of 2x2 "quads". +// Each thread holds several elements in each quad. Right before a wgmma, +// the sum of bitwidth of +// the elements in each quad should add up to 32. +// +// These values are stored unrolled in `elements`. +// The ordering of dimensions is as follows: +// batch (only 1 batch for Hopper currently) +// matM (m-index of the "warp matrix") +// matK (k-index of the "warp matrix") +// quadK (k-index of the "quad" in the core matrix) +// quadM (m-index of the "quad" in the core matrix) +// vecIdx (index of the element in the quad; this is always along the k-dim) +// +// This ordering is decided when a tensor in DotOpEnc is lowered into llvm. +// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. +// Thus, both lowerings must obey this above ordering for the below code to be +// correct. llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &elements,