Skip to content

Commit

Permalink
[BACKEND][NVIDIA] Add Lowering for Shared-to-MMAv3-DotOp Copy (triton…
Browse files Browse the repository at this point in the history
…-lang#5009)

Allows for upcasting in DotOp encoding in RF.
This lowering path is not currently in use; pending
triton-lang#5003
  • Loading branch information
ggengnv authored Oct 30, 2024
1 parent ebce7f3 commit cfddb09
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 86 deletions.
44 changes: 29 additions & 15 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false.
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
// ---- begin Ampere & Hopper ----
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
Expand Down Expand Up @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false.
llvm_unreachable("invalid operand index");
}

// ---- begin version 3 ----
if (mmaEnc.isHopper()) {
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
" is Hopper has not been implemented yet");
return $_get(context, 1, 1, 1, order, CTALayout, true);
}

// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>,
Expand Down Expand Up @@ -1224,7 +1217,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
SmallVector<int> getMMAv1Rep(int opIdx) const;
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
SmallVector<int64_t> getMMAv2OrV3RepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int kWidth, int opIdx) const;

bool supportReduction() const {
Expand Down Expand Up @@ -1319,6 +1312,27 @@ The parent field is the layout of d.
kWidth defines number of consecutive elements stored by one thread along k dimension.
Some layouts do not use this parameter, either because they have a fixed number of
elements along the K dim, or they use all elements of the tensor along the K dim.

# WGMMA Notes
We require kWidth to be provided for Hopper because the dtype at loading might be
different from the dtype at WGMMA, due to casting. The kWidth is determined by the
dtype at WGMMA.

The encoded tensor consists of operand A for possibly multiple wgmma instructions.
For each wgmma, each warp in a warp group feeds a single "warp matrix"
Each warp matrix consists of 2x2 "quads".
Each thread holds several elements in each quad. Right before a wgmma,
the sum of bitwidth of
the elements in each quad should add up to 32.

These values are stored unrolled in `elements`.
The ordering of dimensions is as follows by convention:
batch (only 1 batch for Hopper currently)
matM (m-index of the "warp matrix")
matK (k-index of the "warp matrix")
quadK (k-index of the "quad" in the core matrix)
quadM (m-index of the "quad" in the core matrix)
vecIdx (index of the element in the quad; this is always along the k-dim)
}];

let parameters = (
Expand All @@ -1329,16 +1343,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
);

let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"Type":$eltTy), [{
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper()))
return $_get(context, opIdx, parent, 0); // For MMAV1
// For MMAV2 and V3
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
unsigned kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, kWidth);
}]>
];

Expand Down
21 changes: 20 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ using namespace mlir::triton::gpu;

namespace mlir::triton::gpu {

namespace {

bool isDotOpTensorAndPacked(Type srcTy) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return false;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!encoding)
return false;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
// By code convention, values for Hopper's dotOp-encoded tensors are not
// packed
if (!parentEnc || parentEnc.isHopper())
return false;
return true;
}

} // namespace

Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = dyn_cast<RankedTensorType>(type))
Expand All @@ -33,7 +52,7 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
// If the parent of the dot operand is in block encoding, we don't need to
// reorder elements
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
if (!parentEncoding)
if (!parentEncoding || parentEncoding.isHopper())
return values;
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
Expand Down
33 changes: 24 additions & 9 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1074,13 +1074,18 @@ LogicalResult DotOperandEncodingAttr::verify(
return emitError() << "triton_gpu.dot_op parent paramenter cannot be null";
}
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !parentAttr.isAmpere())
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "triton_gpu.dot_op kWidth parameter can only be "
"non-zero for Ampere MMA parent";
if (kWidth == 0 && parentAttr.isAmpere())
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError()
<< "triton_gpu.dot_op kWidth parameter is mandatory for "
"Ampere MMA parent";
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "triton_gpu.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent, since Hopper WGMMA only allows first "
"operand to be in registers";
return success();
}

