Skip to content

Commit

Permalink
Test Dot. Bypass a proper getOrder use
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Nov 4, 2024
1 parent a0ce675 commit f7a4ad7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
5 changes: 4 additions & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,17 @@ SmallVector<unsigned> getOrder(Attribute layout) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the output elements in the LLVM lowerings. We choose row-major
auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout);
if (nvidiaMma && nvidiaMma.isHopper()) {
if (nvidiaMma) {
llvm::report_fatal_error("Testing");
}
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (isa<NvidiaMmaEncodingAttr>(dotLayout.getParent())) {
llvm::report_fatal_error("Testing DotOperand");
}
auto rank = dotLayout.getWarpsPerCTA().size();
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
}
Expand Down
21 changes: 14 additions & 7 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,12 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
// TODO Revert
auto orderedDimNames =
permuteDimNames(dimNames, getMatrixOrder(rank, /*rowMajor=*/true));
// auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
// By using `reverse(dimNames)` below, we set the order to be row-major
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
// assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down Expand Up @@ -876,11 +879,15 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();
// A and B have kMajor order
assert(getOrder(dot) ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));
// TODO Revert
// assert(getOrder(dot) ==
// getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

// auto kMajorDims =
// permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));
auto kMajorDims = permuteDimNames(
standardOutDimNames(ctx, rank),
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
Expand Down
11 changes: 9 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,16 @@ class FuseTransHopper : public OpRewritePattern<LocalAllocOp> {

// MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3
// without transpose for other data types.)
auto newInnerCvtOrder = getOrder(srcTy.getEncoding());
// TODO Revert
auto getCvtOrder = [](Attribute encoding) {
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(encoding)) {
return getThreadOrder(nvidiaMma);
}
return getOrder(encoding);
};
auto newInnerCvtOrder = getCvtOrder(srcTy.getEncoding());
if (auto cvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
newInnerCvtOrder = getOrder(cvt.getSrc().getType().getEncoding());
newInnerCvtOrder = getCvtOrder(cvt.getSrc().getType().getEncoding());
}
auto srcElemTy = allocType.getElementType();
if (!srcElemTy.isF16() && !srcElemTy.isBF16()) {
Expand Down

0 comments on commit f7a4ad7

Please sign in to comment.