Skip to content

Commit

Permalink
Use blockwise_broadcast_reduce in reduction fusions. (#1668)
Browse files Browse the repository at this point in the history
Currently, we handle reductions that are being fused with gemm-like operations
by using atomic stores to the destination buffer. This can be cripplingly slow when
most of the output is being reduced as evidenced in layer_norm cases.

This PR adds the ability to blockwise_broadcast_reduce on the block sub-tiles of
of gemm output.
However, in-order to that we need to make sure the reduction dimension is uniformly
distributed across the blocks. This is achieved by :

Firstly, this PR introduces a utility where for a given set of upper dimensions, it can
traverse a transform stack and produce a list of sub-dimensions per each lower dimension
where the upper reduction axes are mapped to.
Then, this PR introduces ShuffleGemmForReductions pass, which will split and transpose
the parallel dimension of the gemm such that reduction dimension is uniformly
distributed across the blocks.
Then at AlignTiling pass, we extract the block subtile when fusing in the rock.reduce operator.
Then perform a blockwise_broadcast_reduce on the block subtile.
Since we only want to write the partial reductions per block, we pad out broadcasted part of the subtile.
(We rely on any block coordinate that goes to the padded region within the block will not be written out)
Then we need to do Recombine the modified sub-tile coordinate maps with grid-only coordinates maps.
a) Here, we drop all the upper-dimensions except g_block, m_block and n_block and obtain the grid-only transform map
stack.
b) In parallel, we re-use getLowerSubDimensions utility to figure out which sub-dimension gets mapped with the above
grid-only dimensions.
c) Then we extract of those sub-dimensions in a bottom up fashion and stitch it up with the said grid-only transform map
stack.
  • Loading branch information
manupak authored and dhernandez0 committed Oct 29, 2024
1 parent c2c5c91 commit 608638c
Show file tree
Hide file tree
Showing 24 changed files with 1,757 additions and 81 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Rock/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace rock {
#define GEN_PASS_DECL_ROCKVECTORIZEFUSIONSPASS
#define GEN_PASS_DECL_ROCKOUTPUTSWIZZLEPASS
#define GEN_PASS_DECL_ROCKREUSELDSPASS
#define GEN_PASS_DECL_ROCKSHUFFLEGEMMFORREDUCTIONS

#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Rock/Passes.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Rock/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,10 @@ def RockReuseLDSPass : Pass<"rock-reuse-lds", "::mlir::func::FuncOp"> {
let dependentDialects = ["rock::RockDialect", "memref::MemRefDialect"];
}

def RockShuffleGemmForReductions : Pass<"rock-shuffle-gemm-for-reductions", "::mlir::func::FuncOp"> {
let summary = "This pass shuffles parallel gemm dimensions for equally block splits for fused reductions (if any)";
let dependentDialects = ["rock::RockDialect", "memref::MemRefDialect"];
}


