-
Notifications
You must be signed in to change notification settings - Fork 637
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
313 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 0 additions & 55 deletions
55
compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUDropVectorUnitDims.cpp
This file was deleted.
Oops, something went wrong.
110 changes: 110 additions & 0 deletions
110
compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUOptimizeVectorShapes.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" | ||
#include "iree/compiler/Codegen/LLVMCPU/Passes.h" | ||
#include "iree/compiler/Codegen/Utils/Utils.h" | ||
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" | ||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#define DEBUG_TYPE "iree-llvmcpu-optimize-vector-shapes" | ||
|
||
using namespace mlir::iree_compiler; | ||
|
||
static unsigned getNativeVectorSizeInBytes(mlir::FunctionOpInterface funcOp) { | ||
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); | ||
auto nativeVectorSizeAttr = | ||
getConfigIntegerAttr(targetAttr, "native_vector_size"); | ||
if (nativeVectorSizeAttr) { | ||
unsigned nativeVectorSizeVal = nativeVectorSizeAttr->getInt(); | ||
if (nativeVectorSizeVal) { | ||
return nativeVectorSizeVal; | ||
} | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
namespace mlir::iree_compiler { | ||
namespace { | ||
class LLVMCPUOptimizeVectorShapesPass | ||
: public LLVMCPUOptimizeVectorShapesBase<LLVMCPUOptimizeVectorShapesPass> { | ||
public: | ||
using LLVMCPUOptimizeVectorShapesBase::LLVMCPUOptimizeVectorShapesBase; | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<vector::VectorDialect>(); | ||
} | ||
void runOnOperation() override; | ||
}; | ||
|
||
// TODO: Replace/remove the existing OptimizeVectorTransferPass. | ||
void LLVMCPUOptimizeVectorShapesPass::runOnOperation() { | ||
MLIRContext *ctx = &getContext(); | ||
auto funcOp = getOperation(); | ||
|
||
// Apply transfer ops write to read forwarding and dead transfer write | ||
// optimizations. | ||
// TODO: Remove store-to-load forwarding to a separate pass as it's unrelated | ||
// to the vector shape optimizations applied later. | ||
IRRewriter rewriter(ctx); | ||
vector::transferOpflowOpt(rewriter, funcOp); | ||
|
||
// Remove unit dimensons. | ||
// TODO: Revisit some of these patterns to make sure they are not redundant. | ||
// Dropping trailing unit dims from transfer ops is equivalent to apply | ||
// flattening. | ||
{ | ||
RewritePatternSet patterns(ctx); | ||
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); | ||
vector::populateVectorTransferCollapseInnerMostContiguousDimsPatterns( | ||
patterns); | ||
vector::populateVectorTransferDropUnitDimsPatterns(patterns); | ||
vector::populateDropUnitDimWithShapeCastPatterns(patterns); | ||
vector::InsertOp::getCanonicalizationPatterns(patterns, ctx); | ||
vector::ExtractOp::getCanonicalizationPatterns(patterns, ctx); | ||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
|
||
unsigned targetVectorBitwidth = getNativeVectorSizeInBytes(funcOp) * 8; | ||
if (targetVectorBitwidth == 0) | ||
return; | ||
|
||
// Collapse dimensions from simple operations and transfer ops that are | ||
// contiguous in memoryif main vector dimension of those ops is lower than the | ||
// target vector length. | ||
{ | ||
RewritePatternSet patterns(ctx); | ||
TypeConverter typeConverter; | ||
ConversionTarget target(*ctx); | ||
vector::populateVectorLinearizeTypeConversionsAndLegality( | ||
typeConverter, patterns, target, targetVectorBitwidth); | ||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
{ | ||
RewritePatternSet patterns(ctx); | ||
vector::populateFlattenVectorTransferPatterns(patterns, | ||
targetVectorBitwidth); | ||
memref::CollapseShapeOp::getCanonicalizationPatterns(patterns, ctx); | ||
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); | ||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
} | ||
} // namespace | ||
|
||
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> | ||
createLLVMCPUOptimizeVectorShapesPass() { | ||
return std::make_unique<LLVMCPUOptimizeVectorShapesPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler |
40 changes: 40 additions & 0 deletions
40
compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorBitCastLowering.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" | ||
#include "iree/compiler/Codegen/LLVMCPU/Passes.h" | ||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#define DEBUG_TYPE "iree-llvmcpu-vector-bitcast-lowering" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_LLVMCPUVECTORBITCASTLOWERINGPASS | ||
#include "iree/compiler/Codegen/LLVMCPU/Passes.h.inc" | ||
|
||
namespace { | ||
class LLVMCPUVectorBitCastLoweringPass | ||
: public impl::LLVMCPUVectorBitCastLoweringPassBase< | ||
LLVMCPUVectorBitCastLoweringPass> { | ||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<vector::VectorDialect>(); | ||
} | ||
void runOnOperation() override; | ||
}; | ||
|
||
void LLVMCPUVectorBitCastLoweringPass::runOnOperation() { | ||
MLIRContext *ctx = &getContext(); | ||
auto funcOp = getOperation(); | ||
|
||
RewritePatternSet patterns(ctx); | ||
vector::populateVectorBitCastLoweringPatterns(patterns, /*targetRank=*/1); | ||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); | ||
} | ||
} // namespace | ||
} // namespace mlir::iree_compiler |
Oops, something went wrong.