Skip to content

Commit

Permalink
[VectorDistribution] Configure contraction layouts at linalg level (#…
Browse files Browse the repository at this point in the history
…18152)

This patch moves layout anchoring for contractions to linalg level. This
is primarily motivated by allowing us to decide layouts based on
lowering_config.

Future patches will also move the transfer_read anchoring to linalg
level.
  • Loading branch information
Groverkss authored Aug 26, 2024
1 parent 7a7bfe1 commit 23c92d1
Show file tree
Hide file tree
Showing 29 changed files with 585 additions and 562 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,72 @@ namespace {
using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

static LogicalResult isSubgroupLayoutCompatible(
IREE::GPU::MMAAttr::SingleSubgroupLayout subgroupLayout,
NestedLayoutAttr layout, int64_t dim1, int64_t dim2) {
SmallVector<int64_t> element = {layout.getElementsPerThread()[dim1],
layout.getElementsPerThread()[dim2]};
SmallVector<int64_t> thread = {layout.getThreadsPerOuter()[dim1],
layout.getThreadsPerOuter()[dim2]};
SmallVector<int64_t> tstrides = {layout.getThreadStrides()[dim1],
layout.getThreadStrides()[dim2]};
SmallVector<int64_t> outer = {layout.getOutersPerBatch()[dim1],
layout.getOutersPerBatch()[dim2]};

if (subgroupLayout.element != element) {
return failure();
}
if (subgroupLayout.thread != thread) {
return failure();
}
if (subgroupLayout.tstrides != tstrides) {
return failure();
}
if (subgroupLayout.outer != outer) {
return failure();
}

return success();
}

static LogicalResult isIntrinsicLayoutCompatible(VectorContractOpInfo &opInfo,
IREE::GPU::MMAAttr intrinsic,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
if (failed(isSubgroupLayoutCompatible(intrinsic.getASingleSubgroupLayout(),
lhsLayout, lhsM, lhsK))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getBSingleSubgroupLayout(),
rhsLayout, rhsK, rhsN))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getCSingleSubgroupLayout(),
accLayout, accM, accN))) {
return failure();
}
return success();
}

/// Distributes `vector.contract` ops with nested layouts.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
// Infer the contract kind so that we know know to correlate M/N/K dims.
auto maybeOpDetail = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpDetail)) {
return rewriter.notifyMatchFailure(contractOp, "invalid contraction");
}
VectorContractOpInfo opDetail = maybeOpDetail.value();

auto resultType = dyn_cast<VectorType>(contractOp.getResultType());
if (!resultType) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -55,6 +114,12 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction rhs");
}
NestedLayoutAttr accLayout =
dyn_cast<NestedLayoutAttr>(signature[resultValue]);
if (!accLayout) {
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction acc");
}

// We assume there is an decision made before regarding which mfma intrinsic
// to use and it is attached as an attribute to this contract op.
Expand All @@ -65,8 +130,13 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

// Infer the contract kind so that we know know to correlate M/N/K dims.
VectorContractOpInfo opDetail(contractOp);
// Check if the given intrinsic can be distributed with the given
// layouts.
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaAttr, lhsLayout,
rhsLayout, accLayout))) {
return rewriter.notifyMatchFailure(
contractOp, "the intrinsic does not match the expected layouts");
}

