Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse batch normalization into convolution kernel #2629

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions stablehlo/testdata/bn_conv_fuse_float32.large.mlir

Large diffs are not rendered by default.

118 changes: 118 additions & 0 deletions stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1970,3 +1970,121 @@ func.func @generic_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf3
%0 = "test_dialect.op"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xf32>)
return %0 : tensor<2xf32>
}


// -----

/////////
// BatchNormInferenceOp

// CHECK-LABEL: @fuse_conv_bninf
func.func @fuse_conv_bninf() -> (tensor<1x8x5x5xf32>) {
%input = stablehlo.constant dense<33.0> : tensor<1x3x8x8xf32>
%kernel = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
%conv = stablehlo.convolution(%input, %kernel)
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
window = {}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (tensor<1x3x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>

%dummy = stablehlo.constant dense<1.0> : tensor<8xf32>
%out = "stablehlo.batch_norm_inference"(%conv, %dummy, %dummy, %dummy, %dummy)
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
-> tensor<1x8x5x5xf32>

// CHECK-DAG: [[C0:%.+]] = stablehlo.convolution
// CHECK-DAG: [[C1:%.+]] = stablehlo.broadcast_in_dim
// CHECK-NOT: stablehlo.batch_norm_inference
// CHECK: [[C2:%.+]] = stablehlo.add [[C0]], [[C1]]
// CHECK: return [[C2]]
return %out : tensor<1x8x5x5xf32>
}

// CHECK-LABEL: @fuse_conv_bninf_unsupported_group
func.func @fuse_conv_bninf_unsupported_group()
-> (tensor<1x8x5x5xf32>, tensor<1x8x5x5xf32>) {
%input1 = stablehlo.constant dense<33.0> : tensor<2x3x8x8xf32>
%kernel1 = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
%conv1 = stablehlo.convolution(%input1, %kernel1)
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
{batch_group_count = 2 : i64, feature_group_count = 1 : i64}
: (tensor<2x3x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>

%input2 = stablehlo.constant dense<33.0> : tensor<1x6x8x8xf32>
%kernel2 = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
%conv2 = stablehlo.convolution(%input2, %kernel2)
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
{batch_group_count = 1 : i64, feature_group_count = 2 : i64}
: (tensor<1x6x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>

%cst = stablehlo.constant dense<1.0> : tensor<8xf32>
%out1 = "stablehlo.batch_norm_inference"(%conv1, %cst, %cst, %cst, %cst)
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
-> tensor<1x8x5x5xf32>

%out2 = "stablehlo.batch_norm_inference"(%conv2, %cst, %cst, %cst, %cst)
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
-> tensor<1x8x5x5xf32>

// CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
// CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
// CHECK: return [[C0]], [[C1]]
return %out1, %out2 : tensor<1x8x5x5xf32>, tensor<1x8x5x5xf32>
}

// CHECK-LABEL: @fuse_conv_bninf_unsupported_configuration
func.func @fuse_conv_bninf_unsupported_configuration()
-> (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) {
%input = stablehlo.constant dense<33.0> : tensor<1x1x1x1xf32>
%kernel = stablehlo.constant dense<0.1> : tensor<1x1x1x1xf32>

%conv1 = stablehlo.convolution(%input, %kernel)
dim_numbers = [f, b, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>

%conv2 = stablehlo.convolution(%input, %kernel)
dim_numbers = [0, 1, f, b]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>

%conv3 = stablehlo.convolution(%input, %kernel)
dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>

%conv4 = stablehlo.convolution(%input, %kernel)
dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>

%cst = stablehlo.constant dense<1.0> : tensor<1xf32>

%out1 = "stablehlo.batch_norm_inference"(%conv1, %cst, %cst, %cst, %cst)
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
-> tensor<1x1x1x1xf32>
%out2 = "stablehlo.batch_norm_inference"(%conv2, %cst, %cst, %cst, %cst)
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
-> tensor<1x1x1x1xf32>
%out3 = "stablehlo.batch_norm_inference"(%conv3, %cst, %cst, %cst, %cst)
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
-> tensor<1x1x1x1xf32>
%out4 = "stablehlo.batch_norm_inference"(%conv4, %cst, %cst, %cst, %cst)
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
-> tensor<1x1x1x1xf32>

// CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
// CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
// CHECK: [[C2:%.+]] = "stablehlo.batch_norm_inference"
// CHECK: [[C3:%.+]] = "stablehlo.batch_norm_inference"
// CHECK: return [[C0]], [[C1]], [[C2]], [[C3]]
return %out1, %out2, %out3, %out4 : tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>,
tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>
}
153 changes: 153 additions & 0 deletions stablehlo/transforms/StablehloAggressiveSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -1467,6 +1468,156 @@ struct ReorderElementwiseAndShapeOp final
}
};

// Fuses batch normalization operation with convolution kernel:
// X = conv(input, kernel.old)
// Y = batch_norm_inference(X, ...)
// into ->
// X = conv(input, kernel.new)
// Y = add(X, broadcast_in_dim(bias.new))
//
struct FuseConvolutionBatchNormalization final
: OpRewritePattern<BatchNormInferenceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(BatchNormInferenceOp op,
PatternRewriter &rewriter) const override {
auto bnOperandType = op.getOperand().getType();
auto bnOperandShape = bnOperandType.getShape();
auto bnResultType = op.getResult().getType();
uint64_t bnFeatureIndex = op.getFeatureIndex();

auto convOp = op.getOperand().getDefiningOp<ConvolutionOp>();
if (!convOp) return failure();

auto convKernel = convOp.getRhs();
auto convKernelType = convKernel.getType();
auto convKernelShape = convKernelType.getShape();

auto dimNumbers = convOp.getDimensionNumbers();
if (dimNumbers.getInputBatchDimension() != 0 ||
dimNumbers.getInputFeatureDimension() != 1 ||
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
dimNumbers.getKernelInputFeatureDimension() != 1 ||
dimNumbers.getOutputBatchDimension() != 0 ||
dimNumbers.getOutputFeatureDimension() != 1) {
constexpr StringLiteral msg =
"Only [b, f, ...]x[o, i, ...]->[b, f, ...] configuration is "
"supported";
return rewriter.notifyMatchFailure(convOp, msg);
}

if (convOp.getFeatureGroupCount() > 1 || convOp.getBatchGroupCount() > 1)
return rewriter.notifyMatchFailure(
convOp, "feature or batch grouping is not supported");

if (bnOperandShape[bnFeatureIndex] != convKernelShape.front())
return failure();

DenseFPElementsAttr convKernelElems;
DenseFPElementsAttr scaleElems;
DenseFPElementsAttr offsetElems;
DenseFPElementsAttr meanElems;
DenseFPElementsAttr varianceElems;

const auto epsilon = op.getEpsilon();

if (!matchPattern(convKernel, m_Constant(&convKernelElems)))
return rewriter.notifyMatchFailure(
op, "expected constant convolution kernel");

if (!matchPattern(op.getScale(), m_Constant(&scaleElems)) ||
!matchPattern(op.getOffset(), m_Constant(&offsetElems)) ||
!matchPattern(op.getMean(), m_Constant(&meanElems)) ||
!matchPattern(op.getVariance(), m_Constant(&varianceElems)))
return failure();

const auto &convKernelSemantics =
cast<FloatType>(convKernelType.getElementType()).getFloatSemantics();

// K.new = K.old * gamma * rsqrt(variance + epsilon)
// B.new = (B.old - mean) * rsqrt(variance + epsilon) * gamma + beta
// where: gamma - scaling factor
// beta - shifting factor
// rsqrt - reciprocal square root function
// K - kernel(a.k.a weight)
// B - bias
//
const SmallVector<double> multipliers = llvm::map_to_vector(
llvm::zip_equal(varianceElems, scaleElems),
[&epsilon](const std::tuple<APFloat, APFloat> &pack) -> double {
const auto &[variance, scale] = pack;
auto varEps = (variance + epsilon).convertToDouble();
auto rsqrt = 1.0 / std::sqrt(varEps);
return rsqrt * scale.convertToDouble();
});

SmallVector<APFloat> newKernel;
newKernel.reserve(convKernelType.getNumElements());

const size_t outFeatureTileSize =
computeProduct(convKernelShape.drop_front());
auto it = convKernelElems.begin();
for (const auto &multiplier : multipliers) {
for (size_t i = 0; i < outFeatureTileSize; ++i) {
double v = (*it).convertToDouble() * multiplier;
APFloat result(v);
bool losesInfo;
if (APFloat::opStatus::opInvalidOp ==
result.convert(convKernelSemantics, APFloat::rmNearestTiesToEven,
&losesInfo))
return failure();
newKernel.push_back(result);
++it;
}
}

SmallVector<APFloat> biasValues;
biasValues.reserve(multipliers.size());

for (const auto &[off, multiplier, mean] :
llvm::zip_equal(offsetElems, multipliers, meanElems)) {
// stablehlo convolution operation doesn't have a builtin bias
double convBias = 0;
double v = (convBias - mean.convertToDouble()) * multiplier +
off.convertToDouble();
APFloat result(v);

bool losesInfo;
if (APFloat::opStatus::opInvalidOp ==
result.convert(convKernelSemantics, APFloat::rmNearestTiesToEven,
&losesInfo))
return failure();

biasValues.push_back(result);
}

rewriter.setInsertionPoint(op);
auto newConvKernel = rewriter.create<ConstantOp>(
convKernel.getLoc(), convKernelType,
DenseFPElementsAttr::get(convKernelType, newKernel));

// Keep old convolution as it might have other users
auto newConvOp = rewriter.create<ConvolutionOp>(
convOp.getLoc(), convOp->getResultTypes(),
ValueRange{convOp.getLhs(), newConvKernel}, convOp->getAttrs());

SmallVector<int64_t> biasShape{static_cast<int64_t>(biasValues.size())};
auto biasType =
convKernelType.cloneWith(biasShape, convKernelType.getElementType());
auto bias = rewriter.create<ConstantOp>(
op.getLoc(), biasType, DenseFPElementsAttr::get(biasType, biasValues));

auto indices =
rewriter.getDenseI64ArrayAttr({static_cast<int64_t>(bnFeatureIndex)});
auto bcast = rewriter.create<BroadcastInDimOp>(op.getLoc(), bnResultType,
bias, indices);
auto add = rewriter.create<AddOp>(op.getLoc(), newConvOp, bcast);

rewriter.replaceOp(op, add);
return success();
}
};

struct StablehloAggressiveSimplificationPass final
: impl::StablehloAggressiveSimplificationPassBase<
StablehloAggressiveSimplificationPass> {
Expand Down Expand Up @@ -1513,6 +1664,8 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
patterns
->add<GetDimensionSizeOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);

patterns->add<FuseConvolutionBatchNormalization>(context);
}

} // namespace stablehlo
Expand Down
Loading