Expand Down Expand Up @@ -2013,17 +2018,20 @@ SmallVector<int> NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const {
int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const {
return 2 * getMMAv1Rep(opIdx)[opIdx];
}
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2OrV3RepForOperand(
ArrayRef<int64_t> shape, int bitwidth, int kWidth, int opIdx) const {
assert(isAmpere() || (isHopper() && opIdx == 0));
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();

// {batch, m, n, k}
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
: 1;
assert(isAmpere());

if (opIdx == 0)
return {numRepBatch,
Expand All @@ -2038,19 +2046,26 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
warpsPerCTA[rank - 1]))};
}
}

unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto shapePerCTA = getShapePerCTA(*this, shape);
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
// H100
if (isHopper()) {
return getTotalElemsPerThread(shape, eltTy);
assert(opIdx == 0);
auto instrMNK = getInstrShape();
int repM = ceil<unsigned>(shapePerCTA[0], instrMNK[0] * warpsPerCTAM);
int repK = ceil<unsigned>(shapePerCTA[1], instrMNK[2]);
// For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds
// kWidth elements for each quadrant. WGMMA is repeated repM * repK times.
return 4 * kWidth * repM * repK;
}
// A100
if (isAmpere()) {
auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(),
kWidth, opIdx);
auto rep = getMMAv2OrV3RepForOperand(
shapePerCTA, eltTy.getIntOrFloatBitWidth(), kWidth, opIdx);
if (opIdx == 0)
return 4 * rep[0] * rep[1] * rep[2];
if (opIdx == 1)
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,9 @@ struct MMAV3UseRegOperand
dstEnc.getVersionMajor() != 3)
return failure();
auto srcTy = cast<RankedTensorType>(alloc.getSrc().getType());
auto kWidth = 32 / srcTy.getElementTypeBitWidth();
auto dotOperandEnc = DotOperandEncodingAttr::get(
dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0);
dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/kWidth);
auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(),
dotOperandEnc);
if (!matchMmaV3AndDotOperandLayout(srcTy, newTy))
Expand Down
24 changes: 19 additions & 5 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
%opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }:
tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma>
tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma>
tt.return
}
}
Expand All @@ -114,10 +114,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// Generate a wgmma where the first operand is a struct.
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32}
tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) {
tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1>
%m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } :
tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1>
tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1>
tt.return
}
}
//
// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @dot_reg_operand_upcast(%a_desc: !tt.memdesc<128x64xi8, #shared>, %b: !tt.memdesc<64x64xf16, #shared>, %acc: tensor<128x64xf32, #mma>) {
%a_dotop = triton_gpu.local_load %a_desc : !tt.memdesc<128x64xi8, #shared> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%res = triton_nvidia_gpu.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma>
tt.return
}
}
Expand Down Expand Up @@ -193,7 +207,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: prmt.b32
// CHECK: prmt.b32
tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) {
%opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
%opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}
}
Expand Down
8 changes: 4 additions & 4 deletions test/TritonGPU/dot-operands.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ tt.func @update_kwidth_slice(
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK: tt.func @mma_v3_reg_operand_A
// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma>
// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma>
tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
%A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1>
%r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma>
Expand All @@ -180,8 +180,8 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK: tt.func @mma_v3_reg_operand_A_fp8
// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma>
// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma>
tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
%A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1>
%r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma>
Expand Down
14 changes: 10 additions & 4 deletions test/TritonGPU/invalid-attributes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// expected-error@+2 {{triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: 2}}
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked}>
#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}>

// -----

Expand All @@ -12,19 +12,25 @@

// -----

// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}}
// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}}
#mma = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>

// -----

// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere MMA parent}}
// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}}
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>

// -----

// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}}
// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}}
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>

// -----

// expected-error@+2 {{triton_gpu.dot_op opIdx parameter must be 0 for Hopper MMA parent, since Hopper WGMMA only allows first operand to be in registers}}
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>

Expand Down
Loading

0 comments on commit cfddb09

Please sign in to comment.