From 926599136b916420406d2f2b37dadedd6db566ed Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Tue, 29 Oct 2024 21:56:55 +0000 Subject: [PATCH] Run pre-commit hook --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 3 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 4 +- .../SharedToDotOperandMMAv2OrV3.cpp | 41 ++++++++++--------- .../DotOpToLLVM/MMAv2.cpp | 14 ++++--- .../DotOpToLLVM/WGMMA.cpp | 6 ++- 5 files changed, 37 insertions(+), 31 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index cc7ab7496340..1b7088870ca4 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -21,7 +21,8 @@ bool isDotOpTensorAndPacked(Type srcTy) { if (!encoding) return false; auto parentEnc = dyn_cast(encoding.getParent()); - // By code convention, values for Hopper's dotOp-encoded tensors are not packed + // By code convention, values for Hopper's dotOp-encoded tensors are not + // packed if (!parentEnc || parentEnc.isHopper()) return false; return true; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 709e7c6b49fd..3b5316ecc0e3 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2064,8 +2064,8 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( } // A100 if (isAmpere()) { - auto rep = getMMAv2OrV3RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(), - kWidth, opIdx); + auto rep = getMMAv2OrV3RepForOperand( + shapePerCTA, eltTy.getIntOrFloatBitWidth(), kWidth, opIdx); if (opIdx == 0) return 4 * rep[0] * rep[1] * rep[2]; if (opIdx == 1) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index ed73882711a2..6094a911189d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -444,8 +444,8 @@ MMA16816SmemLoader::MMA16816SmemLoader( ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, int elemBytes, int mmaElemBytes, bool isHopper, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter *typeConverter, const Location &loc) + ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, + const Location &loc) : nPerWarp(nPerWarp), order(order.begin(), order.end()), warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder), kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()), @@ -453,12 +453,12 @@ MMA16816SmemLoader::MMA16816SmemLoader( matShape(matShape.begin(), matShape.end()), multiDimWarpId(multiDimWarpId.begin(), multiDimWarpId.end()), perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), - mmaElemBytes(mmaElemBytes), isHopper(isHopper), - rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) { - // If the current elemType width is different from the MMA elemType width, i.e. - // width-changing casting is done later in DotOp Layout... then, in the case of - // Hopper, the number of bytes held by each thread after loading will no longer - // be 32B. Hence this flag is required to stipulate different logic. + mmaElemBytes(mmaElemBytes), isHopper(isHopper), rewriter(rewriter), + loc(loc), ctx(rewriter.getContext()) { + // If the current elemType width is different from the MMA elemType width, + // i.e. width-changing casting is done later in DotOp Layout... then, in the + // case of Hopper, the number of bytes held by each thread after loading will + // no longer be 32B. Hence this flag is required to stipulate different logic. bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); contiguousMatShape = matShape[order[0]]; @@ -536,7 +536,8 @@ std::vector unpackInt(const std::vector &inValues, Type elTy, for (auto v : inValues) { // cast i32 to appropriate eltType vector and extract elements auto eltType = typeConverter->convertType(elTy); - auto vecType = vec_ty(eltType, inBitWidth / eltType.getIntOrFloatBitWidth()); + auto vecType = + vec_ty(eltType, inBitWidth / eltType.getIntOrFloatBitWidth()); auto vec = bitcast(v, vecType); for (int i = 0; i < inBitWidth / eltType.getIntOrFloatBitWidth(); i++) { outValues.push_back(extract_element(vec, i32_val(i))); @@ -597,12 +598,12 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, std::max(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); // (a, b) is the coordinate. auto load = [=, &rewriter, &vals](int batch, int a, int b) { - MMA16816SmemLoader loader( - nPerWarp, warpsPerTile, sharedLayout.getOrder(), - mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides, - shapePerCTA /*tileShape*/, instrShape, matShape, multiDimWarpId, - perPhase, maxPhase, elemBytes, mmaElemBytes, - isHopper, rewriter, typeConverter, loc); + MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(), + mmaLayout.getWarpsPerCTA(), kOrder, kWidth, + smemObj.strides, shapePerCTA /*tileShape*/, + instrShape, matShape, multiDimWarpId, perPhase, + maxPhase, elemBytes, mmaElemBytes, isHopper, + rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(lane, cSwizzleOffset); @@ -647,9 +648,9 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); int bitwidth = descTy.getElementTypeBitWidth(); - // For Hopper WGMMA, the sum of bitwidth of the elements in each quad should add - // up to 32. We use kWidth to compute the element bitwidth of the input to WGMMA, - // which could be different from `bitwidth` due to later casting. + // For Hopper WGMMA, the sum of bitwidth of the elements in each quad should + // add up to 32. We use kWidth to compute the element bitwidth of the input to + // WGMMA, which could be different from `bitwidth` due to later casting. int mmaBitwidth = isHopper ? (32 / encoding.getKWidth()) : bitwidth; ValueTable vals; @@ -657,8 +658,8 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; int kWidth = encoding.getKWidth(); - auto numRep = mmaLayout.getMMAv2OrV3RepForOperand(shapePerCTA, bitwidth, kWidth, - encoding.getOpIdx()); + auto numRep = mmaLayout.getMMAv2OrV3RepForOperand( + shapePerCTA, bitwidth, kWidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto order = triton::gpu::getOrder(mmaLayout); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 5925c27bfd06..b03fb0989dda 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -393,13 +393,15 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); - auto repA = cast(dotOpA.getParent()) - .getMMAv2OrV3RepForOperand(aShapePerCTA, bitwidth, - dotOpA.getKWidth(), dotOpA.getOpIdx()); + auto repA = + cast(dotOpA.getParent()) + .getMMAv2OrV3RepForOperand(aShapePerCTA, bitwidth, dotOpA.getKWidth(), + dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); - auto repB = cast(dotOpB.getParent()) - .getMMAv2OrV3RepForOperand(bShapePerCTA, bitwidth, - dotOpB.getKWidth(), dotOpB.getOpIdx()); + auto repB = + cast(dotOpB.getParent()) + .getMMAv2OrV3RepForOperand(bShapePerCTA, bitwidth, dotOpB.getKWidth(), + dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 6b4acdcc04f3..2b9b4f159bf4 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -442,8 +442,10 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, if (aSharedLayout) { a = aLoader.smemLoad(m, k, rewriter, loc); } else { - auto aDotOpEnc = cast(aTensorTy.getEncoding()); - assert(aDotOpEnc.getKWidth() == 32 / aTensorTy.getElementTypeBitWidth()); + auto aDotOpEnc = + cast(aTensorTy.getEncoding()); + assert(aDotOpEnc.getKWidth() == + 32 / aTensorTy.getElementTypeBitWidth()); unsigned regASize = (instrShape[0] * instrShape[2]) / 32; llvm::SmallVector regA =