Skip to content

Commit

Permalink
Merge f03b6fc into 10877f6
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhanW authored Jul 24, 2024
2 parents 10877f6 + f03b6fc commit 41084fc
Show file tree
Hide file tree
Showing 14 changed files with 313 additions and 72 deletions.
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ iree_compiler_cc_library(
"LLVMCPUAssignConstantOrdinals.cpp",
"LLVMCPUAssignImportOrdinals.cpp",
"LLVMCPUCheckIRBeforeLLVMConversion.cpp",
"LLVMCPUDropVectorUnitDims.cpp",
"LLVMCPUEmitVectorizationRemarks.cpp",
"LLVMCPULinkExecutables.cpp",
"LLVMCPULowerExecutableTarget.cpp",
"LLVMCPUMmt4dVectorLowering.cpp",
"LLVMCPUOptimizeVectorShapes.cpp",
"LLVMCPUPeel.cpp",
"LLVMCPUSelectLoweringStrategy.cpp",
"LLVMCPUSplitReduction.cpp",
Expand All @@ -69,6 +69,7 @@ iree_compiler_cc_library(
"LLVMCPUTileAndFuse.cpp",
"LLVMCPUUnfuseFMAOps.cpp",
"LLVMCPUVectorShapeCastLowering.cpp",
"LLVMCPUVectorBitCastLowering.cpp",
"LLVMCPUVectorTransferLowering.cpp",
"LLVMCPUVectorTransposeLowering.cpp",
"LLVMCPUVerifyVectorSizeLegality.cpp",
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,19 @@ iree_cc_library(
"LLVMCPUAssignConstantOrdinals.cpp"
"LLVMCPUAssignImportOrdinals.cpp"
"LLVMCPUCheckIRBeforeLLVMConversion.cpp"
"LLVMCPUDropVectorUnitDims.cpp"
"LLVMCPUEmitVectorizationRemarks.cpp"
"LLVMCPULinkExecutables.cpp"
"LLVMCPULowerExecutableTarget.cpp"
"LLVMCPUMmt4dVectorLowering.cpp"
"LLVMCPUOptimizeVectorShapes.cpp"
"LLVMCPUPeel.cpp"
"LLVMCPUSelectLoweringStrategy.cpp"
"LLVMCPUSplitReduction.cpp"
"LLVMCPUSynchronizeSymbolVisibility.cpp"
"LLVMCPUTile.cpp"
"LLVMCPUTileAndFuse.cpp"
"LLVMCPUUnfuseFMAOps.cpp"
"LLVMCPUVectorBitCastLowering.cpp"
"LLVMCPUVectorShapeCastLowering.cpp"
"LLVMCPUVectorTransferLowering.cpp"
"LLVMCPUVectorTransposeLowering.cpp"
Expand Down
18 changes: 11 additions & 7 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,14 +1635,18 @@ getPackVectorTileSizes(mlir::FunctionOpInterface entryPointFn,
if (!hasAVX512fFeature(targetAttr) || !isPackMatmulLHS(op)) {
return tileSizes;
}
if (op.getSourceType().getElementType().isF32()) {
SmallVector<int64_t> innerTiles = op.getStaticTiles();
Type elemType = op.getSourceType().getElementType();
if (elemType.isF32() && innerTiles.back() == 1) {
tileSizes.back() = vectorSize;
}
// TODO(#16314): Generate efficient tile sizes for non-f32 cases.
if (op.getSourceType().getElementType().isF16()) {
if (elemType.isF16() || elemType.isBF16()) {
// We adjust the vector size to half to use the same lowering strategy as
// f32.
tileSizes.back() = vectorSize / 2;
tileSizes.back() = innerTiles[0];
}
if (elemType.isInteger(8)) {
tileSizes.back() = innerTiles[0];
}
return tileSizes;
}
Expand Down Expand Up @@ -2193,14 +2197,14 @@ static LogicalResult setElementwiseGenericOpRootConfig(
LLVM_DEBUG(KD_DBGS() << "Vector pre-processing strategy: "
<< vecPreProcStrategy << "\n");

// Adjust tiling sizes of vector levels to avoid large unroll factors. Most of
// the cases are f32 and i32, so we divide it by 4.
int64_t vecSize = getNativeVectorSizeInBytes(entryPointFn) / 4;
int64_t vecSize = getNativeVectorSizeInBytes(entryPointFn);
SmallVector<int64_t> vecTileSizes = distConfig.minTileSizes;
LLVM_DEBUG(KD_DBGS() << "vecTileSizes: " << vecTileSizes << "\n");
for (auto &i : vecTileSizes) {
i = roundUpToPow2(std::min(i, vecSize),
vecPreProcStrategy == VectorPreProcStrategy::Masking);
}
LLVM_DEBUG(KD_DBGS() << "vecTileSizes: " << vecTileSizes << "\n");

// Setting reduction tile sizes is a workaround to kick in peeling transform.
// The tiling won't happen because the sizes are zeros. Also, no need for
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-llvmcpu-optimize-vector-shapes"

using namespace mlir::iree_compiler;

static unsigned getNativeVectorSizeInBytes(mlir::FunctionOpInterface funcOp) {
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
auto nativeVectorSizeAttr =
getConfigIntegerAttr(targetAttr, "native_vector_size");
if (nativeVectorSizeAttr) {
unsigned nativeVectorSizeVal = nativeVectorSizeAttr->getInt();
if (nativeVectorSizeVal) {
return nativeVectorSizeVal;
}
}

return 0;
}

namespace mlir::iree_compiler {
namespace {
class LLVMCPUOptimizeVectorShapesPass
: public LLVMCPUOptimizeVectorShapesBase<LLVMCPUOptimizeVectorShapesPass> {
public:
using LLVMCPUOptimizeVectorShapesBase::LLVMCPUOptimizeVectorShapesBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
void runOnOperation() override;
};

// TODO: Replace/remove the existing OptimizeVectorTransferPass.
void LLVMCPUOptimizeVectorShapesPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

// Apply transfer ops write to read forwarding and dead transfer write
// optimizations.
// TODO: Remove store-to-load forwarding to a separate pass as it's unrelated
// to the vector shape optimizations applied later.
IRRewriter rewriter(ctx);
vector::transferOpflowOpt(rewriter, funcOp);

// Remove unit dimensons.
// TODO: Revisit some of these patterns to make sure they are not redundant.
// Dropping trailing unit dims from transfer ops is equivalent to apply
// flattening.
{
RewritePatternSet patterns(ctx);
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
vector::populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
patterns);
vector::populateVectorTransferDropUnitDimsPatterns(patterns);
vector::populateDropUnitDimWithShapeCastPatterns(patterns);
vector::InsertOp::getCanonicalizationPatterns(patterns, ctx);
vector::ExtractOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
return signalPassFailure();
}

unsigned targetVectorBitwidth = getNativeVectorSizeInBytes(funcOp) * 8;
if (targetVectorBitwidth == 0)
return;

// Collapse dimensions from simple operations and transfer ops that are
// contiguous in memoryif main vector dimension of those ops is lower than the
// target vector length.
{
RewritePatternSet patterns(ctx);
TypeConverter typeConverter;
ConversionTarget target(*ctx);
vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter, patterns, target, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
{
RewritePatternSet patterns(ctx);
vector::populateFlattenVectorTransferPatterns(patterns,
targetVectorBitwidth);
memref::CollapseShapeOp::getCanonicalizationPatterns(patterns, ctx);
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
return signalPassFailure();
}
}
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUOptimizeVectorShapesPass() {
return std::make_unique<LLVMCPUOptimizeVectorShapesPass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-llvmcpu-vector-bitcast-lowering"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_LLVMCPUVECTORBITCASTLOWERINGPASS
#include "iree/compiler/Codegen/LLVMCPU/Passes.h.inc"

namespace {
class LLVMCPUVectorBitCastLoweringPass
: public impl::LLVMCPUVectorBitCastLoweringPassBase<
LLVMCPUVectorBitCastLoweringPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
void runOnOperation() override;
};

void LLVMCPUVectorBitCastLoweringPass::runOnOperation() {
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

RewritePatternSet patterns(ctx);
vector::populateVectorBitCastLoweringPatterns(patterns, /*targetRank=*/1);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
} // namespace
} // namespace mlir::iree_compiler
Loading

0 comments on commit 41084fc

Please sign in to comment.