#endif // MLIR_DIALECT_ROCK_PASSES
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,21 @@ struct PopulateParamsInfo {
KernelType kernelType;
int64_t batchSize;
uint32_t numCu;
bool hasFusedReduction;

PopulateParamsInfo(GemmSize gemmSize, StringRef arch,
GemmFeatures gemmFeatures, Type gemmAType, Type gemmBType,
KernelType kernelType)
: gemmSize(gemmSize), arch(arch), gemmFeatures(gemmFeatures),
gemmAType(gemmAType), gemmBType(gemmBType), kernelType(kernelType) {}
gemmAType(gemmAType), gemmBType(gemmBType), kernelType(kernelType),
hasFusedReduction(false) {}

PopulateParamsInfo(GemmSize gemmSize, StringRef arch,
GemmFeatures gemmFeatures, Type gemmAType, Type gemmBType,
KernelType kernelType, int64_t batchSize, uint32_t numCu)
: gemmSize(gemmSize), arch(arch), gemmFeatures(gemmFeatures),
gemmAType(gemmAType), gemmBType(gemmBType), kernelType(kernelType),
batchSize(batchSize), numCu(numCu) {}
batchSize(batchSize), numCu(numCu), hasFusedReduction(false) {}

/// Extract the relevant information from a RockGemmWrapperInterface operation
static PopulateParamsInfo fromOp(RockGemmWrapperInterface op);
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ TypedValue<MemRefType> viewBufferAs(OpBuilder &b, Value buffer, Type type);
Value gpuAlloc(OpBuilder &b, Location loc, int64_t bufferDim, Type elementType,
gpu::AddressSpace memoryAddressSpace);

// helper to verify a lds allocation fits in the GPU
LogicalResult checkLDSSize(StringAttr arch, int64_t ldsBytes);

} // end namespace rock
} // end namespace mlir
#endif
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/Rock/utility/transformMapUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,28 @@ FailureOr<ArrayAttr> removeUpperDims(OpBuilder &b, ArrayAttr transformAttrs,
// padded data.
FailureOr<ArrayAttr> removeUpperDims(OpBuilder &b, ArrayAttr transformAttrs,
const StringSet<> &removeDimNamesSet);

struct SubDimInfo {
int64_t size;
int64_t stride;
};

inline raw_ostream &operator<<(raw_ostream &os, const SubDimInfo &sdInfo) {
os << "<size: " << sdInfo.size << ",stride=" << sdInfo.stride << ">";
return os;
}

// Given a sequence of transform maps, this will obtain the lower sub-dimensions
// each provided upper dim would map to.
FailureOr<llvm::SmallDenseMap<int64_t, SmallVector<SubDimInfo>>>
getLowerSubDimensions(OpBuilder &b, ArrayAttr transformAttrs, int64_t dim);
FailureOr<llvm::SmallDenseMap<int64_t, SmallVector<SubDimInfo>>>
getLowerSubDimensions(OpBuilder &b, ArrayAttr transformAttrs,
ArrayRef<int64_t> dims);

SmallVector<SmallString<8>> createDimNames(int64_t len, StringRef prefix);
SmallVector<StringRef> getStringRefsFor(ArrayRef<SmallString<8>> strings);

} // end namespace rock
} // end namespace mlir
#endif
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2033,7 +2033,7 @@ LogicalResult BlockwiseBroadcastReduceOp::verify() {
}
if (blockwiseInputPartialReductionTensorElements > wsShape[0]) {
return emitError(
"workspace should be at least the size of elements per block");
"workspace should be at least the size of elements per block ");
}
return success();
}
Expand Down
9 changes: 8 additions & 1 deletion mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,22 @@ void rock::buildKernelPipeline(OpPassManager &pm,
funcPm.addPass(rock::createRockConvToGemmPass());
funcPm.addPass(rock::createRockGemmToGridwisePass());
funcPm.addPass(rock::createRockRegularizePass());
funcPm.addPass(rock::createRockShuffleGemmForReductions());
funcPm.addPass(rock::createRockGridwiseGemmToBlockwisePass());
funcPm.addPass(rock::createRockBlockwiseGemmToThreadwisePass());
// We want to delay blockwise lowering in the fusion cases
// until after linalg align pass because with reduction fusion
// it may introduce blockwise_reductions.
if (!options.enableFusion) {
funcPm.addPass(rock::createRockBlockwiseGemmToThreadwisePass());
}

if (options.enableFusion) {
// align linalg tiling
/* rocmlir-opt --rock-linalg-align --canonicalize
* --convert-linalg-to-affine-loops
*/
funcPm.addPass(rock::createRockLinalgAlignPass());
funcPm.addPass(rock::createRockBlockwiseGemmToThreadwisePass());
funcPm.addPass(rock::createRockPipelinePass());
funcPm.addPass(createCanonicalizerPass());
funcPm.addPass(createConvertLinalgToAffineLoopsPass());
Expand Down
Loading

0 comments on commit 608638c

Please sign in to comment.