Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Create MLGraphBuilder for every model builder #21514

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
}

std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder_,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;
Expand All @@ -103,7 +103,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
const auto* node(graph_viewer.GetNode(node_idx));
bool supported = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) {
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
}
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder_,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const logging::Logger& logger);
static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
Expand Down Expand Up @@ -241,14 +241,14 @@
{"Where", {"where", true}},
};

inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_,
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,

Check warning on line 244 in onnxruntime/core/providers/webnn/builders/helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.h:244: Add #include <string> for string [build/include_what_you_use] [4]
const WebnnDeviceType device_type) {
// Returns false if the op_type is not listed in the op_map.
if (op_map.find(op_type) == op_map.end()) {
return false;
}
// Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (!wnn_builder_[op_map.find(op_type)->second.opName].as<bool>()) {
if (!wnn_builder[op_map.find(op_type)->second.opName].as<bool>()) {
return false;
}
// The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather
Expand Down
16 changes: 12 additions & 4 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@ namespace onnxruntime {
namespace webnn {

ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const emscripten::val& builder,
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type)
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type)
: graph_viewer_(graph_viewer),
logger_(logger),
wnn_context_(context),
wnn_builder_(builder),
preferred_layout_(preferred_layout),
wnn_device_type_(wnn_device_type) {}
wnn_device_type_(wnn_device_type) {
// Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build()
// is only allowed to be called once.
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context);
if (!wnn_builder_.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}
}

Status ModelBuilder::Initialize() {
PreprocessInitializers();
Expand Down Expand Up @@ -332,6 +338,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
if (!wnn_graph.as<bool>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
}
// Explicitly release the WebNN builder to free memory.
wnn_builder_ = emscripten::val::undefined();
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_));
model->SetInputs(std::move(input_names_));
model->SetOutputs(std::move(output_names_));
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class IOpBuilder;
class ModelBuilder {
public:
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const emscripten::val& builder,
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type);
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type);
~ModelBuilder() = default;

Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
Expand Down Expand Up @@ -62,8 +62,8 @@ class ModelBuilder {
const GraphViewer& graph_viewer_;
const logging::Logger& logger_;

emscripten::val wnn_context_ = emscripten::val::object();
emscripten::val wnn_builder_ = emscripten::val::object();
emscripten::val wnn_context_ = emscripten::val::undefined();
emscripten::val wnn_builder_ = emscripten::val::undefined();
DataLayout preferred_layout_;
WebnnDeviceType wnn_device_type_;
InlinedHashMap<std::string, emscripten::val> wnn_operands_;
Expand Down
21 changes: 7 additions & 14 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
if (!wnn_builder_.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}
}

WebNNExecutionProvider::~WebNNExecutionProvider() {}
Expand Down Expand Up @@ -81,14 +77,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view

const auto& logger = *GetLogger();

if (!wnn_builder_.as<bool>()) {
// The GetCapability function may be called again after Compile due to the logic in the
// PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc).
// We need to re-create the wnn_builder_ here to avoid it's been released in last Compile.
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
if (!wnn_builder.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}

const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger);
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger);
wnn_builder = emscripten::val::undefined();

if (node_groups.empty()) {
return result;
Expand Down Expand Up @@ -218,9 +213,10 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);

webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_,
wnn_builder_, preferred_layout_, wnn_device_type_);
preferred_layout_, wnn_device_type_);
std::unique_ptr<webnn::Model> model;
ORT_RETURN_IF_ERROR(builder.Compile(model));

// Build map from input name to its index in input definitions.
{
InlinedHashMap<std::string, size_t> input_map;
Expand Down Expand Up @@ -329,9 +325,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
node_compute_funcs.push_back(compute_info);
}

// Explicitly release the WebNN builder to free memory.
wnn_builder_ = emscripten::val::undefined();

return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class WebNNExecutionProvider : public IExecutionProvider {

private:
emscripten::val wnn_context_ = emscripten::val::undefined();
mutable emscripten::val wnn_builder_ = emscripten::val::undefined();

DataLayout preferred_layout_;
webnn::WebnnDeviceType wnn_device_type_;
Expand Down
Loading