Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layernorm: convert instance norm and group norm to layer norm. #2595

Merged
merged 69 commits into from
Nov 9, 2023

Conversation

AlexandreEichenberger
Copy link
Collaborator

Layer norm is a superset of previous InstanceNormalization and GroupNormalization.

For instance norm:

    %0 = "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32} : (tensor<2x3x4x5x6xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<2x3x4x5x6xf32>

becomes

  %0 = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64>
  %1 = "onnx.Unsqueeze"(%arg1, %0) : (tensor<3xf32>, tensor<3xi64>) -> tensor<3x1x1x1xf32>
  %2 = "onnx.Unsqueeze"(%arg2, %0) : (tensor<3xf32>, tensor<3xi64>) -> tensor<3x1x1x1xf32>
  %Y, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %1, %2) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<2x3x4x5x6xf32>, tensor<3x1x1x1xf32>, tensor<3x1x1x1xf32>) -> (tensor<2x3x4x5x6xf32>, none, none)

For group norm:

    %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x6x8x16xf32>

becomes

  %0 = onnx.Constant dense<[3, 4, 6, 8, 16]> : tensor<5xi64>
  %1 = onnx.Constant dense<[3, 2, -1, 6, 8, 16]> : tensor<6xi64>
  %2 = onnx.Constant dense<[1, 2, 3, 4]> : tensor<4xi64>
  %3 = "onnx.Unsqueeze"(%arg1, %2) : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x1x1x1x1xf32>
  %4 = "onnx.Unsqueeze"(%arg2, %2) : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x1x1x1x1xf32>
  %5 = "onnx.Reshape"(%arg0, %1) {allowzero = 0 : si64} : (tensor<3x4x6x8x16xf32>, tensor<6xi64>) -> tensor<3x2x2x6x8x16xf32>
  %Y, %Mean, %InvStdDev = "onnx.LayerNormalization"(%5, %3, %4) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<2x1x1x1x1xf32>, tensor<2x1x1x1x1xf32>) -> (tensor<3x2x2x6x8x16xf32>, none, none)
  %6 = "onnx.Reshape"(%Y, %0) {allowzero = 0 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<5xi64>) -> tensor<3x4x6x8x16xf32>

AlexandreEichenberger and others added 30 commits October 13, 2023 15:31
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
@AlexandreEichenberger
Copy link
Collaborator Author

@philass @sorenlassen : this transformation is done at decompose, and at this time, all instance and group norms are switched to layer norm.

Let me know if you want a switch.

Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add into inference_backend.py backend tests for GroupNormalization?

