Skip to content

Commit

Permalink
Run pre-commit hook
Browse files Browse the repository at this point in the history
  • Loading branch information
ggengnv committed Oct 29, 2024
1 parent 3009866 commit 9265991
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 31 deletions.
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ bool isDotOpTensorAndPacked(Type srcTy) {
if (!encoding)
return false;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
// By code convention, values for Hopper's dotOp-encoded tensors are not packed
// By code convention, values for Hopper's dotOp-encoded tensors are not
// packed
if (!parentEnc || parentEnc.isHopper())
return false;
return true;
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2064,8 +2064,8 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
}
// A100
if (isAmpere()) {
auto rep = getMMAv2OrV3RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(),
kWidth, opIdx);
auto rep = getMMAv2OrV3RepForOperand(
shapePerCTA, eltTy.getIntOrFloatBitWidth(), kWidth, opIdx);
if (opIdx == 0)
return 4 * rep[0] * rep[1] * rep[2];
if (opIdx == 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,21 +444,21 @@ MMA16816SmemLoader::MMA16816SmemLoader(
ArrayRef<int> instrShape, ArrayRef<int> matShape,
SmallVector<Value> multiDimWarpId, int perPhase, int maxPhase,
int elemBytes, int mmaElemBytes, bool isHopper,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter *typeConverter, const Location &loc)
ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter,
const Location &loc)
: nPerWarp(nPerWarp), order(order.begin(), order.end()),
warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder),
kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
matShape(matShape.begin(), matShape.end()),
multiDimWarpId(multiDimWarpId.begin(), multiDimWarpId.end()),
perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes),
mmaElemBytes(mmaElemBytes), isHopper(isHopper),
rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) {
// If the current elemType width is different from the MMA elemType width, i.e.
// width-changing casting is done later in DotOp Layout... then, in the case of
// Hopper, the number of bytes held by each thread after loading will no longer
// be 32B. Hence this flag is required to stipulate different logic.
mmaElemBytes(mmaElemBytes), isHopper(isHopper), rewriter(rewriter),
loc(loc), ctx(rewriter.getContext()) {
// If the current elemType width is different from the MMA elemType width,
// i.e. width-changing casting is done later in DotOp Layout... then, in the
// case of Hopper, the number of bytes held by each thread after loading will
// no longer be 32B. Hence this flag is required to stipulate different logic.
bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes);

contiguousMatShape = matShape[order[0]];
Expand Down Expand Up @@ -536,7 +536,8 @@ std::vector<Value> unpackInt(const std::vector<Value> &inValues, Type elTy,
for (auto v : inValues) {
// cast i32 to appropriate eltType vector and extract elements
auto eltType = typeConverter->convertType(elTy);
auto vecType = vec_ty(eltType, inBitWidth / eltType.getIntOrFloatBitWidth());
auto vecType =
vec_ty(eltType, inBitWidth / eltType.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecType);
for (int i = 0; i < inBitWidth / eltType.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
Expand Down Expand Up @@ -597,12 +598,12 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
std::max<int>(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8);
// (a, b) is the coordinate.
auto load = [=, &rewriter, &vals](int batch, int a, int b) {
MMA16816SmemLoader loader(
nPerWarp, warpsPerTile, sharedLayout.getOrder(),
mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides,
shapePerCTA /*tileShape*/, instrShape, matShape, multiDimWarpId,
perPhase, maxPhase, elemBytes, mmaElemBytes,
isHopper, rewriter, typeConverter, loc);
MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(),
mmaLayout.getWarpsPerCTA(), kOrder, kWidth,
smemObj.strides, shapePerCTA /*tileShape*/,
instrShape, matShape, multiDimWarpId, perPhase,
maxPhase, elemBytes, mmaElemBytes, isHopper,
rewriter, typeConverter, loc);
// Offset of a slice within the original tensor in shared memory
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
SmallVector<Value> offs = loader.computeOffsets(lane, cSwizzleOffset);
Expand Down Expand Up @@ -647,18 +648,18 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc,
bool isHopper = mmaLayout.getVersionMajor() == 3;
auto shapePerCTA = getShapePerCTA(descTy);
int bitwidth = descTy.getElementTypeBitWidth();
// For Hopper WGMMA, the sum of bitwidth of the elements in each quad should add
// up to 32. We use kWidth to compute the element bitwidth of the input to WGMMA,
// which could be different from `bitwidth` due to later casting.
// For Hopper WGMMA, the sum of bitwidth of the elements in each quad should
// add up to 32. We use kWidth to compute the element bitwidth of the input to
// WGMMA, which could be different from `bitwidth` due to later casting.
int mmaBitwidth = isHopper ? (32 / encoding.getKWidth()) : bitwidth;

ValueTable vals;
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth;
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth;

int kWidth = encoding.getKWidth();
auto numRep = mmaLayout.getMMAv2OrV3RepForOperand(shapePerCTA, bitwidth, kWidth,
encoding.getOpIdx());
auto numRep = mmaLayout.getMMAv2OrV3RepForOperand(
shapePerCTA, bitwidth, kWidth, encoding.getOpIdx());

auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto order = triton::gpu::getOrder(mmaLayout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,15 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,

int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth();
auto dotOpA = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
auto repA = cast<NvidiaMmaEncodingAttr>(dotOpA.getParent())
.getMMAv2OrV3RepForOperand(aShapePerCTA, bitwidth,
dotOpA.getKWidth(), dotOpA.getOpIdx());
auto repA =
cast<NvidiaMmaEncodingAttr>(dotOpA.getParent())
.getMMAv2OrV3RepForOperand(aShapePerCTA, bitwidth, dotOpA.getKWidth(),
dotOpA.getOpIdx());
auto dotOpB = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
auto repB = cast<NvidiaMmaEncodingAttr>(dotOpB.getParent())
.getMMAv2OrV3RepForOperand(bShapePerCTA, bitwidth,
dotOpB.getKWidth(), dotOpB.getOpIdx());
auto repB =
cast<NvidiaMmaEncodingAttr>(dotOpB.getParent())
.getMMAv2OrV3RepForOperand(bShapePerCTA, bitwidth, dotOpB.getKWidth(),
dotOpB.getOpIdx());

assert(repA[2] == repB[1]);
assert(repA[0] == repB[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,10 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
if (aSharedLayout) {
a = aLoader.smemLoad(m, k, rewriter, loc);
} else {
auto aDotOpEnc = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
assert(aDotOpEnc.getKWidth() == 32 / aTensorTy.getElementTypeBitWidth());
auto aDotOpEnc =
cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
assert(aDotOpEnc.getKWidth() ==
32 / aTensorTy.getElementTypeBitWidth());

unsigned regASize = (instrShape[0] * instrShape[2]) / 32;
llvm::SmallVector<Value> regA =
Expand Down

0 comments on commit 9265991

Please sign in to comment.