Skip to content

Commit

Permalink
[LinalgExt] Add iree_linalg_ext.im2col op and verifier (#17644)
Browse files Browse the repository at this point in the history
This PR adds a new iree_linalg_ext.im2col op representing the im2col
transformation for convolutions. The PR contains the op definition and
the verifier.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Jun 24, 2024
1 parent e41e71c commit 1f69b85
Show file tree
Hide file tree
Showing 5 changed files with 450 additions and 15 deletions.
177 changes: 162 additions & 15 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
Expand Down Expand Up @@ -624,24 +625,30 @@ areNotFullTiles(ArrayRef<int64_t> inputShape,
return false;
}

static SmallVector<OpFoldResult> getMixedValues(MLIRContext *context,
ArrayRef<int64_t> staticValues,
OperandRange dynamicValues) {
OpBuilder b(context);
return mlir::getMixedValues(staticValues, dynamicValues, b);
}

static SmallVector<int64_t>
getStaticValues(SmallVector<OpFoldResult> mixedValues) {
SmallVector<Value> dynamicTiles;
SmallVector<int64_t> staticTiles;
dispatchIndexOpFoldResults(mixedValues, dynamicTiles, staticTiles);
return staticTiles;
}

/// Utility function shared between Pack and UnPack to get the tile sizes as
/// OpFoldResults.
// TODO: interface or base class in .td
template <typename OpTy>
static SmallVector<OpFoldResult> getMixedTiles(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
SmallVector<OpFoldResult> mixedInnerTiles;
unsigned dynamicValIndex = 0;
OpBuilder b(op.getContext());
for (int64_t tileSize : op.getStaticInnerTiles()) {
if (!ShapedType::isDynamic(tileSize)) {
mixedInnerTiles.push_back(b.getIndexAttr(tileSize));
} else {
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
}
}
return mixedInnerTiles;
return LinalgExt::getMixedValues(op.getContext(), op.getStaticInnerTiles(),
op.getInnerTiles());
}

/// Return the tile sizes as `int64_t`. If a tile size is dynamic a sentinel
Expand All @@ -650,10 +657,7 @@ template <typename OpTy>
static SmallVector<int64_t> getStaticTiles(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
SmallVector<Value> dynamicTiles;
SmallVector<int64_t> staticTiles;
dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
return staticTiles;
return getStaticValues(op.getMixedTiles());
}

/// Utility function shared between Pack and UnPack to get a map between
Expand Down Expand Up @@ -1502,6 +1506,148 @@ SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}

//===----------------------------------------------------------------------===//
// Im2colOp
//===----------------------------------------------------------------------===//

/// Return all static and dynamic kernel_size as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedKernelSize() {
return LinalgExt::getMixedValues(getContext(), getStaticKernelSize(),
getKernelSize());
}

/// Return all static and dynamic k_offset as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedKOffset() {
return LinalgExt::getMixedValues(getContext(), getStaticKOffset(),
getKOffset());
}

/// Return all static and dynamic k_offset as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedMOffset() {
return LinalgExt::getMixedValues(getContext(), getStaticMOffset(),
getMOffset());
}

void Im2colOp::setMixedKOffset(SmallVector<OpFoldResult> kOffset) {
SmallVector<int64_t> staticKOffset;
SmallVector<Value> dynamicKOffset;
dispatchIndexOpFoldResults(kOffset, dynamicKOffset, staticKOffset);
setStaticKOffset(staticKOffset);
getKOffsetMutable().assign(dynamicKOffset);
}

void Im2colOp::setMixedMOffset(SmallVector<OpFoldResult> mOffset) {
SmallVector<int64_t> staticMOffset;
SmallVector<Value> dynamicMOffset;
dispatchIndexOpFoldResults(mOffset, dynamicMOffset, staticMOffset);
setStaticMOffset(staticMOffset);
getMOffsetMutable().assign(dynamicMOffset);
}

