Skip to content

Commit

Permalink
[LLVMGPU] Add tests for VectorDistribution subgroup reductipon pipeli…
Browse files Browse the repository at this point in the history
…ne (#19285)

This patch adds two things:

- A way to set how resources (subgroups/threads) are distributed over a
basis.
- A test showing subgroup reduction for matvec using vector distribution
pipeline

There are some differences from the current WarpReduction pipeline:

- Currently, the pipeline doesn't do a split reduction. This can be
configured in the lowering config later (partial_reduction tile sizes
instead of just reduction)
- The writeback to global memory is not gaurded, i.e. all threads are
writing to global memory. This will be fixed as a followup patch to only
make one thread write at a time.
  • Loading branch information
Groverkss authored Dec 9, 2024
1 parent 39c56de commit aef2da1
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,13 @@ struct DistributeTransferWrite final
Value slicedVector = rewriter.create<vector::ExtractOp>(
writeOp.getLoc(), distributedVector,
offsetArray.take_front(rank * 2));
// Promote the slicedVector to 0-d vector if it is a scalar.
if (!isa<VectorType>(slicedVector.getType())) {
auto promotedType =
VectorType::get({}, getElementTypeOrSelf(slicedVector));
slicedVector = rewriter.create<vector::BroadcastOp>(
writeOp.getLoc(), promotedType, slicedVector);
}
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), slicedVector, writeOp.getSource(), slicedIndices,
writeOp.getPermutationMapAttr(), writeOp.getMask(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

namespace mlir::iree_compiler::IREE::GPU {

static SmallVector<int64_t> getIntegerVector(ArrayAttr array) {
static std::optional<SmallVector<int64_t>> getIntegerVector(ArrayAttr array) {
if (!array || !llvm::all_of(array.getValue(), llvm::IsaPred<IntegerAttr>)) {
return {};
return std::nullopt;
}
return llvm::map_to_vector(array.getValue(), [](Attribute s) -> int64_t {
return cast<IntegerAttr>(s).getInt();
Expand Down Expand Up @@ -68,6 +68,54 @@ void setSubgroupNCount(MLIRContext *context,
IntegerAttr::get(IntegerType::get(context, 64), subgroup_n_count));
}

const StringLiteral kSubgroupBasisName = "subgroup_basis";
const StringLiteral kThreadBasisName = "thread_basis";

static StringLiteral getBasisLevelName(IREE::GPU::TilingLevel level) {
switch (level) {
case GPU::TilingLevel::Thread:
return kThreadBasisName;
case GPU::TilingLevel::Subgroup:
return kSubgroupBasisName;
default:
assert(false && "Unknown tiling level for distribution");
return "";
}
}

void setBasis(MLIRContext *context, SmallVector<NamedAttribute> &attrs,
IREE::GPU::TilingLevel level, const Basis &basis) {
Builder b(context);
ArrayAttr basisAttr = b.getArrayAttr(
{b.getI64ArrayAttr(basis.counts), b.getI64ArrayAttr(basis.mapping)});
attrs.emplace_back(b.getNamedAttr(getBasisLevelName(level), basisAttr));
}

FailureOr<Basis> getBasis(IREE::GPU::LoweringConfigAttr config,
IREE::GPU::TilingLevel level) {
auto basisAttr = dyn_cast_or_null<ArrayAttr>(
config.getAttributes().get(getBasisLevelName(level)));
if (!basisAttr) {
return failure();
}

ArrayRef<Attribute> attrs = basisAttr.getValue();
if (attrs.size() != 2) {
return failure();
}

std::optional<SmallVector<int64_t>> maybeCounts =
getIntegerVector(dyn_cast_or_null<ArrayAttr>(attrs[0]));
std::optional<SmallVector<int64_t>> maybeMapping =
getIntegerVector(dyn_cast_or_null<ArrayAttr>(attrs[1]));

if (!maybeCounts.has_value() || !maybeMapping.has_value()) {
return failure();
}

return Basis{maybeCounts.value(), maybeMapping.value()};
}

constexpr StringLiteral kPromoteOperandsName = "promote_operands";

std::optional<SmallVector<int64_t>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ void setSubgroupNCount(MLIRContext *context,
SmallVectorImpl<NamedAttribute> &attrs,
int64_t subgroupNCount);

// The basis consists of two integer arrays:
// - "counts": number of resource to use per dimension in the basis.
// - "mapping": a projected permutation to map to basis to the operations
// iteration space.
//
// Given a resource "x", the "basis" can be used to determine the distribution
// of an iteration space using:
//
// b = delinearize(x, counts)
// idx = apply(b, mapping)
struct Basis {
SmallVector<int64_t> counts;
SmallVector<int64_t> mapping;
};

// Helper to retrieve/set distribution basis.
FailureOr<Basis> getBasis(IREE::GPU::LoweringConfigAttr config,
IREE::GPU::TilingLevel level);
void setBasis(MLIRContext *context, SmallVector<NamedAttribute> &attrs,
IREE::GPU::TilingLevel level, const Basis &basis);

/// Helper to retrieve/set a list of operand indices to promote.
std::optional<SmallVector<int64_t>>
getPromotedOperandList(LoweringConfigAttr config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,7 @@ struct LLVMGPUCastTypeToFitMMAPass final
auto func = getOperation();

// Set MMA type from config embedded in toLayoutOp of contraction.
func.walk([&](vector::ContractionOp contract) {
inferMmaKind(contract);
if (!contract->hasAttr("iree.amdgpu.mma")) {
func.emitOpError("Failed to detect valid to_layout consumer of "
"vector.contract to infer MMA kind.");
return signalPassFailure();
}
});
func.walk([&](vector::ContractionOp contract) { inferMmaKind(contract); });

MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -685,7 +686,8 @@ static IREE::VectorExt::VectorLayoutInterface
getLayoutForMap(VectorLayoutInterface layout, AffineMap map) {
// Project out unusued dims in layout.
SmallVector<bool> projectedDims(layout.getRank(), false);
for (int dim : getUnusedDimsBitVector(map).set_bits()) {
llvm::SmallBitVector unusedBits = getUnusedDimsBitVector(map);
for (int dim : unusedBits.set_bits()) {
projectedDims[dim] = true;
}
IREE::VectorExt::VectorLayoutInterface projectedLayout =
Expand Down Expand Up @@ -804,6 +806,153 @@ static LogicalResult setIntrinsicLoweringConfigLayout(
return failure();
}

/// Given two arrays bounds and tile, compute bounds /= tile.
///
/// If "tile" contains 0, or is smaller than bounds, divide bounds by 1
/// for those values.
///
/// Returns the actual divisor (without zeros or out of bounds) used to compute
/// bounds /= divisor.
FailureOr<SmallVector<int64_t>> divideTile(SmallVector<int64_t> &bounds,
ArrayRef<int64_t> tile) {
assert(bounds.size() >= tile.size() &&
"cannot divide bounds with a larger tile size");

SmallVector<int64_t> divisor(bounds.size(), 1);
for (auto [div, size] : llvm::zip(divisor, tile)) {
if (size == 0) {
continue;
}
div = size;
}

for (auto [bound, div] : llvm::zip_equal(bounds, divisor)) {
bound /= div;
}

return divisor;
}

SmallVector<int64_t> applyProjectedPermutation(ArrayRef<int64_t> input,
ArrayRef<int64_t> perm) {
SmallVector<int64_t> result;
result.reserve(perm.size());
for (int64_t dim : perm) {
result.push_back(input[dim]);
}
return result;
}

SmallVector<int64_t> getStridesFromBasis(ArrayRef<int64_t> basis) {
SmallVector<int64_t> strides(basis.size());
int64_t currStride = 1;
for (auto [stride, size] : llvm::reverse(llvm::zip_equal(strides, basis))) {
stride = currStride;
currStride *= size;
}
return strides;
}

static LogicalResult distributeTilingSizes(linalg::LinalgOp candidate,
IREE::GPU::LoweringConfigAttr config,
IREE::GPU::TilingLevel level,
SmallVector<int64_t> &bounds,
SmallVector<int64_t> &sizes,
SmallVector<int64_t> &strides) {
if (ShapedType::isDynamicShape(bounds)) {
candidate->emitError()
<< "Cannot set layouts on a dynamically shaped iteration space";
return failure();
}

FailureOr<IREE::GPU::Basis> basis = IREE::GPU::getBasis(config, level);
if (failed(basis)) {
candidate->emitError()
<< "Could not find a subgroup basis from lowering config";
return failure();
}

sizes = applyProjectedPermutation(basis->counts, basis->mapping);
strides = applyProjectedPermutation(getStridesFromBasis(basis->counts),
basis->mapping);

if (failed(divideTile(bounds, sizes))) {
candidate->emitError()
<< "Could not divide bounds over given basis for level: "
<< IREE::GPU::stringifyTilingLevel(level);
return failure();
}

return success();
}

static LogicalResult setGPULoweringConfigLayout(
IREE::GPU::LoweringConfigAttr config, linalg::LinalgOp candidate,
ArrayRef<int64_t> workgroupSize, RewriterBase &rewriter) {
MLIRContext *context = config.getContext();
Location loc = candidate.getLoc();

SmallVector<int64_t> bounds = candidate.getStaticLoopRanges();

// Subgroup distribution layouts.
SmallVector<int64_t> subgroupSizes, subgroupStrides;
if (failed(distributeTilingSizes(candidate, config,
IREE::GPU::TilingLevel::Subgroup, bounds,
subgroupSizes, subgroupStrides))) {
return failure();
}

// Thread distribution layouts.
SmallVector<int64_t> threadSizes, threadStrides;
if (failed(distributeTilingSizes(candidate, config,
IREE::GPU::TilingLevel::Thread, bounds,
threadSizes, threadStrides))) {
return failure();
}

// Use thread tile sizes as the vector width for each thread.
SmallVector<int64_t> threadTileSizes = config.getStaticTilingLevelSizes(
llvm::to_underlying(IREE::GPU::TilingLevel::Thread), candidate);
FailureOr<SmallVector<int64_t>> elementTile =
divideTile(bounds, threadTileSizes);
if (failed(elementTile)) {
candidate->emitError() << "Could not divide bounds over given thread tile";
}
// The remaining bounds become batch sizes. We could also use subgroup tile
// sizes, as a way of specifying batch size, but since it is a derived
// property, we choose to compute it.
ArrayRef<int64_t> batchTile = bounds;
SmallVector<int64_t> outerTile(bounds.size(), 1);

auto layout = IREE::VectorExt::NestedLayoutAttr::get(
context, subgroupSizes, batchTile, outerTile, threadSizes,
elementTile.value(), subgroupStrides, threadStrides);

SmallVector<bool> promotedOperands = getPromotedOperands(candidate);

rewriter.setInsertionPoint(candidate);
for (OpOperand &operand : candidate->getOpOperands()) {
VectorLayoutInterface operandLayout =
getLayoutForMap(layout, candidate.getMatchingIndexingMap(&operand));
auto toLayout =
rewriter.create<ToLayoutOp>(loc, operand.get(), operandLayout);
// Set shared memory promotion if requested.
toLayout.setSharedMemoryConversion(
promotedOperands[operand.getOperandNumber()]);
operand.set(toLayout);
}

rewriter.setInsertionPointAfter(candidate);
for (OpResult result : candidate->getResults()) {
VectorLayoutInterface resultLayout =
getLayoutForMap(layout, candidate.getIndexingMapMatchingResult(result));
auto toLayout = rewriter.create<ToLayoutOp>(loc, result, resultLayout);
rewriter.replaceAllUsesExcept(result, toLayout, toLayout);
}

return success();
}

static Operation *getOpWithAttr(Operation *root, StringRef attr) {
Operation *result = nullptr;
WalkResult walkResult = root->walk([&](Operation *op) {
Expand Down Expand Up @@ -900,10 +1049,8 @@ struct LLVMGPUConfigureTensorLayoutsPass final
return setIntrinsicLoweringConfigLayout(
config, candidate, workgroupSize, rewriter);
}
candidate->emitError() << "Unable to set layouts on operation "
"based on given lowering config: "
<< config;
return failure();
return setGPULoweringConfigLayout(config, candidate,
workgroupSize, rewriter);
})
.Default([](Attribute) -> LogicalResult { return failure(); });

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_lit_test_suite(
"pipeline_igemm_tile_and_fuse.mlir",
"pipeline_tile_and_fuse.mlir",
"pipeline_vector_distribute_gfx942.mlir",
"pipeline_vector_distribute_reduction_gfx942.mlir",
"pipeline_vector_distribute_gfx1100.mlir",
"pipeline_warp_reduction.mlir",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"pipeline_tile_and_fuse.mlir"
"pipeline_vector_distribute_gfx1100.mlir"
"pipeline_vector_distribute_gfx942.mlir"
"pipeline_vector_distribute_reduction_gfx942.mlir"
"pipeline_warp_reduction.mlir"
TOOLS
FileCheck
Expand Down
Loading

0 comments on commit aef2da1

Please sign in to comment.