assert(C % numGroups == 0 && "expected numGroups to divide C");
layerNormShapeVal.emplace_back(C / numGroups);
} else
layerNormShapeVal.emplace_back(-1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be ShapedType::kDynamic instead of -1? MLIR is no longer using -1 for dynamic dimension

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

Type biasScaleType = RankedTensorType::get(biasScaleShape, elementType);
Value newScale = create.onnx.unsqueeze(biasScaleType, scale, axes);
Value newBias = create.onnx.unsqueeze(biasScaleType, bias, axes);
// Convert input from N x C x D1...Dn to N x (NG x C/NG) x D1...Dn.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle special cases differently for the following cases? i.e. not need to split C into groups.

When the number of groups is the same as the number of channels, this operator is equivalent to InstanceNormalization. When there is only one group, this operator is equivalent to LayerNormalization.

or we finally get the same performance even though we split C?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I see, the consecutive reshape get optimized away, so there is no need to special case these two corner cases.

return IntegerAttr::get(b().getIntegerType(64, /*isSigned=*/true),
APInt(64, n, /*isSigned=*/true));
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attribute is independent of ONNX. Perhaps, we should have a AttrBuilder that can be used with different dialects. This can be done by another PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was manually copied over and over again; got tired of it so I made a local private function. Would you like it to be in the MLIR dialect? If you want a new attribute builder, let's discuss which other functions you may want to have there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, MLIR dialect looks ok.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tungld, I decided against putting it into mlir builder, as other wise I need to build that other builder each time I simply want this. If you don't like this, I can put it as a static function (not part of the class) as a simple helper, or remove it all together.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's have another PR to deal with Attribute since this is not the main issue in this patch. Another candidate is something like Support/TypeUtilities.hpp that provides utility functions about type.

PartialSpecified_FullySpecified, // Flattened to 2D.
FullySpecified_Scalar, // Flattened to 2D.
FullySpecified_FullySpecified // Flattened to 2D.
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a worth extension, the output is more specific. Thanks!

Copy link
Member

@sorenlassen sorenlassen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

I didn't review the ONNXToKrnl stuff but the rest looks good to me

PatternRewriter &rewriter) const final {
// Match.
Value input = instanceNormOp.getInput();
if (!input.getType().isa<ShapedType>())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also fail if input has no rank, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion, included in the next batch of changes/

auto inputShape = inputType.getShape();
int64_t C = inputShape[1];
int64_t inputRank = inputType.getRank();
assert(inputRank > 2 && "expected instance norm with input ranks > 2");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conceptually, is the "magic number" 2 here the same as the 2 in axis in line 775? consider naming the constant 2 in a way so that you can refer to it by name, both in InstanceNormIntoLayerNormPattern and GroupNormIntoLayerNormPattern where the number 2 appears multiple times too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea.

@@ -28,6 +28,11 @@ namespace onnx_mlir {

//====-------------------------- ONNX Builder ---------------------------===//

IntegerAttr OnnxBuilder::getSignedInt64Attr(int64_t n) const {
return IntegerAttr::get(b().getIntegerType(64, /*isSigned=*/true),
APInt(64, n, /*isSigned=*/true));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just pass n as 2nd arg to IntegerAttr::get, no need for APInt

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, thanks.

@AlexandreEichenberger
Copy link
Collaborator Author

This PR uncovered issue #2601

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@tungld
Copy link
Collaborator

tungld commented Nov 2, 2023

@AlexandreEichenberger does this issue #2601 block this patch? I am working on a fix for that issue, and it'll take a bit of time.

@AlexandreEichenberger
Copy link
Collaborator Author

AlexandreEichenberger commented Nov 2, 2023

does this issue #2601 block this patch?

It does, but there is no rush to this patch.

@AlexandreEichenberger AlexandreEichenberger merged commit 1c13ecf into onnx:main Nov 9, 2023
5 checks passed
@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #13346 [push] Layernorm: convert insta... started at 17:30

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #12339 [push] Layernorm: convert insta... started at 17:38

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #13321 [push] Layernorm: convert insta... started at 16:30

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #13346 [push] Layernorm: convert insta... passed after 1 hr 50 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #12339 [push] Layernorm: convert insta... passed after 2 hr 4 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #13321 [push] Layernorm: convert insta... failed after 3 hr 55 min

cjvolzka pushed a commit to cjvolzka/onnx-mlir that referenced this pull request Nov 15, 2023
* detect LayerNorm in presence of reciprocal and div of 1 (onnx#2609)

Signed-off-by: Alexandre Eichenberger <[email protected]>

* [NNPA] Use F16 as element type for zTensor (onnx#2611)

* Use f16 as element type for zTensor

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>

* Layernorm: convert instance norm and group norm to layer norm. (onnx#2595)

Signed-off-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>

* Parse and set --mcpu in onnx-mlir-opt command (onnx#2614)

Signed-off-by: Tung D. Le <[email protected]>

* Update sqrt.mlir

* Update sqrt.mlir

* Update invsqrt.mlir

* Update invsqrt.mlir

* Update invsqrt.mlir

* Update invsqrt.mlir

Co-authored-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>
Co-authored-by: C-P2PN897 <[email protected]>
cjvolzka added a commit to cjvolzka/onnx-mlir that referenced this pull request Nov 15, 2023
* detect LayerNorm in presence of reciprocal and div of 1 (onnx#2609)

Signed-off-by: Alexandre Eichenberger <[email protected]>

* [NNPA] Use F16 as element type for zTensor (onnx#2611)

* Use f16 as element type for zTensor

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>

* Layernorm: convert instance norm and group norm to layer norm. (onnx#2595)

Signed-off-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>

* Parse and set --mcpu in onnx-mlir-opt command (onnx#2614)

Signed-off-by: Tung D. Le <[email protected]>

* Import dim_param for model inputs and outputs (onnx#2616)

* Import dim_param for model inputs and outputs
* use argument attributes

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>

* [DialectBuilder] add builder funcrions for ONNXSumOp and ONNXConvOp (onnx#2572)

The DialectBuilder class seems to be missing the function create the
ONNXSumOp and ONNXConOp nodes and check their shape.  This patch adds
the necessary functions.

Signed-off-by: Ashay Rane <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>

* [StableHLO] Lowers PadOp (constant mode) & GatherElements Op to StableHLO (onnx#2602)

* [Stablehlo] Pad constant mode & GatherElements to Stablehlo

Signed-off-by: chongsong.chen <[email protected]>
Signed-off-by: Yan Xu <[email protected]>
Co-authored-by: chongsong.chen <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>

* [build] Add cmake option to enable/disable Java components build (onnx#2613)

* Add ONNX_MLIR_ENABLE_JAVA cmake option (default TRUE)

Signed-off-by: Boyana Norris <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>

Co-authored-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>
Co-authored-by: Ashay Rane <[email protected]>
Co-authored-by: Yan Xu <[email protected]>
Co-authored-by: chongsong.chen <[email protected]>
Co-authored-by: Boyana Norris <[email protected]>
cjvolzka added a commit to cjvolzka/onnx-mlir that referenced this pull request Nov 15, 2023
* 'main' of github.ibm.com:zosdev/onnx-mlir:
  Use dim_params in dynamic dimension analysis (onnx#2620)
  Update rapidcheck to include the fix for missing <cstdint> include (onnx#2623)
  Initial changes for llvm uplift (onnx#2568)
  [build] Add cmake option to enable/disable Java components build (onnx#2613)
  [StableHLO] Lowers PadOp (constant mode) & GatherElements Op to StableHLO (onnx#2602)
  [DialectBuilder] add builder funcrions for ONNXSumOp and ONNXConvOp (onnx#2572)
  Import dim_param for model inputs and outputs (onnx#2616)
  Parse and set --mcpu in onnx-mlir-opt command (onnx#2614)
  Layernorm: convert instance norm and group norm to layer norm. (onnx#2595)
  [NNPA] Use F16 as element type for zTensor (onnx#2611)
  detect LayerNorm in presence of reciprocal and div of 1 (onnx#2609)

# Conflicts:
#	test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants