From 608638c1c6bd255fd363aa34385c4fa772dd0c03 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 15 Oct 2024 13:32:10 +0100 Subject: [PATCH] Use `blockwise_broadcast_reduce` in reduction fusions. (#1668) Currently, we handle reductions that are being fused with gemm-like operations by using atomic stores to the destination buffer. This can be cripplingly slow when most of the output is being reduced as evidenced in layer_norm cases. This PR adds the ability to blockwise_broadcast_reduce on the block sub-tiles of of gemm output. However, in-order to that we need to make sure the reduction dimension is uniformly distributed across the blocks. This is achieved by : Firstly, this PR introduces a utility where for a given set of upper dimensions, it can traverse a transform stack and produce a list of sub-dimensions per each lower dimension where the upper reduction axes are mapped to. Then, this PR introduces ShuffleGemmForReductions pass, which will split and transpose the parallel dimension of the gemm such that reduction dimension is uniformly distributed across the blocks. Then at AlignTiling pass, we extract the block subtile when fusing in the rock.reduce operator. Then perform a blockwise_broadcast_reduce on the block subtile. Since we only want to write the partial reductions per block, we pad out broadcasted part of the subtile. (We rely on any block coordinate that goes to the padded region within the block will not be written out) Then we need to do Recombine the modified sub-tile coordinate maps with grid-only coordinates maps. a) Here, we drop all the upper-dimensions except g_block, m_block and n_block and obtain the grid-only transform map stack. b) In parallel, we re-use getLowerSubDimensions utility to figure out which sub-dimension gets mapped with the above grid-only dimensions. c) Then we extract of those sub-dimensions in a bottom up fashion and stitch it up with the said grid-only transform map stack. --- mlir/include/mlir/Dialect/Rock/Passes.h | 1 + mlir/include/mlir/Dialect/Rock/Passes.td | 6 + .../Dialect/Rock/Tuning/GridwiseGemmParams.h | 6 +- .../mlir/Dialect/Rock/utility/loweringUtils.h | 3 + .../Dialect/Rock/utility/transformMapUtils.h | 22 + mlir/lib/Dialect/Rock/IR/RockDialect.cpp | 2 +- mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | 9 +- .../Dialect/Rock/Transforms/AlignTiling.cpp | 512 +++++++++++++- .../Dialect/Rock/Transforms/CMakeLists.txt | 1 + .../Dialect/Rock/Transforms/OutputSwizzle.cpp | 2 + mlir/lib/Dialect/Rock/Transforms/ReuseLDS.cpp | 4 +- .../Transforms/ShuffleGemmForReductions.cpp | 629 ++++++++++++++++++ .../Rock/Tuning/GridwiseGemmParams.cpp | 46 +- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 9 +- .../Dialect/Rock/utility/loweringUtils.cpp | 6 + .../Rock/utility/transformMapUtils.cpp | 292 +++++++- .../rock-shuffle-gemm-for-reductions.mlir | 76 +++ .../pr-e2e/mixr-multi-reduce-mo-4d.mlir | 29 + ...ixr-multi-reduce-mo-dot-reduce-n-only.mlir | 24 + .../pr-e2e/mixr-multi-reduce-mo-dot.mlir | 25 + .../pr-e2e/mixr-multi-reduce-mo-raxis-1.mlir | 29 + .../fusion/pr-e2e/mixr-multi-reduce-mo.mlir | 29 + .../fusion/rock-gemm-reduce-align-tiling.mlir | 73 +- mlir/test/rocmlir-driver/pipelines.mlir | 3 +- 24 files changed, 1757 insertions(+), 81 deletions(-) create mode 100644 mlir/lib/Dialect/Rock/Transforms/ShuffleGemmForReductions.cpp create mode 100644 mlir/test/Dialect/Rock/rock-shuffle-gemm-for-reductions.mlir create mode 100644 mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-4d.mlir create mode 100644 mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-dot-reduce-n-only.mlir create mode 100644 mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-dot.mlir create mode 100644 mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-raxis-1.mlir create mode 100644 mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo.mlir diff --git a/mlir/include/mlir/Dialect/Rock/Passes.h b/mlir/include/mlir/Dialect/Rock/Passes.h index ccec5efbec46..cf7f95a02f8b 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.h +++ b/mlir/include/mlir/Dialect/Rock/Passes.h @@ -44,6 +44,7 @@ namespace rock { #define GEN_PASS_DECL_ROCKVECTORIZEFUSIONSPASS #define GEN_PASS_DECL_ROCKOUTPUTSWIZZLEPASS #define GEN_PASS_DECL_ROCKREUSELDSPASS +#define GEN_PASS_DECL_ROCKSHUFFLEGEMMFORREDUCTIONS #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Rock/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index a88f5ec8d441..872618360b82 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -161,4 +161,10 @@ def RockReuseLDSPass : Pass<"rock-reuse-lds", "::mlir::func::FuncOp"> { let dependentDialects = ["rock::RockDialect", "memref::MemRefDialect"]; } +def RockShuffleGemmForReductions : Pass<"rock-shuffle-gemm-for-reductions", "::mlir::func::FuncOp"> { + let summary = "This pass shuffles parallel gemm dimensions for equally block splits for fused reductions (if any)"; + let dependentDialects = ["rock::RockDialect", "memref::MemRefDialect"]; +} + + #endif // MLIR_DIALECT_ROCK_PASSES diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h index e20f3c66b782..3f5ee596819a 100644 --- a/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h +++ b/mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h @@ -76,19 +76,21 @@ struct PopulateParamsInfo { KernelType kernelType; int64_t batchSize; uint32_t numCu; + bool hasFusedReduction; PopulateParamsInfo(GemmSize gemmSize, StringRef arch, GemmFeatures gemmFeatures, Type gemmAType, Type gemmBType, KernelType kernelType) : gemmSize(gemmSize), arch(arch), gemmFeatures(gemmFeatures), - gemmAType(gemmAType), gemmBType(gemmBType), kernelType(kernelType) {} + gemmAType(gemmAType), gemmBType(gemmBType), kernelType(kernelType), + hasFusedReduction(false) {} PopulateParamsInfo(GemmSize gemmSize, StringRef arch, GemmFeatures gemmFeatures, Type gemmAType, Type gemmBType, KernelType kernelType, int64_t batchSize, uint32_t numCu) : gemmSize(gemmSize), arch(arch), gemmFeatures(gemmFeatures), gemmAType(gemmAType), gemmBType(gemmBType), kernelType(kernelType), - batchSize(batchSize), numCu(numCu) {} + batchSize(batchSize), numCu(numCu), hasFusedReduction(false) {} /// Extract the relevant information from a RockGemmWrapperInterface operation static PopulateParamsInfo fromOp(RockGemmWrapperInterface op); diff --git a/mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h b/mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h index 2ec3bfbdec0e..52547e5d71d9 100644 --- a/mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h +++ b/mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h @@ -201,6 +201,9 @@ TypedValue viewBufferAs(OpBuilder &b, Value buffer, Type type); Value gpuAlloc(OpBuilder &b, Location loc, int64_t bufferDim, Type elementType, gpu::AddressSpace memoryAddressSpace); +// helper to verify a lds allocation fits in the GPU +LogicalResult checkLDSSize(StringAttr arch, int64_t ldsBytes); + } // end namespace rock } // end namespace mlir #endif diff --git a/mlir/include/mlir/Dialect/Rock/utility/transformMapUtils.h b/mlir/include/mlir/Dialect/Rock/utility/transformMapUtils.h index 143a5ea8374c..7bb105562588 100644 --- a/mlir/include/mlir/Dialect/Rock/utility/transformMapUtils.h +++ b/mlir/include/mlir/Dialect/Rock/utility/transformMapUtils.h @@ -266,6 +266,28 @@ FailureOr removeUpperDims(OpBuilder &b, ArrayAttr transformAttrs, // padded data. FailureOr removeUpperDims(OpBuilder &b, ArrayAttr transformAttrs, const StringSet<> &removeDimNamesSet); + +struct SubDimInfo { + int64_t size; + int64_t stride; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const SubDimInfo &sdInfo) { + os << ""; + return os; +} + +// Given a sequence of transform maps, this will obtain the lower sub-dimensions +// each provided upper dim would map to. +FailureOr>> +getLowerSubDimensions(OpBuilder &b, ArrayAttr transformAttrs, int64_t dim); +FailureOr>> +getLowerSubDimensions(OpBuilder &b, ArrayAttr transformAttrs, + ArrayRef dims); + +SmallVector> createDimNames(int64_t len, StringRef prefix); +SmallVector getStringRefsFor(ArrayRef> strings); + } // end namespace rock } // end namespace mlir #endif diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index ee5552d948ea..d1d1f098d267 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -2033,7 +2033,7 @@ LogicalResult BlockwiseBroadcastReduceOp::verify() { } if (blockwiseInputPartialReductionTensorElements > wsShape[0]) { return emitError( - "workspace should be at least the size of elements per block"); + "workspace should be at least the size of elements per block "); } return success(); } diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index adfaff9141ed..d5e41e420b4d 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -151,8 +151,14 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass(rock::createRockConvToGemmPass()); funcPm.addPass(rock::createRockGemmToGridwisePass()); funcPm.addPass(rock::createRockRegularizePass()); + funcPm.addPass(rock::createRockShuffleGemmForReductions()); funcPm.addPass(rock::createRockGridwiseGemmToBlockwisePass()); - funcPm.addPass(rock::createRockBlockwiseGemmToThreadwisePass()); + // We want to delay blockwise lowering in the fusion cases + // until after linalg align pass because with reduction fusion + // it may introduce blockwise_reductions. + if (!options.enableFusion) { + funcPm.addPass(rock::createRockBlockwiseGemmToThreadwisePass()); + } if (options.enableFusion) { // align linalg tiling @@ -160,6 +166,7 @@ void rock::buildKernelPipeline(OpPassManager &pm, * --convert-linalg-to-affine-loops */ funcPm.addPass(rock::createRockLinalgAlignPass()); + funcPm.addPass(rock::createRockBlockwiseGemmToThreadwisePass()); funcPm.addPass(rock::createRockPipelinePass()); funcPm.addPass(createCanonicalizerPass()); funcPm.addPass(createConvertLinalgToAffineLoopsPass()); diff --git a/mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp b/mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp index 85800af8cf91..ad27437691ea 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/Rock/IR/TransformMapBuilder.h" #include "mlir/Dialect/Rock/Passes.h" #include "mlir/Dialect/Rock/utility/builderUtils.h" +#include "mlir/Dialect/Rock/utility/loweringUtils.h" #include "mlir/Dialect/Rock/utility/transformMapUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" @@ -1097,6 +1098,469 @@ MemcpyRewritePattern::matchAndRewrite(memref::CopyOp copy, return failure(); } +// We have mutated the blockSubTile views by declaring the broadcasted +// reductions as padding +// -- so that they will not be written out. +// The purpose of this function is to integrate back the mutated blockSubTile +// view -- known as "paddedReducedTrStack" back to the original views. In order +// to do that, the function expects the caller to provide (among others): a) +// "gridOnlyDims" which is basically g_block x m_block x n_block --> d0 x d1 x +// ... x dr x ... x dn views. b) "lowerGridOnlySubDims" which denotes the +// subdimensions that above maps to in the lower view. Then it will concatanate +// the grid only dims and block sub tile dims as: [ grid_only_d0 x ... x +// grid_only_dn x block_d0 x ... x block_dn ] Additionally, it will split the +// lower view into grid only sub dims and the rest. Then rearrange the split in +// the above form in a bottom up manner. Finally, it will re-stitch transform +// stack back and return to be used in the final writeback. +static FailureOr> +getRecombinedGridOnlyBlockOnlyTiles( + LinalgAlignRewriter &rewriter, ArrayAttr gridOnlyDims, + ArrayAttr paddedReducedTrStack, unsigned upperGemmSpaceRank, + ArrayAttr toBeReducedViews, + const llvm::SmallDenseMap> + &lowerGridOnlySubDims) { + SmallVector transformAttrs; + ArrayRef lowerShapeGridOnly = getLowerShape(gridOnlyDims); + size_t lowerGridOnlyRank = lowerShapeGridOnly.size(); + for (auto [idx, attr] : llvm::enumerate(paddedReducedTrStack)) { + TransformMapAttr trMapAttr = cast(attr); + SmallVector trAttrs; + ArrayRef gridUpperShape; + ArrayRef gridLowerShape; + if (idx < gridOnlyDims.size()) { + TransformMapAttr gridOnlyAttr = cast(gridOnlyDims[idx]); + ArrayRef ops = gridOnlyAttr.getOps(); + trAttrs.insert(trAttrs.end(), ops.begin(), ops.end()); + gridUpperShape = gridOnlyAttr.getUpperBounds().asArrayRef(); + gridLowerShape = gridOnlyAttr.getLowerBounds().asArrayRef(); + } else { + SmallVector> names = + createDimNames(lowerGridOnlyRank, "dim"); + SmallVector nameRefs = getStringRefsFor(names); + SmallVector dims; + for (unsigned i = 0; i < lowerGridOnlyRank; i++) { + dims.push_back(i); + } + gridUpperShape = lowerShapeGridOnly; + gridLowerShape = lowerShapeGridOnly; + TransformAttr blockPt = + TransformAttr::get(rewriter.getContext(), TransformType::PassThrough, + {}, nameRefs, dims, nameRefs, dims); + trAttrs.push_back(blockPt); + } + for (TransformAttr trAttr : trMapAttr.getOps()) { + SmallVector upperDims; + llvm::transform( + trAttr.getUpperDims(), std::back_inserter(upperDims), + [&](unsigned idx) { return idx + gridUpperShape.size(); }); + SmallVector lowerDims; + llvm::transform( + trAttr.getLowerDims(), std::back_inserter(lowerDims), + [&](unsigned idx) { return idx + gridLowerShape.size(); }); + TransformAttr newTrAttr = TransformAttr::get( + rewriter.getContext(), trAttr.getType(), trAttr.getParams(), + trAttr.getUpperNames(), upperDims, trAttr.getLowerNames(), lowerDims); + trAttrs.push_back(newTrAttr); + } + // set the bounds + SmallVector upperBounds = llvm::to_vector(gridUpperShape); + ArrayRef origUpperBounds = trMapAttr.getUpperBounds().asArrayRef(); + upperBounds.insert(upperBounds.end(), origUpperBounds.begin(), + origUpperBounds.end()); + SmallVector lowerBounds = llvm::to_vector(gridLowerShape); + ArrayRef origLowerBounds = trMapAttr.getLowerBounds().asArrayRef(); + lowerBounds.insert(lowerBounds.end(), origLowerBounds.begin(), + origLowerBounds.end()); + // create new trMapAttr + LLVM_DEBUG(llvm::dbgs() << "trAttrs = "; + llvm::interleaveComma(trAttrs, llvm::dbgs()); + llvm::dbgs() << "\n"; llvm::dbgs() << "upperBounds = "; + llvm::interleaveComma(upperBounds, llvm::dbgs()); + llvm::dbgs() << "\n"; llvm::dbgs() << "lowerBounds = "; + llvm::interleaveComma(lowerBounds, llvm::dbgs()); + llvm::dbgs() << "\n"); + TransformMapAttr newTrMap = + TransformMapAttr::get(trAttrs, upperBounds, lowerBounds); + transformAttrs.push_back(newTrMap); + } + ArrayRef currLowerShape = + cast(transformAttrs.back()).getLowerBounds(); + if (currLowerShape.size() < lowerGridOnlyRank * 2) { + SmallVector> names = + createDimNames(currLowerShape.size(), "d"); + SmallVector nameRefs = getStringRefsFor(names); + TopDownTMBuilder toAddMissingBlockDims(rewriter, nameRefs, currLowerShape); + { + SmallVector gridOnlyDimIdxs; + for (unsigned i = 0; i < upperGemmSpaceRank - 2; i++) { + gridOnlyDimIdxs.push_back(i); + } + toAddMissingBlockDims.passThrough(gridOnlyDimIdxs, gridOnlyDimIdxs); + int64_t missingDimCount = lowerGridOnlyRank * 2 - currLowerShape.size(); + SmallVector> names = createDimNames(missingDimCount, "cd"); + SmallVector nameRefs = getStringRefsFor(names); + unsigned dimInsertionPoint = 3; + for (int64_t md = 0; md < missingDimCount; md++) { + toAddMissingBlockDims.constDim(nameRefs.back(), dimInsertionPoint++, 0, + 1); + } + for (unsigned lowerDim = 3; lowerDim < currLowerShape.size(); + lowerDim++) { + toAddMissingBlockDims.passThrough({dimInsertionPoint++}, {lowerDim}); + } + TransformMapAttr addMissingBlockDims = toAddMissingBlockDims.get(); + LLVM_DEBUG(llvm::dbgs() + << "addMissingBlockDims = " << addMissingBlockDims << "\n"); + transformAttrs.push_back(addMissingBlockDims); + } + } + currLowerShape = + cast(transformAttrs.back()).getLowerBounds(); + if (currLowerShape.size() != lowerGridOnlyRank * 2) { + LLVM_DEBUG(llvm::dbgs() + << "Recombine: currLowerRank=" << currLowerShape.size() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "Recombine: lowerGridOnlyRank=" << lowerGridOnlyRank << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Recombine: current lower rank should be 2x " + "as the grid only rank\n"); + return failure(); + } + + // The last two transforms are constructed bottom up as it is easier. + // where we joint them once we have grid and block tiles coordinates + // seperated. + ArrayRef toBeReducedShape = getLowerShape(toBeReducedViews); + SmallVector> reduceLowerShapeNames = + createDimNames(toBeReducedShape.size(), "d"); + SmallVector reduceLowerShapeNameRefs = + getStringRefsFor(reduceLowerShapeNames); + BottomUpTMBuilder toMatrixView(rewriter, reduceLowerShapeNameRefs, + toBeReducedShape); + llvm::SmallDenseMap> gridSubDims; + llvm::SmallDenseMap> blockSubDims; + TransformMapAttr lastMerge; + { + llvm::SmallDenseMap>> names; + llvm::SmallDenseMap> nameRefs; + llvm::SmallDenseMap> upperDims; + int64_t dimInsertionPoint = 0; + for (unsigned dim = 0; dim < toBeReducedShape.size(); dim++) { + // The lower subDims contain sub-dimensions where blocking + // indices -- namely g_block, m_block and n_block -- maps to + // in the matrix coordinates. Here we split out matrix dims + // into sub-dims that are related to the said blocking dimensions. + SmallVector subDims; + if (lowerGridOnlySubDims.contains(dim)) { + subDims = lowerGridOnlySubDims.at(dim); + } + llvm::sort(subDims, [](const SubDimInfo &L, const SubDimInfo &R) { + return L.stride > R.stride; + }); + SmallVector splitSizes; + int64_t currSize = toBeReducedShape[dim]; + for (const SubDimInfo &subDim : subDims) { + if (currSize % (subDim.size * subDim.stride) != 0) { + LLVM_DEBUG(llvm::dbgs() + << "Recombine: currSize=" << currSize << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "Recombine: subDim.size=" << subDim.size << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "Recombine: subDim.stride=" << subDim.stride << "\n"); + LLVM_DEBUG( + llvm::dbgs() + << "Recombine: subDims should equally divide current dims\n"); + return failure(); + } + int64_t newSize = currSize / (subDim.size * subDim.stride); + if (newSize > 1) { + blockSubDims[dim].push_back(dimInsertionPoint); + SmallString<8> dimName( + Twine("block_dim" + Twine(dim) + "_" + Twine(dimInsertionPoint)) + .str()); + names[dim].push_back(dimName); + nameRefs[dim].push_back(names[dim].back()); + upperDims[dim].push_back(dimInsertionPoint++); + splitSizes.push_back(newSize); + } + gridSubDims[dim].push_back(dimInsertionPoint); + SmallString<8> dimName( + Twine("grid_dim" + Twine(dim) + "_" + Twine(dimInsertionPoint)) + .str()); + names[dim].push_back(dimName); + nameRefs[dim].push_back(names[dim].back()); + upperDims[dim].push_back(dimInsertionPoint++); + splitSizes.push_back(subDim.size); + currSize = subDim.stride; + } + if (currSize > 1 || splitSizes.empty()) { + blockSubDims[dim].push_back(dimInsertionPoint); + SmallString<8> dimName( + Twine("block_dim" + Twine(dim) + "_" + Twine(dimInsertionPoint)) + .str()); + names[dim].push_back(dimName); + nameRefs[dim].push_back(names[dim].back()); + upperDims[dim].push_back(dimInsertionPoint++); + splitSizes.push_back(currSize); + } + LLVM_DEBUG(llvm::dbgs() << "dim=" << dim << "\n"); + LLVM_DEBUG(llvm::dbgs() << "\tsplits="; + llvm::interleaveComma(splitSizes, llvm::dbgs()); + llvm::dbgs() << "\n"); + toMatrixView.unmerge(nameRefs[dim], upperDims[dim], + reduceLowerShapeNameRefs[dim], splitSizes); + } + lastMerge = toMatrixView.get(); + } + LLVM_DEBUG(llvm::dbgs() << "lastMerge=" << lastMerge << "\n"); + // The above view contains splitted sub-dims that are either associated + // with grid and non-grid dimensions. Then, we concat them as follows: + // [concat_grid_dim0, concat_grid_dim1, .. , concat_grid_dimX, + // concat_blk_dim0, concat_blk_dim1, .. , concat_blk_dimX] + BottomUpTMBuilder toGridBlockSeperation = + BottomUpTMBuilder::above(toMatrixView, lastMerge); + TransformMapAttr gridblockSeperation; + { + SmallVector lowerNameRefs; + toGridBlockSeperation.getStartNames(lowerNameRefs); + SmallVector upperGridNames; + for (unsigned dim = 0; dim < toBeReducedShape.size(); dim++) { + upperGridNames.push_back(Twine("grid_dim" + Twine(dim)).str()); + if (gridSubDims.contains(dim)) { + SmallVector upperGridSubDimNames; + for (int64_t upperGridSubDim : gridSubDims[dim]) { + upperGridSubDimNames.push_back(lowerNameRefs[upperGridSubDim]); + } + toGridBlockSeperation.merge(upperGridNames.back(), dim, + upperGridSubDimNames); + } else { + toGridBlockSeperation.addDim(upperGridNames.back(), dim, 1); + } + } + SmallVector upperBlockNames; + for (unsigned dim = 0; dim < toBeReducedShape.size(); dim++) { + upperBlockNames.push_back(Twine("block_dim" + Twine(dim)).str()); + if (blockSubDims.contains(dim)) { + SmallVector upperBlockSubDimNames; + for (int64_t upperBlockSubDim : blockSubDims[dim]) { + upperBlockSubDimNames.push_back(lowerNameRefs[upperBlockSubDim]); + } + toGridBlockSeperation.merge(upperBlockNames.back(), + dim + toBeReducedShape.size(), + upperBlockSubDimNames); + } else { + toGridBlockSeperation.addDim(upperBlockNames.back(), + dim + toBeReducedShape.size(), 1); + } + } + gridblockSeperation = toGridBlockSeperation.get(); + } + LLVM_DEBUG(llvm::dbgs() << "gridblockSeperation=" << gridblockSeperation + << "\n"); + // Now we join them to finish the recombination. + transformAttrs.push_back(gridblockSeperation); + transformAttrs.push_back(lastMerge); + LLVM_DEBUG(llvm::dbgs() << "transformAttrs = " + << "\n"; + llvm::interleaveComma(transformAttrs, llvm::dbgs()); + llvm::dbgs() << "\n"); + return transformAttrs; +} + +// This function will attempt to add blockwise reductions when fusing +// in reduction to the write back of the core kernel. +static LogicalResult insertBlockwiseReduction( + LinalgAlignRewriter &rewriter, Location loc, rock::ReduceOp reduceOp, + ThreadwiseWriteAllOp threadwiseWriteOp, StoreMethodAttr stMethod) { + // This has < block dimensions ... > x tid x iter to Gemm Dimensions. + ArrayAttr extraViews = threadwiseWriteOp.getExtraViews(); + ArrayAttr destTrs; + Value dest; + std::tie(dest, destTrs, std::ignore) = + untransform(rewriter, threadwiseWriteOp.getDest()); + + ArrayAttr toBeReducedViews = prependUpperViews(rewriter, extraViews, destTrs); + TransformMapAttr firstCoordTransform = + cast(toBeReducedViews[0]); + int64_t upperRank = firstCoordTransform.getUpperBounds().size(); + SetVector removeIndicesSet; + // We only want to keep tid x iter in the maps + // which is the last two for block subtile + for (int64_t i = 0; i < upperRank - 2; i++) { + removeIndicesSet.insert(i); + } + FailureOr blockSubTileViews = + removeUpperDims(rewriter, toBeReducedViews, removeIndicesSet); + if (failed(blockSubTileViews)) { + LLVM_DEBUG(llvm::dbgs() << "blockSubTileViews creation using " + "removeUpperDims is unsuccesful.\n"); + return failure(); + } + // We only want to keep tid in the maps + // which is the last two for block subtile tid + // hence, add back iter to remove indices. + removeIndicesSet.insert(upperRank - 1); + + FailureOr blockSubTileTidSliceViews = + removeUpperDims(rewriter, toBeReducedViews, removeIndicesSet); + if (failed(blockSubTileTidSliceViews)) { + LLVM_DEBUG(llvm::dbgs() << "blockSubTileTidSliceViews creation using " + "removeUpperDims is unsuccesful.\n"); + return failure(); + } + // We only want to keep iter in the maps + // which is the last one. + removeIndicesSet.remove(upperRank - 1); + removeIndicesSet.insert(upperRank - 2); + + FailureOr threadSubTileViews = + removeUpperDims(rewriter, toBeReducedViews, removeIndicesSet); + if (failed(threadSubTileViews)) { + LLVM_DEBUG(llvm::dbgs() << "threadSubTileViews creation using " + "removeUpperDims is unsuccesful.\n"); + return failure(); + } + + // Extract grid-only dims + removeIndicesSet.clear(); + for (int64_t i = upperRank - 2; i < upperRank; i++) { + removeIndicesSet.insert(i); + } + FailureOr gridOnlyDims = + removeUpperDims(rewriter, toBeReducedViews, removeIndicesSet); + if (failed(gridOnlyDims)) { + LLVM_DEBUG( + llvm::dbgs() + << "gridOnlyDims creation using removeUpperDims is unsuccesful.\n"); + return failure(); + } + + SmallVector gridOnlyDimIdxs; + for (int64_t i = 0; i < upperRank - 2; i++) { + gridOnlyDimIdxs.push_back(i); + } + FailureOr>> + lowerSubDims = + getLowerSubDimensions(rewriter, toBeReducedViews, gridOnlyDimIdxs); + if (failed(lowerSubDims)) { + LLVM_DEBUG(llvm::dbgs() << "lowerSubDims creation using " + "getLowerSubDimensions is unsuccesful.\n"); + return failure(); + } + + int64_t reductionAxis = reduceOp.getAxisAttr().getInt(); + TypedValue redOut = reduceOp.getOut(); + ArrayRef reduceOutShape = redOut.getType().getShape(); + TypedValue redIn = reduceOp.getIn(); + ArrayRef reduceInShape = redIn.getType().getShape(); + + int64_t blockReductionAxis = reductionAxis; + int64_t blockReductionAxisFromLeft = + (reduceInShape.size() - 1) - blockReductionAxis; + + ArrayRef blockLowerShape = getLowerShape(blockSubTileViews.value()); + ArrayRef blockSubTileTidSliceShape = + getLowerShape(blockSubTileTidSliceViews.value()); + int64_t blockSubTileTidSliceRank = blockSubTileTidSliceShape.size(); + // The block sub-tile view might not have the slower changing + // dimensions in it. Thus, we always keep track of the reduction + // dimensions from its distance to fastest changing dimensions. + blockReductionAxis = + blockSubTileTidSliceRank - 1 - blockReductionAxisFromLeft; + int64_t partialReductionsPerThread = + blockSubTileTidSliceShape[blockReductionAxis]; + int64_t ldsWorkspaceSize = 1; + for (auto [idx, size] : llvm::enumerate(blockLowerShape)) { + if (idx == (size_t)blockReductionAxis) { + ldsWorkspaceSize *= partialReductionsPerThread; + } else { + ldsWorkspaceSize *= size; + } + } + auto maybeArch = getArch(reduceOp); + if (succeeded(maybeArch)) { + if (failed(checkLDSSize(maybeArch.value(), ldsWorkspaceSize))) { + LLVM_DEBUG(llvm::dbgs() + << "lds size for blockwise reduction does not fit.\n"); + return failure(); + } + } + TypedValue src = threadwiseWriteOp.getSource(); + auto broadcastReducedSrc = rewriter.create(loc, src.getType()); + Value ldsWorkspace = rock::gpuAlloc(rewriter, loc, ldsWorkspaceSize, + src.getType().getElementType(), + gpu::AddressSpace::Workgroup); + + rewriter.create( + loc, src, ldsWorkspace, broadcastReducedSrc, + /*extraOut=*/nullptr, rewriter.getIndexAttr(blockReductionAxis), + reduceOp.getReduceMethodAttr(), blockSubTileViews.value(), + blockSubTileTidSliceViews.value(), threadSubTileViews.value(), + /*extraViews=*/nullptr, + getBlockSize(reduceOp->getParentOfType()).value()); + + ViewLikeOpInterface viewOp = + ldsWorkspace.getDefiningOp(); + rewriter.create(loc, viewOp.getViewSource()); + // Create partial reduction views + ArrayAttr paddedReducedTrStack; + { + SmallVector transformAttrs; + ArrayRef blockTileShape = getLowerShape(blockSubTileViews.value()); + SmallVector> names = + createDimNames(blockTileShape.size(), "dim"); + SmallVector nameRefs = getStringRefsFor(names); + TopDownTMBuilder toReducedView(rewriter, nameRefs, blockTileShape); + for (unsigned i = 0; i < blockTileShape.size(); i++) { + if (blockReductionAxis == i) { + // The blockwise_broadcast_reduce will populate + // all indices of pre-reduction space with the + // reduced value. However, for the write back + // we only want one of the reduced values to be + // written. Therefore, we keep the 0th and declare + // rest as padding. + toReducedView.pad({nameRefs[i]}, {0, blockTileShape[i] - 1}); + } else { + toReducedView.passThrough({nameRefs[i]}, {i}, {nameRefs[i]}); + } + } + transformAttrs.push_back(toReducedView.get()); + ArrayAttr arrayTransformAttrs = rewriter.getArrayAttr(transformAttrs); + paddedReducedTrStack = prependUpperViews( + rewriter, blockSubTileViews.value(), arrayTransformAttrs); + } + + // Recombine block dimensions + FailureOr> transformAttrs = + getRecombinedGridOnlyBlockOnlyTiles( + rewriter, gridOnlyDims.value(), paddedReducedTrStack, upperRank, + toBeReducedViews, lowerSubDims.value()); + if (failed(transformAttrs)) { + LLVM_DEBUG(llvm::dbgs() << "Recombination failed.\n"); + return failure(); + } + reduceInShape = + cast(transformAttrs.value().back()).getLowerBounds(); + BottomUpTMBuilder dropReductionDim(rewriter, reduceOutShape, loc); + for (uint32_t i = 0; i < reduceOutShape.size(); ++i) { + if (i == reductionAxis) { + dropReductionDim.broadcast({i}, {reduceInShape[i]}); + } else { + dropReductionDim.passThrough({i}, {i}); + } + } + transformAttrs.value().push_back(dropReductionDim.get()); + threadwiseWriteOp.setExtraViewsAttr(rewriter.getArrayAttr({})); + threadwiseWriteOp.getSourceMutable().assign(broadcastReducedSrc); + TypedValue reduceOut = reduceOp.getOut(); + reduceOut = cast>( + applyViewsOnDest(rewriter, loc, reduceOut, transformAttrs.value())); + threadwiseWriteOp.getDestMutable().assign(reduceOut); + // TODO : in future if all reductions are done within the block + // we can revert this back to a non-atomic store. + threadwiseWriteOp.setStoreMethodAttr(stMethod); + return success(); +} + LogicalResult ReduceRewritePattern::matchAndRewrite(rock::ReduceOp reduceOp, LinalgAlignRewriter &rewriter) const { @@ -1145,27 +1609,37 @@ ReduceRewritePattern::matchAndRewrite(rock::ReduceOp reduceOp, } rewriter.moveAfterIfNeeded(threadwiseWriteOp, reduceOp); - int64_t reductionAxis = reduceOp.getAxisAttr().getInt(); - TypedValue redOut = reduceOp.getOut(); - ArrayRef reduceOutShape = redOut.getType().getShape(); - TypedValue redIn = reduceOp.getIn(); - ArrayRef reduceInShape = redIn.getType().getShape(); - BottomUpTMBuilder dropReductionDim(rewriter, reduceOutShape, loc); - for (uint32_t i = 0; i < reduceOutShape.size(); ++i) { - if (i == reductionAxis) { - dropReductionDim.broadcast({i}, {reduceInShape[i]}); - } else { - dropReductionDim.passThrough({i}, {i}); + LogicalResult canUseBlockwiseReductions = insertBlockwiseReduction( + rewriter, loc, reduceOp, threadwiseWriteOp, stMethod); + // fallback to doing pure atomics based reductions + if (failed(canUseBlockwiseReductions)) { + LLVM_DEBUG(llvm::dbgs() << "Unable to add blockwise reductions for this " + "reduction fusion case.\n"); + int64_t reductionAxis = reduceOp.getAxisAttr().getInt(); + TypedValue redOut = reduceOp.getOut(); + ArrayRef reduceOutShape = redOut.getType().getShape(); + TypedValue redIn = reduceOp.getIn(); + ArrayRef reduceInShape = redIn.getType().getShape(); + BottomUpTMBuilder dropReductionDim(rewriter, reduceOutShape, loc); + for (uint32_t i = 0; i < reduceOutShape.size(); ++i) { + if (i == reductionAxis) { + dropReductionDim.broadcast({i}, {reduceInShape[i]}); + } else { + dropReductionDim.passThrough({i}, {i}); + } } + TransformMapAttr trAttr = dropReductionDim.get(); + views.push_back(trAttr); + LLVM_DEBUG(llvm::dbgs() << "views = " + << "\n"; + llvm::interleaveComma(views, llvm::dbgs()); + llvm::dbgs() << "\n"); + TypedValue reduceOut = reduceOp.getOut(); + reduceOut = cast>( + applyViewsOnDest(rewriter, loc, reduceOut, views)); + threadwiseWriteOp.getDestMutable().assign(reduceOut); + threadwiseWriteOp.setStoreMethodAttr(stMethod); } - TransformMapAttr trAttr = dropReductionDim.get(); - views.push_back(trAttr); - TypedValue reduceOut = reduceOp.getOut(); - reduceOut = cast>( - applyViewsOnDest(rewriter, loc, reduceOut, views)); - threadwiseWriteOp.getDestMutable().assign(reduceOut); - threadwiseWriteOp.setStoreMethodAttr(stMethod); - rewriter.eraseOp(reduceOp); return success(); } diff --git a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt index 636b294bea19..0e1042783c31 100644 --- a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt @@ -26,6 +26,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms VectorizeFusions.cpp OutputSwizzle.cpp ReuseLDS.cpp + ShuffleGemmForReductions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Rock diff --git a/mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp b/mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp index ed80202e0ea9..c67aa04a0cce 100644 --- a/mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/OutputSwizzle.cpp @@ -108,6 +108,8 @@ static LogicalResult checkLDSSize(Operation *op, int64_t ldsBytes) { static std::optional> getIdToLDS(ThreadwiseWriteAllOp &op, OpBuilder &b) { ArrayAttr srcTransform = op.getExtraViewsAttr(); + if (srcTransform.empty()) + return std::nullopt; StringSet<> dimensionsToRemove{"g_block", "m_block", "n_block"}; FailureOr maybeIdToLDS = removeUpperDims(b, srcTransform, dimensionsToRemove); diff --git a/mlir/lib/Dialect/Rock/Transforms/ReuseLDS.cpp b/mlir/lib/Dialect/Rock/Transforms/ReuseLDS.cpp index 2e700158086c..d20fc18035ac 100644 --- a/mlir/lib/Dialect/Rock/Transforms/ReuseLDS.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/ReuseLDS.cpp @@ -59,9 +59,7 @@ static LogicalResult checkLDSSize(Operation *op, int64_t ldsBytes) { // Check for arch limitations exceeded FailureOr maybeArch = getArch(op); if (succeeded(maybeArch)) { - StringAttr arch = maybeArch.value(); - const int64_t ldsSize = rock::lookupArchInfo(arch).maxSharedMemPerWG; - return success(ldsBytes <= ldsSize); + return checkLDSSize(maybeArch.value(), ldsBytes); } return success(); } diff --git a/mlir/lib/Dialect/Rock/Transforms/ShuffleGemmForReductions.cpp b/mlir/lib/Dialect/Rock/Transforms/ShuffleGemmForReductions.cpp new file mode 100644 index 000000000000..e121f4c867a8 --- /dev/null +++ b/mlir/lib/Dialect/Rock/Transforms/ShuffleGemmForReductions.cpp @@ -0,0 +1,629 @@ +//===- ShuffleGemmForReductions - MLIR Rock ops lowering passes -----===// +// +// Copyright 2024 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================ +// +// This pass will rearrange M & N dimensions of a gemm that +// is being fused with a reduction at the end -- possibly with +// reshapes in between. This pass will re-order M & N dimensions +// such that sub-dimensions of M/N being reduced will be split +// equally across blocks. +// +//===-----------------------------------------------------===// +#include "mlir/Dialect/Rock/utility/AmdArchDb.h" +#include "mlir/Dialect/Rock/utility/loweringUtils.h" +#include "mlir/Dialect/Rock/utility/math.h" +#include "mlir/Dialect/Rock/utility/transformMapUtils.h" + +#include "mlir/Analysis/BufferDependencyAnalysis.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Rock/Passes.h" +#include "mlir/Dialect/Rock/utility/builderUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace rock { +#define GEN_PASS_DEF_ROCKSHUFFLEGEMMFORREDUCTIONS +#include "mlir/Dialect/Rock/Passes.h.inc" +} // namespace rock +} // namespace mlir + +#define DEBUG_TYPE "rock-shuffle-gemm-for-reductions" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::rock; + +namespace { +struct RockShuffleGemmForReductionsPass + : public rock::impl::RockShuffleGemmForReductionsBase< + RockShuffleGemmForReductionsPass> { + void runOnOperation() override; +}; +} // end anonymous namespace + +ArrayAttr reverse(ArrayAttr attrs) { + SmallVector attrsReversed = llvm::to_vector(llvm::reverse(attrs)); + IRRewriter rewriter(attrs.getContext()); + return rewriter.getArrayAttr(attrsReversed); +} + +ArrayAttr getAllViewsFromSource(OpOperand *operand) { + Value val = operand->get(); + IRRewriter rewriter(val.getContext()); + ArrayAttr trs; + Value untransformed; + std::tie(untransformed, trs, std::ignore) = untransform(rewriter, val); + return reverse(trs); +} + +FailureOr> +obtainViewsFromReaderToWriter(memref::AllocOp buffer, + const BufferDependencyAnalysis &deps, + ArrayAttr currViews) { + LLVM_DEBUG(llvm::dbgs() << "buffer = " << buffer << "\n"); + IRRewriter rewriter(buffer.getContext()); + std::optional> writersOperands = + deps.getWriters(buffer); + if (!writersOperands.has_value()) + return failure(); + for (OpOperand *writerOperand : writersOperands.value()) { + ArrayAttr viewsFromAllocOp = getAllViewsFromSource(writerOperand); + currViews = prependUpperViews(rewriter, currViews, viewsFromAllocOp); + if (isa(writerOperand->getOwner())) { + return std::make_tuple(reverse(currViews), writerOperand->getOwner()); + } + LLVM_DEBUG(llvm::dbgs() + << "write op = " << *writerOperand->getOwner() << "\n"); + auto writeOp = dyn_cast(writerOperand->getOwner()); + if (!writeOp) { + LLVM_DEBUG(llvm::dbgs() << "\tit is not a memory effect interface op\n"); + continue; + } + SmallVector effects; + writeOp.getEffects(effects); + for (const MemoryEffects::EffectInstance &effect : effects) { + OpOperand *readOperand = effect.getEffectValue(); + LLVM_DEBUG(llvm::dbgs() + << "readOperand = " << readOperand->get() << "\n"); + // Test against the write operand to guard against [MemRead, MemWrite] + if (readOperand && readOperand != writerOperand && + isa(effect.getEffect())) { + if (memref::AllocOp readBuffer = + dyn_cast(readOperand->get().getDefiningOp())) { + FailureOr> mayBeViewsAndGemmOp = + obtainViewsFromReaderToWriter(readBuffer, deps, currViews); + if (succeeded(mayBeViewsAndGemmOp)) { + return mayBeViewsAndGemmOp; + } + } + } + } + } + LLVM_DEBUG(llvm::dbgs() << "No writer goes to a gemm op.\n"); + return failure(); +} + +FailureOr> +obtainGemmToReduceViews(ReduceOp rOp, const BufferDependencyAnalysis &deps) { + IRRewriter rewriter(rOp.getContext()); + memref::AllocOp rSrc = rOp.getIn().getDefiningOp(); + if (!rSrc) + return failure(); + ArrayAttr views = rewriter.getArrayAttr({}); + return obtainViewsFromReaderToWriter(rSrc, deps, views); +} + +struct MNPerBlock { + int64_t MPerBlock; + int64_t NPerBlock; +}; + +static FailureOr getMNPerBlock(Operation *gemmOp) { + MNPerBlock ret; + if (auto xdlGemmOp = dyn_cast(gemmOp)) { + ret.MPerBlock = xdlGemmOp.getParams().getMPerBlock(); + ret.NPerBlock = xdlGemmOp.getParams().getNPerBlock(); + } else if (auto nonxdlGemmOp = dyn_cast(gemmOp)) { + ret.MPerBlock = nonxdlGemmOp.getParams().getMPerBlock(); + ret.NPerBlock = nonxdlGemmOp.getParams().getNPerBlock(); + } else { + return failure(); + } + return ret; +} + +ArrayAttr reorderReductionDims(BottomUpTMBuilder &toReductionSplit, + ArrayRef reductionSubDims, + StringRef dName, int64_t dLen, + int64_t dPerBlock) { + SmallVector reductionSubDimsVec = + llvm::to_vector(reductionSubDims); + llvm::sort(reductionSubDimsVec, [](const SubDimInfo &L, const SubDimInfo &R) { + return L.stride > R.stride; + }); + llvm::SmallDenseMap reductionDims; + { + toReductionSplit.passThrough(ArrayRef{0, 1}, + ArrayRef{0, 1}); + SmallVector splitSizes; + SmallVector> splitNames; + SmallVector splitNamesRefs; + SmallVector splitDims; + int64_t dimInsertionPoint = 2; + int64_t currSize = dLen; + LLVM_DEBUG(llvm::dbgs() << "dLen = " << dLen << "\n"); + for (auto [idx, sdInfo] : enumerate(reductionSubDimsVec)) { + { + SmallString<8> dimName(Twine("d_nr" + Twine(idx)).str()); + splitNames.push_back(dimName); + } + splitNamesRefs.push_back(splitNames.back()); + splitDims.push_back(dimInsertionPoint++); + LLVM_DEBUG(llvm::dbgs() + << "\tsplitSize = " << currSize / (sdInfo.size * sdInfo.stride) + << "\n"); + splitSizes.push_back(currSize / (sdInfo.size * sdInfo.stride)); + { + SmallString<8> dimName(Twine("d_r" + Twine(idx)).str()); + splitNames.push_back(dimName); + } + splitNamesRefs.push_back(splitNames.back()); + reductionDims[dimInsertionPoint] = sdInfo.size; + splitDims.push_back(dimInsertionPoint++); + LLVM_DEBUG(llvm::dbgs() << "\tsplitSize = " << sdInfo.size << "\n"); + splitSizes.push_back(sdInfo.size); + currSize = sdInfo.stride; + } + if (currSize > 1) { + { + SmallString<8> dimName(Twine("d_nr_end").str()); + splitNames.push_back(dimName); + } + splitNamesRefs.push_back(splitNames.back()); + splitDims.push_back(dimInsertionPoint++); + LLVM_DEBUG(llvm::dbgs() << "\tsplitSize = " << currSize << "\n"); + splitSizes.push_back(currSize); + } + toReductionSplit.unmerge(splitNamesRefs, splitDims, dName, splitSizes); + } + TransformMapAttr reduceSplit = toReductionSplit.get(); + LLVM_DEBUG(llvm::dbgs() << "reduceSplit = " << reduceSplit << "\n"); + auto toCommonReductionDim = + BottomUpTMBuilder::above(toReductionSplit, reduceSplit); + { + toCommonReductionDim.passThrough(ArrayRef{0, 1}, + ArrayRef{0, 1}); + SmallVector startNames; + toCommonReductionDim.getStartNames(startNames); + SmallVector reduceDimNames; + SmallVector nonReduceDimNames; + unsigned dimInsertionPoint = 2; + for (unsigned dim = 2; dim < reduceSplit.getUpperBounds().size(); dim++) { + if (!reductionDims.contains(dim)) { + nonReduceDimNames.push_back(startNames[dim]); + } else { + reduceDimNames.push_back(startNames[dim]); + } + } + toCommonReductionDim.merge("d_nr", dimInsertionPoint++, nonReduceDimNames); + toCommonReductionDim.merge("d_r", dimInsertionPoint++, reduceDimNames); + } + TransformMapAttr commonReduction = toCommonReductionDim.get(); + LLVM_DEBUG(llvm::dbgs() << "commonReduction = " << commonReduction << "\n"); + auto toResplitReduction = + BottomUpTMBuilder::above(toCommonReductionDim, commonReduction); + { + unsigned upperDimCount = commonReduction.getUpperBounds().size(); + int64_t commonReduceSize = + commonReduction.getUpperBounds().asArrayRef().back(); + int64_t commonFactor = math_util::gcd(commonReduceSize, dPerBlock); + toResplitReduction.passThrough(ArrayRef{0, 1, 3}, + ArrayRef{0, 1, 2}); + toResplitReduction.unmerge({"d_rh", "d_rl"}, {2, upperDimCount}, "d_r", + {commonReduceSize / commonFactor, commonFactor}); + } + TransformMapAttr resplitReduction = toResplitReduction.get(); + LLVM_DEBUG(llvm::dbgs() << "resplitReductionAttr = " << resplitReduction + << "\n"); + auto toRecombined = + BottomUpTMBuilder::above(toResplitReduction, resplitReduction); + { + toRecombined.passThrough(ArrayRef{0, 1}, + ArrayRef{0, 1}); + toRecombined.merge("d", 2, {"d_rh", "d_nr", "d_rl"}); + } + TransformMapAttr recombined = toRecombined.get(); + LLVM_DEBUG(llvm::dbgs() << "recombined = " << recombined << "\n"); + OpBuilder builder(recombined.getContext()); + return builder.getArrayAttr( + {reduceSplit, commonReduction, resplitReduction, recombined}); +} + +// This function will shuffle the M & N dimensions so that the +// reductions are uniformly split across block tiles. Note that +// we dont consider G as we dont block tile across G dimension. +std::tuple generateShuffledGemmInputViews( + OpBuilder &builder, int64_t g, int64_t m, int64_t mPerBlock, int64_t k, + int64_t n, int64_t nPerBlock, + const llvm::SmallDenseMap> + &reductionSubDims) { + BottomUpTMBuilder toReductionSplitA(builder, {"G", "K", "M"}, {g, k, m}); + ArrayAttr additionalViewsA = builder.getArrayAttr({}); + if (reductionSubDims.contains(1) && !reductionSubDims.at(1).empty()) { + additionalViewsA = reorderReductionDims( + toReductionSplitA, reductionSubDims.at(1), "M", m, mPerBlock); + } + BottomUpTMBuilder toReductionSplitB(builder, {"G", "K", "N"}, {g, k, n}); + ArrayAttr additionalViewsB = builder.getArrayAttr({}); + if (reductionSubDims.contains(2) && !reductionSubDims.at(2).empty()) { + additionalViewsB = reorderReductionDims( + toReductionSplitB, reductionSubDims.at(2), "N", n, nPerBlock); + } + return {additionalViewsA, additionalViewsB}; +} + +ArrayAttr generateShuffledGemmOutputViews( + OpBuilder &builder, int64_t g, int64_t m, int64_t mPerBlock, int64_t n, + int64_t nPerBlock, + const llvm::SmallDenseMap> + &reductionSubDims) { + // Split the reduction and non-reduction splits + int64_t totalReductionSizeM = 1; + if (reductionSubDims.contains(1)) { + for (const SubDimInfo &sdInfo : reductionSubDims.at(1)) { + totalReductionSizeM *= sdInfo.size; + } + } + int64_t totalReductionSizeN = 1; + if (reductionSubDims.contains(2)) { + for (const SubDimInfo &sdInfo : reductionSubDims.at(2)) { + totalReductionSizeN *= sdInfo.size; + } + } + int64_t commonMPerBlockReductionFactor = + math_util::gcd(totalReductionSizeM, mPerBlock); + int64_t commonNPerBlockReductionFactor = + math_util::gcd(totalReductionSizeN, nPerBlock); + + // Split the reduction and non-reduction splits + BottomUpTMBuilder toReductionSplit(builder, {"G", "M", "N"}, {g, m, n}); + { + toReductionSplit.passThrough("G"); + toReductionSplit.unmerge( + {"m_rh", "m_nr", "m_rl"}, {1, 2, 3}, "M", + {totalReductionSizeM / commonMPerBlockReductionFactor, + m / totalReductionSizeM, commonMPerBlockReductionFactor}); + toReductionSplit.unmerge( + {"n_rh", "n_nr", "n_rl"}, {4, 5, 6}, "N", + {totalReductionSizeN / commonNPerBlockReductionFactor, + n / totalReductionSizeN, commonNPerBlockReductionFactor}); + } + TransformMapAttr reductionSplit = toReductionSplit.get(); + LLVM_DEBUG(llvm::dbgs() << "reductionSplit = " << reductionSplit << "\n"); + + // combine reduction dimension + auto toCombinedReductionDim = + BottomUpTMBuilder::above(toReductionSplit, reductionSplit); + { + toCombinedReductionDim.passThrough("G"); + toCombinedReductionDim.passThrough({1}, {2}); + toCombinedReductionDim.merge("m_r", 2, {"m_rh", "m_rl"}); + toCombinedReductionDim.passThrough({3}, {5}); + toCombinedReductionDim.merge("n_r", 4, {"n_rh", "n_rl"}); + } + TransformMapAttr combinedReduction = toCombinedReductionDim.get(); + LLVM_DEBUG(llvm::dbgs() << "combinedReduction = " << combinedReduction + << "\n"); + + // Split to original sub dimensions + auto toSplitOriginalSubDims = + BottomUpTMBuilder::above(toCombinedReductionDim, combinedReduction); + int64_t nSubDimStartPoint = -1; + { + toSplitOriginalSubDims.passThrough("G"); + SmallVector mReductionSubDimInfo; + if (reductionSubDims.contains(1)) { + mReductionSubDimInfo = reductionSubDims.at(1); + } + SmallVector nReductionSubDimInfo; + if (reductionSubDims.contains(2)) { + nReductionSubDimInfo = reductionSubDims.at(2); + } + + unsigned dimInsertionPoint = 1; + { + SmallVector mReductionSubDims; + SmallVector mReductionSubDimSizes; + SmallVector> mReductionSubDimNames; + SmallVector mReductionSubDimNameRefs; + + SmallVector mNonReductionSubDims; + SmallVector mNonReductionSubDimSizes; + SmallVector> mNonReductionSubDimNames; + SmallVector mNonReductionSubDimNameRefs; + int64_t currSize = m; + for (const auto &[idx, sdInfo] : enumerate(mReductionSubDimInfo)) { + mNonReductionSubDimSizes.push_back(currSize / + (sdInfo.size * sdInfo.stride)); + { + SmallString<8> dimName(Twine("m_nr" + Twine(idx)).str()); + mNonReductionSubDimNames.push_back(dimName); + } + mNonReductionSubDimNameRefs.push_back(mNonReductionSubDimNames.back()); + mNonReductionSubDims.push_back(dimInsertionPoint++); + + mReductionSubDimSizes.push_back(sdInfo.size); + { + SmallString<8> dimName(Twine("m_r" + Twine(idx)).str()); + mReductionSubDimNames.push_back(dimName); + } + mReductionSubDimNameRefs.push_back(mReductionSubDimNames.back()); + mReductionSubDims.push_back(dimInsertionPoint++); + + currSize = sdInfo.stride; + } + if (currSize > 1 || mNonReductionSubDimSizes.empty()) { + mNonReductionSubDimSizes.push_back(currSize); + { + SmallString<8> dimName("m_nr_last"); + mNonReductionSubDimNames.push_back(dimName); + } + mNonReductionSubDimNameRefs.push_back(mNonReductionSubDimNames.back()); + mNonReductionSubDims.push_back(dimInsertionPoint++); + } + toSplitOriginalSubDims.unmerge(mNonReductionSubDimNameRefs, + mNonReductionSubDims, "m_nr", + mNonReductionSubDimSizes); + if (!mReductionSubDimSizes.empty()) { + toSplitOriginalSubDims.unmerge(mReductionSubDimNameRefs, + mReductionSubDims, "m_r", + mReductionSubDimSizes); + } else { + toSplitOriginalSubDims.passThrough({"m_r"}, {dimInsertionPoint++}, + {"m_r"}); + } + } + nSubDimStartPoint = dimInsertionPoint; + + { + SmallVector nReductionSubDims; + SmallVector nReductionSubDimSizes; + SmallVector> nReductionSubDimNames; + SmallVector nReductionSubDimNameRefs; + + SmallVector nNonReductionSubDims; + SmallVector nNonReductionSubDimSizes; + SmallVector> nNonReductionSubDimNames; + SmallVector nNonReductionSubDimNameRefs; + int64_t currSize = n; + for (const auto &[idx, sdInfo] : enumerate(nReductionSubDimInfo)) { + nNonReductionSubDimSizes.push_back(currSize / + (sdInfo.size * sdInfo.stride)); + { + SmallString<8> dimName(Twine("n_nr" + Twine(idx)).str()); + nNonReductionSubDimNames.push_back(dimName); + } + nNonReductionSubDimNameRefs.push_back(nNonReductionSubDimNames.back()); + nNonReductionSubDims.push_back(dimInsertionPoint++); + + nReductionSubDimSizes.push_back(sdInfo.size); + { + SmallString<8> dimName(Twine("n_r" + Twine(idx)).str()); + nReductionSubDimNames.push_back(dimName); + } + nReductionSubDimNameRefs.push_back(nReductionSubDimNames.back()); + nReductionSubDims.push_back(dimInsertionPoint++); + + currSize = sdInfo.stride; + } + if (currSize > 1 || nNonReductionSubDimSizes.empty()) { + nNonReductionSubDimSizes.push_back(currSize); + { + SmallString<8> dimName("n_nr_last"); + nNonReductionSubDimNames.push_back(dimName); + } + nNonReductionSubDimNameRefs.push_back(nNonReductionSubDimNames.back()); + nNonReductionSubDims.push_back(dimInsertionPoint++); + } + toSplitOriginalSubDims.unmerge(nNonReductionSubDimNameRefs, + nNonReductionSubDims, "n_nr", + nNonReductionSubDimSizes); + if (!nReductionSubDimSizes.empty()) { + toSplitOriginalSubDims.unmerge(nReductionSubDimNameRefs, + nReductionSubDims, "n_r", + nReductionSubDimSizes); + } else { + toSplitOriginalSubDims.passThrough({"n_r"}, {dimInsertionPoint++}, + {"n_r"}); + } + } + } + TransformMapAttr splitOriginalSubDims = toSplitOriginalSubDims.get(); + LLVM_DEBUG(llvm::dbgs() << "splitOriginalSubDims = " << splitOriginalSubDims + << "\n"); + + // Recombine into original M & N + auto toRecombineMN = + BottomUpTMBuilder::above(toSplitOriginalSubDims, splitOriginalSubDims); + { + toRecombineMN.passThrough("G"); + SmallVector startNames; + toRecombineMN.getStartNames(startNames); + + // M + { + SmallVector mSubDimNames; + for (int dim = 1; dim < nSubDimStartPoint; dim++) { + mSubDimNames.push_back(startNames[dim]); + } + toRecombineMN.merge("M", 1, mSubDimNames); + } + + // N + { + SmallVector nSubDimNames; + for (unsigned dim = nSubDimStartPoint; dim < startNames.size(); dim++) { + nSubDimNames.push_back(startNames[dim]); + } + toRecombineMN.merge("N", 2, nSubDimNames); + } + } + TransformMapAttr recombineMN = toRecombineMN.get(); + LLVM_DEBUG(llvm::dbgs() << "recombineMN = " << recombineMN << "\n"); + + return builder.getArrayAttr( + {reductionSplit, combinedReduction, splitOriginalSubDims, recombineMN}); +} + +// This function will attempt to shuffle M & N dimensions of the gemm so that +// reductions sub-dimensions within it could be split to blocks equally. +// However, for that to work the transform stack needs to be invertible and +// sub-dimension should be discoverable using "getLowerSubDimensions". If one of +// those fail, we bail and not attempt to use blockwise_reductions in such +// fusions. +static LogicalResult +rearrangeGemmParallelDimsForReduction(ReduceOp rOp, + const BufferDependencyAnalysis &deps) { + FailureOr> maybeViewsAndGemmOp = + obtainGemmToReduceViews(rOp, deps); + if (succeeded(maybeViewsAndGemmOp)) { + auto [views, gemmOp] = maybeViewsAndGemmOp.value(); + LLVM_DEBUG(llvm::dbgs() << "gemmToReduceViews=" << views << "\n"); + FailureOr mnPerBlock = getMNPerBlock(gemmOp); + if (failed(mnPerBlock)) { + LLVM_DEBUG(llvm::dbgs() + << "m/n per block extraction failed from gemm op.\n"); + return failure(); + } + IRRewriter rewriter(rOp.getContext()); + ArrayAttr invertedViews = invertTransforms(rewriter, rOp.getLoc(), views); + LLVM_DEBUG(llvm::dbgs() + << "inv(gemmToReduceViews)=" << invertedViews << "\n"); + if (!invertedViews || invertedViews.empty()) { + LLVM_DEBUG(llvm::dbgs() << "gemm to reduce view inversion failed.\n"); + return failure(); + } + FailureOr>> + reductionSubDimsinGemmSpace = getLowerSubDimensions( + rewriter, invertedViews, rOp.getAxis().getZExtValue()); + if (failed(reductionSubDimsinGemmSpace) || + reductionSubDimsinGemmSpace.value().empty()) { + LLVM_DEBUG(llvm::dbgs() + << "reduce to gemm lower subdimension tracing failed.\n"); + return failure(); + } + for (auto [dim, subDimInfos] : reductionSubDimsinGemmSpace.value()) { + LLVM_DEBUG(llvm::dbgs() << "dim=" << dim << ":"); + LLVM_DEBUG(llvm::interleaveComma(subDimInfos, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + } + + TypedValue gemmInA; + TypedValue gemmInB; + TypedValue gemmOut; + if (GridwiseGemmAccelOp gemmAccelOp = + dyn_cast(gemmOp)) { + gemmInA = gemmAccelOp.getA(); + gemmInB = gemmAccelOp.getB(); + gemmOut = gemmAccelOp.getC(); + } else if (GridwiseGemmOp gemmNonAccelOp = + dyn_cast(gemmOp)) { + gemmInA = gemmNonAccelOp.getA(); + gemmInB = gemmNonAccelOp.getB(); + gemmOut = gemmNonAccelOp.getC(); + } else { + LLVM_DEBUG(llvm::dbgs() << "unsupported op:" << *gemmOp << "\n"); + return failure(); + } + int64_t g = gemmInA.getType().getShape()[0]; + int64_t k = gemmInA.getType().getShape()[1]; + int64_t m = gemmInA.getType().getShape()[2]; + int64_t n = gemmInB.getType().getShape()[2]; + auto [additionalViewsA, additionalViewsB] = generateShuffledGemmInputViews( + rewriter, g, m, mnPerBlock.value().MPerBlock, k, n, + mnPerBlock.value().NPerBlock, reductionSubDimsinGemmSpace.value()); + Value trGemmInA = gemmInA; + rewriter.setInsertionPointAfterValue(gemmInA); + for (Attribute trMap : additionalViewsA) { + trGemmInA = rewriter.create(rOp.getLoc(), trGemmInA, + cast(trMap)); + } + Value trGemmInB = gemmInB; + rewriter.setInsertionPointAfterValue(gemmInB); + for (Attribute trMap : additionalViewsB) { + trGemmInB = rewriter.create(rOp.getLoc(), trGemmInB, + cast(trMap)); + } + ArrayAttr additionalOutputViews = generateShuffledGemmOutputViews( + rewriter, g, m, mnPerBlock.value().MPerBlock, n, + mnPerBlock.value().NPerBlock, reductionSubDimsinGemmSpace.value()); + rewriter.setInsertionPointAfterValue(gemmOut); + Value trGemmOut = gemmOut; + ArrayAttr invertedOutViews = + invertTransforms(rewriter, rOp.getLoc(), additionalOutputViews); + for (Attribute trMap : invertedOutViews) { + trGemmOut = rewriter.create(rOp.getLoc(), trGemmOut, + cast(trMap)); + } + if (GridwiseGemmAccelOp gemmAccelOp = + dyn_cast(gemmOp)) { + gemmAccelOp.getAMutable().assign(trGemmInA); + gemmAccelOp.getBMutable().assign(trGemmInB); + gemmAccelOp.getCMutable().assign(trGemmOut); + } else if (GridwiseGemmOp gemmNonAccelOp = + dyn_cast(gemmOp)) { + gemmNonAccelOp.getAMutable().assign(trGemmInA); + gemmNonAccelOp.getBMutable().assign(trGemmInB); + gemmNonAccelOp.getCMutable().assign(trGemmOut); + } else { + LLVM_DEBUG(llvm::dbgs() << "unsupported op:" << *gemmOp << "\n"); + return failure(); + } + } else { + LLVM_DEBUG(llvm::dbgs() << "failed to obtain gemm to reduce views.\n"); + return failure(); + } + return success(); +} + +void RockShuffleGemmForReductionsPass::runOnOperation() { + func::FuncOp func = getOperation(); + // Only run this pass on GPU kernel functions. + if (!func->hasAttr("kernel")) { + return; + } + ReduceOp largestReductionOp; + int64_t currReductionDimSize = 0; + func.walk([&](ReduceOp rOp) -> WalkResult { + TypedValue rIn = rOp.getIn(); + int64_t reduceDimSize = + rIn.getType().getShape()[rOp.getAxis().getZExtValue()]; + if (reduceDimSize > currReductionDimSize) { + largestReductionOp = rOp; + } + return WalkResult::advance(); + }); + if (largestReductionOp) { + auto &bufferDeps = getAnalysis(); + LogicalResult res = + rearrangeGemmParallelDimsForReduction(largestReductionOp, bufferDeps); + if (failed(res)) { + LLVM_DEBUG( + llvm::dbgs() + << "unable to shuffle the gemm dims for blockwise reductions.\n"); + } + } +} diff --git a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp index bd4442049442..eec87ee51444 100644 --- a/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp @@ -81,6 +81,10 @@ PopulateParamsInfo PopulateParamsInfo::fromOp(RockGemmWrapperInterface op) { info.numCu = convOp.getNumCU(); info.batchSize = convDims.n; } + func::FuncOp func = op->getParentOfType(); + WalkResult wRes = func.walk( + [&](ReduceOp rOp) -> WalkResult { return WalkResult::interrupt(); }); + info.hasFusedReduction = wRes.wasInterrupted(); return info; } @@ -211,12 +215,44 @@ PopulateParams::paramsProbablyValid(OpBuilder &b, return populateDerived(params); } +static LogicalResult couldFusedReductionBePerformant(const GemmSize &gemmSize, + int64_t mPerBlock, + int64_t nPerBlock) { + // 16 is practically lowest m in MFMAs/WMMAs + // that could be performant. If the gemm sizes + // are not divisible by that, then we definitely + // need padding. Therefore, it can't use blockwise + // reductions. + + // Thus, it becomes a competition among + // atomic_store based reduction kernels. + // So basically, all configs could be performant relative to each other. + if (gemmSize.m % 16 != 0) { + return success(); + } + if (gemmSize.n % 16 != 0) { + return success(); + } + // We can skip knowing that dPerBlock=16 + // is there on the tuning space that should + // be faster than anyone that use m or n + // padding. + if (gemmSize.m % mPerBlock != 0) { + return failure(); + } + if (gemmSize.n % nPerBlock != 0) { + return failure(); + } + return success(); +} + LogicalResult PopulateParams::couldBePerformant(const PopulateParamsInfo &info, const InitParamsNonAccel ¶ms) { - // Implement this if needed. - (void)info; - (void)params; + if (info.hasFusedReduction) { + return couldFusedReductionBePerformant(info.gemmSize, params.gemmMPerBlock, + params.gemmNPerBlock); + } return success(); } @@ -338,6 +374,10 @@ PopulateParamsAccel::paramsProbablyValid(OpBuilder &b, LogicalResult PopulateParamsAccel::couldBePerformant(const PopulateParamsInfo &info, const InitParamsAccel ¶ms) { + if (info.hasFusedReduction) { + return couldFusedReductionBePerformant(info.gemmSize, params.gemmMPerBlock, + params.gemmNPerBlock); + } return specificCouldBePerformant(params, info.gemmAType, info.gemmBType); } diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 092344905fb3..b483e0770741 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -364,7 +364,8 @@ void createQuickTuningRange(TuningParamSet *newSpace, tuningInfo.getTuningParameters(info.kernelType, info.gemmAType, info.gemmBType, info.arch), info.gemmSize)) { - if (succeeded(tuningInfo.paramsProbablyValid(b, info, param))) + if (succeeded(tuningInfo.paramsProbablyValid(b, info, param)) && + succeeded(tuningInfo.couldBePerformant(info, param))) newSpace->tuningRange.push_back(cast( tuningInfo.getGemmParamsAttr(b, param))); } @@ -375,7 +376,8 @@ void createQuickTuningRange(TuningParamSet *newSpace, tuningInfo.getTuningParameters(info.kernelType, info.gemmAType, info.gemmBType, info.arch), info.gemmSize)) { - if (succeeded(tuningInfo.paramsProbablyValid(b, info, param))) + if (succeeded(tuningInfo.paramsProbablyValid(b, info, param)) && + succeeded(tuningInfo.couldBePerformant(info, param))) newSpace->tuningRange.push_back(cast( tuningInfo.getGemmParamsAttr(b, param))); } @@ -386,7 +388,8 @@ void createQuickTuningRange(TuningParamSet *newSpace, tuningInfo.getTuningParameters(info.kernelType, info.gemmAType, info.gemmBType), info.gemmSize)) { - if (succeeded(tuningInfo.paramsProbablyValid(b, info, param))) + if (succeeded(tuningInfo.paramsProbablyValid(b, info, param)) && + succeeded(tuningInfo.couldBePerformant(info, param))) newSpace->tuningRange.push_back(cast( tuningInfo.getGemmParamsAttr(b, param))); } diff --git a/mlir/lib/Dialect/Rock/utility/loweringUtils.cpp b/mlir/lib/Dialect/Rock/utility/loweringUtils.cpp index 3900e4b4e330..f1b5f0774a2f 100644 --- a/mlir/lib/Dialect/Rock/utility/loweringUtils.cpp +++ b/mlir/lib/Dialect/Rock/utility/loweringUtils.cpp @@ -775,3 +775,9 @@ Value mlir::rock::gpuAlloc(OpBuilder &b, Location loc, int64_t bufferDim, return viewBufferAs(b, buffer, elementType); } + +LogicalResult mlir::rock::checkLDSSize(StringAttr arch, int64_t ldsBytes) { + // Check for arch limitations exceede + const int64_t ldsSize = rock::lookupArchInfo(arch).maxSharedMemPerWG; + return success(ldsBytes <= ldsSize); +} diff --git a/mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp b/mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp index 9216fac4e780..5b9c7097b1d1 100644 --- a/mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp +++ b/mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp @@ -1884,11 +1884,6 @@ void remapDims(std::vector &argsVector, } } -struct SubDimInfo { - int64_t size; - int64_t stride; -}; - static SmallVector getStrides(ArrayRef dimLens) { SmallVector ret{1}; for (int64_t dimLen : llvm::reverse(dimLens)) { @@ -1993,15 +1988,27 @@ static FailureOr removeUpperDimsFromMap( assert(preservedLowerDims.size() == 1); SmallVector subDimStrides = getStrides(tr.getParams()); // Collect all removedSubDims in upper to the lower dim + SetVector alreadyAddedStrides; for (auto [upperDim, subDimStride] : zip(tr.getUpperDims(), subDimStrides)) { - for (const SubDimInfo &remSubDimInfo : removedSubDims[upperDim]) { - LLVM_DEBUG(llvm::dbgs() << "creating newRemovedSubDim /w size = " - << remSubDimInfo.size << ", stride=" - << remSubDimInfo.stride * subDimStride - << " @ " << preservedLowerDims[0] << "\n"); - newRemovedSubDims[preservedLowerDims[0]].push_back( - {remSubDimInfo.size, remSubDimInfo.stride * subDimStride}); + if (removedSubDims.contains(upperDim)) { + LLVM_DEBUG(llvm::dbgs() << "remSubDimInfo.size = " + << removedSubDims[upperDim].size() << "\n"); + for (const SubDimInfo &remSubDimInfo : removedSubDims[upperDim]) { + int64_t newStride = remSubDimInfo.stride * subDimStride; + if (alreadyAddedStrides.contains(newStride)) + continue; + alreadyAddedStrides.insert(newStride); + if (remSubDimInfo.size > 1) { + LLVM_DEBUG(llvm::dbgs() + << "1:creating newRemovedSubDim /w size = " + << remSubDimInfo.size << ", stride=" << newStride + << ",upperdim=" << upperDim << " " + << " @ " << preservedLowerDims[0] << "\n"); + newRemovedSubDims[preservedLowerDims[0]].push_back( + {remSubDimInfo.size, newStride}); + } + } } } SetVector removedDimsInTr = @@ -2009,12 +2016,14 @@ static FailureOr removeUpperDimsFromMap( for (auto [idx, subDimSize] : enumerate(tr.getParams())) { int64_t upperDim = tr.getUpperDims()[idx]; if (removedDimsInTr.contains(upperDim)) { - LLVM_DEBUG(llvm::dbgs() - << "creating newRemovedSubDim /w size = " << subDimSize - << ", stride=" << subDimStrides[idx] << " @ " - << preservedLowerDims[0] << "\n"); - newRemovedSubDims[preservedLowerDims[0]].push_back( - {subDimSize, subDimStrides[idx]}); + if (subDimSize > 1) { + LLVM_DEBUG(llvm::dbgs() + << "2:creating newRemovedSubDim /w size = " + << subDimSize << ", stride=" << subDimStrides[idx] + << " @ " << preservedLowerDims[0] << "\n"); + newRemovedSubDims[preservedLowerDims[0]].push_back( + {subDimSize, subDimStrides[idx]}); + } } } uint32_t total = 1; @@ -2052,7 +2061,7 @@ static FailureOr removeUpperDimsFromMap( LLVM_DEBUG(llvm::dbgs() << "The relative stride of removed subDim is larger " "than original subDim\n"); - } else if (removedSubDimInfo.stride * removedSubDimInfo.size < + } else if (removedSubDimInfo.stride * removedSubDimInfo.size <= subDimStrides[subDim]) { // do nothing LLVM_DEBUG(llvm::dbgs() @@ -2068,32 +2077,58 @@ static FailureOr removeUpperDimsFromMap( int diff = 0; int newRemovedSubDimStride = 0; // Overlap on right side of removedSubDim + int64_t maxStrideSubDim = + subDimStrides[subDim] * tr.getParams()[subDim]; if (removedSubDimInfo.stride * removedSubDimInfo.size >= - subDimStrides[subDim] * tr.getParams()[subDim]) { + maxStrideSubDim) { int64_t rhsBoundForRemoval = std::max(removedSubDimInfo.stride, subDimStrides[subDim]); - diff = (subDimStrides[subDim] * tr.getParams()[subDim]) / - rhsBoundForRemoval; + if (maxStrideSubDim % rhsBoundForRemoval != 0) { + LLVM_DEBUG( + llvm::dbgs() + << "non divisible subDim removal found. aborting..\n"); + return failure(); + } + diff = maxStrideSubDim / rhsBoundForRemoval; newRemovedSubDimStride = rhsBoundForRemoval / subDimStrides[subDim]; } // The whole of removedSubDim is within the newly created lowDim else if (removedSubDimInfo.stride >= subDimStrides[subDim]) { diff = removedSubDimInfo.size; - newRemovedSubDimStride = removedSubDimInfo.stride; + newRemovedSubDimStride = + removedSubDimInfo.stride / subDimStrides[subDim]; } // Overlap is left side of removedSubDim else { - diff = (removedSubDimInfo.stride * removedSubDimInfo.size) / - subDimStrides[subDim]; + int64_t maxStrideRemovedSubDim = + removedSubDimInfo.stride * removedSubDimInfo.size; + if (maxStrideRemovedSubDim % subDimStrides[subDim] != 0) { + LLVM_DEBUG( + llvm::dbgs() + << "non divisible subDim removal found. aborting..\n"); + return failure(); + } + diff = maxStrideRemovedSubDim / subDimStrides[subDim]; newRemovedSubDimStride = 1; } - LLVM_DEBUG(llvm::dbgs() - << "creating newRemovedSubDim /w size = " << diff - << ", stride=" << newRemovedSubDimStride << " @ " - << lowDim << "\n"); - newRemovedSubDims[lowDim].push_back( - {diff, newRemovedSubDimStride}); + if (diff > 1) { + LLVM_DEBUG(llvm::dbgs() + << "creating newRemovedSubDim /w size = " << diff + << ", stride=" << newRemovedSubDimStride << " @ " + << lowDim << "\n"); + newRemovedSubDims[lowDim].push_back( + {diff, newRemovedSubDimStride}); + } + LLVM_DEBUG(llvm::dbgs() << "origLowerBounds[lowDim]=" + << origLowerBounds[lowDim] << "\n"); + LLVM_DEBUG(llvm::dbgs() << "diff=" << diff << "\n"); + if (origLowerBounds[lowDim] % diff != 0) { + LLVM_DEBUG( + llvm::dbgs() + << "non divisible subDim removal found. aborting..\n"); + return failure(); + } origLowerBounds[lowDim] = origLowerBounds[lowDim] / diff; } } @@ -2110,9 +2145,17 @@ static FailureOr removeUpperDimsFromMap( origLowerBounds[lowerDim] = origUpperBounds[upperDim]; } for (auto [dim, subDimInfo] : removedSubDims) { - LLVM_DEBUG(llvm::dbgs() << "copying removedSubDimInfo from:" << dim - << " to:" << upperToLower[dim] << "\n"); - newRemovedSubDims[upperToLower[dim]] = subDimInfo; + if (upperToLower.contains(dim)) { + LLVM_DEBUG(llvm::dbgs() << "copying removedSubDimInfo from:" << dim + << " to:" << upperToLower[dim] << "\n"); + for (const auto &sdIndo : subDimInfo) { + LLVM_DEBUG(llvm::dbgs() + << "\tcreating newRemovedSubDim /w size = " + << sdIndo.size << ", stride=" << sdIndo.stride << " @ " + << upperToLower[dim] << "\n"); + } + newRemovedSubDims[upperToLower[dim]] = subDimInfo; + } } llvm::copy(tr.getParams(), std::back_inserter(args.params)); break; @@ -2132,6 +2175,15 @@ static FailureOr removeUpperDimsFromMap( const auto lowerDim = tr.getLowerDims()[idx]; origLowerBounds[lowerDim] = origUpperBounds[upperDim]; args.params.append({0, 0}); + LLVM_DEBUG(llvm::dbgs() << "copying removedSubDimInfo from:" + << upperDim << " to:" << lowerDim << "\n"); + for (const auto &sdIndo : removedSubDims[upperDim]) { + LLVM_DEBUG(llvm::dbgs() + << "\tcreating newRemovedSubDim /w size = " + << sdIndo.size << ", stride=" << sdIndo.stride << " @ " + << lowerDim << "\n"); + } + newRemovedSubDims[lowerDim] = removedSubDims[upperDim]; } else { args.params.append( {tr.getParams()[idx * 2], tr.getParams()[idx * 2 + 1]}); @@ -2262,7 +2314,7 @@ mlir::rock::removeUpperDims(OpBuilder &b, ArrayAttr transformAttrs, SmallVector results; llvm::SmallVector upperBounds = {}; - llvm::SmallDenseMap> preservedSubDims; + llvm::SmallDenseMap> removedSubDims; if (!transformAttrs.empty()) { auto first = *(transformAttrs.begin()); auto trMap = cast(first); @@ -2278,7 +2330,7 @@ mlir::rock::removeUpperDims(OpBuilder &b, ArrayAttr transformAttrs, llvm::SmallVector lowerBounds = {}; FailureOr maybeNewTrMapAttr = removeUpperDimsFromMap(b, trMap, removeIndicesSet, upperBounds, - lowerBounds, preservedSubDims); + lowerBounds, removedSubDims); upperBounds = lowerBounds; if (failed(maybeNewTrMapAttr)) { return failure(); @@ -2321,3 +2373,171 @@ mlir::rock::removeUpperDims(OpBuilder &b, ArrayAttr transformAttrs, convertDimNamesToIndices(transformAttrs, removeDimNamesSet); return removeUpperDims(b, transformAttrs, removeIndicesSet); } + +FailureOr>> +mlir::rock::getLowerSubDimensions(OpBuilder &b, ArrayAttr transformAttrs, + int64_t dim) { + return getLowerSubDimensions(b, transformAttrs, ArrayRef{dim}); +} + +FailureOr>> +mlir::rock::getLowerSubDimensions(OpBuilder &b, ArrayAttr transformAttrs, + ArrayRef dims) { + llvm::SmallDenseMap> subDimInfo; + if (transformAttrs.empty()) { + LLVM_DEBUG(llvm::dbgs() << "transformAttrs is empty.\n"); + return failure(); + } + TransformMapAttr topMap = cast(transformAttrs[0]); + for (int64_t dim : dims) { + // No point of tracing size 1 dimensions + if (topMap.getUpperBounds()[dim] == 1) { + continue; + } + LLVM_DEBUG(llvm::dbgs() + << "creating subDim of @ " << dim << "\n"); + subDimInfo[dim].push_back({topMap.getUpperBounds()[dim], 1}); + } + + for (TransformMapAttr trMap : transformAttrs.getAsRange()) { + LLVM_DEBUG(llvm::dbgs() << "analyzing trMap:" << trMap << "\n"); + // local function to update the next subdim info + auto getNextSubDimInfo = + [&trMap](const llvm::SmallDenseMap> + &currSubDimInfo) + -> FailureOr>> { + llvm::SmallDenseMap> nextSubDimInfo; + for (TransformAttr trAttr : trMap.getOps()) { + switch (trAttr.getType()) { + case TransformType::PassThrough: { + llvm::SmallDenseMap upperToLower; + for (auto [idx, upperDim] : llvm::enumerate(trAttr.getUpperDims())) { + const auto lowerDim = trAttr.getLowerDims()[idx]; + LLVM_DEBUG(llvm::dbgs() << "pt:upper=" << upperDim + << ",lower=" << lowerDim << "\n"); + upperToLower[upperDim] = lowerDim; + }; + for (auto [dim, subDimInfo] : currSubDimInfo) { + if (upperToLower.contains(dim)) { + LLVM_DEBUG(llvm::dbgs() << "remapping:" << dim << " to " + << upperToLower[dim] << "\n"); + nextSubDimInfo[upperToLower[dim]] = subDimInfo; + } + } + } break; + case TransformType::Merge: { + SmallVector subDimStrides = getStrides(trAttr.getParams()); + int64_t upperDim = trAttr.getUpperDims()[0]; + for (auto [lowDim, subDimStride, param] : llvm::zip( + trAttr.getLowerDims(), subDimStrides, trAttr.getParams())) { + if (currSubDimInfo.contains(upperDim)) { + for (const SubDimInfo &sdInfo : currSubDimInfo.at(upperDim)) { + if (sdInfo.stride >= subDimStride * param) { + LLVM_DEBUG(llvm::dbgs() + << "No overlap: stride of analyzed dim is larger " + "than new subdim stride.\n"); + } else if (sdInfo.stride * sdInfo.size <= subDimStride) { + LLVM_DEBUG(llvm::dbgs() + << "No overlap: stride of new subdim stride is " + "larger than the analyzed dim.\n"); + } else { + // New sizes and strides for newly annotated subdims + int64_t newSize; + int64_t newStride; + int64_t maxStrideSubDim = subDimStride * param; + // Overlap on the right side of annotated subdim + if (sdInfo.stride * sdInfo.size >= maxStrideSubDim) { + int64_t rhsBoundForAnnotate = + std::max(sdInfo.stride, subDimStride); + newSize = maxStrideSubDim / rhsBoundForAnnotate; + newStride = rhsBoundForAnnotate / subDimStride; + } + // The whole of annotatedSubDim is within the newly created + // lowDim + else if (sdInfo.stride >= subDimStride) { + newSize = sdInfo.size; + newStride = sdInfo.stride / subDimStride; + } + // Overlap on the left side of annotated subdim + else { + int64_t maxStrideRemovedSubDim = + sdInfo.stride * sdInfo.size; + newSize = maxStrideRemovedSubDim / subDimStride; + newStride = 1; + } + LLVM_DEBUG(llvm::dbgs() << "creating subDim of @ " << lowDim << "\n"); + if (newSize > 1) { + nextSubDimInfo[lowDim].push_back({newSize, newStride}); + } + } + } + } + } + } break; + case TransformType::Unmerge: { + SmallVector subDimStrides = getStrides(trAttr.getParams()); + int64_t lowDim = trAttr.getLowerDims()[0]; + for (size_t subDim = 0; subDim < trAttr.getParams().size(); + subDim++) { + int64_t upperDim = trAttr.getUpperDims()[subDim]; + if (currSubDimInfo.contains(upperDim)) { + for (const SubDimInfo &sdInfo : currSubDimInfo.at(upperDim)) { + int64_t newStride = sdInfo.stride * subDimStrides[subDim]; + LLVM_DEBUG(llvm::dbgs() + << "creating subDim of @ " << lowDim + << "\n"); + if (sdInfo.size > 1) { + nextSubDimInfo[lowDim].push_back({sdInfo.size, newStride}); + } + } + } + } + break; + } + case TransformType::ConstDim: + case TransformType::AddDim: { + // Nothing to do + break; + } + default: + LLVM_DEBUG(llvm::dbgs() + << "Unsupported transform type : " << trAttr << "\n"); + return failure(); + } + } + return nextSubDimInfo; + }; + FailureOr>> + nextSubDimInfo = getNextSubDimInfo(subDimInfo); + if (failed(nextSubDimInfo)) + return failure(); + subDimInfo = nextSubDimInfo.value(); + } + if (subDimInfo.empty()) { + return failure(); + } + return subDimInfo; +} + +SmallVector> mlir::rock::createDimNames(int64_t len, + StringRef prefix) { + SmallVector> names; + for (unsigned d = 0; d < len; d++) { + SmallString<8> dimName(prefix.str() + Twine(d).str()); + names.push_back(dimName); + } + return names; +} + +SmallVector +mlir::rock::getStringRefsFor(ArrayRef> strings) { + SmallVector nameRefs; + for (const SmallString<8> &str : strings) { + nameRefs.push_back(str); + } + return nameRefs; +} diff --git a/mlir/test/Dialect/Rock/rock-shuffle-gemm-for-reductions.mlir b/mlir/test/Dialect/Rock/rock-shuffle-gemm-for-reductions.mlir new file mode 100644 index 000000000000..4dee78871737 --- /dev/null +++ b/mlir/test/Dialect/Rock/rock-shuffle-gemm-for-reductions.mlir @@ -0,0 +1,76 @@ +// RUN: rocmlir-opt -rock-shuffle-gemm-for-reductions -mlir-print-local-scope -rock-gemm-to-gridwise %s | FileCheck %s + +// CHECK-LABEL: @mlir_convolution_multi_reduce +func.func @mlir_convolution_multi_reduce(%arg0: memref<320xf32>, %arg1: memref<32768xf32>, %arg2: memref<11520xf32>, %arg3: memref<64xf32> {mhal.read_access, rock.prefill = 0.000000e+00 : f32}, %arg4: memref<64xf32>, %arg5: memref<2621440xf32>) attributes {arch = "gfx942:sramecc+:xnack-", block_size = 256 : i32, grid_size = 320 : i32, kernel = "mixr"} { + %cst = arith.constant 2.44140629E-5 : f32 + %0 = rock.transform %arg0 by (d0 * 10 + d1 + d2 + d3 + d4)> by [ ["dim0"] at [0]>] bounds = [32, 10, 1, 1, 1] -> [320]> : memref<320xf32> to memref<32x10x1x1x1xf32> + %1 = rock.transform %0 by (d1, d2, d3, d4, d0)> by [ ["dim4", "dim0", "dim1", "dim2", "dim3"] at [4, 0, 1, 2, 3]>] bounds = [1, 32, 10, 1, 1] -> [32, 10, 1, 1, 1]> : memref<32x10x1x1x1xf32> to memref<1x32x10x1x1xf32> + %2 = rock.transform %1 by (0, d1, d2, 0, 0)> by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, ["dim3"] at [3]>, ["dim4"] at [4]>] bounds = [2, 32, 10, 64, 64] -> [1, 32, 10, 1, 1]> : memref<1x32x10x1x1xf32> to memref<2x32x10x64x64xf32> + %3 = rock.transform %arg2 by (((d0 * 4 + d1) * 3 + d2) * 3 + d3)> by [ ["dim0"] at [0]>] bounds = [320, 4, 3, 3] -> [11520]> : memref<11520xf32> to memref<320x4x3x3xf32> + %4 = rock.transform %arg1 by (((d0 * 4 + d1) * 64 + d2) * 64 + d3)> by [ ["dim0"] at [0]>] bounds = [2, 4, 64, 64] -> [32768]> : memref<32768xf32> to memref<2x4x64x64xf32> + %alloc = memref.alloc() : memref<2x32x10x64x64xf32> + %5 = rock.transform %alloc by (d0, d1 floordiv 10, d1 mod 10, d2, d3)> by [ ["dim0"] at [0]>, ["exp1", "exp2"] at [1, 2]>, ["dim2"] at [3]>, ["dim3"] at [4]>] bounds = [2, 320, 64, 64] -> [2, 32, 10, 64, 64]> : memref<2x32x10x64x64xf32> to memref<2x320x64x64xf32> + %6 = rock.transform %4 by (d0, d1 * 4 + d2, d3, d4)> by [ ["n", "h", "w"] at [0, 2, 3]>, ["c"] at [1]>] bounds = [2, 1, 4, 64, 64] -> [2, 4, 64, 64]> : memref<2x4x64x64xf32> to memref<2x1x4x64x64xf32> + %7 = rock.transform %3 by (d0 * 320 + d1, d2, d3, d4)> by [ ["c", "y", "x"] at [1, 2, 3]>, ["k"] at [0]>] bounds = [1, 320, 4, 3, 3] -> [320, 4, 3, 3]> : memref<320x4x3x3xf32> to memref<1x320x4x3x3xf32> + %8 = rock.transform %5 by (d0, d1 * 320 + d2, d3, d4)> by [ ["n", "h", "w"] at [0, 2, 3]>, ["k"] at [1]>] bounds = [2, 1, 320, 64, 64] -> [2, 320, 64, 64]> : memref<2x320x64x64xf32> to memref<2x1x320x64x64xf32> + // CHECK: %[[GEMM_IN_A:.+]] = rock.transform %{{.+}} by (d0, d2, d1 floordiv 9, (d1 mod 9) floordiv 3, d1 mod 3)> by [ ["g"] at [0]>, ["c", "0", "1"] at [2, 3, 4]>, ["k"] at [1]>] bounds = [1, 36, 320] -> [1, 320, 4, 3, 3]> : memref<1x320x4x3x3xf32> to memref<1x36x320xf32> + // CHECK: %[[GEMM_IN_A_TR0:.+]] = rock.transform %[[GEMM_IN_A]] by (d0, d1, d2 * 10 + d3)> by [ ["G", "K"] at [0, 1]>, ["M"] at [2]>] bounds = [1, 36, 32, 10] -> [1, 36, 320]> : memref<1x36x320xf32> to memref<1x36x32x10xf32> + // CHECK: %[[GEMM_IN_A_TR1:.+]] = rock.transform %[[GEMM_IN_A_TR0]] by (d0, d1, d2, d3)> by [ ["G", "K"] at [0, 1]>, ["d_nr0"] at [2]>, ["d_r0"] at [3]>] bounds = [1, 36, 32, 10] -> [1, 36, 32, 10]> : memref<1x36x32x10xf32> to memref<1x36x32x10xf32> + // CHECK: %[[GEMM_IN_A_TR2:.+]] = rock.transform %[[GEMM_IN_A_TR1]] by (d0, d1, d3, d2 * 2 + d4)> by [ ["G", "K", "d_nr"] at [0, 1, 2]>, ["d_r"] at [3]>] bounds = [1, 36, 5, 32, 2] -> [1, 36, 32, 10]> : memref<1x36x32x10xf32> to memref<1x36x5x32x2xf32> + // CHECK: %[[GEMM_IN_A_TR3:.+]] = rock.transform %[[GEMM_IN_A_TR2]] by (d0, d1, d2 floordiv 64, (d2 mod 64) floordiv 2, d2 mod 2)> by [ ["G", "K"] at [0, 1]>, ["d_rh", "d_nr", "d_rl"] at [2, 3, 4]>] bounds = [1, 36, 320] -> [1, 36, 5, 32, 2]> : memref<1x36x5x32x2xf32> to memref<1x36x320xf32> + %9 = rock.transform %7 by (d0, d2, d1 floordiv 9, (d1 mod 9) floordiv 3, d1 mod 3)> by [ ["g"] at [0]>, ["c", "0", "1"] at [2, 3, 4]>, ["k"] at [1]>] bounds = [1, 36, 320] -> [1, 320, 4, 3, 3]> : memref<1x320x4x3x3xf32> to memref<1x36x320xf32> + %10 = rock.transform %6 by (d0, d1, d2, d3 - 1, d4 - 1)> by [ ["ni"] at [0]>, ["gi"] at [1]>, ["ci"] at [2]>, ["0i", "1i"] at [3, 4]>] bounds = [2, 1, 4, 66, 66] -> [2, 1, 4, 64, 64]> : memref<2x1x4x64x64xf32> to memref<2x1x4x66x66xf32> + %11 = rock.transform %10 by (d0, d1, d2, d3 + d4, d5 + d6)> by [ ["ni", "gi", "ci"] at [0, 1, 2]>, ["0ipad"] at [3]>, ["1ipad"] at [4]>] bounds = [2, 1, 4, 3, 64, 3, 64] -> [2, 1, 4, 66, 66]> : memref<2x1x4x66x66xf32> to memref<2x1x4x3x64x3x64xf32> + // CHECK: %[[GEMM_IN_B:.+]] = rock.transform %{{.+}} by (d2 floordiv 4096, d0, d1 floordiv 9, (d1 mod 9) floordiv 3, (d2 mod 4096) floordiv 64, d1 mod 3, d2 mod 64)> by [ ["gi"] at [1]>, ["ci", "0", "1"] at [2, 3, 5]>, ["ni", "0o", "1o"] at [0, 4, 6]>] bounds = [1, 36, 8192] -> [2, 1, 4, 3, 64, 3, 64]> : memref<2x1x4x3x64x3x64xf32> to memref<1x36x8192xf32> + // CHECK: %[[GEMM_IN_B_TR0:.+]] = rock.transform %[[GEMM_IN_B]] by (d0, d1, (d2 * 64 + d3 + d4) * 64 + d5)> by [ ["G", "K"] at [0, 1]>, ["N"] at [2]>] bounds = [1, 36, 2, 64, 1, 64] -> [1, 36, 8192]> : memref<1x36x8192xf32> to memref<1x36x2x64x1x64xf32> + // CHECK: %[[GEMM_IN_B_TR1:.+]] = rock.transform %[[GEMM_IN_B_TR0]] by (d0, d1, d2, d3 floordiv 64, 0, d3 mod 64)> by [ ["G", "K"] at [0, 1]>, ["d_nr0", "d_nr1"] at [2, 4]>, ["d_r0", "d_r1"] at [3, 5]>] bounds = [1, 36, 2, 4096] -> [1, 36, 2, 64, 1, 64]> : memref<1x36x2x64x1x64xf32> to memref<1x36x2x4096xf32> + // CHECK: %[[GEMM_IN_B_TR2:.+]] = rock.transform %[[GEMM_IN_B_TR1]] by (d0, d1, d3, d2 * 128 + d4)> by [ ["G", "K", "d_nr"] at [0, 1, 2]>, ["d_r"] at [3]>] bounds = [1, 36, 32, 2, 128] -> [1, 36, 2, 4096]> : memref<1x36x2x4096xf32> to memref<1x36x32x2x128xf32> + // CHECK: %[[GEMM_IN_B_TR3:.+]] = rock.transform %[[GEMM_IN_B_TR2]] by (d0, d1, d2 floordiv 256, (d2 mod 256) floordiv 128, d2 mod 128)> by [ ["G", "K"] at [0, 1]>, ["d_rh", "d_nr", "d_rl"] at [2, 3, 4]>] bounds = [1, 36, 8192] -> [1, 36, 32, 2, 128]> : memref<1x36x32x2x128xf32> to memref<1x36x8192xf32> + %12 = rock.transform %11 by (d2 floordiv 4096, d0, d1 floordiv 9, (d1 mod 9) floordiv 3, (d2 mod 4096) floordiv 64, d1 mod 3, d2 mod 64)> by [ ["gi"] at [1]>, ["ci", "0", "1"] at [2, 3, 5]>, ["ni", "0o", "1o"] at [0, 4, 6]>] bounds = [1, 36, 8192] -> [2, 1, 4, 3, 64, 3, 64]> : memref<2x1x4x3x64x3x64xf32> to memref<1x36x8192xf32> + // CHECK:%[[GEMM_OUT_C:.+]] = rock.transform %{{.+}} by (d2 floordiv 4096, d0, d1, (d2 mod 4096) floordiv 64, d2 mod 64)> by [ ["go"] at [1]>, ["ko"] at [2]>, ["no", "0o", "1o"] at [0, 3, 4]>] bounds = [1, 320, 8192] -> [2, 1, 320, 64, 64]> : memref<2x1x320x64x64xf32> to memref<1x320x8192xf32> + // CHECK: %[[GEMM_OUT_C_TR0:.+]] = rock.transform %[[GEMM_OUT_C]] by (d0, d1 * 10 + d2, (d3 * 64 + d4 + d5) * 64 + d6)> by [ ["G"] at [0]>, ["M"] at [1]>, ["N"] at [2]>] bounds = [1, 32, 10, 2, 64, 1, 64] -> [1, 320, 8192]> : memref<1x320x8192xf32> to memref<1x32x10x2x64x1x64xf32> + // CHECK: %[[GEMM_OUT_C_TR1:.+]] = rock.transform %[[GEMM_OUT_C_TR0]] by (d0, d1, d2, d3, d4 floordiv 64, 0, d4 mod 64)> by [ ["G"] at [0]>, ["m_nr0"] at [1]>, ["m_r0"] at [2]>, ["n_nr0", "n_nr1"] at [3, 5]>, ["n_r0", "n_r1"] at [4, 6]>] bounds = [1, 32, 10, 2, 4096] -> [1, 32, 10, 2, 64, 1, 64]> : memref<1x32x10x2x64x1x64xf32> to memref<1x32x10x2x4096xf32> + // CHECK: %[[GEMM_OUT_C_TR2:.+]] = rock.transform %[[GEMM_OUT_C_TR1]] by (d0, d2, d1 * 2 + d3, d5, d4 * 128 + d6)> by [ ["G"] at [0]>, ["m_nr"] at [1]>, ["m_r"] at [2]>, ["n_nr"] at [3]>, ["n_r"] at [4]>] bounds = [1, 5, 32, 2, 32, 2, 128] -> [1, 32, 10, 2, 4096]> : memref<1x32x10x2x4096xf32> to memref<1x5x32x2x32x2x128xf32> + // CHECK: %[[GEMM_OUT_C_TR3:.+]] = rock.transform %[[GEMM_OUT_C_TR2]] by (d0, d1 floordiv 64, (d1 mod 64) floordiv 2, d1 mod 2, d2 floordiv 256, (d2 mod 256) floordiv 128, d2 mod 128)> by [ ["G"] at [0]>, ["m_rh", "m_nr", "m_rl"] at [1, 2, 3]>, ["n_rh", "n_nr", "n_rl"] at [4, 5, 6]>] bounds = [1, 320, 8192] -> [1, 5, 32, 2, 32, 2, 128]> : memref<1x5x32x2x32x2x128xf32> to memref<1x320x8192xf32> + %13 = rock.transform %8 by (d2 floordiv 4096, d0, d1, (d2 mod 4096) floordiv 64, d2 mod 64)> by [ ["go"] at [1]>, ["ko"] at [2]>, ["no", "0o", "1o"] at [0, 3, 4]>] bounds = [1, 320, 8192] -> [2, 1, 320, 64, 64]> : memref<2x1x320x64x64xf32> to memref<1x320x8192xf32> + // CHECK: rock.gridwise_gemm_accel(%[[GEMM_IN_A_TR3]], %[[GEMM_IN_B_TR3]], %[[GEMM_OUT_C_TR3]]) + rock.gridwise_gemm_accel(%9, %12, %13) storeMethod( set) features = mfma|dot|atomic_add {arch = "gfx942:sramecc+:xnack-", blockSize = 256 : i32, gridSize = 320 : i32, numCU = 228 : i32, params = #rock.xdlops_gemm_derived_params} : memref<1x36x320xf32>, memref<1x36x8192xf32>, memref<1x320x8192xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x32x10x64x64xf32> + %alloc_1 = memref.alloc() : memref<2621440xf32> + %14 = rock.transform %alloc_1 by ((((d0 * 32 + d1) * 10 + d2) * 64 + d3) * 64 + d4)> by [ ["dim0"] at [0]>] bounds = [2, 32, 10, 64, 64] -> [2621440]> : memref<2621440xf32> to memref<2x32x10x64x64xf32> + %alloc_2 = memref.alloc() : memref<2x32x10x64x64xf32> + %alloc_3 = memref.alloc() : memref<2x32x10x64x64xf32> + linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%alloc, %2 : memref<2x32x10x64x64xf32>, memref<2x32x10x64x64xf32>) outs(%alloc_0 : memref<2x32x10x64x64xf32>) attrs = {rock.majorTensorNumber = 1 : index} { + ^bb0(%in: f32, %in_8: f32, %out: f32): + %19 = arith.addf %in, %in_8 : f32 + linalg.yield %19 : f32 + } + memref.copy %alloc_0, %14 : memref<2x32x10x64x64xf32> to memref<2x32x10x64x64xf32> + memref.copy %alloc_0, %alloc_2 : memref<2x32x10x64x64xf32> to memref<2x32x10x64x64xf32> + memref.copy %alloc_0, %alloc_3 : memref<2x32x10x64x64xf32> to memref<2x32x10x64x64xf32> + %alloc_4 = memref.alloc() : memref<2x32x40960xf32> + %15 = rock.transform %alloc_4 by (d0, d1, (d2 * 64 + d3) * 64 + d4)> by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>] bounds = [2, 32, 10, 64, 64] -> [2, 32, 40960]> : memref<2x32x40960xf32> to memref<2x32x10x64x64xf32> + linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%alloc_2 : memref<2x32x10x64x64xf32>) outs(%15 : memref<2x32x10x64x64xf32>) attrs = {rock.majorTensorNumber = 0 : index} { + ^bb0(%in: f32, %out: f32): + %19 = arith.mulf %in, %cst : f32 + linalg.yield %19 : f32 + } + %alloc_5 = memref.alloc() : memref<64xf32> + %16 = rock.transform %alloc_5 by (d0 * 32 + d1 + d2)> by [ ["dim0"] at [0]>] bounds = [2, 32, 1] -> [64]> : memref<64xf32> to memref<2x32x1xf32> + rock.reduce sum %alloc_4 into %16 features = mfma|dot|atomic_add {axis = 2 : index, blockSize = 256 : i32, gridSize = 10240 : i32} : memref<2x32x40960xf32> into memref<2x32x1xf32> + %alloc_6 = memref.alloc() : memref<2x32x40960xf32> + %17 = rock.transform %alloc_6 by (d0, d1, (d2 * 64 + d3) * 64 + d4)> by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>] bounds = [2, 32, 10, 64, 64] -> [2, 32, 40960]> : memref<2x32x40960xf32> to memref<2x32x10x64x64xf32> + linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%alloc_3 : memref<2x32x10x64x64xf32>) outs(%17 : memref<2x32x10x64x64xf32>) attrs = {rock.majorTensorNumber = 0 : index} { + ^bb0(%in: f32, %out: f32): + %19 = arith.mulf %in, %in : f32 + %20 = arith.mulf %19, %cst : f32 + linalg.yield %20 : f32 + } + %alloc_7 = memref.alloc() : memref<64xf32> + %18 = rock.transform %alloc_7 by (d0 * 32 + d1 + d2)> by [ ["dim0"] at [0]>] bounds = [2, 32, 1] -> [64]> : memref<64xf32> to memref<2x32x1xf32> + rock.reduce sum %alloc_6 into %18 features = mfma|dot|atomic_add {axis = 2 : index, blockSize = 256 : i32, gridSize = 10240 : i32} : memref<2x32x40960xf32> into memref<2x32x1xf32> + memref.copy %alloc_5, %arg3 : memref<64xf32> to memref<64xf32> + memref.copy %alloc_7, %arg4 : memref<64xf32> to memref<64xf32> + memref.copy %alloc_1, %arg5 : memref<2621440xf32> to memref<2621440xf32> + return +} diff --git a/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-4d.mlir b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-4d.mlir new file mode 100644 index 000000000000..cd8a4f33c9e0 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-4d.mlir @@ -0,0 +1,29 @@ +// RUN: rocmlir-gen -fut mlir_convolution_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_convolution_multi_reduce_wrapper --verifier clone -relDiff_threshold 0.01 -RMS_threshold 0.01 -absDiff_threshold 0.4 -| rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 + +// We need a check for each output as this test case has three outputs in it. +// CHECK: [1 1 1] +// CHECK: [1 1 1] +// CHECK: [1 1 1] +module { + func.func @mlir_convolution_multi_reduce(%arg0: !migraphx.shaped<2x32x10x64x64xf32, 0x10x1x0x0>, %arg1: !migraphx.shaped<2x4x64x64xf32, 16384x4096x64x1>, %arg2: !migraphx.shaped<320x4x3x3xf32, 36x9x3x1>) -> (!migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1>) // attributes {arch = "gfx90a:sramecc+:xnack-", kernel = "mixr"} + { + %0 = migraphx.literal(dense<2.44140629E-5> : tensor<1xf32>) : <1xf32, 0> + %1 = migraphx.literal(dense<2.44140629E-5> : tensor<1xf32>) : <1xf32, 0> + %2 = migraphx.convolution %arg1, %arg2 {dilation = [1, 1], group = 1 : i64, padding = [1, 1, 1, 1], padding_mode = 0 : i64, stride = [1, 1]} : <2x4x64x64xf32, 16384x4096x64x1>, <320x4x3x3xf32, 36x9x3x1> -> <2x320x64x64xf32, 1310720x4096x64x1> + %3 = migraphx.reshape %2 {dims = [2, 32, 10, 64, 64]} : <2x320x64x64xf32, 1310720x4096x64x1> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %4 = migraphx.add %3, %arg0 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x10x1x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %5 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0> + %6 = migraphx.mul %4, %5 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %7 = migraphx.reshape %6 {dims = [2, 32, 40960, 1]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x40960x1xf32, 1310720x40960x1x1> + %8 = migraphx.reduce_sum %7 {axes = [2]} : <2x32x40960x1xf32, 1310720x40960x1x1> -> <2x32x1x1xf32, 32x1x1x1> + %9 = migraphx.reshape %8 {dims = [2, 32, 1, 1, 1]} : <2x32x1x1xf32, 32x1x1x1> -> <2x32x1x1x1xf32, 32x1x1x1x1> + %10 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0> + %11 = migraphx.mul %4, %4 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %12 = migraphx.mul %11, %10 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %13 = migraphx.reshape %12 {dims = [2, 32, 40960]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x40960xf32, 1310720x40960x1> + %14 = migraphx.reduce_sum %13 {axes = [2]} : <2x32x40960xf32, 1310720x40960x1> -> <2x32x1xf32, 32x1x1> + %15 = migraphx.reshape %14 {dims = [2, 32, 1, 1, 1]} : <2x32x1xf32, 32x1x1> -> <2x32x1x1x1xf32, 32x1x1x1x1> + return %9, %15, %4 : !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1> + } +} diff --git a/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-dot-reduce-n-only.mlir b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-dot-reduce-n-only.mlir new file mode 100644 index 000000000000..5c951aa2cd57 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-dot-reduce-n-only.mlir @@ -0,0 +1,24 @@ +// RUN: rocmlir-gen -fut mlir_dot_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_dot_multi_reduce_wrapper --verifier clone -relDiff_threshold 0.01 -RMS_threshold 0.01 -absDiff_threshold 1.2 -| rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 + +// We need a check for each output as this test case has three outputs in it. +// CHECK: [1 1 1] +// CHECK: [1 1 1] +module { + func.func @mlir_dot_multi_reduce(%arg0: !migraphx.shaped<2x32x10x64x64xf16, 0x10x1x20480x320>, %arg1: !migraphx.shaped<2x32x10x64x64xf16, 1310720x40960x4096x64x1>, %arg2: !migraphx.shaped<2x4096x320xf16, 1310720x320x1>, %arg3: !migraphx.shaped<320x320xf16, 320x1>) -> (!migraphx.shaped<2x320x1xf32, 320x1x1>, !migraphx.shaped<2x32x10x64x64xf16, 1310720x40960x4096x64x1>) // attributes {arch = "gfx942:sramecc+:xnack-", kernel = "mixr", num_cu = 304 : i64} + { + %0 = migraphx.literal(dense<2.441410e-05> : tensor<1xf32>) : <1xf32, 0> + %1 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [2, 320, 320]} : <320x320xf16, 320x1> -> <2x320x320xf16, 0x320x1> + %2 = migraphx.dot %arg2, %1 : <2x4096x320xf16, 1310720x320x1>, <2x320x320xf16, 0x320x1> -> <2x4096x320xf16, 1310720x320x1> + %3 = migraphx.reshape %2 {dims = [2, 64, 64, 32, 10]} : <2x4096x320xf16, 1310720x320x1> -> <2x64x64x32x10xf16, 1310720x20480x320x10x1> + %4 = migraphx.transpose %3 {permutation = [0, 3, 4, 1, 2]} : <2x64x64x32x10xf16, 1310720x20480x320x10x1> -> <2x32x10x64x64xf16, 1310720x10x1x20480x320> + %5 = migraphx.add %4, %arg0 : <2x32x10x64x64xf16, 1310720x10x1x20480x320>, <2x32x10x64x64xf16, 0x10x1x20480x320> -> <2x32x10x64x64xf16, 1310720x10x1x20480x320> + %6 = migraphx.add %5, %arg1 : <2x32x10x64x64xf16, 1310720x10x1x20480x320>, <2x32x10x64x64xf16, 1310720x40960x4096x64x1> -> <2x32x10x64x64xf16, 1310720x40960x4096x64x1> + %7 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0> + %8 = migraphx.convert %6 {target_type = 2 : i64} : <2x32x10x64x64xf16, 1310720x40960x4096x64x1> to <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %9 = migraphx.mul %8, %7 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %10 = migraphx.reshape %9 {dims = [2, 320, 4096]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x320x4096xf32, 1310720x4096x1> + %11 = migraphx.reduce_sum %10 {axes = [2]} : <2x320x4096xf32, 1310720x4096x1> -> <2x320x1xf32, 320x1x1> + return %11, %6 : !migraphx.shaped<2x320x1xf32, 320x1x1>, !migraphx.shaped<2x32x10x64x64xf16, 1310720x40960x4096x64x1> + } +} diff --git a/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-dot.mlir b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-dot.mlir new file mode 100644 index 000000000000..68a74529dd33 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-dot.mlir @@ -0,0 +1,25 @@ +// RUN: rocmlir-gen -fut mlir_dot_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_dot_multi_reduce_wrapper --verifier clone -relDiff_threshold 0.01 -RMS_threshold 0.01 -absDiff_threshold 1.2 -| rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 + +// We need a check for each output as this test case has three outputs in it. +// CHECK: [1 1 1] +// CHECK: [1 1 1] +module { + func.func @mlir_dot_multi_reduce(%arg0: !migraphx.shaped<2x32x10x64x64xf16, 0x10x1x20480x320>, %arg1: !migraphx.shaped<2x32x10x64x64xf16, 1310720x40960x4096x64x1>, %arg2: !migraphx.shaped<2x4096x320xf16, 1310720x320x1>, %arg3: !migraphx.shaped<320x320xf16, 320x1>) -> (!migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf16, 1310720x40960x4096x64x1>) // attributes {arch = "gfx942:sramecc+:xnack-", kernel = "mixr", num_cu = 304 : i64} + { + %0 = migraphx.literal(dense<2.441410e-05> : tensor<1xf32>) : <1xf32, 0> + %1 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [2, 320, 320]} : <320x320xf16, 320x1> -> <2x320x320xf16, 0x320x1> + %2 = migraphx.dot %arg2, %1 : <2x4096x320xf16, 1310720x320x1>, <2x320x320xf16, 0x320x1> -> <2x4096x320xf16, 1310720x320x1> + %3 = migraphx.reshape %2 {dims = [2, 64, 64, 32, 10]} : <2x4096x320xf16, 1310720x320x1> -> <2x64x64x32x10xf16, 1310720x20480x320x10x1> + %4 = migraphx.transpose %3 {permutation = [0, 3, 4, 1, 2]} : <2x64x64x32x10xf16, 1310720x20480x320x10x1> -> <2x32x10x64x64xf16, 1310720x10x1x20480x320> + %5 = migraphx.add %4, %arg0 : <2x32x10x64x64xf16, 1310720x10x1x20480x320>, <2x32x10x64x64xf16, 0x10x1x20480x320> -> <2x32x10x64x64xf16, 1310720x10x1x20480x320> + %6 = migraphx.add %5, %arg1 : <2x32x10x64x64xf16, 1310720x10x1x20480x320>, <2x32x10x64x64xf16, 1310720x40960x4096x64x1> -> <2x32x10x64x64xf16, 1310720x40960x4096x64x1> + %7 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0> + %8 = migraphx.convert %6 {target_type = 2 : i64} : <2x32x10x64x64xf16, 1310720x40960x4096x64x1> to <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %9 = migraphx.mul %8, %7 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %10 = migraphx.reshape %9 {dims = [2, 32, 40960]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x40960xf32, 1310720x40960x1> + %11 = migraphx.reduce_sum %10 {axes = [2]} : <2x32x40960xf32, 1310720x40960x1> -> <2x32x1xf32, 32x1x1> + %12 = migraphx.reshape %11 {dims = [2, 32, 1, 1, 1]} : <2x32x1xf32, 32x1x1> -> <2x32x1x1x1xf32, 32x1x1x1x1> + return %12, %6 : !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf16, 1310720x40960x4096x64x1> + } +} diff --git a/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-raxis-1.mlir b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-raxis-1.mlir new file mode 100644 index 000000000000..cd452691264d --- /dev/null +++ b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo-raxis-1.mlir @@ -0,0 +1,29 @@ +// RUN: rocmlir-gen -fut mlir_convolution_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_convolution_multi_reduce_wrapper --verifier clone -relDiff_threshold 0.01 -RMS_threshold 0.01 -absDiff_threshold 0.4 -| rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 + +// We need a check for each output as this test case has three outputs in it. +// CHECK: [1 1 1] +// CHECK: [1 1 1] +// CHECK: [1 1 1] +module { + func.func @mlir_convolution_multi_reduce(%arg0: !migraphx.shaped<2x32x10x64x64xf32, 0x10x1x0x0>, %arg1: !migraphx.shaped<2x4x64x64xf32, 16384x4096x64x1>, %arg2: !migraphx.shaped<320x4x3x3xf32, 36x9x3x1>) -> (!migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x1x32xf32, 32x32x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1>) // attributes {arch = "gfx942:sramecc+:xnack-", kernel = "mixr"} + { + %0 = migraphx.literal(dense<2.44140629E-5> : tensor<1xf32>) : <1xf32, 0> + %1 = migraphx.literal(dense<2.44140629E-5> : tensor<1xf32>) : <1xf32, 0> + %2 = migraphx.convolution %arg1, %arg2 {dilation = [1, 1], group = 1 : i64, padding = [1, 1, 1, 1], padding_mode = 0 : i64, stride = [1, 1]} : <2x4x64x64xf32, 16384x4096x64x1>, <320x4x3x3xf32, 36x9x3x1> -> <2x320x64x64xf32, 1310720x4096x64x1> + %3 = migraphx.reshape %2 {dims = [2, 32, 10, 64, 64]} : <2x320x64x64xf32, 1310720x4096x64x1> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %4 = migraphx.add %3, %arg0 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x10x1x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %5 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0> + %6 = migraphx.mul %4, %5 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %7 = migraphx.reshape %6 {dims = [2, 32, 40960]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x40960xf32, 1310720x40960x1> + %8 = migraphx.reduce_sum %7 {axes = [2]} : <2x32x40960xf32, 1310720x40960x1> -> <2x32x1xf32, 32x1x1> + %9 = migraphx.reshape %8 {dims = [2, 32, 1, 1, 1]} : <2x32x1xf32, 32x1x1> -> <2x32x1x1x1xf32, 32x1x1x1x1> + %10 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0> + %11 = migraphx.mul %4, %4 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %12 = migraphx.mul %11, %10 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %13 = migraphx.reshape %12 {dims = [2, 32, 40960]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x40960xf32, 1310720x40960x1> + %14 = migraphx.transpose %13 {permutation = [0, 2, 1]} : <2x32x40960xf32, 1310720x40960x1> -> <2x40960x32xf32, 1310720x32x1> + %15 = migraphx.reduce_sum %14 {axes = [1]} : <2x40960x32xf32, 1310720x32x1> -> <2x1x32xf32, 32x32x1> + return %9, %15, %4 : !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x1x32xf32, 32x32x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1> + } +} diff --git a/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo.mlir b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo.mlir new file mode 100644 index 000000000000..b366a144b439 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/mixr-multi-reduce-mo.mlir @@ -0,0 +1,29 @@ +// RUN: rocmlir-gen -fut mlir_convolution_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_convolution_multi_reduce_wrapper --verifier clone -relDiff_threshold 0.01 -RMS_threshold 0.01 -absDiff_threshold 0.4 -| rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// ALLOW_RETRIES: 2 + +// We need a check for each output as this test case has three outputs in it. +// CHECK: [1 1 1] +// CHECK: [1 1 1] +// CHECK: [1 1 1] +module { + func.func @mlir_convolution_multi_reduce(%arg0: !migraphx.shaped<2x32x10x64x64xf32, 0x10x1x0x0>, %arg1: !migraphx.shaped<2x4x64x64xf32, 16384x4096x64x1>, %arg2: !migraphx.shaped<320x4x3x3xf32, 36x9x3x1>) -> (!migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1>) // attributes {arch = "gfx942:sramecc+:xnack-", kernel = "mixr"} + { + %0 = migraphx.literal(dense<2.44140629E-5> : tensor<1xf32>) : <1xf32, 0> + %1 = migraphx.literal(dense<2.44140629E-5> : tensor<1xf32>) : <1xf32, 0> + %2 = migraphx.convolution %arg1, %arg2 {dilation = [1, 1], group = 1 : i64, padding = [1, 1, 1, 1], padding_mode = 0 : i64, stride = [1, 1]} : <2x4x64x64xf32, 16384x4096x64x1>, <320x4x3x3xf32, 36x9x3x1> -> <2x320x64x64xf32, 1310720x4096x64x1> + %3 = migraphx.reshape %2 {dims = [2, 32, 10, 64, 64]} : <2x320x64x64xf32, 1310720x4096x64x1> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %4 = migraphx.add %3, %arg0 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x10x1x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %5 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0> + %6 = migraphx.mul %4, %5 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %7 = migraphx.reshape %6 {dims = [2, 32, 40960]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x40960xf32, 1310720x40960x1> + %8 = migraphx.reduce_sum %7 {axes = [2]} : <2x32x40960xf32, 1310720x40960x1> -> <2x32x1xf32, 32x1x1> + %9 = migraphx.reshape %8 {dims = [2, 32, 1, 1, 1]} : <2x32x1xf32, 32x1x1> -> <2x32x1x1x1xf32, 32x1x1x1x1> + %10 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0> + %11 = migraphx.mul %4, %4 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %12 = migraphx.mul %11, %10 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1> + %13 = migraphx.reshape %12 {dims = [2, 32, 40960]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x40960xf32, 1310720x40960x1> + %14 = migraphx.reduce_sum %13 {axes = [2]} : <2x32x40960xf32, 1310720x40960x1> -> <2x32x1xf32, 32x1x1> + %15 = migraphx.reshape %14 {dims = [2, 32, 1, 1, 1]} : <2x32x1xf32, 32x1x1> -> <2x32x1x1x1xf32, 32x1x1x1x1> + return %9, %15, %4 : !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1> + } +} diff --git a/mlir/test/fusion/rock-gemm-reduce-align-tiling.mlir b/mlir/test/fusion/rock-gemm-reduce-align-tiling.mlir index 3d9933a1ba5b..95bad400cdfc 100644 --- a/mlir/test/fusion/rock-gemm-reduce-align-tiling.mlir +++ b/mlir/test/fusion/rock-gemm-reduce-align-tiling.mlir @@ -1,14 +1,24 @@ -// RUN: rocmlir-opt --rock-view-to-transform -rock-affix-params -rock-conv-to-gemm -rock-gemm-to-gridwise -rock-gridwise-gemm-to-blockwise -rock-linalg-align %s | FileCheck %s - -// CHECK: [[MAP0:.*]] = #rock.transform_map<{{.*}} by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>] bounds = [1, 128, 256] -> [1, 128, 1]> -// CHECK: [[MAP1:.*]] = #rock.transform_map<{{.*}} by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>] bounds = [1, 128, 256] -> [1, 1, 256]> +// RUN: rocmlir-opt --rock-view-to-transform -rock-affix-params -rock-conv-to-gemm -rock-gemm-to-gridwise -rock-shuffle-gemm-for-reductions -rock-gridwise-gemm-to-blockwise -rock-linalg-align %s -mlir-print-local-scope | FileCheck %s // CHECK: test_gemm_reduce_last_axis_fusion func.func @test_gemm_reduce_last_axis_fusion(%arg0: memref<1x128x64xf32>, %arg1: memref<1x64x256xf32>, %arg2: memref<1x128x1xf32>) attributes {arch = "", kernel} { %0 = memref.alloc() : memref<1x128x256xf32> rock.gemm %0 = %arg0 * %arg1 features = none storeMethod = set {arch = ""} : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> - // CHECK: %[[trOut:.*]] = rock.transform %arg2 by [[MAP0]] : memref<1x128x1xf32> to memref<1x128x256xf32> - // CHECK: rock.threadwise_write_all {{.*}}(%[[trOut]]){{.*}} by atomic_add : {{.*}} -> memref<1x128x256xf32> + // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} into %[[BLOCK_RED_OUT:[0-9]+]] + + // CHECK: %[[TR0:.+]] = rock.transform %arg2 by {{.*}} : memref<1x128x1xf32> to memref<1x128x256xf32> + // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<2x128x2x128xf32> + // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<2x128x2x128xf32> to memref<2x1x2x1x128x128xf32> + // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<2x1x2x1x128x128xf32> to memref<2x1x2x128x1xf32> + // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<2x1x2x128x1xf32> to memref<2x1x2x128x128xf32> + + // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x4x4x4x4x2x4x2x4xf32 + // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<2x1x2x4x4x4x4x2x4x2x4xf32> to memref<2x1x2x256x64xf32> + + // CHECK: rock.threadwise_write_all {{.*}}%[[BLOCK_RED_OUT]] -> [](%[[TR9]]){{.*}} by atomic_add : {{.*}} rock.reduce sum %0 into %arg2 features = mfma|dot|atomic_add {axis = 2 : index, blockSize = 256 : i32, gridSize = 1 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> return } @@ -18,8 +28,21 @@ func.func @test_gemm_reduce_last_axis_fusion(%arg0: memref<1x128x64xf32>, %arg1: func.func @test_gemm_reduce_middle_axis_fusion(%arg0: memref<1x128x64xf32>, %arg1: memref<1x64x256xf32>, %arg2: memref<1x1x256xf32>) attributes {arch = "", kernel} { %0 = memref.alloc() : memref<1x128x256xf32> rock.gemm %0 = %arg0 * %arg1 features = none storeMethod = set {arch = ""} : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> - // CHECK: %[[trOut:.*]] = rock.transform %arg2 by [[MAP1]] : memref<1x1x256xf32> to memref<1x128x256xf32> - // CHECK: rock.threadwise_write_all {{.*}}(%[[trOut]]){{.*}} by atomic_add : {{.*}} -> memref<1x128x256xf32> + // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} into %[[BLOCK_RED_OUT:[0-9]+]] + + // CHECK: %[[TR0:.+]] = rock.transform %arg2 by {{.*}} : memref<1x1x256xf32> to memref<1x128x256xf32> + // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<2x128x2x128xf32> + // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<2x128x2x128xf32> to memref<2x1x2x1x128x128xf32> + // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<2x1x2x1x128x128xf32> to memref<2x1x2x1x128xf32> + // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim0"] at [3]>{{.*}} : memref<2x1x2x1x128xf32> to memref<2x1x2x128x128xf32> + + // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x4x4x4x4x2x4x2x4xf32> + // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<2x1x2x4x4x4x4x2x4x2x4xf32> to memref<2x1x2x256x64xf32> + + // CHECK: rock.threadwise_write_all {{.*}}%[[BLOCK_RED_OUT]] -> [](%[[TR9]]){{.*}} by atomic_add : {{.*}} rock.reduce sum %0 into %arg2 features = mfma|dot|atomic_add {axis = 1 : index, blockSize = 256 : i32, gridSize = 1 : i32} : memref<1x128x256xf32> into memref<1x1x256xf32> return } @@ -35,8 +58,21 @@ func.func @test_gemm_add_reduce_fusion(%arg0: memref<1x128x64xf32>, %arg1: memre %4 = arith.addf %arg4, %arg5 : f32 linalg.yield %4 : f32 } - // CHECK: %[[trOut:.*]] = rock.transform %arg3 by [[MAP0]] : memref<1x128x1xf32> to memref<1x128x256xf32> - // CHECK: rock.threadwise_write_all {{.*}}(%[[trOut]]){{.*}} by atomic_add : {{.*}} -> memref<1x128x256xf32> + // CHECK: rock.blockwise_broadcast_reduce sum {{.*}} into %[[BLOCK_RED_OUT:[0-9]+]] + + // CHECK: %[[TR0:.+]] = rock.transform %arg3 by {{.*}} : memref<1x128x1xf32> to memref<1x128x256xf32> + // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<2x128x2x128xf32> + // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<2x128x2x128xf32> to memref<2x1x2x1x128x128xf32> + // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<2x1x2x1x128x128xf32> to memref<2x1x2x128x1xf32> + // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<2x1x2x128x1xf32> to memref<2x1x2x128x128xf32> + + // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x4x4x4x4x2x4x2x4xf32> + // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<2x1x2x4x4x4x4x2x4x2x4xf32> to memref<2x1x2x256x64xf32> + + // CHECK: rock.threadwise_write_all {{.*}}%[[BLOCK_RED_OUT]] -> [](%[[TR9]]){{.*}} by atomic_add : {{.*}} rock.reduce sum %1 into %arg3 features = mfma|dot|atomic_add {axis = 2 : index, blockSize = 256 : i32, gridSize = 1 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> return } @@ -45,8 +81,21 @@ func.func @test_gemm_add_reduce_fusion(%arg0: memref<1x128x64xf32>, %arg1: memre func.func @test_gemm_reduce_max(%arg0: memref<1x128x64xf32>, %arg1: memref<1x64x256xf32>, %arg2: memref<1x128x1xf32>) attributes {arch = "", kernel} { %0 = memref.alloc() : memref<1x128x256xf32> rock.gemm %0 = %arg0 * %arg1 features = none storeMethod = set {arch = ""} : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> - // CHECK: %[[trOut:.*]] = rock.transform %arg2 by [[MAP0]] : memref<1x128x1xf32> to memref<1x128x256xf32> - // CHECK: rock.threadwise_write_all {{.*}}(%[[trOut]]){{.*}} by atomic_max : {{.*}} -> memref<1x128x256xf32> + // CHECK: rock.blockwise_broadcast_reduce max {{.*}} into %[[BLOCK_RED_OUT:[0-9]+]] + + // CHECK: %[[TR0:.+]] = rock.transform %arg2 by {{.*}} : memref<1x128x1xf32> to memref<1x128x256xf32> + // CHECK: %[[TR1:.+]] = rock.transform %[[TR0]] by {{.*}} : memref<1x128x256xf32> to memref<2x128x2x128xf32> + // CHECK: %[[TR2:.+]] = rock.transform %[[TR1]] by {{.*}} : memref<2x128x2x128xf32> to memref<2x1x2x1x128x128xf32> + // CHECK: %[[TR3:.+]] = rock.transform %[[TR2]] by {{.*}} : memref<2x1x2x1x128x128xf32> to memref<2x1x2x128x1xf32> + // CHECK: %[[TR4:.+]] = rock.transform %[[TR3]] by {{.*}} ["dim1"] at [4]>{{.*}} : memref<2x1x2x128x1xf32> to memref<2x1x2x128x128xf32> + + // CHECK: %[[TR5:.+]] = rock.transform %[[TR4]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR6:.+]] = rock.transform %[[TR5]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR7:.+]] = rock.transform %[[TR6]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x128x128xf32> + // CHECK: %[[TR8:.+]] = rock.transform %[[TR7]] by {{.*}} : memref<2x1x2x128x128xf32> to memref<2x1x2x4x4x4x4x2x4x2x4xf32> + // CHECK: %[[TR9:.+]] = rock.transform %[[TR8]] by {{.*}} : memref<2x1x2x4x4x4x4x2x4x2x4xf32> to memref<2x1x2x256x64xf32> + + // CHECK: rock.threadwise_write_all {{.*}}%[[BLOCK_RED_OUT]] -> [](%[[TR9]]){{.*}} by atomic_max : {{.*}} rock.reduce max %0 into %arg2 features = mfma|dot|atomic_add {axis = 2 : index, blockSize = 256 : i32, gridSize = 1 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> return } diff --git a/mlir/test/rocmlir-driver/pipelines.mlir b/mlir/test/rocmlir-driver/pipelines.mlir index 192932a1624f..c90034f4bf73 100644 --- a/mlir/test/rocmlir-driver/pipelines.mlir +++ b/mlir/test/rocmlir-driver/pipelines.mlir @@ -18,9 +18,10 @@ // GPU-NEXT:rock-conv-to-gemm, // GPU-NEXT:rock-gemm-to-gridwise, // GPU-NEXT:rock-regularize, +// GPU-NEXT:rock-shuffle-gemm-for-reductions, // GPU-NEXT:rock-gridwise-gemm-to-blockwise, -// GPU-NEXT:rock-blockwise-gemm-to-threadwise, // GPU-NEXT:rock-linalg-align, +// GPU-NEXT:rock-blockwise-gemm-to-threadwise, // GPU-NEXT:rock-pipeline{rock-pipeline-remove-stages=true}, // GPU-NEXT:canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, // GPU-NEXT:convert-linalg-to-affine-loops,