diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 1a0fdfdd2365b..321f3b3d7e52d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -69,17 +69,16 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We } } -bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) { - const auto& input_name = input.Name(); - const auto* shape_proto = input.Shape(); +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) { + const auto& node_arg_name = node_arg.Name(); + const auto* shape_proto = node_arg.Shape(); // Optional tensors can be indicated by an empty name, just ignore it. - if (input_name.empty()) { + if (node_arg_name.empty()) { return true; } - // We do not support input with no shape. + // We do not support input/output with no shape. if (!shape_proto) { - LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name - << "] has not shape"; + LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape"; return false; } @@ -87,12 +86,11 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons // WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape. if (!dim.has_dim_value()) { LOGS(logger, VERBOSE) << "Dynamic shape is not supported, " - << "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: " - << input_name; + << "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name; return false; } if (dim.dim_value() == 0) { - LOGS(logger, VERBOSE) << "The shape of [" << input_name << "] has 0 dimension which is not supported by WebNN"; + LOGS(logger, VERBOSE) << "The shape of [" << node_arg_name << "] has 0 dimension which is not supported by WebNN"; return false; } } @@ -106,13 +104,6 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const emscripten::val& wnn_limits, const logging::Logger& logger) { std::vector> supported_node_groups; - - for (const auto* input : graph_viewer.GetInputs()) { - if (!IsInputSupported(*input, "graph", logger)) { - return supported_node_groups; - } - } - std::vector supported_node_group; const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index ccc8394349c38..7b4a77336562f 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } -bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); +bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger); // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index d9a3a8a6ff1b3..704896e43a7e1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -32,7 +32,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons if (!HasSupportedInputs(node, wnn_limits, logger)) return false; - if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) + if (!HasSupportedOutputs(node, wnn_limits, logger)) return false; if (!HasSupportedOpSet(node, logger)) @@ -45,7 +45,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsInputSupported(*input, node_name, logger)) { + if (!IsTensorShapeSupported(*input, node_name, logger)) { return false; } } @@ -66,6 +66,18 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger); } +bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); + for (const auto* output : node.OutputDefs()) { + if (!IsTensorShapeSupported(*output, node_name, logger)) { + return false; + } + } + + return HasSupportedOutputsImpl(node, wnn_limits, logger); +} + bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 584455f62cb4e..a632876dab2b9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -54,6 +54,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 8bcd3621b5d4b..8a82fce42189d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -227,7 +227,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (!shape.empty()) { dims.reserve(shape.size()); for (const auto& dim : shape) { - // dim_param free dimensions should have already been excluded by IsInputSupported(). + // dim_param free dimensions should have already been excluded by IsTensorShapeSupported(). assert(dim.has_dim_value()); dims.push_back(SafeInt(dim.dim_value())); }