Skip to content

Commit

Permalink
[VectorDistribution] Configure contraction layouts at linalg level
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Aug 9, 2024
1 parent df3d588 commit 1c682aa
Show file tree
Hide file tree
Showing 23 changed files with 432 additions and 376 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
// Infer the contract kind so that we know know to correlate M/N/K dims.
auto maybeOpDetail = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpDetail)) {
return rewriter.notifyMatchFailure(contractOp, "invalid contraction");
}
VectorContractOpInfo opDetail = maybeOpDetail.value();

auto resultType = dyn_cast<VectorType>(contractOp.getResultType());
if (!resultType) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -65,9 +73,6 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

// Infer the contract kind so that we know know to correlate M/N/K dims.
VectorContractOpInfo opDetail(contractOp);

SmallVector<int64_t> distShape = resultLayout.getDistributedShape();
LLVM_DEBUG({
llvm::dbgs() << "distributed shape: [";
Expand Down
48 changes: 22 additions & 26 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -104,46 +105,41 @@ struct GPUVectorAllocPass final
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();

SmallVector<vector::ContractionOp> opsToPromote;
funcOp.walk([&](vector::ContractionOp op) {
// Today we only do promotion for certain contractions.
if (contractOpFilter(op))
SmallVector<IREE::VectorExt::ToLayoutOp> opsToPromote;
funcOp.walk([&](IREE::VectorExt::ToLayoutOp op) {
if (op->hasAttr("shared_memory_conversion")) {
opsToPromote.push_back(op);
}
});
for (vector::ContractionOp contractOp : opsToPromote) {
OpBuilder builder(contractOp);

for (IREE::VectorExt::ToLayoutOp op : opsToPromote) {
OpBuilder builder(op);

// HACK: Until proper barrier placement is handled later we have to
// synchronize explicitly in this pass.

// Synchronize before the write to shared memory to avoid stepping over
// reads in the previous iteration of a loop.
builder.create<gpu::BarrierOp>(contractOp->getLoc());
builder.create<gpu::BarrierOp>(op->getLoc());

// Promote both of the input operands, excluding the accumulator.
OpOperand &lhs = contractOp.getLhsMutable();
FailureOr<Value> lhsRet =
allocateTensorForVector(builder, contractOp->getLoc(), lhs.get());
if (failed(lhsRet)) {
return signalPassFailure();
}

OpOperand &rhs = contractOp.getRhsMutable();
FailureOr<Value> rhsRet =
allocateTensorForVector(builder, contractOp->getLoc(), rhs.get());
if (failed(rhsRet)) {
OpOperand &operand = op.getInputMutable();
FailureOr<Value> ret =
allocateTensorForVector(builder, op->getLoc(), operand.get());
if (failed(ret)) {
return signalPassFailure();
}

// Synchronize after the write to shared memory before we read from it.
builder.create<gpu::BarrierOp>(contractOp->getLoc());

Value lhsVec =
readVectorFromTensor(builder, contractOp.getLhsType(), *lhsRet);
Value rhsVec =
readVectorFromTensor(builder, contractOp.getRhsType(), *rhsRet);
lhs.set(lhsVec);
rhs.set(rhsVec);
builder.create<gpu::BarrierOp>(op->getLoc());

VectorType inputTy = cast<VectorType>(op.getType());
Value read = readVectorFromTensor(builder, inputTy, *ret);
operand.set(read);

// Remove the shared_memory_conversion attribute from the to_layout
// operation.
op->removeAttr("shared_memory_conversion");
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -957,8 +957,14 @@ NestedLayoutAttr createNestedLayout(MLIRContext *context, int64_t rank,
FailureOr<std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
VectorContractOpInfo opInfo(contractOp);
MMAScheduleAttr::getContractionLayout(linalg::GenericOp contractOp) const {
auto maybeOpInfo = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpInfo)) {
return failure();
}
VectorContractOpInfo opInfo = maybeOpInfo.value();

LLVM_DEBUG({
llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
llvm::errs() << "For schedule: " << *this << "\n";
Expand All @@ -971,8 +977,7 @@ MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
auto mmaAttr = llvm::cast<MMAAttr>(getIntrinsic());
MLIRContext *context = getContext();

SmallVector<int64_t> bounds;
contractOp.getIterationBounds(bounds);
SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
int64_t batchCount = opInfo.getBatchCount();
if (batchCount == 1 && bounds[0] != 1) {
LLVM_DEBUG({ llvm::errs() << "non-unit batch dimension\n"; });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def IREEGPU_MmaScheduleAttr : AttrDef<IREEGPU_Dialect, "MMASchedule"> {
::mlir::FailureOr<::std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
getContractionLayout(::mlir::vector::ContractionOp contractOp) const;
getContractionLayout(::mlir::linalg::GenericOp contractOp) const;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ struct VectorizeToLayoutOpPattern final
// Create the toLayout operation but with vector types instead.
auto newLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, newInput.getType(), newInput, toLayoutOp.getLayout());
// Set attributes.
newLayoutOp->setAttrs(toLayoutOp->getAttrs());

// Create the write back to a tensor.
int64_t rank = inputTy.getRank();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ struct AMDGPUPrepareForChainedMatmulPass
registry.insert<vector::VectorDialect>();
}

VectorContractOpInfo getOpInfo(vector::ContractionOp contract) const {
auto maybeOpInfo = VectorContractOpInfo::inferFromIndexingMaps(
contract.getIndexingMapsArray());
assert(succeeded(maybeOpInfo) &&
"contraction info for vector.contract should always be valid");
return maybeOpInfo.value();
}

VectorValue swapDims(RewriterBase &rewriter, VectorValue val, int64_t dimA,
int64_t dimB) const {
ArrayRef<int64_t> shape = val.getType().getShape();
Expand Down Expand Up @@ -106,7 +114,7 @@ struct AMDGPUPrepareForChainedMatmulPass
/// simply swap the operands without transposing them.
void swapOperandsAndTranspose(RewriterBase &rewriter,
vector::ContractionOp contractOp) const {
VectorContractOpInfo opInfo(contractOp);
VectorContractOpInfo opInfo = getOpInfo(contractOp);
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
Expand Down Expand Up @@ -174,7 +182,7 @@ struct AMDGPUPrepareForChainedMatmulPass
bool isOperandSwapInvariant(vector::ContractionOp contractOp) const {
// Check if the innermost m, n, k dimensions are in the order:
// lhs: (m, k), rhs: (n, k)
VectorContractOpInfo opInfo(contractOp);
VectorContractOpInfo opInfo = getOpInfo(contractOp);
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
bool isLhsTransposed = lhsM > lhsK;
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ iree_compiler_cc_library(
"KernelConfig.cpp",
"LLVMGPUCastAddressSpaceFunction.cpp",
"LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPUConfigureTensorLayouts.cpp",
"LLVMGPUConfigureVectorLayouts.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ iree_cc_library(
"KernelConfig.cpp"
"LLVMGPUCastAddressSpaceFunction.cpp"
"LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPUConfigureTensorLayouts.cpp"
"LLVMGPUConfigureVectorLayouts.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ struct UpcastContractOutput final : OpRewritePattern<vector::ContractionOp> {

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
VectorContractOpInfo opInfo(contractOp);
auto maybeOpInfo = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpInfo)) {
return rewriter.notifyMatchFailure(contractOp, "not a contraction");
}
VectorContractOpInfo opInfo = maybeOpInfo.value();

auto srcCType = dyn_cast<VectorType>(contractOp.getAccType());
if (!srcCType) {
Expand Down Expand Up @@ -66,6 +71,8 @@ struct UpcastContractOutput final : OpRewritePattern<vector::ContractionOp> {
auto newContractOp = rewriter.create<vector::ContractionOp>(
loc, contractOp.getLhs(), contractOp.getRhs(), extOp,
contractOp.getIndexingMaps(), contractOp.getIteratorTypes());
newContractOp->setDiscardableAttrs(
contractOp->getDiscardableAttrDictionary());
rewriter.replaceOpWithNewOp<arith::TruncFOp>(contractOp, srcCType,
newContractOp);
return success();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"

#define DEBUG_TYPE "iree-llvmgpu-configure-vector-layouts"

namespace mlir::iree_compiler {

namespace {

LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
RewriterBase &rewriter,
linalg::GenericOp contract) {
// TODO: Add SIMT fallback.
if (!schedule) {
return contract->emitError("missing mma schedule for contraction");
}

// This function should have only be called on a contraction op.
assert(linalg::isaContractionOpInterface(contract) &&
"cannot set contraction anchor on non contraction op");

auto layouts = schedule.getContractionLayout(contract);
if (failed(layouts)) {
return contract->emitError("cannot get concrete layout for contraction");
}

auto [aLayout, bLayout, cLayout] = *layouts;
Location loc = contract.getLoc();

Value lhs = contract.getOperand(0);
Value rhs = contract.getOperand(1);
Value acc = contract.getOperand(2);

// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(contract);
auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, lhs.getType(), lhs, aLayout);
auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, rhs.getType(), rhs, bLayout);
auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, acc.getType(), acc, cLayout);

// Promote matmul lhs and rhs.
// TODO: We should read this from the lowering_config on the operation.
// TODO: This is a hack until layout analysis is improved. The layout analysis
// should decide where to put these shared memory conversions.
layoutedLhs->setAttr("shared_memory_conversion", rewriter.getUnitAttr());
layoutedRhs->setAttr("shared_memory_conversion", rewriter.getUnitAttr());

contract->setOperand(0, layoutedLhs.getResult());
contract->setOperand(1, layoutedRhs.getResult());
contract->setOperand(2, layoutedAcc.getResult());

// Set layout for result.
rewriter.setInsertionPointAfter(contract);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, contract.getResult(0).getType(), contract.getResult(0), cLayout);
rewriter.replaceAllUsesExcept(contract.getResult(0), toLayout.getResult(),
toLayout);

return success();
}

struct LLVMGPUConfigureTensorLayoutsPass
: public LLVMGPUConfigureTensorLayoutsBase<
LLVMGPUConfigureTensorLayoutsPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
registry.insert<vector::VectorDialect>();
}

void runOnOperation() override {
auto func = getOperation();

std::array<int64_t, 3> workgroupSize;
if (func->hasAttr("workgroup_size")) {
auto tmpSizes =
llvm::cast<ArrayAttr>(func->getAttr("workgroup_size")).getValue();
for (auto [i, size] : llvm::enumerate(tmpSizes)) {
workgroupSize[i] = llvm::cast<IntegerAttr>(size).getInt();
}
} else {
std::optional<SmallVector<int64_t>> maybeWorkgroupSize =
getWorkgroupSize(func);
if (!maybeWorkgroupSize) {
func->emitOpError()
<< "unable to query workgroup_size information from entry point";
return signalPassFailure();
}
for (auto [index, value] : llvm::enumerate(maybeWorkgroupSize.value())) {
workgroupSize[index] = value;
}
for (auto index : llvm::seq<size_t>(maybeWorkgroupSize->size(), 3)) {
workgroupSize[index] = 1;
}
}

llvm::StringLiteral scheduleAttrName =
IREE::GPU::MMAScheduleAttr::getMnemonic();
auto scheduleAttr =
func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
if (!scheduleAttr) {
DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
configDict.get(scheduleAttrName));
}

// Vector layout option setter aimed at contractions. For now, layout
// setting for other problems like reductions is TODO.
SmallVector<linalg::GenericOp> contracts;

func->walk([&](linalg::GenericOp linalgOp) {
if (linalg::isaContractionOpInterface(linalgOp)) {
contracts.push_back(linalgOp);
}
});

IRRewriter rewriter(func);

for (linalg::GenericOp contract : contracts) {
if (failed(setContractionAnchor(scheduleAttr, rewriter, contract))) {
return signalPassFailure();
}
}
}
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMGPUConfigureTensorLayouts() {
return std::make_unique<LLVMGPUConfigureTensorLayoutsPass>();
}

} // namespace mlir::iree_compiler
Loading

0 comments on commit 1c682aa

Please sign in to comment.