From 2ff4102aba9e878f729840da66a44fe4bd3c8790 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Wed, 12 Jun 2024 16:49:56 -0700 Subject: [PATCH] Revert "[LinalgExt] Add online_attention op" (#17658) Reverts iree-org/iree#17536 This caused `sdxl-scheduled-unet-3-tank` to hit timeouts when compiling for cpu: https://github.com/iree-org/iree/actions/runs/9484305572/job/26134004282 --- .../Codegen/LLVMCPU/KernelDispatch.cpp | 42 ++-- .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 7 +- .../test/select_x86_64_lowering_strategy.mlir | 8 +- .../compiler/Dialect/LinalgExt/IR/BUILD.bazel | 9 +- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 78 ------ .../Dialect/LinalgExt/IR/LinalgExtOps.h | 1 - .../Dialect/LinalgExt/IR/LinalgExtOps.td | 91 ------- .../LinalgExtExtensionsOps.cpp | 12 - .../LinalgExtExtensionsOps.td | 32 --- .../Transforms/AggregatedOpInterfaceImpl.cpp | 228 ------------------ .../Dialect/LinalgExt/Transforms/BUILD.bazel | 1 - .../LinalgExt/Transforms/CMakeLists.txt | 1 - .../Transforms/DecomposeAttention.cpp | 10 - .../Dialect/LinalgExt/Transforms/Passes.h | 6 - .../Dialect/LinalgExt/Transforms/Passes.td | 8 - .../LinalgExt/Transforms/TileAttention.cpp | 122 ---------- .../Transforms/TilingInterfaceImpl.cpp | 217 +++++------------ .../LinalgExt/Transforms/test/BUILD.bazel | 1 - .../LinalgExt/Transforms/test/CMakeLists.txt | 1 - .../test/decompose_online_attention.mlir | 64 ----- .../LinalgExt/Transforms/test/tiling.mlir | 63 ----- .../Dialect/LinalgExt/Utils/IndexingUtils.cpp | 18 +- .../Dialect/LinalgExt/Utils/IndexingUtils.h | 9 - .../Dialect/LinalgExt/Utils/Utils.cpp | 7 - .../compiler/Dialect/LinalgExt/Utils/Utils.h | 5 - 25 files changed, 87 insertions(+), 954 deletions(-) delete mode 100644 compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp delete mode 100644 compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index cc073548ba2e..9ae470db916c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -890,7 +890,7 @@ getDefaultDistributedLevelTileSizes(Operation *op, /// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the /// reduction loops. static void splitParallelAndReductionTiles( - Operation *op, SmallVectorImpl ¶llelSizes, + linalg::LinalgOp op, SmallVectorImpl ¶llelSizes, SmallVectorImpl &reductionSizes, SmallVectorImpl *parallelScalableFlags = nullptr, SmallVectorImpl *reductionScalableFlags = nullptr) { @@ -900,9 +900,8 @@ static void splitParallelAndReductionTiles( reductionScalableFlags->assign(parallelScalableFlags->begin(), parallelScalableFlags->end()); } - TilingInterface tilingOp = cast(op); for (auto [index, iteratorType] : - llvm::enumerate(tilingOp.getLoopIteratorTypes())) { + llvm::enumerate(op.getIteratorTypesArray())) { if (iteratorType == utils::IteratorType::parallel) { reductionSizes[index] = 0; if (reductionScalableFlags) @@ -1122,9 +1121,9 @@ setMatmulRootConfig(mlir::FunctionOpInterface entryPointFn, SmallVector parallelTileSizes = vecTileSizes; SmallVector reductionTileSizes; SmallVector reductionScalableFlags; - splitParallelAndReductionTiles(op, parallelTileSizes, reductionTileSizes, - ¶llelScalableFlags, - &reductionScalableFlags); + splitParallelAndReductionTiles( + cast(op.getOperation()), parallelTileSizes, + reductionTileSizes, ¶llelScalableFlags, &reductionScalableFlags); if (vecPreProcStrategy == VectorPreProcStrategy::None) { setVectorSizesForDynamicShapes(cast(op.getOperation()), @@ -1752,13 +1751,14 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn, // Batch, M and N (parallel dimensions) are distributed on workgroups. DistributionHeuristicConfig config; - SmallVector distTileSizes = - getDefaultDistributedLevelTileSizes(attnOp, config); + SmallVector distTileSizes = getDefaultDistributedLevelTileSizes( + attnOp, DistributionHeuristicConfig{}); // Batch, M and N (parallel dimensions) are distributed on workgroups. SmallVector vecTileSizes(attnOp.getIterationDomainRank(), 1); - // Mark k1 reduction dimensions not to distribute. - for (int i : opInfo.getK1Dims()) { + // Mark reduction dimensions not to distribute. + for (int64_t i : + llvm::concat(opInfo.getK1Dims(), opInfo.getK2Dims())) { vecTileSizes[i] = 0; } int64_t vectorSize = getVectorSize(entryPointFn, attnOp.getOutputType()); @@ -1773,17 +1773,18 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn, /*numElem=*/tileSize, vectorSize, vectorSize); } - SmallVector parallelTileSizes = vecTileSizes; - SmallVector reductionTileSizes; - splitParallelAndReductionTiles(attnOp, parallelTileSizes, reductionTileSizes); + // TODO (17467): Due to a bug in TileAndDecomposeAttention, N dimension + // cannot be tiled. Remove this once fixed. + for (int64_t i : opInfo.getNDims()) { + distTileSizes[i] = 0; + vecTileSizes[i] = 0; + } - LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (parallel): " - << parallelTileSizes << "\n"); - LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (reduction): " - << reductionTileSizes << "\n"); + TileSizesListType tileSizes = {distTileSizes, vecTileSizes}; - TileSizesListType tileSizes = {distTileSizes, parallelTileSizes, - reductionTileSizes}; + // TODO: (Groverkss): Tile K2 here using reduction tiling interface once we + // have it. TileAndDecomposeAttention pass only tiles K2. I think it should + // be possible to tile K1 also, but need to explore it more. return setOpConfigAndEntryPointFnTranslation( entryPointFn, attnOp, tileSizes, @@ -1842,9 +1843,6 @@ setWinogradRootConfig(mlir::FunctionOpInterface entryPointFn, tileSizes.push_back(distTileSizes); SmallVector vecTileSizes(iterationRank, 1); tileSizes.push_back(vecTileSizes); - // Dummy tiling config for reduction level. - SmallVector reductionTileSizes(iterationRank, 0); - tileSizes.push_back(reductionTileSizes); return setOpConfigAndEntryPointFnTranslation( entryPointFn, winogradOp, tileSizes, DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 6a4363ff47af..31fcead83eca 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -617,13 +617,10 @@ void addCPULinalgExtTileAndVectorizePipeline( createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel())); // TODO: Remove the pass once we have PartialReductionOpInterface implemented // for AttentionOp. - funcPassManager.addPass( - IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass()); - funcPassManager.addPass( - createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel())); + funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass()); + funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass()); funcPassManager.addPass( IREE::LinalgExt::createDecomposeWinogradTransformPass()); - funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass()); { GenericVectorizationPassOptions options; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir index 6777e97511f5..4df56056a6f4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir @@ -1531,7 +1531,7 @@ module { return } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @winograd_output_transform() // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -1556,7 +1556,7 @@ module { return } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @winograd_input_transform() // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -1581,7 +1581,7 @@ module { return } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @winograd_filter_transform() // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -1613,7 +1613,7 @@ module { return } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @attention() // CHECK-SAME: translation_info = #[[TRANSLATION]] diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel index a771ee0f6a74..8675c431b3e2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel @@ -29,7 +29,6 @@ iree_td_library( "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:LinalgOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:PDLDialectTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", @@ -160,9 +159,7 @@ iree_gentbl_cc_library( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "LinalgExtOps.td", - deps = [ - ":td_files", - ], + deps = [":td_files"], ) iree_gentbl_cc_library( @@ -215,7 +212,5 @@ iree_tablegen_doc( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "LinalgExtOps.td", - deps = [ - ":td_files", - ], + deps = [":td_files"], ) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 34ce4f4c32d4..c5c42ec41920 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -8,7 +8,6 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" -#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -1316,9 +1315,6 @@ LogicalResult AttentionOp::verify() { for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) { AffineDimExpr dim = cast(dimExpr); int64_t pos = dim.getPosition(); - if (ShapedType::isDynamic(valShape[i])) { - continue; - } if (!foundDims[pos]) { foundDims[pos] = true; shape[pos] = valShape[i]; @@ -1431,79 +1427,6 @@ SmallVector AttentionOp::getIndexingMapsArray() { return results; } -//===----------------------------------------------------------------------===// -// OnlineAttentionOp -//===----------------------------------------------------------------------===// - -LogicalResult OnlineAttentionOp::verify() { - OnlineAttentionOp attnOp = *this; - - SmallVector indexingMaps = attnOp.getIndexingMapsArray(); - - // Check if indexing maps can represent attention. - FailureOr maybeOpInfo = - AttentionOpDetail::get(indexingMaps); - - // Check shape compatibility based on indexing maps. - SmallVector shape(getIterationDomainRank()); - SmallVector foundDims(getIterationDomainRank(), false); - auto checkShape = [&shape, &foundDims, - &attnOp](StringRef operandName, ArrayRef valShape, - AffineMap indexingMap) -> LogicalResult { - if (indexingMap.getNumResults() != valShape.size()) { - return attnOp->emitError("Rank Mismatch for ") - << operandName << ". Expected: " << indexingMap.getNumResults() - << " Got: " << valShape.size(); - } - for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) { - AffineDimExpr dim = cast(dimExpr); - int64_t pos = dim.getPosition(); - if (ShapedType::isDynamic(valShape[i])) { - continue; - } - if (!foundDims[pos]) { - foundDims[pos] = true; - shape[pos] = valShape[i]; - } - if (shape[pos] != valShape[i]) { - return attnOp->emitError("Shape Mismatch for ") - << operandName << ". Expected: " << shape[pos] - << " Got: " << valShape[i]; - } - } - return success(); - }; - - if (failed(checkShape("Query", getQuery().getType().getShape(), - getQueryMap())) || - failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) || - failed(checkShape("Value", getValue().getType().getShape(), - getValueMap())) || - failed(checkShape("Output", getOutput().getType().getShape(), - getOutputMap())) || - failed(checkShape("Max", getMax().getType().getShape(), getMaxMap())) || - failed(checkShape("Sum", getSum().getType().getShape(), getSumMap()))) { - return failure(); - } - - return success(); -} - -MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() { - return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3); -} - -LogicalResult OnlineAttentionOp::reifyResultShapes( - OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -SmallVector OnlineAttentionOp::getIndexingMapsArray() { - return SmallVector( - getIndexingMaps().getAsValueRange()); -} - #define DEFINE_OP_GET_EFFECTS(OP_NAME) \ void OP_NAME::getEffects( \ SmallVectorImpl> \ @@ -1523,7 +1446,6 @@ DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp) DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp) DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp) DEFINE_OP_GET_EFFECTS(AttentionOp) -DEFINE_OP_GET_EFFECTS(OnlineAttentionOp) } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h index 3d52ae62b26b..97caaabc4699 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h @@ -7,7 +7,6 @@ #ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_ #define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_ -#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/Attributes.h" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 0eebd2e16976..bf9694d91002 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -9,7 +9,6 @@ include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td" include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td" -include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -679,96 +678,6 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention", }]; } -//===----------------------------------------------------------------------===// -// OnlineAttention -//===----------------------------------------------------------------------===// - -def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", - [DeclareOpInterfaceMethods, - DestinationStyleOpInterface, LinalgExtInterface, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Online Attention operator"; - let description = [{ - Traditional scaled dot product attention computes: - - attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V - - Online Attention on the other hand, uses an online normalizer instead of - softmax: - - online_attention(Q, K, V, scale, running_max, running_sum) - = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V - - The advantage of this online_normalizer is that it can be tiled along - it's reduction dimension, making the online_attention operator: - - Tilable along softmax reduction dimension - - Associative along softmax reduction dimension - - Commutative along softmax associative dimension - - Note: The results of online_attention need to be combined after computing - it over the entire softmax reduction dimension by: - x, _, sum : results - x = (1 / sum) * x - }]; - - let arguments = (ins AnyShaped:$query, - AnyShaped:$key, - AnyShaped:$value, - AnyFloat:$scale, - AnyShaped:$output, - AnyShaped:$max, - AnyShaped:$sum, - AffineMapArrayAttr:$indexing_maps - ); - - let results = (outs Variadic:$results); - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; - let assemblyFormat = [{ - attr-dict - `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)` - `outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)` - (`->` type($results)^)? - }]; - - let extraClassDeclaration = [{ - // Method to implement for specifying output range for - // DestinationStyleOpInterface - MutableOperandRange getDpsInitsMutable(); - - SmallVector getIndexingMapsArray(); - - AffineMap getQueryMap() { - return getIndexingMapsArray()[0]; - } - AffineMap getKeyMap() { - return getIndexingMapsArray()[1]; - } - AffineMap getValueMap() { - return getIndexingMapsArray()[2]; - } - AffineMap getOutputMap() { - return getIndexingMapsArray()[3]; - } - AffineMap getMaxMap() { - return getIndexingMapsArray()[4]; - } - AffineMap getSumMap() { - return getIndexingMapsArray()[5]; - } - - int64_t getIterationDomainRank() { - return getQueryMap().getNumDims(); - } - }]; -} - } // OpGroupNonStructuredOps //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp index 923b30a8da73..00bb383245b7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp @@ -51,17 +51,5 @@ DiagnosedSilenceableFailure LinalgExt::DecomposeTiledAttentionOp::applyToOne( return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure LinalgExt::ConvertToOnlineAttention::applyToOne( - transform::TransformRewriter &rewriter, LinalgExt::AttentionOp attentionOp, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { - SmallVector ops; - LinalgExt::convertToOnlineAttention(attentionOp, ops, rewriter); - for (Operation *op : ops) { - results.push_back(op); - } - return DiagnosedSilenceableFailure::success(); -} - #define GET_OP_CLASSES #include "iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td index 84e588af9621..c3a6310f30a2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/LinalgExtExtensionsOps.td @@ -86,36 +86,4 @@ def DecomposeTiledAttentionOp : Op { - let description = [{ - Target iree_linalg_ext.attention ops and decompose them. - This transform consumes the target handle and produces a result handle. - }]; - - let arguments = ( - ins TransformHandleTypeInterface:$target - ); - let results = (outs Variadic:$result); - - let assemblyFormat = "attr-dict $target `:` functional-type(operands, results)"; - let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt"; - - let assemblyFormat = [{ - $target attr-dict `:` functional-type(operands, results) - }]; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::transform::TransformRewriter &rewriter, - ::mlir::iree_compiler::IREE::LinalgExt::AttentionOp target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; -} - #endif // IREE_DIALECT_LINALGEXT_TRANSFORMOPS diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp deleted file mode 100644 index b4370e7ef1a0..000000000000 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" - -namespace mlir::iree_compiler::IREE::LinalgExt { - -static Value scaleValueInPlace(OpBuilder &builder, Location loc, - AffineMap inputMap, AffineMap scaleMap, - Value value, Value scale) { - SmallVector compressedMaps = - compressUnusedDims(SmallVector{inputMap, scaleMap}); - inputMap = compressedMaps[0]; - scaleMap = compressedMaps[1]; - - SmallVector iteratorTypes(inputMap.getNumDims(), - utils::IteratorType::parallel); - - auto genericOp = builder.create( - loc, value.getType(), scale, value, - SmallVector{scaleMap, inputMap}, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - // Convert scale to the same datatype as input. - Value scale = convertScalarToDtype(b, loc, args[0], args[1].getType(), - /*isUnsignedCast=*/false); - Value result = b.create(loc, scale, args[1]); - b.create(loc, result); - }); - return genericOp.getResult(0); -} - -template -static Value reduce(OpBuilder &builder, Location loc, AffineMap inputMap, - AffineMap outputMap, Value input, Value output) { - SmallVector compressedMaps = - compressUnusedDims(SmallVector{inputMap, outputMap}); - inputMap = compressedMaps[0]; - outputMap = compressedMaps[1]; - - // Dims not present in outputMap are reductionDims. - SmallVector iteratorTypes( - inputMap.getNumDims(), utils::IteratorType::reduction); - for (AffineExpr dim : outputMap.getResults()) { - int pos = cast(dim).getPosition(); - iteratorTypes[pos] = utils::IteratorType::parallel; - } - - auto genericOp = builder.create( - loc, output.getType(), input, output, - SmallVector{inputMap, outputMap}, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - // Convert input to the same datatype as acc. - Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(), - /*isUnsignedCast=*/false); - Value result = b.create(loc, in, args[1]); - b.create(loc, result); - }); - - return genericOp.getResult(0); -} - -static Value computeMatmul(OpBuilder &builder, Location loc, AffineMap lhsMap, - AffineMap rhsMap, AffineMap accMap, Value lhs, - Value rhs, Value acc) { - - SmallVector compressedMaps = - compressUnusedDims(SmallVector{lhsMap, rhsMap, accMap}); - lhsMap = compressedMaps[0]; - rhsMap = compressedMaps[1]; - accMap = compressedMaps[2]; - - // Dims not present in accMap are reduction dims. - SmallVector iteratorTypes( - accMap.getNumDims(), utils::IteratorType::reduction); - for (AffineExpr dim : accMap.getResults()) { - int pos = cast(dim).getPosition(); - iteratorTypes[pos] = utils::IteratorType::parallel; - } - - auto genericOp = builder.create( - loc, acc.getType(), SmallVector{lhs, rhs}, acc, - SmallVector{lhsMap, rhsMap, accMap}, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - // Cast inputs to match output datatype. - Value lhs = convertScalarToDtype(b, loc, args[0], args[2].getType(), - /*isUnsignedCast=*/false); - Value rhs = convertScalarToDtype(b, loc, args[1], args[2].getType(), - /*isUnsignedCast=*/false); - Value mul = b.create(loc, lhs, rhs); - Value add = b.create(loc, mul, args[2]); - b.create(loc, add); - }); - - return genericOp.getResult(0); -} - -// Compute output = exp2(output - input) -static Value computeSubAndExp2(OpBuilder &builder, Location loc, - AffineMap inputMap, AffineMap outputMap, - Value input, Value output) { - SmallVector compressedMaps = - compressUnusedDims(SmallVector{inputMap, outputMap}); - inputMap = compressedMaps[0]; - outputMap = compressedMaps[1]; - - SmallVector iteratorTypes(inputMap.getNumDims(), - utils::IteratorType::parallel); - auto genericOp = builder.create( - loc, output.getType(), input, output, - SmallVector{inputMap, outputMap}, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - // Convert input to the same datatype as output. - Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(), - /*isUnsignedCast=*/false); - Value diff = b.create(loc, args[1], in); - Value weight = b.create(loc, diff); - b.create(loc, weight); - }); - return genericOp.getResult(0); -} - -FailureOr> -OnlineAttentionOp::decomposeOperation(OpBuilder &b) { - Location loc = getLoc(); - Value query = getQuery(); - Value key = getKey(); - Value value = getValue(); - Value oldAcc = getOutput(); - Value oldMax = getMax(); - Value oldSum = getSum(); - Type elementType = getQuery().getType().getElementType(); - - FailureOr maybeOpInfo = - AttentionOpDetail::get(getIndexingMapsArray()); - assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); - AttentionOpDetail opInfo = maybeOpInfo.value(); - - SmallVector sizes = llvm::map_to_vector( - getIterationDomain(b), [](Range x) { return x.size; }); - - // Since we use exp2 for attention instead of the original exp, we have to - // multiply the scale by log2(e). We use exp2 instead of exp as most platforms - // have better support for exp2 (we verified that we gain some speedup on - // some GPUs). - Value scale = getScale(); - Value log2e = - b.create(loc, b.getFloatAttr(elementType, M_LOG2E)); - scale = b.create(loc, scale, log2e); - - // In the original algorithm, the scaling is done after the softmax: - // softmax(Q @ K.T * scale) @ V - // - // But, it is mathematically equivalent to do it on Q first and then multiply - // it by K.T. This just allows us to do the scaling once, instead of each - // iteration of the loop. - AffineMap qMap = getQueryMap(); - AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(), - /*symbolCount=*/0, getContext()); - query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale); - - // ---- Matmul 1 ---- - - // Get sizes for S. - AffineMap sMap = opInfo.getSMap(); - SmallVector sSizes; - for (AffineExpr dimExpr : sMap.getResults()) { - int dim = cast(dimExpr).getPosition(); - sSizes.push_back(sizes[dim]); - } - - // S = Q @ K - // SMap = QMap @ KMap - Value emptyS = b.create(loc, sSizes, elementType); - Value sZero = b.create(loc, b.getZeroAttr(elementType)); - Value s = b.create(loc, sZero, emptyS).getResult(0); - s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); - - // TODO: This decomposition should be in a seperate op called - // "online softmax". - // ---- Online Softmax ---- - - // newMax = max(oldMax, rowMax(S)) - AffineMap maxMap = getMaxMap(); - Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); - - // P = exp2(S - newMax) - // PMap = SMap - AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); - - // norm = exp2(oldMax - newMax) - // normMap = maxMap - AffineMap normMap = getMaxMap(); - Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax); - - // normSum = norm * oldSum - AffineMap sumMap = getSumMap(); - Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm); - - // newSum = normSum + rowMax(P) - Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); - - // newAcc = norm * oldAcc - AffineMap accMap = getOutputMap(); - Value newAcc = scaleValueInPlace(b, loc, accMap, normMap, oldAcc, norm); - - // ---- Matmul 2 ---- - - // newAcc = P @ V + newAcc - newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc); - - return SmallVector{newAcc, newMax, newSum}; -} - -} // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel index 1a20a8f60f5a..1d08da973991 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel @@ -30,7 +30,6 @@ iree_gentbl_cc_library( iree_compiler_cc_library( name = "Transforms", srcs = [ - "AggregatedOpInterfaceImpl.cpp", "ConvertConv2DToWinograd.cpp", "ConvertToLoops.cpp", "DecomposeAttention.cpp", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt index 19c7522f002a..668d28aec84a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -26,7 +26,6 @@ iree_cc_library( "Passes.h" "Passes.h.inc" SRCS - "AggregatedOpInterfaceImpl.cpp" "ConvertConv2DToWinograd.cpp" "ConvertToLoops.cpp" "DecomposeAttention.cpp" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index 2cd851dc58a0..c70000f09778 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -366,16 +366,6 @@ void DecomposeAttentionPass::runOnOperation() { SmallVector ops; decomposeTiledAttention(attnOp, ops, rewriter, optionalTileSize); }); - getOperation().walk([&](OnlineAttentionOp onlineAtt) { - rewriter.setInsertionPoint(onlineAtt); - FailureOr> results = - onlineAtt.decomposeOperation(rewriter); - if (failed(results)) { - onlineAtt->emitOpError("Could not decompose online attention"); - return signalPassFailure(); - } - rewriter.replaceOp(onlineAtt, results.value()); - }); } std::unique_ptr createDecomposeAttentionPass() { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h index 165430a75c85..43d44a377d83 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h @@ -56,18 +56,12 @@ void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp, RewriterBase &rewriter, std::optional tileSize = std::nullopt); -void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp, - SmallVectorImpl &ops, - RewriterBase &rewriter); - // Creates a pass to tile the attention op along the reduction dim. std::unique_ptr createTileAttentionPass(); // Creates a pass to convert the attention op into a sequence of linalg ops. std::unique_ptr createDecomposeAttentionPass(); -std::unique_ptr createConvertAttentionToOnlineAttentionPass(); - //===---------------------------------------------------------------------===// // Codegen Strategy passes that are moved into IREE. //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td index ee801b623870..77bb10596821 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td @@ -98,12 +98,4 @@ def DecomposeAttention : ]; } -def ConvertAttentionToOnlineAttention : - InterfacePass<"iree-linalg-ext-convert-attention-to-online-attention", - "mlir::FunctionOpInterface"> { - let summary = ""; - let constructor = "mlir::iree_compiler::IREE::LinalgExt::" - "createConvertAttentionToOnlineAttentionPass()"; -} - #endif // IREE_DIALECT_LINALGEXT_PASSES diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp index df0d87e13fb2..4c862b58e5c7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp @@ -8,7 +8,6 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" -#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -159,17 +158,6 @@ struct TileAttentionPass : public TileAttentionBase { void runOnOperation() override; }; -struct ConvertAttentionToOnlineAttentionPass final - : ConvertAttentionToOnlineAttentionBase< - ConvertAttentionToOnlineAttentionPass> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - void runOnOperation() override; -}; - } // namespace /// Tile iree_linalg_ext.attention. @@ -304,103 +292,6 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp, return tiledAttentionOp; } -void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp, - SmallVectorImpl &ops, - RewriterBase &rewriter) { - rewriter.setInsertionPoint(attnOp); - - Location loc = attnOp.getLoc(); - MLIRContext *ctx = attnOp.getContext(); - - FailureOr maybeOpInfo = - AttentionOpDetail::get(attnOp.getIndexingMapsArray()); - assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); - AttentionOpDetail opInfo = maybeOpInfo.value(); - - // Create standard maps for max and sum: (batch, m) - int64_t rank = opInfo.getDomainRank(); - AffineMap maxMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, ctx); - for (auto dim : - llvm::concat(opInfo.getBatchDims(), opInfo.getMDims())) { - maxMap = maxMap.insertResult(rewriter.getAffineDimExpr(dim), - maxMap.getNumResults()); - } - AffineMap sumMap = maxMap; - - SmallVector sizes = attnOp.getIterationDomain(rewriter); - - // Create fill for acc, max and sum. - // TODO: Acc should not need a fill. The attention op should get a filled - // input instead of an empty input. - Value zeroAcc = rewriter.create( - loc, rewriter.getZeroAttr(attnOp.getOutputType().getElementType())); - Value accFill = - rewriter - .create(loc, ValueRange{zeroAcc}, attnOp.getOutput()) - .result(); - - SmallVector rowRedSize = - llvm::map_to_vector(sizes, [](Range x) { return x.size; }); - rowRedSize = applyPermutationMap(maxMap, rowRedSize); - - Type f32Type = rewriter.getF32Type(); - Value rowRedEmpty = - rewriter.create(loc, rowRedSize, f32Type); - - Value maxInit = - arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, rewriter, - loc, /*useOnlyFiniteValue=*/true); - Value sumInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, - rewriter, loc); - - Value maxFill = - rewriter.create(loc, ValueRange{maxInit}, rowRedEmpty) - .getResult(0); - Value sumFill = - rewriter.create(loc, ValueRange{sumInit}, rowRedEmpty) - .getResult(0); - - // Create online attention op. - SmallVector indexingMaps = attnOp.getIndexingMapsArray(); - indexingMaps.push_back(maxMap); - indexingMaps.push_back(sumMap); - OnlineAttentionOp onlineAttn = rewriter.create( - loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()}, - attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(), - accFill, maxFill, sumFill, rewriter.getAffineMapArrayAttr(indexingMaps)); - onlineAttn->setDiscardableAttrs(attnOp->getDiscardableAttrDictionary()); - ops.push_back(onlineAttn); - - Value x = onlineAttn.getResult(0); - Value sum = onlineAttn.getResult(2); - - // Merge the outputs of online attention: - // x = (1 / sum) * x - - // Compress the indexing maps. - SmallVector compressedMaps = - compressUnusedDims(SmallVector{sumMap, attnOp.getOutputMap()}); - - SmallVector iteratorTypes(compressedMaps[0].getNumDims(), - utils::IteratorType::parallel); - - auto genericOp = rewriter.create( - loc, x.getType(), sum, x, compressedMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value one = b.create( - loc, b.getFloatAttr(args[0].getType(), 1.0)); - Value reciprocal = b.create(loc, one, args[0]); - // Convert sum to the same datatype as x. - reciprocal = convertScalarToDtype(b, loc, reciprocal, args[1].getType(), - /*isUnsignedCast=*/false); - Value result = b.create(loc, reciprocal, args[1]); - b.create(loc, result); - }); - ops.push_back(genericOp); - - rewriter.replaceOp(attnOp, genericOp); -} - void TileAttentionPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); @@ -414,21 +305,8 @@ void TileAttentionPass::runOnOperation() { }); } -void ConvertAttentionToOnlineAttentionPass::runOnOperation() { - MLIRContext *context = &getContext(); - IRRewriter rewriter(context); - getOperation().walk([&](AttentionOp attnOp) { - SmallVector ops; - convertToOnlineAttention(attnOp, ops, rewriter); - }); -} - std::unique_ptr createTileAttentionPass() { return std::make_unique(); } -std::unique_ptr createConvertAttentionToOnlineAttentionPass() { - return std::make_unique(); -} - } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp index 2ecfd80f02e4..6b291c730bbe 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp @@ -4,9 +4,11 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -14,7 +16,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" namespace mlir::iree_compiler::IREE::LinalgExt { @@ -1603,16 +1605,16 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition( } //===----------------------------------------------------------------------===// -// Attention Helpers +// AttentionOp //===----------------------------------------------------------------------===// -static SmallVector -getAttentionIterationDomain(Location loc, OpBuilder &b, int64_t domainRank, - ArrayRef values, - ArrayRef indexingMaps) { +SmallVector AttentionOp::getIterationDomain(OpBuilder &builder) { + int64_t domainRank = getIterationDomainRank(); + SmallVector loopBounds(domainRank); - OpFoldResult zero = b.getIndexAttr(0); - OpFoldResult one = b.getIndexAttr(1); + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); for (auto dim : llvm::seq(0, domainRank)) { loopBounds[dim].offset = zero; @@ -1629,27 +1631,26 @@ getAttentionIterationDomain(Location loc, OpBuilder &b, int64_t domainRank, continue; } dimsFound[pos] = true; - loopBounds[pos].size = getDimValue(b, loc, val, idx); + loopBounds[pos].size = getDimValue(builder, loc, val, idx); } }; - for (auto [val, indexingMap] : llvm::zip_equal(values, indexingMaps)) { - fillSizes(val, indexingMap); - } + // Sizes can be found from Q, K, V alone. + fillSizes(getQuery(), getQueryMap()); + fillSizes(getKey(), getKeyMap()); + fillSizes(getValue(), getValueMap()); return loopBounds; } -static SmallVector -getAttentionIteratorTypes(int64_t domainRank, - ArrayRef indexingMaps) { +SmallVector AttentionOp::getLoopIteratorTypes() { FailureOr maybeOpInfo = - AttentionOpDetail::get(indexingMaps); + AttentionOpDetail::get(getIndexingMapsArray()); assert(succeeded(maybeOpInfo) && "Failed to infer attention op details"); AttentionOpDetail opInfo = maybeOpInfo.value(); // All dimensions other than k1 and k2 are parallel. - SmallVector iteratorTypes(domainRank, + SmallVector iteratorTypes(getIterationDomainRank(), utils::IteratorType::parallel); for (auto dim : @@ -1660,42 +1661,6 @@ getAttentionIteratorTypes(int64_t domainRank, return iteratorTypes; } -static SmallVector getPermutedSlice(AffineMap permutation, - ArrayRef offsets, - ArrayRef sizes) { - auto one = IntegerAttr::get(IndexType::get(permutation.getContext()), 1); - assert(permutation.isProjectedPermutation() && - "Indexing map should be a projected permutation"); - SmallVector output; - for (AffineExpr dimExpr : permutation.getResults()) { - int dim = cast(dimExpr).getPosition(); - Range dimRange; - dimRange.offset = offsets[dim]; - dimRange.size = sizes[dim]; - dimRange.stride = one; - output.push_back(dimRange); - } - return output; -} - -//===----------------------------------------------------------------------===// -// AttentionOp -//===----------------------------------------------------------------------===// - -SmallVector AttentionOp::getIterationDomain(OpBuilder &b) { - // Attention shape can be determined from Q, K, V alone. - SmallVector shapedValues = {getQuery(), getKey(), getValue()}; - SmallVector indexingMaps = {getQueryMap(), getKeyMap(), - getValueMap()}; - return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(), - shapedValues, indexingMaps); -} - -SmallVector AttentionOp::getLoopIteratorTypes() { - return getAttentionIteratorTypes(getIterationDomainRank(), - getIndexingMapsArray()); -} - FailureOr AttentionOp::getTiledImplementation(OpBuilder &builder, ArrayRef offsets, @@ -1704,36 +1669,59 @@ AttentionOp::getTiledImplementation(OpBuilder &builder, assert(sizes.size() == getIterationDomainRank()); Location loc = getLoc(); + auto one = builder.getIndexAttr(1); + + auto tileValue = [&](Value val, AffineMap indexingMap) + -> std::tuple, SmallVector, + SmallVector> { + assert(indexingMap.isProjectedPermutation() && + "Indexing map should be a projected permutation"); + SmallVector outputOffsets; + SmallVector outputSizes; + SmallVector outputStrides(indexingMap.getNumResults(), one); + for (AffineExpr dimExpr : indexingMap.getResults()) { + int dim = cast(dimExpr).getPosition(); + outputOffsets.push_back(offsets[dim]); + outputSizes.push_back(sizes[dim]); + } + return {outputOffsets, outputSizes, outputStrides}; + }; - SmallVector querySlice = - getPermutedSlice(getQueryMap(), offsets, sizes); - SmallVector keySlice = getPermutedSlice(getKeyMap(), offsets, sizes); - SmallVector valueSlice = - getPermutedSlice(getValueMap(), offsets, sizes); - SmallVector outputSlice = - getPermutedSlice(getOutputMap(), offsets, sizes); + auto [queryOffsets, querySizes, queryStrides] = + tileValue(getQuery(), getQueryMap()); + auto [keyOffsets, keySizes, keyStrides] = tileValue(getKey(), getKeyMap()); + auto [valueOffsets, valueSizes, valueStrides] = + tileValue(getValue(), getValueMap()); + auto [outputOffsets, outputSizes, outputStrides] = + tileValue(getOutput(), getOutputMap()); Value scale = getScale(); SmallVector tiledOperands; - tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), querySlice)); - tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice)); - tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice)); + tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), queryOffsets, + querySizes, queryStrides)); + tiledOperands.emplace_back( + getSlice(builder, loc, getKey(), keyOffsets, keySizes, keyStrides)); + tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueOffsets, + valueSizes, valueStrides)); tiledOperands.emplace_back(scale); - tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice)); + tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputOffsets, + outputSizes, outputStrides)); std::optional max = getMax(); if (max) { - SmallVector maxSlice = - getPermutedSlice(*getMaxMap(), offsets, sizes); - tiledOperands.emplace_back(getSlice(builder, loc, max.value(), maxSlice)); + auto [maxOffsets, maxSizes, maxStrides] = + tileValue(max.value(), *getMaxMap()); + tiledOperands.emplace_back( + getSlice(builder, loc, max.value(), maxOffsets, maxSizes, maxStrides)); } std::optional sum = getMax(); if (sum) { - SmallVector sumSlice = - getPermutedSlice(*getSumMap(), offsets, sizes); - tiledOperands.emplace_back(getSlice(builder, loc, sum.value(), sumSlice)); + auto [sumOffsets, sumSizes, sumStrides] = + tileValue(sum.value(), *getSumMap()); + tiledOperands.emplace_back( + getSlice(builder, loc, sum.value(), sumOffsets, sumSizes, sumStrides)); } SmallVector resultTypes; @@ -1783,93 +1771,4 @@ LogicalResult AttentionOp::getResultTilePosition( return success(); } -//===----------------------------------------------------------------------===// -// OnlineAttentionOp -//===----------------------------------------------------------------------===// - -SmallVector OnlineAttentionOp::getIterationDomain(OpBuilder &b) { - // Attention shape can be determined from Q, K, V alone. - SmallVector shapedValues = {getQuery(), getKey(), getValue()}; - SmallVector indexingMaps = {getQueryMap(), getKeyMap(), - getValueMap()}; - return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(), - shapedValues, indexingMaps); -} - -SmallVector OnlineAttentionOp::getLoopIteratorTypes() { - return getAttentionIteratorTypes(getIterationDomainRank(), - getIndexingMapsArray()); -} - -FailureOr -OnlineAttentionOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - assert(offsets.size() == getIterationDomainRank()); - assert(sizes.size() == getIterationDomainRank()); - - Location loc = getLoc(); - - SmallVector querySlice = - getPermutedSlice(getQueryMap(), offsets, sizes); - SmallVector keySlice = getPermutedSlice(getKeyMap(), offsets, sizes); - SmallVector valueSlice = - getPermutedSlice(getValueMap(), offsets, sizes); - SmallVector outputSlice = - getPermutedSlice(getOutputMap(), offsets, sizes); - SmallVector maxSlice = getPermutedSlice(getMaxMap(), offsets, sizes); - SmallVector sumSlice = getPermutedSlice(getSumMap(), offsets, sizes); - - Value scale = getScale(); - - SmallVector tiledOperands; - tiledOperands.emplace_back(getSlice(builder, loc, getQuery(), querySlice)); - tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keySlice)); - tiledOperands.emplace_back(getSlice(builder, loc, getValue(), valueSlice)); - tiledOperands.emplace_back(scale); - tiledOperands.emplace_back(getSlice(builder, loc, getOutput(), outputSlice)); - tiledOperands.emplace_back(getSlice(builder, loc, getMax(), maxSlice)); - tiledOperands.emplace_back(getSlice(builder, loc, getSum(), sumSlice)); - - SmallVector resultTypes; - resultTypes.push_back(tiledOperands[4].getType()); - resultTypes.push_back(tiledOperands[5].getType()); - resultTypes.push_back(tiledOperands[6].getType()); - - Operation *tiledOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; -} - -LogicalResult OnlineAttentionOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - resultOffsets.clear(); - resultSizes.clear(); - - AffineMap resultIndexingMap; - switch (resultNumber) { - case 0: - resultIndexingMap = getOutputMap(); - break; - case 1: - resultIndexingMap = getMaxMap(); - break; - case 2: - resultIndexingMap = getSumMap(); - break; - default: - return failure(); - } - - for (AffineExpr dimExpr : resultIndexingMap.getResults()) { - int dim = cast(dimExpr).getPosition(); - resultOffsets.push_back(offsets[dim]); - resultSizes.push_back(sizes[dim]); - } - return success(); -} - } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel index fe21d60b99ff..19176319e7d6 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel @@ -19,7 +19,6 @@ iree_lit_test_suite( "conv2d_to_winograd.mlir", "convert_to_loops.mlir", "decompose_attention.mlir", - "decompose_online_attention.mlir", "decompose_winograd.mlir", "distribution.mlir", "pad_contraction_to_block_size.mlir", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt index 7ef5fd789a33..92abdf79269c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt @@ -17,7 +17,6 @@ iree_lit_test_suite( "conv2d_to_winograd.mlir" "convert_to_loops.mlir" "decompose_attention.mlir" - "decompose_online_attention.mlir" "decompose_winograd.mlir" "distribution.mlir" "pad_contraction_to_block_size.mlir" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir deleted file mode 100644 index 945cff8e32a0..000000000000 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-attention),canonicalize,cse)" %s | FileCheck %s - -#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> -#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> -#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> -#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> -#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> - -func.func @attention_f16(%query: tensor<192x1024x64xf16>, - %key: tensor<192x1024x64xf16>, - %value: tensor<192x1024x64xf16>, - %output: tensor<192x1024x64xf32>, - %max: tensor<192x1024xf32>, - %sum: tensor<192x1024xf32>) - -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) { - %scale = arith.constant 1.0 : f16 - - %out:3 = iree_linalg_ext.online_attention - { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] } - ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) - outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) - -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> - - return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> -} - -// We just want to check if we are using the correct algorithm. -// CHECK-LABEL: @attention_f16 -// S = Q @ K -// CHECK: linalg.generic -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield -// newMax = max(oldMax, rowMax(S)) -// CHECK: linalg.generic -// CHECK: arith.maximumf -// CHECK: linalg.yield -// P = exp2(S - newMax) -// CHECK: linalg.generic -// CHECK: arith.subf -// CHECK: math.exp2 -// CHECK: linalg.yield -// norm = exp2(oldMax - newMax) -// CHECK: linalg.generic -// CHECK: arith.subf -// CHECK: math.exp2 -// CHECK: linalg.yield -// normSum = norm * oldSum -// CHECK: linalg.generic -// CHECK: arith.mulf -// CHECK: linalg.yield -// newSum = normSum + rowMax(P) -// CHECK: linalg.generic -// CHECK: arith.addf -// CHECK: linalg.yield -// newAcc = norm * oldAcc -// CHECK: linalg.generic -// CHECK: arith.mulf -// CHECK: linalg.yield -// newAcc = P @ V + newAcc -// CHECK: linalg.generic -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index a974f3a3fba4..f92bdb838da6 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir @@ -1536,66 +1536,3 @@ module attributes { transform.with_named_sequence } { // CHECK: } // CHECK: return // CHECK: } - -// ----- - -#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> -#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> -#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> -#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> -#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> - -func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { - %scale = arith.constant 1.0 : f32 - - %output_empty = tensor.empty() : tensor<192x1024x64xf32> - %row_red_empty = tensor.empty() : tensor<192x1024xf32> - - %sum_ident = arith.constant 0.000000e+00 : f32 - %max_ident = arith.constant -3.40282347E+38 : f32 - - %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> - %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32> - %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32> - - %out:3 = iree_linalg_ext.online_attention - { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] } - ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) - outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) - -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> - - return %out#0 : tensor<192x1024x64xf32> -} - -// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)> -// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)> -// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> -// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> -// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> -// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> -// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> -// CHECK-LABEL: @online_attention -// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2) -// CHECK-DAG: %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]]) -// CHECK-DAG: %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]]) -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]]) -// CHECK-DAG: %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32> -// CHECK-DAG: %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32> -// CHECK-DAG: %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32> -// CHECK-DAG: %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32> -// CHECK-DAG: %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32> -// CHECK-DAG: %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32> -// CHECK-DAG: iree_linalg_ext.online_attention -// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP4]]]} -// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}} : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32) -// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>) -// CHECK: scf.forall.in_parallel - -module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op - %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp index ac5f5c04aa23..7feee69c1eb2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp @@ -27,7 +27,7 @@ findPermutationsIndexingOperand(AffineMap indexingMap) { void AttentionOpDetail::inferFromIndexingMaps( ArrayRef indexingMaps) { - assert(indexingMaps.size() >= 4); + assert(indexingMaps.size() == 4); AffineMap qMap = indexingMaps[0]; AffineMap kMap = indexingMaps[1]; AffineMap vMap = indexingMaps[2]; @@ -82,23 +82,7 @@ AttentionOpDetail::get(ArrayRef indexingMaps) { AttentionOpDetail opInfo; opInfo.inferFromIndexingMaps(indexingMaps); - opInfo.maps = SmallVector(indexingMaps); return opInfo; } -AffineMap AttentionOpDetail::getSMap() const { - // We need to create an indexing map for the intermediate result of first - // matmul. There could be other options, but we choose to create a standard - // indexing map: - // SMap = (batch, m, k1, k2, n) -> (batch, m, k2) - AffineMap sMap = AffineMap::get(/*dimCount=*/getDomainRank(), - /*symbolCount=*/0, getContext()); - for (auto dim : - llvm::concat(getBatchDims(), getMDims(), getK2Dims())) { - AffineExpr dimExpr = getAffineDimExpr(dim, getContext()); - sMap = sMap.insertResult(dimExpr, sMap.getNumResults()); - } - return sMap; -} - }; // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h index e66bc865f011..cfba4ff51ece 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h @@ -42,29 +42,20 @@ class AttentionOpDetail { public: static FailureOr get(ArrayRef indexingMaps); - int64_t getDomainRank() const { return maps[0].getNumDims(); } ArrayRef getBatchDims() const { return batch; } ArrayRef getMDims() const { return m; } ArrayRef getK1Dims() const { return k1; } ArrayRef getK2Dims() const { return k2; } ArrayRef getNDims() const { return n; } - ArrayRef getIndexingMaps() const { return maps; } - - AffineMap getSMap() const; - private: void inferFromIndexingMaps(ArrayRef indexingMaps); - MLIRContext *getContext() const { return maps[0].getContext(); } - SmallVector batch; SmallVector m; SmallVector k1; SmallVector k2; SmallVector n; - - SmallVector maps; }; }; // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index e6c3548e2f64..d2313221246d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -44,13 +44,6 @@ SmallVector getDims(OpBuilder &builder, Location loc, [&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); }); } -Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef slice) { - return getSlice(b, loc, src, - llvm::map_to_vector(slice, [](Range x) { return x.offset; }), - llvm::map_to_vector(slice, [](Range x) { return x.size; }), - llvm::map_to_vector(slice, [](Range x) { return x.stride; })); -} - Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index eec973fcd2bd..9b40a5425d63 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -13,10 +13,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" -namespace mlir { -struct Range; -}; // namespace mlir - namespace mlir::iree_compiler::IREE::LinalgExt { /// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at @@ -30,7 +26,6 @@ SmallVector getDims(OpBuilder &builder, Location loc, Value v); /// Returns a `memref.subview` or a `tensor.extract_slice` based on the type of /// `src`. -Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef slice); Value getSlice(OpBuilder &b, Location loc, Value src, ArrayRef offsets, ArrayRef sizes, ArrayRef strides);