/// Custom builder methods for im2col op.
void Im2colOp::build(OpBuilder &builder, OperationState &state, Value input,
Value output, ArrayRef<int64_t> strides,
ArrayRef<int64_t> dilations,
ArrayRef<OpFoldResult> kernelSize,
ArrayRef<OpFoldResult> kOffset,
ArrayRef<OpFoldResult> mOffset, ArrayRef<int64_t> batchPos,
ArrayRef<int64_t> mPos, ArrayRef<int64_t> kPos) {
assert(strides.size() == kernelSize.size() &&
dilations.size() == kernelSize.size() &&
mPos.size() == kernelSize.size() &&
"strides, dilations, m_pos, and kernel expected to be the same rank");
SmallVector<int64_t> staticKernelSize, staticMOffset, staticKOffset;
SmallVector<Value> dynamicKernelSize, dynamicMOffset, dynamicKOffset;
dispatchIndexOpFoldResults(kernelSize, dynamicKernelSize, staticKernelSize);
dispatchIndexOpFoldResults(mOffset, dynamicMOffset, staticMOffset);
dispatchIndexOpFoldResults(kOffset, dynamicKOffset, staticKOffset);
SmallVector<Type> resultType;
auto outputType = output.getType();
if (isa<RankedTensorType>(outputType)) {
resultType.push_back(outputType);
}
build(builder, state, resultType, input, output,
builder.getDenseI64ArrayAttr(strides),
builder.getDenseI64ArrayAttr(dilations), dynamicKernelSize,
builder.getDenseI64ArrayAttr(staticKernelSize), dynamicKOffset,
builder.getDenseI64ArrayAttr(staticKOffset), dynamicMOffset,
builder.getDenseI64ArrayAttr(staticMOffset),
builder.getDenseI64ArrayAttr(batchPos),
builder.getDenseI64ArrayAttr(mPos), builder.getDenseI64ArrayAttr(kPos));
}

LogicalResult Im2colOp::verify() {
Operation *op = getOperation();
if (llvm::count_if(getDpsInputs(), [](Value v) {
return isa<ShapedType>(v.getType());
}) != 1) {
return op->emitOpError("expected only one ShapedType operand");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected one output operand");
}

// TODO(Max191): Support cases with more than 1 m or k dimension, and remove
// the check for a single m_offset and k_offset.
if (getMixedMOffset().size() != 1) {
return op->emitOpError("expected one m_offset");
}
if (getMixedKOffset().size() != 1) {
return op->emitOpError("expected one k_offset");
}
auto inputType = getInputType();
unsigned inputRank = inputType.getRank();
ArrayRef<int64_t> batchPos = getBatchPos();
ArrayRef<int64_t> mPos = getMPos();
ArrayRef<int64_t> kPos = getKPos();
if (inputRank != batchPos.size() + mPos.size() + kPos.size()) {
return op->emitOpError(
"expected input rank to be the sum of batch, m, and k ranks");
}
ArrayRef<int64_t> strides = getStrides();
ArrayRef<int64_t> dilations = getDilations();
SmallVector<OpFoldResult> kernelSize = getMixedKernelSize();
if (kernelSize.size() != mPos.size()) {
return op->emitOpError(
"expected kernel rank to be equal to the m_pos rank");
}
if (strides.size() != kernelSize.size()) {
return op->emitOpError(
"expected strides rank to be equal to the kernel rank");
}
if (dilations.size() != kernelSize.size()) {
return op->emitOpError(
"expected dilations rank to be equal to the kernel rank");
}

ArrayRef<int64_t> inputShape = inputType.getShape();
SmallVector<int64_t> expectedOutputShape;
for (auto pos : batchPos) {
expectedOutputShape.push_back(inputShape[pos]);
}
ArrayRef<int64_t> outputShape = getOutputType().getShape();
// When the op is tiled, the m and k dimensions of the output are tiled, but
// they are not tiled in the input, so we cannot verify the output size of
// these dimensions.
expectedOutputShape.push_back(outputShape[outputShape.size() - 2]);
expectedOutputShape.push_back(outputShape.back());
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
}

LogicalResult Im2colOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}

LogicalResult
Im2colOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}

#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
Expand All @@ -1522,6 +1668,7 @@ DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
DEFINE_OP_GET_EFFECTS(AttentionOp)
DEFINE_OP_GET_EFFECTS(OnlineAttentionOp)
DEFINE_OP_GET_EFFECTS(Im2colOp)

} // namespace mlir::iree_compiler::IREE::LinalgExt

Expand Down
128 changes: 128 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,134 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
}
}];
}
//===----------------------------------------------------------------------===//
// Im2col
//===----------------------------------------------------------------------===//

