diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 301d795912443..ae7488c035834 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -94,6 +94,7 @@ static const InlinedHashMap op_map = { {"Div", "div"}, {"Pow", "pow"}, {"Cos", "cos"}, + {"Equal", "equal"}, {"Erf", "erf"}, {"Not", "logicalNot"}, {"Floor", "floor"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc new file mode 100644 index 0000000000000..58747797495b1 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" +#include "core/providers/webnn/builders/helper.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class LogicalOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& /* logger */) const { + const auto& op_type = node.OpType(); + emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); + emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val output = emscripten::val::object(); + if (op_type == "Equal") { + output = model_builder.GetBuilder().call("equal", input0, input1); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend()) + return; + + static std::vector op_types = + { + "Equal", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const logging::Logger& logger) const { + const auto& name = node.Name(); + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 2) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " + << input_defs.size(); + return false; + } + return true; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 82a45719ec1c3..a98b82ca24a30 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -77,6 +77,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGemmOpBuilder("MatMul", op_registrations); } + { // Logical + CreateLogicalOpBuilder("Equal", op_registrations); + } + { // Pool CreatePoolOpBuilder("GlobalAveragePool", op_registrations); CreatePoolOpBuilder("GlobalMaxPool", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 8f8299e5138a4..f95085d9200b6 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -39,6 +39,7 @@ void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace webnn } // namespace onnxruntime