Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VectorDistribution] Configure contraction layouts at linalg level #18152

Merged
merged 5 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,72 @@ namespace {
using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

static LogicalResult isSubgroupLayoutCompatible(
IREE::GPU::MMAAttr::SingleSubgroupLayout subgroupLayout,
NestedLayoutAttr layout, int64_t dim1, int64_t dim2) {
SmallVector<int64_t> element = {layout.getElementsPerThread()[dim1],
layout.getElementsPerThread()[dim2]};
SmallVector<int64_t> thread = {layout.getThreadsPerOuter()[dim1],
layout.getThreadsPerOuter()[dim2]};
SmallVector<int64_t> tstrides = {layout.getThreadStrides()[dim1],
layout.getThreadStrides()[dim2]};
SmallVector<int64_t> outer = {layout.getOutersPerBatch()[dim1],
layout.getOutersPerBatch()[dim2]};

if (subgroupLayout.element != element) {
return failure();
}
if (subgroupLayout.thread != thread) {
return failure();
}
if (subgroupLayout.tstrides != tstrides) {
return failure();
}
if (subgroupLayout.outer != outer) {
return failure();
}

return success();
}

static LogicalResult isIntrinsicLayoutCompatible(VectorContractOpInfo &opInfo,
IREE::GPU::MMAAttr intrinsic,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
if (failed(isSubgroupLayoutCompatible(intrinsic.getASingleSubgroupLayout(),
lhsLayout, lhsM, lhsK))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getBSingleSubgroupLayout(),
rhsLayout, rhsK, rhsN))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getCSingleSubgroupLayout(),
accLayout, accM, accN))) {
return failure();
}
return success();
}

/// Distributes `vector.contract` ops with nested layouts.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
// Infer the contract kind so that we know know to correlate M/N/K dims.
auto maybeOpDetail = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpDetail)) {
return rewriter.notifyMatchFailure(contractOp, "invalid contraction");
}
VectorContractOpInfo opDetail = maybeOpDetail.value();

auto resultType = dyn_cast<VectorType>(contractOp.getResultType());
if (!resultType) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -55,6 +114,12 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction rhs");
}
NestedLayoutAttr accLayout =
dyn_cast<NestedLayoutAttr>(signature[resultValue]);
if (!accLayout) {
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction acc");
}

// We assume there is an decision made before regarding which mfma intrinsic
// to use and it is attached as an attribute to this contract op.
Expand All @@ -65,8 +130,13 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

// Infer the contract kind so that we know know to correlate M/N/K dims.
VectorContractOpInfo opDetail(contractOp);
// Check if the given intrinsic can be distributed with the given
// layouts.
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaAttr, lhsLayout,
rhsLayout, accLayout))) {
return rewriter.notifyMatchFailure(
contractOp, "the intrinsic does not match the expected layouts");
}