def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "Im2col operation for convolutions";
let description = [{
Im2col op for convolutions. The operation performs a transformation on the
input to convert it from a convolution input to an equivalent gemm input.
The op is defined by its input, output, some conv metadata, and some
indexing metadata. The `strides`, `dilations`, and `kernel_size` are taken
from the convolution from which this op is generated, and they define how
the input operand is indexed when the operation is decomposed. The shape of
the output should be `tensor<BxMxK>`, and the `m_pos`, `k_pos`, and
`batch_pos` indicate which input dimensions map to which output dimensions.

The `k_offset` is an offset within the output K dimension from which the
iteration space of the operation begins. This is used for tiling, since the
tiled implementation must leave the output K dimension untiled. Similarly,
`m_offset` is the offset within the output M dimension from which the
iteration space of the operation begins.
The iteration space is the full output shape of the im2col op, so if the
im2col op were tiled to loops with a scalar inner tile, it would look like
the following:
```
%im2col = iree_linalg_ext.im2col
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
m_offset = [0] k_offset = [0]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%in : tensor<2x34x34x640xf32>)
outs(%out : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
```
becomes:
```
scf.for %arg0 = %c0 to %c2 step %c1
scf.for %arg1 = %c0 to %c1024 step %c1
scf.for %arg2 = %c0 to %c5760 step %c1
%im2col = iree_linalg_ext.im2col
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
m_offset = [%arg1] k_offset = [%arg2]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%in_tile : tensor<1x34x34x640xf32>)
outs(%out_tile : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
```
Then, when the tiled op is decomposed, it becomes a loop over the iteration
space of the im2col op, whith an extract_slice from the `%in_tile` followed
by an insert_slice to the `%out_tile`. The indices for the extract slice are
computed using the `m_offset` and `k_offset` as:
(b, m, k) -> (b, M / 32 + K / (640*3), M % 32 + K % (640*3) / 640, K % 640)
Where `(b, m, k)` are the indices of the tiled op's iteration space, and
`M = m + m_offset` and `K = k + K_offset`.
}];

let arguments = (ins AnyShaped:$input, AnyShaped:$output,
DenseI64ArrayAttr:$strides,
DenseI64ArrayAttr:$dilations,
Variadic<Index>:$kernel_size,
DenseI64ArrayAttr:$static_kernel_size,
Variadic<Index>:$m_offset,
DenseI64ArrayAttr:$static_m_offset,
Variadic<Index>:$k_offset,
DenseI64ArrayAttr:$static_k_offset,
DenseI64ArrayAttr:$batch_pos,
DenseI64ArrayAttr:$m_pos,
DenseI64ArrayAttr:$k_pos);

let results = (outs Variadic<AnyShaped>:$results);
let hasFolder = 1;
let assemblyFormat = [{
attr-dict
`strides` `=` $strides
`dilations` `=` $dilations
`kernel_size` `=`
custom<DynamicIndexList>($kernel_size, $static_kernel_size)
`m_offset` `=`
custom<DynamicIndexList>($m_offset, $static_m_offset)
`k_offset` `=`
custom<DynamicIndexList>($k_offset, $static_k_offset)
`batch_pos` `=` $batch_pos
`m_pos` `=` $m_pos
`k_pos` `=` $k_pos
`ins` `(` $input `:` type($input) `)`
`outs` `(` $output `:` type($output) `)`
(`->` type($results)^)?
}];

let builders = [
OpBuilder<(ins "Value":$input, "Value":$output,
"ArrayRef<int64_t>":$strides,
"ArrayRef<int64_t>":$dilations,
"ArrayRef<OpFoldResult>":$kernel_size,
"ArrayRef<OpFoldResult>":$m_offset,
"ArrayRef<OpFoldResult>":$k_offset,
"ArrayRef<int64_t>":$batch_dimensions,
"ArrayRef<int64_t>":$m_dimensions,
"ArrayRef<int64_t>":$k_dimensions)>
];

let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
ShapedType getInputType() {
return cast<ShapedType>(getInput().getType());
}
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getInputRank() {
return getInputType().getRank();
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
// Return op metadata.
SmallVector<OpFoldResult> getMixedKernelSize();
SmallVector<OpFoldResult> getMixedMOffset();
SmallVector<OpFoldResult> getMixedKOffset();

// Set op metadata.
void setMixedKOffset(SmallVector<OpFoldResult> kOffset);
void setMixedMOffset(SmallVector<OpFoldResult> mOffset);

// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputMutable();
}
}];
}

} // OpGroupNonStructuredOps

Expand Down
Loading

0 comments on commit 1f69b85

Please sign in to comment.