Skip to content

Commit

Permalink
add support for onnx equal (microsoft#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
guschmue authored May 19, 2023
1 parent e23e593 commit 958e800
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 0 deletions.
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Div", "div"},
{"Pow", "pow"},
{"Cos", "cos"},
{"Equal", "equal"},
{"Erf", "erf"},
{"Not", "logicalNot"},
{"Floor", "floor"},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include <core/providers/common.h>

#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"
#include "core/providers/webnn/builders/helper.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class LogicalOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) const override;
};

// Add operator related.

Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& /* logger */) const {
const auto& op_type = node.OpType();
emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name());
emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name());
emscripten::val output = emscripten::val::object();
if (op_type == "Equal") {
output = model_builder.GetBuilder().call<emscripten::val>("equal", input0, input1);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}

void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.find(op_type) != op_registrations.op_builder_map.cend())
return;

static std::vector<std::string> op_types =
{
"Equal",
};

op_registrations.builders.push_back(std::make_unique<LogicalOpBuilder>());
for (const auto& type : op_types) {
op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get());
}
}

bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) const {
const auto& name = node.Name();
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
if (input_defs.size() < 2) {
LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: "
<< input_defs.size();
return false;
}
return true;
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateGemmOpBuilder("MatMul", op_registrations);
}

{ // Logical
CreateLogicalOpBuilder("Equal", op_registrations);
}

{ // Pool
CreatePoolOpBuilder("GlobalAveragePool", op_registrations);
CreatePoolOpBuilder("GlobalMaxPool", op_registrations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations&
void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

} // namespace webnn
} // namespace onnxruntime

0 comments on commit 958e800

Please sign in to comment.