From c8d9efc057340fbcb8dfb81d9aba6948c42b23e4 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 19 Dec 2024 16:26:09 +0800 Subject: [PATCH] [WebNN] Support SkipSimplifiedLayerNormalization op The algorithm of SkipSimplifiedLayerNormalization is quite similar to the SimplifiedLayerNormalization, only different is SkipSimplifiedLayerNormalization provides an additional output used for caculating the sum of the input, skip and bias (if it exits). BTW, fix a bug in SimplifiedLayerNormalization, adding bias if it exits. --- js/web/docs/webnn-operators.md | 3 +- .../core/providers/webnn/builders/helper.h | 1 + .../builders/impl/normalization_op_builder.cc | 82 ++++++++++++++----- .../webnn/builders/op_builder_factory.cc | 1 + 4 files changed, 64 insertions(+), 23 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index af7348dba532f..636767e506a6f 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -89,9 +89,10 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements | ✗ | ✓ | Only supports 'reduction' == 'none' | | ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | -| SimplifiedLayerNormalization | ai.onnx(1+) | pow + reduceMean + add + sqrt + div + mul | ✓ | ✓ | | +| SimplifiedLayerNormalization | com.microsoft(1+) | pow, reduceMean, add, sqrt, div, mul | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | | Sign | ai.onnx(9-12, 13+) | sign | ✓ | ✓ | | +| SkipSimplifiedLayerNormalization | com.microsoft(1+) | pow, reduceMean, add, sqrt, div, mul | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | | | Softsign | ai.onnx(7+) | softsign | ✓ | ✓ | | | Sin | ai.onnx(7+) | sin | ✓ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index a06f46f1bdf0a..c68653ae8ebcb 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -278,6 +278,7 @@ static const InlinedHashMap op_map = { {"Softplus", "softplus"}, {"Softsign", "softsign"}, {"Sin", "sin"}, + {"SkipSimplifiedLayerNormalization", "layerNormalization"}, {"Slice", "slice"}, {"Softmax", "softmax"}, {"Split", "split"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 50e49884bdfa9..bb5c493e00f75 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -34,6 +34,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); + const auto& output_defs = node.OutputDefs(); ORT_RETURN_IF_NOT(input_defs.size() >= 2, op_type, " requires at least two inputs."); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); @@ -45,7 +46,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder options.set("label", node.Name()); std::vector scale_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); + const size_t scale_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 2 : 1; + ORT_RETURN_IF_NOT(GetShape(*input_defs[scale_input_index], scale_shape, logger), "Cannot get scale shape"); const auto scale_size = scale_shape.size(); // Except LayerNormalization, other normalization ops' scale input should be 1-D. if (op_type == "LayerNormalization") { @@ -55,19 +57,17 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(scale_size == 1, "The scale size should be one."); } - if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { + emscripten::val scale = model_builder.GetOperand(input_defs[scale_input_index]->Name()); + options.set("scale", scale); + + const size_t bias_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 3 : 2; + emscripten::val bias = emscripten::val::undefined(); + if (input_defs.size() > bias_input_index && input_defs[bias_input_index]->Exists()) { // Bias input exists, and bias's shape should be the same as scale's shape. std::vector bias_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[2], bias_shape, logger), "Cannot get bias shape"); + ORT_RETURN_IF_NOT(GetShape(*input_defs[bias_input_index], bias_shape, logger), "Cannot get bias shape"); ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); - } - - emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); - options.set("scale", scale); - - if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { - // Bias input exists, and bias's shape is the same as scale's shape. - emscripten::val bias = model_builder.GetOperand(input_defs[2]->Name()); + bias = model_builder.GetOperand(input_defs[bias_input_index]->Name()); options.set("bias", bias); } @@ -76,6 +76,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder options.set("epsilon", epsilon); emscripten::val output = emscripten::val::undefined(); + // SkipSimplifiedLayerNormalization's output: input_skip_bias_sum. + emscripten::val input_skip_bias_sum = emscripten::val::undefined(); if (op_type == "BatchNormalization") { ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs."); emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name()); @@ -85,7 +87,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder } output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); - } else if (op_type == "LayerNormalization" || op_type == "SimplifiedLayerNormalization") { + } else if (op_type == "LayerNormalization" || + op_type == "SimplifiedLayerNormalization" || + op_type == "SkipSimplifiedLayerNormalization") { int64_t axis = helper.Get("axis", -1); axis = HandleNegativeAxis(axis, rank); std::vector axes(rank - SafeInt(axis)); @@ -94,13 +98,17 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder if (op_type == "LayerNormalization") { options.set("axes", emscripten::val::array(axes)); output = model_builder.GetBuilder().call("layerNormalization", input, options); - } else { // SimplifiedLayerNormalization + } else { // SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization /** - WebNN doesn't support SimplifiedLayerNormalization. So decompose it into a series of ops: - X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul - ^ ^ ^ ^ ^ - | | | | | - Y:2 axis B:epsilon A:X A:scale + WebNN doesn't support SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization. + So decompose it into a series of ops: + X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul -> Add (optional) + ^ ^ ^ ^ ^ ^ + | | | | | | + Y:2 axis B:epsilon A:X A:scale B:bias + + If it is SkipSimplifiedLayerNormalization and its output input_skip_bias_sum exists, + input_skip_bias_sum = X + skip + bias (if it exists) */ int32_t input_type; @@ -137,6 +145,25 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder // Mul common_options.set("label", node.Name() + "_mul"); output = model_builder.GetBuilder().call("mul", scale, div, common_options); + + // Add (if bias exits) + if (!bias.isUndefined()) { + common_options.set("label", node.Name() + "_add_bias"); + output = model_builder.GetBuilder().call("add", output, bias, common_options); + } + + // SkipSimplifiedLayerNormalization's output input_skip_bias_sum is the sum of input, skip, and bias. + if (op_type == "SkipSimplifiedLayerNormalization" && output_defs.size() > 3 && output_defs[3]->Exists()) { + emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name()); + common_options.set("label", node.Name() + "_add_skip"); + input_skip_bias_sum = model_builder.GetBuilder().call("add", input, skip, common_options); + if (!bias.isUndefined()) { + common_options.set("label", node.Name() + "_add_skip_bias"); + input_skip_bias_sum = model_builder.GetBuilder().call( + "add", input_skip_bias_sum, bias, common_options); + } + model_builder.AddOperand(output_defs[3]->Name(), std::move(input_skip_bias_sum)); + } } } else if (op_type == "InstanceNormalization") { // WebNN spec only supports 4D input for instanceNormalization. @@ -188,7 +215,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); } - model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + model_builder.AddOperand(output_defs[0]->Name(), std::move(output)); return Status::OK(); } @@ -215,9 +242,19 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi } const auto& output_defs = node.OutputDefs(); - if (output_defs.size() != 1) { - LOGS(logger, VERBOSE) << op_type << " output count must be one."; - return false; + if (op_type == "SkipSimplifiedLayerNormalization") { + for (size_t i = 1; i < output_defs.size(); i++) { + if (output_defs[i]->Exists() && i < 3) { + // Output mean and inv_std_var are used for training mode, which is not supported. + const auto output_name = i == 1 ? "mean" : "inv_std_var"; + LOGS(logger, VERBOSE) << "SkipSimplifiedLayerNormalization's output: " << output_name << " is not supported."; + } + } + } else { + if (output_defs.size() != 1) { + LOGS(logger, VERBOSE) << op_type << " output count must be one."; + return false; + } } if (op_type == "BatchNormalization" && helper.Get("training_mode", 0)) { @@ -277,6 +314,7 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat "InstanceNormalization", "LayerNormalization", "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 6d1c572128b93..e0ca50a36dbf9 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -159,6 +159,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); CreateNormalizationOpBuilder("LayerNormalization", op_registrations); CreateNormalizationOpBuilder("SimplifiedLayerNormalization", op_registrations); + CreateNormalizationOpBuilder("SkipSimplifiedLayerNormalization", op_registrations); } { // Pad