diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 4233d8b212b7d..4d6cdeb7fc64a 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -118,6 +118,7 @@ static const InlinedHashMap op_map = { {"GlobalAveragePool", "averagePool2d"}, {"GlobalMaxPool", "maxPool2d"}, {"AveragePool", "averagePool2d"}, + {"LayerNormalization", "meanVarianceNormalization"}, {"MaxPool", "maxPool2d"}, {"ReduceMax", "reduceMax"}, {"ReduceMean", "reduceMean"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc new file mode 100644 index 0000000000000..724dc1ca1ab0a --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class NormalizationOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; +}; + +Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + ORT_RETURN_IF_NOT(input_defs.size() >= 2, "LayerNormalization requires at least two inputs."); + + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); + const auto rank = input_shape.size(); + + emscripten::val options = emscripten::val::object(); + + std::vector scale_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); + const auto scale_size = scale_shape.size(); + ORT_RETURN_IF_NOT(scale_size >= 1 && scale_size <= rank, "The scale size should be less than or equal to input size."); + + if (scale_size < rank) { + // Enlarge new shape to input.rank, right aligned with leading ones + scale_shape.insert(scale_shape.begin(), rank - scale_size, 1); + std::vector new_scale_shape; + std::transform(scale_shape.cbegin(), scale_shape.cend(), + std::back_inserter(new_scale_shape), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + emscripten::val reshape_scale = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val reshape_output_scale = + model_builder.GetBuilder().call("reshape", reshape_scale, emscripten::val::array(new_scale_shape)); + options.set("scale", reshape_output_scale); + } else { + options.set("scale", model_builder.GetOperand(input_defs[1]->Name())); + } + + if (input_defs.size() == 3) { + // Inputs contain optional bias + std::vector bias_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[2], bias_shape, logger), "Cannot get bias shape"); + const auto bias_size = bias_shape.size(); + ORT_RETURN_IF_NOT(bias_size >= 1 && bias_size <= rank, "The bias size should be less than or equal to input size."); + + if (bias_size < rank) { + // Enlarge new shape to input.rank, right aligned with leading ones + bias_shape.insert(bias_shape.begin(), rank - bias_size, 1); + std::vector new_bias_shape; + std::transform(bias_shape.cbegin(), bias_shape.cend(), + std::back_inserter(new_bias_shape), + [](int64_t dim) -> int32_t { return SafeInt(dim); }); + emscripten::val reshape_bias = model_builder.GetOperand(input_defs[2]->Name()); + emscripten::val reshape_output_bias = + model_builder.GetBuilder().call("reshape", reshape_bias, emscripten::val::array(new_bias_shape)); + options.set("bias", reshape_output_bias); + } else { + options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); + } + } + + NodeAttrHelper helper(node); + options.set("epsilon", helper.Get("epsilon", 1e-05f)); + + int64_t axis = helper.Get("axis", -1); + axis = HandleNegativeAxis(axis, rank); + std::vector axes{static_cast(axis)}; + options.set("axes", emscripten::val::array(axes)); + + emscripten::val output = model_builder.GetBuilder().call("meanVarianceNormalization", input, options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + + return Status::OK(); +} + +// Operator support related. + +bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 2) { + LOGS(logger, VERBOSE) << "LayerNormalization requires at least two inputs."; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get input shape."; + return false; + } + const auto rank = input_shape.size(); + + NodeAttrHelper helper(node); + int64_t axis = helper.Get("axis", -1); + axis = HandleNegativeAxis(axis, rank); + + const auto& scale_name = input_defs[1]->Name(); + if (!Contains(initializers, scale_name)) { + LOGS(logger, VERBOSE) << "The scale must be a constant initializer."; + return false; + } + + if (input_defs.size() == 3) { + // Inputs contain optional bias + const auto& bias_name = input_defs[2]->Name(); + if (!Contains(initializers, bias_name)) { + LOGS(logger, VERBOSE) << "The bias must be a constant initializer."; + return false; + } + } + + const auto& output_defs = node.OutputDefs(); + if (output_defs.size() != 1) { + LOGS(logger, VERBOSE) << "MeanVarianceNormalization output count must be one."; + return false; + } + + return true; +} + +void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 0745910594fbf..e4a7aa69f575a 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -81,6 +81,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateLogicalOpBuilder("Equal", op_registrations); } + { // LayerNormalization + CreateNormalizationOpBuilder("LayerNormalization", op_registrations); + } + { // Pool CreatePoolOpBuilder("GlobalAveragePool", op_registrations); CreatePoolOpBuilder("GlobalMaxPool", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index a4239edc9f3fb..40f96b4bfbde9 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -31,6 +31,7 @@ void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);