SmallVector<int64_t> distShape = resultLayout.getDistributedShape();
LLVM_DEBUG({
Expand Down
62 changes: 29 additions & 33 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -80,12 +82,7 @@ static FailureOr<Value> allocateTensorForVector(OpBuilder &b, Location loc,
Value copied = b.create<vector::TransferWriteOp>(loc, vector, allocTensorOp,
indices, inBounds)
.getResult();
// Create a marker for bufferization to keep this tensor in place. This
// prevents read/write forwarding of the transfers used to do the copy.
return b
.create<bufferization::MaterializeInDestinationOp>(copied.getLoc(),
copied, copied)
->getResult(0);
return copied;
}

static Value readVectorFromTensor(OpBuilder &b, VectorType vectorType,
Expand All @@ -104,46 +101,45 @@ struct GPUVectorAllocPass final
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();

SmallVector<vector::ContractionOp> opsToPromote;
funcOp.walk([&](vector::ContractionOp op) {
// Today we only do promotion for certain contractions.
if (contractOpFilter(op))
SmallVector<IREE::VectorExt::ToLayoutOp> opsToPromote;
funcOp.walk([&](IREE::VectorExt::ToLayoutOp op) {
if (op.getSharedMemoryConversion()) {
opsToPromote.push_back(op);
}
});
for (vector::ContractionOp contractOp : opsToPromote) {
OpBuilder builder(contractOp);

for (IREE::VectorExt::ToLayoutOp op : opsToPromote) {
OpBuilder builder(op);

// HACK: Until proper barrier placement is handled later we have to
// synchronize explicitly in this pass.

// Synchronize before the write to shared memory to avoid stepping over
// reads in the previous iteration of a loop.
builder.create<gpu::BarrierOp>(contractOp->getLoc());
// reads in the previous iteration of a loop. We set this barrier
// at the start of this block.
builder.setInsertionPointToStart(op->getBlock());
builder.create<gpu::BarrierOp>(op->getLoc());

// Promote both of the input operands, excluding the accumulator.
OpOperand &lhs = contractOp.getLhsMutable();
FailureOr<Value> lhsRet =
allocateTensorForVector(builder, contractOp->getLoc(), lhs.get());
if (failed(lhsRet)) {
return signalPassFailure();
}

OpOperand &rhs = contractOp.getRhsMutable();
FailureOr<Value> rhsRet =
allocateTensorForVector(builder, contractOp->getLoc(), rhs.get());
if (failed(rhsRet)) {
builder.setInsertionPoint(op);
OpOperand &operand = op.getInputMutable();
FailureOr<Value> ret =
allocateTensorForVector(builder, op->getLoc(), operand.get());
if (failed(ret)) {
return signalPassFailure();
}

// Synchronize after the write to shared memory before we read from it.
builder.create<gpu::BarrierOp>(contractOp->getLoc());

Value lhsVec =
readVectorFromTensor(builder, contractOp.getLhsType(), *lhsRet);
Value rhsVec =
readVectorFromTensor(builder, contractOp.getRhsType(), *rhsRet);
lhs.set(lhsVec);
rhs.set(rhsVec);
auto synced =
builder.create<IREE::GPU::ValueBarrierOp>(op->getLoc(), *ret);

VectorType inputTy = cast<VectorType>(op.getType());
Value read = readVectorFromTensor(builder, inputTy, synced.getResult(0));
operand.set(read);

// Remove the shared_memory_conversion attribute from the to_layout
// operation.
op.setSharedMemoryConversion(false);
}
}
};
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def GPUVectorAllocPass :
let summary = "Pass to create allocations for contraction inputs to copy "
"to GPU shared memory";
let dependentDialects = [
"::mlir::gpu::GPUDialect", "::mlir::bufferization::BufferizationDialect"
"::mlir::gpu::GPUDialect",
"::mlir::vector::VectorDialect",
"::mlir::bufferization::BufferizationDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ builtin.module attributes { transform.with_named_sequence } {
elements_per_thread = [1, 4],

subgroup_strides = [1, 1],
thread_strides = [32, 1]
thread_strides = [1, 32]
>

// C: shape = 32x64, layout = layoutC
Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,26 @@
// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-vector-alloc))" | FileCheck %s

func.func @matmul_256x256x256(%lhs: tensor<16x256xf16>,
%rhs: tensor<256x16xf16>,
%out: tensor<16x16xf32>) -> tensor<16x16xf32> {
%cst = arith.constant 0.000000e+00 : f16
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf32>
%c32 = arith.constant 32 : index
%c256 = arith.constant 256 : index
%c0 = arith.constant 0 : index
%8 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_0) -> (vector<16x16xf32>) {
%10 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : tensor<16x256xf16>, vector<16x32xf16>
%11 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true]} : tensor<256x16xf16>, vector<32x16xf16>
%12 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %10, %11, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32>
scf.yield %12 : vector<16x16xf32>
}
%9 = vector.transfer_write %8, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32>
return %9 : tensor<16x16xf32>
}

#layout = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [1, 1],
batches_per_subgroup = [1, 1],
outers_per_batch = [1, 1],
threads_per_outer = [4, 16],
elements_per_thread = [4, 1],

// CHECK-LABEL: func.func @matmul_256x256x256
// CHECK: scf.for {{.*}} -> (vector<16x16xf32>) {
// CHECK-DAG: %[[A:.*]] = vector.transfer_read %{{.*}} : tensor<16x256xf16>, vector<16x32xf16>
// CHECK-DAG: %[[B:.*]] = vector.transfer_read %{{.*}} : tensor<256x16xf16>, vector<32x16xf16>
// CHECK: gpu.barrier
subgroup_strides = [1, 1],
thread_strides = [0, 0]
>

// LHS copy.
// CHECK: %[[PA:.*]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: %[[LWRITE:.+]] = vector.transfer_write %[[A]], %[[PA]]{{.*}} : vector<16x32xf16>, tensor<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: %[[LCOPY:.+]] = bufferization.materialize_in_destination %[[LWRITE]] in %[[LWRITE]]
func.func @test(%vector: vector<16x16xf16>) -> vector<16x16xf16> {
%out = iree_vector_ext.to_layout %vector to #layout {shared_memory_conversion} : vector<16x16xf16>
return %out : vector<16x16xf16>
}

// RHS copy.
// CHECK: %[[PB:.*]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<32x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[RWRITE:.+]] = vector.transfer_write %[[B]], %[[PB]]{{.*}} : vector<32x16xf16>, tensor<32x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[RCOPY:.+]] = bufferization.materialize_in_destination %[[RWRITE]] in %[[RWRITE]]
// CHECK: gpu.barrier

// CHECK: %[[LHS:.+]] = vector.transfer_read %[[LCOPY]]{{.*}} : tensor<16x32xf16, #gpu.address_space<workgroup>>, vector<16x32xf16>
// CHECK: %[[RHS:.+]] = vector.transfer_read %[[RCOPY]]{{.*}} : tensor<32x16xf16, #gpu.address_space<workgroup>>, vector<32x16xf16>
// CHECK: %12 = vector.contract {{.*}} %[[LHS]], %[[RHS]], %{{.*}}
// CHECK-LABEL: func.func @test
// CHECK: gpu.barrier
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<16x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[WRITE:.+]] = vector.transfer_write %{{.*}}, %[[ALLOC]]
// CHECK: %[[BAR:.+]] = iree_gpu.value_barrier %[[WRITE]]
// CHECK: %[[READ:.+]] = vector.transfer_read %[[BAR]]
// CHECK: %[[OUT:.+]] = iree_vector_ext.to_layout %[[READ]]
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict(
Value input = opOperand.get();
// Create a resolution operation. This conflict should be handeled later by
// someone else, not this analysis.
Operation *resolveOp = builder.create<IREE::VectorExt::ToLayoutOp>(
input.getLoc(), input.getType(), input, rhs);
Operation *resolveOp =
builder.create<IREE::VectorExt::ToLayoutOp>(input.getLoc(), input, rhs);
Value resolvedValue = resolveOp->getResult(0);
opOperand.set(resolvedValue);

Expand Down
24 changes: 17 additions & 7 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 16},
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 0},
/*element=*/{1, 16}};
}
}
Expand Down Expand Up @@ -598,7 +598,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{0, 1},
/*element=*/{16, 1}};
}
}
Expand All @@ -624,7 +624,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
/*element=*/{1, 1}};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*strides=*/{0, 1},
/*element=*/{1, 1}};
}
}
Expand Down Expand Up @@ -977,8 +977,14 @@ NestedLayoutAttr createNestedLayout(MLIRContext *context, int64_t rank,
FailureOr<std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
VectorContractOpInfo opInfo(contractOp);
MMAScheduleAttr::getContractionLayout(linalg::LinalgOp contractOp) const {
auto maybeOpInfo = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpInfo)) {
return failure();
}
VectorContractOpInfo opInfo = maybeOpInfo.value();

LLVM_DEBUG({
llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
llvm::errs() << "For schedule: " << *this << "\n";
Expand All @@ -991,8 +997,12 @@ MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
auto mmaAttr = llvm::cast<MMAAttr>(getIntrinsic());
MLIRContext *context = getContext();

SmallVector<int64_t> bounds;
contractOp.getIterationBounds(bounds);
SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
if (llvm::any_of(bounds,
[](int64_t x) { return x == ShapedType::kDynamic; })) {
return failure();
}

int64_t batchCount = opInfo.getBatchCount();
if (batchCount == 1 && bounds[0] != 1) {
LLVM_DEBUG({ llvm::errs() << "non-unit batch dimension\n"; });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def IREEGPU_MmaScheduleAttr : AttrDef<IREEGPU_Dialect, "MMASchedule"> {
::mlir::FailureOr<::std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
getContractionLayout(::mlir::vector::ContractionOp contractOp) const;
getContractionLayout(::mlir::linalg::LinalgOp contractOp) const;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,5 +298,4 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
let genVerifyDecl = 1;
}


#endif // IREE_DIALECT_VECTOREXT_ATTRS
Loading

0 comments on commit 23c92d1

Please sign in to comment.