SmallVector<int64_t> distShape = resultLayout.getDistributedShape();
LLVM_DEBUG({
Expand Down
62 changes: 29 additions & 33 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -80,12 +82,7 @@ static FailureOr<Value> allocateTensorForVector(OpBuilder &b, Location loc,
Value copied = b.create<vector::TransferWriteOp>(loc, vector, allocTensorOp,
indices, inBounds)
.getResult();
// Create a marker for bufferization to keep this tensor in place. This
// prevents read/write forwarding of the transfers used to do the copy.
return b
.create<bufferization::MaterializeInDestinationOp>(copied.getLoc(),
copied, copied)
->getResult(0);
return copied;
}

static Value readVectorFromTensor(OpBuilder &b, VectorType vectorType,
Expand All @@ -104,46 +101,45 @@ struct GPUVectorAllocPass final
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();

SmallVector<vector::ContractionOp> opsToPromote;
funcOp.walk([&](vector::ContractionOp op) {
// Today we only do promotion for certain contractions.
if (contractOpFilter(op))
SmallVector<IREE::VectorExt::ToLayoutOp> opsToPromote;
funcOp.walk([&](IREE::VectorExt::ToLayoutOp op) {
if (op.getSharedMemoryConversion()) {
Groverkss marked this conversation as resolved.
Show resolved Hide resolved
opsToPromote.push_back(op);
}
});
for (vector::ContractionOp contractOp : opsToPromote) {
OpBuilder builder(contractOp);

for (IREE::VectorExt::ToLayoutOp op : opsToPromote) {
OpBuilder builder(op);

// HACK: Until proper barrier placement is handled later we have to
// synchronize explicitly in this pass.

// Synchronize before the write to shared memory to avoid stepping over
// reads in the previous iteration of a loop.
builder.create<gpu::BarrierOp>(contractOp->getLoc());
// reads in the previous iteration of a loop. We set this barrier
// at the start of this block.
builder.setInsertionPointToStart(op->getBlock());
builder.create<gpu::BarrierOp>(op->getLoc());

// Promote both of the input operands, excluding the accumulator.
OpOperand &lhs = contractOp.getLhsMutable();
FailureOr<Value> lhsRet =
allocateTensorForVector(builder, contractOp->getLoc(), lhs.get());
if (failed(lhsRet)) {
return signalPassFailure();
}

OpOperand &rhs = contractOp.getRhsMutable();
FailureOr<Value> rhsRet =
allocateTensorForVector(builder, contractOp->getLoc(), rhs.get());
if (failed(rhsRet)) {
builder.setInsertionPoint(op);
OpOperand &operand = op.getInputMutable();
FailureOr<Value> ret =
allocateTensorForVector(builder, op->getLoc(), operand.get());
if (failed(ret)) {
return signalPassFailure();
}

// Synchronize after the write to shared memory before we read from it.
builder.create<gpu::BarrierOp>(contractOp->getLoc());

Value lhsVec =
readVectorFromTensor(builder, contractOp.getLhsType(), *lhsRet);
Value rhsVec =
readVectorFromTensor(builder, contractOp.getRhsType(), *rhsRet);
lhs.set(lhsVec);
rhs.set(rhsVec);
auto synced =
builder.create<IREE::GPU::ValueBarrierOp>(op->getLoc(), *ret);

VectorType inputTy = cast<VectorType>(op.getType());
Value read = readVectorFromTensor(builder, inputTy, synced.getResult(0));
operand.set(read);

// Remove the shared_memory_conversion attribute from the to_layout
// operation.
op.setSharedMemoryConversion(false);
}
}
};
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def GPUVectorAllocPass :
let summary = "Pass to create allocations for contraction inputs to copy "
"to GPU shared memory";
let dependentDialects = [
"::mlir::gpu::GPUDialect", "::mlir::bufferization::BufferizationDialect"
"::mlir::gpu::GPUDialect",
"::mlir::vector::VectorDialect",
"::mlir::bufferization::BufferizationDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ builtin.module attributes { transform.with_named_sequence } {
elements_per_thread = [1, 4],

subgroup_strides = [1, 1],
thread_strides = [32, 1]
thread_strides = [1, 32]
>

// C: shape = 32x64, layout = layoutC
Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,26 @@
// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-vector-alloc))" | FileCheck %s

func.func @matmul_256x256x256(%lhs: tensor<16x256xf16>,
%rhs: tensor<256x16xf16>,
%out: tensor<16x16xf32>) -> tensor<16x16xf32> {
%cst = arith.constant 0.000000e+00 : f16
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf32>
%c32 = arith.constant 32 : index
%c256 = arith.constant 256 : index
%c0 = arith.constant 0 : index
%8 = scf.for %arg0 = %c0 to %c256 step %c32 iter_args(%arg1 = %cst_0) -> (vector<16x16xf32>) {
%10 = vector.transfer_read %lhs[%c0, %arg0], %cst {in_bounds = [true, true]} : tensor<16x256xf16>, vector<16x32xf16>
%11 = vector.transfer_read %rhs[%arg0, %c0], %cst {in_bounds = [true, true]} : tensor<256x16xf16>, vector<32x16xf16>
%12 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %10, %11, %arg1 : vector<16x32xf16>, vector<32x16xf16> into vector<16x16xf32>
scf.yield %12 : vector<16x16xf32>
}
%9 = vector.transfer_write %8, %out[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32>
return %9 : tensor<16x16xf32>
}

#layout = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [1, 1],
batches_per_subgroup = [1, 1],
outers_per_batch = [1, 1],
threads_per_outer = [4, 16],
elements_per_thread = [4, 1],

// CHECK-LABEL: func.func @matmul_256x256x256
// CHECK: scf.for {{.*}} -> (vector<16x16xf32>) {
// CHECK-DAG: %[[A:.*]] = vector.transfer_read %{{.*}} : tensor<16x256xf16>, vector<16x32xf16>
// CHECK-DAG: %[[B:.*]] = vector.transfer_read %{{.*}} : tensor<256x16xf16>, vector<32x16xf16>
// CHECK: gpu.barrier
subgroup_strides = [1, 1],
thread_strides = [0, 0]
>

// LHS copy.
// CHECK: %[[PA:.*]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: %[[LWRITE:.+]] = vector.transfer_write %[[A]], %[[PA]]{{.*}} : vector<16x32xf16>, tensor<16x32xf16, #gpu.address_space<workgroup>>
// CHECK: %[[LCOPY:.+]] = bufferization.materialize_in_destination %[[LWRITE]] in %[[LWRITE]]
func.func @test(%vector: vector<16x16xf16>) -> vector<16x16xf16> {
%out = iree_vector_ext.to_layout %vector to #layout {shared_memory_conversion} : vector<16x16xf16>
return %out : vector<16x16xf16>
}

// RHS copy.
// CHECK: %[[PB:.*]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<32x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[RWRITE:.+]] = vector.transfer_write %[[B]], %[[PB]]{{.*}} : vector<32x16xf16>, tensor<32x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[RCOPY:.+]] = bufferization.materialize_in_destination %[[RWRITE]] in %[[RWRITE]]
// CHECK: gpu.barrier

// CHECK: %[[LHS:.+]] = vector.transfer_read %[[LCOPY]]{{.*}} : tensor<16x32xf16, #gpu.address_space<workgroup>>, vector<16x32xf16>
// CHECK: %[[RHS:.+]] = vector.transfer_read %[[RCOPY]]{{.*}} : tensor<32x16xf16, #gpu.address_space<workgroup>>, vector<32x16xf16>
// CHECK: %12 = vector.contract {{.*}} %[[LHS]], %[[RHS]], %{{.*}}
// CHECK-LABEL: func.func @test
// CHECK: gpu.barrier
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<16x16xf16, #gpu.address_space<workgroup>>
// CHECK: %[[WRITE:.+]] = vector.transfer_write %{{.*}}, %[[ALLOC]]
// CHECK: %[[BAR:.+]] = iree_gpu.value_barrier %[[WRITE]]
// CHECK: %[[READ:.+]] = vector.transfer_read %[[BAR]]
// CHECK: %[[OUT:.+]] = iree_vector_ext.to_layout %[[READ]]
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict(
Value input = opOperand.get();
// Create a resolution operation. This conflict should be handeled later by
// someone else, not this analysis.
Operation *resolveOp = builder.create<IREE::VectorExt::ToLayoutOp>(
input.getLoc(), input.getType(), input, rhs);
Operation *resolveOp =
builder.create<IREE::VectorExt::ToLayoutOp>(input.getLoc(), input, rhs);
Value resolvedValue = resolveOp->getResult(0);
opOperand.set(resolvedValue);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 16},
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 0},
/*element=*/{1, 16}};
}
}
Expand Down Expand Up @@ -598,7 +598,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{0, 1},
/*element=*/{16, 1}};
}
}
Expand All @@ -624,7 +624,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
/*element=*/{1, 1}};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*strides=*/{0, 1},
/*element=*/{1, 1}};
}
}
Expand Down Expand Up @@ -977,8 +977,14 @@ NestedLayoutAttr createNestedLayout(MLIRContext *context, int64_t rank,
FailureOr<std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
VectorContractOpInfo opInfo(contractOp);
MMAScheduleAttr::getContractionLayout(linalg::LinalgOp contractOp) const {
auto maybeOpInfo = VectorContractOpInfo::inferFromIndexingMaps(
contractOp.getIndexingMapsArray());
if (failed(maybeOpInfo)) {
return failure();
}
VectorContractOpInfo opInfo = maybeOpInfo.value();

LLVM_DEBUG({
llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n";
llvm::errs() << "For schedule: " << *this << "\n";
Expand All @@ -991,8 +997,12 @@ MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
auto mmaAttr = llvm::cast<MMAAttr>(getIntrinsic());
MLIRContext *context = getContext();

SmallVector<int64_t> bounds;
contractOp.getIterationBounds(bounds);
SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
Groverkss marked this conversation as resolved.
Show resolved Hide resolved
if (llvm::any_of(bounds,
[](int64_t x) { return x == ShapedType::kDynamic; })) {
return failure();
}

int64_t batchCount = opInfo.getBatchCount();
if (batchCount == 1 && bounds[0] != 1) {
LLVM_DEBUG({ llvm::errs() << "non-unit batch dimension\n"; });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def IREEGPU_MmaScheduleAttr : AttrDef<IREEGPU_Dialect, "MMASchedule"> {
::mlir::FailureOr<::std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
getContractionLayout(::mlir::vector::ContractionOp contractOp) const;
getContractionLayout(::mlir::linalg::LinalgOp contractOp) const;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,5 +298,4 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
let genVerifyDecl = 1;
}


#endif // IREE_DIALECT_VECTOREXT_ATTRS
Loading
Loading