Skip to content

Commit

Permalink
Revert "[LinalgExt] Add online_attention op" (#17658)
Browse files Browse the repository at this point in the history
Reverts #17536

This caused `sdxl-scheduled-unet-3-tank` to hit timeouts when compiling
for cpu:
https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282
  • Loading branch information
ScottTodd authored Jun 12, 2024
1 parent 71c07fa commit 2ff4102
Show file tree
Hide file tree
Showing 25 changed files with 87 additions and 954 deletions.
42 changes: 20 additions & 22 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ getDefaultDistributedLevelTileSizes(Operation *op,
/// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the
/// reduction loops.
static void splitParallelAndReductionTiles(
Operation *op, SmallVectorImpl<int64_t> &parallelSizes,
linalg::LinalgOp op, SmallVectorImpl<int64_t> &parallelSizes,
SmallVectorImpl<int64_t> &reductionSizes,
SmallVectorImpl<bool> *parallelScalableFlags = nullptr,
SmallVectorImpl<bool> *reductionScalableFlags = nullptr) {
Expand All @@ -900,9 +900,8 @@ static void splitParallelAndReductionTiles(
reductionScalableFlags->assign(parallelScalableFlags->begin(),
parallelScalableFlags->end());
}
TilingInterface tilingOp = cast<TilingInterface>(op);
for (auto [index, iteratorType] :
llvm::enumerate(tilingOp.getLoopIteratorTypes())) {
llvm::enumerate(op.getIteratorTypesArray())) {
if (iteratorType == utils::IteratorType::parallel) {
reductionSizes[index] = 0;
if (reductionScalableFlags)
Expand Down Expand Up @@ -1122,9 +1121,9 @@ setMatmulRootConfig(mlir::FunctionOpInterface entryPointFn,
SmallVector<int64_t> parallelTileSizes = vecTileSizes;
SmallVector<int64_t> reductionTileSizes;
SmallVector<bool> reductionScalableFlags;
splitParallelAndReductionTiles(op, parallelTileSizes, reductionTileSizes,
&parallelScalableFlags,
&reductionScalableFlags);
splitParallelAndReductionTiles(
cast<linalg::LinalgOp>(op.getOperation()), parallelTileSizes,
reductionTileSizes, &parallelScalableFlags, &reductionScalableFlags);

if (vecPreProcStrategy == VectorPreProcStrategy::None) {
setVectorSizesForDynamicShapes(cast<linalg::LinalgOp>(op.getOperation()),
Expand Down Expand Up @@ -1752,13 +1751,14 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,

// Batch, M and N (parallel dimensions) are distributed on workgroups.
DistributionHeuristicConfig config;
SmallVector<int64_t> distTileSizes =
getDefaultDistributedLevelTileSizes(attnOp, config);
SmallVector<int64_t> distTileSizes = getDefaultDistributedLevelTileSizes(
attnOp, DistributionHeuristicConfig{});

// Batch, M and N (parallel dimensions) are distributed on workgroups.
SmallVector<int64_t> vecTileSizes(attnOp.getIterationDomainRank(), 1);
// Mark k1 reduction dimensions not to distribute.
for (int i : opInfo.getK1Dims()) {
// Mark reduction dimensions not to distribute.
for (int64_t i :
llvm::concat<const int64_t>(opInfo.getK1Dims(), opInfo.getK2Dims())) {
vecTileSizes[i] = 0;
}
int64_t vectorSize = getVectorSize(entryPointFn, attnOp.getOutputType());
Expand All @@ -1773,17 +1773,18 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
/*numElem=*/tileSize, vectorSize, vectorSize);
}

SmallVector<int64_t> parallelTileSizes = vecTileSizes;
SmallVector<int64_t> reductionTileSizes;
splitParallelAndReductionTiles(attnOp, parallelTileSizes, reductionTileSizes);
// TODO (17467): Due to a bug in TileAndDecomposeAttention, N dimension
// cannot be tiled. Remove this once fixed.
for (int64_t i : opInfo.getNDims()) {
distTileSizes[i] = 0;
vecTileSizes[i] = 0;
}

LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (parallel): "
<< parallelTileSizes << "\n");
LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (reduction): "
<< reductionTileSizes << "\n");
TileSizesListType tileSizes = {distTileSizes, vecTileSizes};

TileSizesListType tileSizes = {distTileSizes, parallelTileSizes,
reductionTileSizes};
// TODO: (Groverkss): Tile K2 here using reduction tiling interface once we
// have it. TileAndDecomposeAttention pass only tiles K2. I think it should
// be possible to tile K1 also, but need to explore it more.

return setOpConfigAndEntryPointFnTranslation(
entryPointFn, attnOp, tileSizes,
Expand Down Expand Up @@ -1842,9 +1843,6 @@ setWinogradRootConfig(mlir::FunctionOpInterface entryPointFn,
tileSizes.push_back(distTileSizes);
SmallVector<int64_t> vecTileSizes(iterationRank, 1);
tileSizes.push_back(vecTileSizes);
// Dummy tiling config for reduction level.
SmallVector<int64_t> reductionTileSizes(iterationRank, 0);
tileSizes.push_back(reductionTileSizes);
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, winogradOp, tileSizes,
DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize);
Expand Down
7 changes: 2 additions & 5 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,13 +617,10 @@ void addCPULinalgExtTileAndVectorizePipeline(
createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
// TODO: Remove the pass once we have PartialReductionOpInterface implemented
// for AttentionOp.
funcPassManager.addPass(
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
funcPassManager.addPass(
createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(
IREE::LinalgExt::createDecomposeWinogradTransformPass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());

{
GenericVectorizationPassOptions options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_output_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand All @@ -1556,7 +1556,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_input_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand All @@ -1581,7 +1581,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1], [0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_filter_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -1613,7 +1613,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 64], [20, 32, 0, 0, 32], [0, 0, 0, 32, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 0], [20, 32, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @attention()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down
9 changes: 2 additions & 7 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ iree_td_library(
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:LinalgOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PDLDialectTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
Expand Down Expand Up @@ -160,9 +159,7 @@ iree_gentbl_cc_library(
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "LinalgExtOps.td",
deps = [
":td_files",
],
deps = [":td_files"],
)

iree_gentbl_cc_library(
Expand Down Expand Up @@ -215,7 +212,5 @@ iree_tablegen_doc(
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "LinalgExtOps.td",
deps = [
":td_files",
],
deps = [":td_files"],
)
78 changes: 0 additions & 78 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -1316,9 +1315,6 @@ LogicalResult AttentionOp::verify() {
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (ShapedType::isDynamic(valShape[i])) {
continue;
}
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
Expand Down Expand Up @@ -1431,79 +1427,6 @@ SmallVector<AffineMap> AttentionOp::getIndexingMapsArray() {
return results;
}

//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//

LogicalResult OnlineAttentionOp::verify() {
OnlineAttentionOp attnOp = *this;

SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();

// Check if indexing maps can represent attention.
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(indexingMaps);

// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank());
SmallVector<bool> foundDims(getIterationDomainRank(), false);
auto checkShape = [&shape, &foundDims,
&attnOp](StringRef operandName, ArrayRef<int64_t> valShape,
AffineMap indexingMap) -> LogicalResult {
if (indexingMap.getNumResults() != valShape.size()) {
return attnOp->emitError("Rank Mismatch for ")
<< operandName << ". Expected: " << indexingMap.getNumResults()
<< " Got: " << valShape.size();
}
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (ShapedType::isDynamic(valShape[i])) {
continue;
}
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
}
if (shape[pos] != valShape[i]) {
return attnOp->emitError("Shape Mismatch for ")
<< operandName << ". Expected: " << shape[pos]
<< " Got: " << valShape[i];
}
}
return success();
};

if (failed(checkShape("Query", getQuery().getType().getShape(),
getQueryMap())) ||
failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) ||
failed(checkShape("Value", getValue().getType().getShape(),
getValueMap())) ||
failed(checkShape("Output", getOutput().getType().getShape(),
getOutputMap())) ||
failed(checkShape("Max", getMax().getType().getShape(), getMaxMap())) ||
failed(checkShape("Sum", getSum().getType().getShape(), getSumMap()))) {
return failure();
}

return success();
}

MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3);
}

LogicalResult OnlineAttentionOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}

SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
return SmallVector<AffineMap>(
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}

#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
Expand All @@ -1523,7 +1446,6 @@ DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
DEFINE_OP_GET_EFFECTS(AttentionOp)
DEFINE_OP_GET_EFFECTS(OnlineAttentionOp)

} // namespace mlir::iree_compiler::IREE::LinalgExt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Attributes.h"
Expand Down
91 changes: 0 additions & 91 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down Expand Up @@ -679,96 +678,6 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
}];
}

//===----------------------------------------------------------------------===//
// OnlineAttention
//===----------------------------------------------------------------------===//

def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Online Attention operator";
let description = [{
Traditional scaled dot product attention computes:

attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V

Online Attention on the other hand, uses an online normalizer instead of
softmax:

online_attention(Q, K, V, scale, running_max, running_sum)
= online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V

The advantage of this online_normalizer is that it can be tiled along
it's reduction dimension, making the online_attention operator:
- Tilable along softmax reduction dimension
- Associative along softmax reduction dimension
- Commutative along softmax associative dimension

Note: The results of online_attention need to be combined after computing
it over the entire softmax reduction dimension by:
x, _, sum : results
x = (1 / sum) * x
}];

let arguments = (ins AnyShaped:$query,
AnyShaped:$key,
AnyShaped:$value,
AnyFloat:$scale,
AnyShaped:$output,
AnyShaped:$max,
AnyShaped:$sum,
AffineMapArrayAttr:$indexing_maps
);

let results = (outs Variadic<AnyRankedTensor>:$results);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict
`ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
`outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
(`->` type($results)^)?
}];

let extraClassDeclaration = [{
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable();

SmallVector<AffineMap> getIndexingMapsArray();

AffineMap getQueryMap() {
return getIndexingMapsArray()[0];
}
AffineMap getKeyMap() {
return getIndexingMapsArray()[1];
}
AffineMap getValueMap() {
return getIndexingMapsArray()[2];
}
AffineMap getOutputMap() {
return getIndexingMapsArray()[3];
}
AffineMap getMaxMap() {
return getIndexingMapsArray()[4];
}
AffineMap getSumMap() {
return getIndexingMapsArray()[5];
}

int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}
}];
}

} // OpGroupNonStructuredOps

//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 2ff4102

Please sign in to comment.