Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse batch normalization into convolution kernel #2629

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mvpant
Copy link
Contributor

@mvpant mvpant commented Nov 18, 2024

This introduces a simplification that merges the batch normalization inference operation with convolution kernel (a.k.a. weight). The key idea is that while the batch normalization parameters change during the training phase, but remain constant during inference. This means that the convolution kernel can be adjusted to incorporate the effects of batch normalization. This optimization is applied by default to the ResNet model in the ONNX framework.

It performs the following transformation:

X = conv(input, kernel.old)
Y = batch_norm_inference(X, ...)

into

X = conv(input, kernel.new)
Y = add(X, broadcast_in_dim(bias.new))

using following calculations:

K.new = K.old * gamma * rsqrt(variance + epsilon)
B.new = (B.old - mean) * rsqrt(variance + epsilon) * gamma + beta
where: 
    gamma - scaling factor
    beta - shifting factor
    rsqrt - reciprocal square root function
    K - kernel(a.k.a weight)
    B - bias

Similar optimization can be found in PyTorch:
https://github.com/pytorch/pytorch/blob/main/torch/nn/utils/fusion.py#L56

@mvpant
Copy link
Contributor Author

mvpant commented Nov 18, 2024

Regarding terminology, what is preferred in StableHLO for convolution rhs: kernel or weight?

@GleasonK
Copy link
Member

The key idea is that while the batch normalization parameters change during the training phase, but remain constant during inference. This means that the convolution kernel can be adjusted to incorporate the effects of batch normalization.

Is this to say - during training these values won't be constant ops, and this pattern won't apply, but during inference it will? This seems reasonable. Overall certainly interested in growing the set of patterns available in the StableHLO repo.

We've discussed before that we'll need a way to adjust the knobs in terms of what patterns get applied, and that's a problem I plan to take on early next year. In the meantime, probably fine to have this pattern in this pass. If we decided it wasn't desirable on the default path, we can always make this it's own pass.

Regarding terminology, what is preferred in StableHLO for convolution rhs: kernel or weight?

cc @ghpvnist regarding the terminology question, any preference from a spec perspective?

@ghpvnist
Copy link
Member

I like kernel but both are equally well understood imo, so up to the code author :) Since this isn't affecting the spec, anything works!

@mvpant
Copy link
Contributor Author

mvpant commented Nov 19, 2024

Is this to say - during training these values won't be constant ops, and this pattern won't apply, but during inference it will? This seems reasonable.

Yes, I assume that’s why there are several operations like stablehlo.batch_norm_grad, stablehlo.batch_norm_inference, and stablehlo.batch_norm_training. The stablehlo.batch_norm_inference is designed to be used during the inference phase, normalizing input data using the statistics computed during training.

@abhigunj abhigunj added the Transformations Pertaining to MLIR passes and transformations label Dec 6, 2024
@GleasonK
Copy link
Member

GleasonK commented Dec 9, 2024

This fell off my radar a few weeks back - That all makes sense! Pattern LGTM if we can make the test file more targeted / shorter!

@mvpant
Copy link
Contributor Author

mvpant commented Dec 10, 2024

This fell off my radar a few weeks back - That all makes sense! Pattern LGTM if we can make the test file more targeted / shorter!

Sorry for the lack of updates, been a bit swamped lately. Not sure how to make test shorter. Started by taking the kernel/weight from the first layer of the ResNet model (probably resnet18) in ONNX as my expected data. Then I took a random picture and ran it through the ONNX Runtime, compiled with debug flags, to dump the input and output data from that layer for the current test case.

The goal is to see if the results from fused operators and the simplified batch normalization operations (according to the spec) match up. The problem is that the interpreter is running slower than I expected, so I cut down the input, expected output, and weights data (using stablehlo.slice and applying folding patterns to preserve the initial idea) to make it less CPU-intensive. But it’s still too slow

I think I can trim it down even more.

Also, I believe this requires a few tests to check which convolution configurations are currently supported.

@GleasonK
Copy link
Member

But it’s still too slow

I need to figure out why bazel builds are so much slower than cmake..this test only took a few seconds on cmake. At a bare minimum I'll figure out a way to tag tests as large and not run the bazel CI for them.

I didn't notice that this test was in testdata, that's totally fine to have more practical "exported from X model" tests there! It's actually probably best to have something roughly testing numerics in testdata.

I'm thinking about unit tests, i.e. stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir. A few things that test the structural changes from the comment, could use all ones and zeros / garbage data if that's feasible:

// X = conv(input, weight)
// Y = batch_norm_inference(X, ...)
// into ->
// X = conv(input, weight(new))
// Y = add(X, broadcast_in_dim(Bias(new)))

@GleasonK
Copy link
Member

Made the following PR which lets testdata tests to use the suffix .large.mlir to extend the timeout for the file

#2671

@mvpant
Copy link
Contributor Author

mvpant commented Dec 17, 2024

I'm thinking about unit tests, i.e. stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir. A few things that test the structural changes from the comment, could use all ones and zeros / garbage data if that's feasible:

// X = conv(input, weight)
// Y = batch_norm_inference(X, ...)
// into ->
// X = conv(input, weight(new))
// Y = add(X, broadcast_in_dim(Bias(new)))

Yes, i agree that it should be fine to use dummy data as we interested in transformations.

Made the following PR which lets testdata tests to use the suffix .large.mlir to extend the timeout for the file

Cool. I`ll try to finish up this pull request..

@mvpant mvpant marked this pull request as ready for review December 25, 2024 13:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Transformations Pertaining to MLIR passes and transformations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants