Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][NVGPU] Adding nvgpu.warpgroup.mma Op for Hopper GPUs #65440

Merged
merged 8 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1610,9 +1610,9 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
PredOpTrait<"input struct and result struct must be the same type",
TCresIsSameAsOpBase<0, 0>>,]>
{
let results = (outs LLVM_AnyAggregate:$results);
let results = (outs LLVM_AnyStruct:$results);
let arguments = (ins
LLVM_AnyAggregate:$inouts,
LLVM_AnyStruct:$inouts,
I64:$descriptorA,
I64:$descriptorB,
NVVM_MMAShapeAttr:$shape,
Expand Down
56 changes: 56 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "w
let assemblyFormat = "`<` struct(params) `>`";
}

def NVGPU_WarpgroupAccumulator : NVGPU_Type<"WarpgroupAccumulator", "warpgroup.accumulator", []> {
let parameters = (ins "VectorType":$fragmented);
let assemblyFormat = "`<` struct(params) `>`";
let description = [{
This type represents the result matrix obtained from `nvgpu.warpgroup.mma`.
The `$fragmented` type signifies the distributed or fragmented result
vector that is collectively owned by all the threads in the warp-group
that executed `nvgpu.warpgroup.mma`.
[See the details of register fragment layout for accumulator matrix D]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
}];
}

//===----------------------------------------------------------------------===//
// NVGPU Op Definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -664,5 +677,48 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
let hasVerifier = 1;
}

def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
let description = [{
The `nvgpu.warpgroup.mma` op performs the warpgroup-level (4 warps)
matrix-multiply-and-accumulate (mma) operation that results in
`nvvm.wgmma.mma_async`.

The operands are `descriptorA` and `descriptorB` that are wgmma matrix
descriptors that shows the properties of the matrix in shared memory. The
results are thread-level ownership to the warpgroup-level mma operation
shape. The shape is deduced from the descriptor types and output vector.

The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the
given shape. As the instruction `nvvm.wgmma.async` is an asynchronous,
this Op groups the `nvvm.wgmma.async` and surrounds them between
`wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`,
`wgmma.wait.group.sync.aligned` Ops.

Example:
```mlir
%r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2:
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
->
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
```
grypp marked this conversation as resolved.
Show resolved Hide resolved
}];

let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA,
NVGPU_WarpgroupMatrixDescriptor:$descriptorB,
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
OptionalAttr<UnitAttr>:$transposeB,
Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
let assemblyFormat = [{
$descriptorA`,` $descriptorB`,` $matrixC attr-dict
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
}];
let hasVerifier = 1;
}

#endif // NVGPU
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc"

constexpr int kWarpSize = 32;

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"

Expand Down
165 changes: 162 additions & 3 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "nvgpu-to-nvvm"
Expand All @@ -34,6 +36,10 @@ namespace mlir {

using namespace mlir;

/// Number of bits that needs to excluded when building matrix descriptor for
/// wgmma operations.
constexpr int exclude4LSB = 4;

/// GPU has 32 bit registers, this function truncates values when larger width
/// is not needed.
static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
Expand Down Expand Up @@ -419,6 +425,15 @@ struct ConvertNVGPUToNVVMPass
converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
VectorType vtype = type.getFragmented();
SmallVector<Type> structBody;
for (unsigned i = 0; i < vtype.getDimSize(0); i++)
structBody.push_back(vtype.getElementType());
auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return converter.convertType(convertedType);
});
converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 64));
});
Expand All @@ -438,6 +453,8 @@ struct ConvertNVGPUToNVVMPass
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::memref::MemRefDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
Expand Down Expand Up @@ -984,10 +1001,9 @@ struct NVGPUGenerateGmmaDescriptorLowering
shiftLeft(val, startBit));
};

int ex4LSB = 4;
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
uint64_t strideDimVal = (layout << 3) >> ex4LSB;
uint64_t leadDimVal = (sizeN * layout) >> ex4LSB;
uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
uint64_t offsetVal = 0;

Value strideDim = makeConst(strideDimVal);
Expand Down Expand Up @@ -1141,6 +1157,148 @@ struct NVGPUTmaCreateDescriptorOpLowering
}
};

struct NVGPUWarpgroupMmaOpLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;

LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
int &wgmmaShapeM, int &wgmmaShapeN,
int &wgmmaShapeK) const {
wgmmaShapeM = 64;
wgmmaShapeN = sizeN;
if (inputElemType.isTF32()) {
wgmmaShapeK = 8;
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
wgmmaShapeK = 16;
} else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
inputElemType.isInteger(16)) {
wgmmaShapeK = 32;
} else if (inputElemType.isInteger(1)) {
wgmmaShapeK = 256;
} else {
llvm_unreachable("msg: not supported K shape");
}
LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
<< ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
<< "]\n");
return success();
}

Value generateNVVMWgmmaOp(MLIRContext *ctx,
ConversionPatternRewriter &rewriter, Location loc,
int m, int n, int k, Type resultStructType,
Value inout, Value descriptorA,
Value descriptorB) const {
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
// todo: handle other input and output types
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
auto overflow =
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
return res;
}

LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);

LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
<< sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
<< sizeN << "] ---===\n");

int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
wgmmaShapeN, wgmmaShapeK))) {
return failure();
}

Value descriptorA = adaptor.getDescriptorA();
Value descriptorB = adaptor.getDescriptorB();

// Generate wgmma group

auto loc = op->getLoc();
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();

auto makeAdd = [&](Value lhs, Value rhs) -> Value {
return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
};

auto iterateDescA = [&](Value desc, int iterM, int iterN,
int iterK) -> Value {
// todo : Handle column major
int byte = typeTensorA.getElementTypeBitWidth() / 8;
int tileShapeA = typeTensorA.getDimSize(1);
int incrementVal =
((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
incrementVal = incrementVal >> exclude4LSB;
LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
<< iterK << "] [wgmma descriptors] Descriptor A + "
<< incrementVal << " | \t ");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
};

auto iterateDescB = [&](Value desc, int iterM, int iterN,
int iterK) -> Value {
// todo : Handle row major
int byte = typeTensorB.getElementTypeBitWidth() / 8;
int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
incrementVal = incrementVal >> exclude4LSB;
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
};

rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);

SmallVector<Value> wgmmaResults;
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
Value matrixC = adaptor.getMatrixC()[iterM];
Value matrixD = op.getMatrixD()[iterM];
Type structType = getTypeConverter()->convertType(matrixD.getType());
LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
<< (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
<< ":" << wgmmaShapeN << "] += \n");
for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
<< wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
<< ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
<< (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
<< " B[" << (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
<< ":" << wgmmaShapeN << "])\n");
matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
structType, matrixC, descA, descB);
}
wgmmaResults.push_back(matrixC);
}
rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());

ValueRange myres(wgmmaResults);
rewriter.replaceOp(op, myres);
return success();
}
};

} // namespace

void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
Expand All @@ -1156,6 +1314,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
Expand Down
Loading