Skip to content

Commit

Permalink
[WebNN EP] Support LRN operator (microsoft#22775)
Browse files Browse the repository at this point in the history
WebNN doesn't provide dedicate op for LRN, use a couple of WebNN ops to
emulate it in WebNN EP:
pow -> transpose -> pad -> averagePool -> transpose -> mul -> add -> pow
-> div
@Honry @fdwr PTAL, thanks!
  • Loading branch information
miaobin authored and ankitm3k committed Dec 11, 2024
1 parent ee07444 commit d05f287
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 72 deletions.
11 changes: 5 additions & 6 deletions onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
int32_t input_data_type;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_data_type, logger), "Cannot get input type");
const auto input_data_type = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
const auto node_name = node.Name();
emscripten::val wnn_builder = model_builder.GetBuilder();
Expand All @@ -43,10 +42,10 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

// Prepare WebNN constants for alpha, beta, bias attributes.
// Assume T is float, because input_data_type has been limited to float32 and float16 in 'hasSupportedInitsImpl'.
emscripten::val alpha_constant = model_builder.CreateOrGetConstant<float>(input_data_type, alpha);
emscripten::val beta_constant = model_builder.CreateOrGetConstant<float>(input_data_type, beta);
emscripten::val bias_constant = model_builder.CreateOrGetConstant<float>(input_data_type, bias);
emscripten::val pow1_constant = model_builder.CreateOrGetConstant<float>(input_data_type, 2);
emscripten::val alpha_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, alpha);
emscripten::val beta_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, beta);
emscripten::val bias_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, bias);
emscripten::val pow1_constant = model_builder.CreateOrGetScalarConstant<float>(input_data_type, 2);

/**
WebNN doesn't support LRN. So decompose it into a series of ops:
Expand Down
108 changes: 42 additions & 66 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class ModelBuilder {
void AddOperand(const std::string& name, const emscripten::val& operand);
const emscripten::val& GetZeroConstant(
const int32_t& data_type, const std::vector<uint32_t>& shape = {});

template <typename T>
const emscripten::val& CreateOrGetScalarConstant(const int32_t& data_type, T value);

// Use the buffers to persist WebNN allocated data like transposed weight.
// It ensures the validity during inference session.
std::vector<std::unique_ptr<uint8_t[]>> mem_persist_buffers_;
Expand Down Expand Up @@ -100,12 +104,11 @@ class ModelBuilder {
static const IOpBuilder* GetOpBuilder(const Node& node);
};

// Create or retrieve one of the following:
// - A WebNN constant MLOperand filled with the specified value, data type, and shape.
// - A WebNN scalar constant MLOperand with the specified value and data type.
// For scalar constant, it is workaround for builer.constant(type, value) method since
// it has not been implemented now.
// Create a scalar constant MLOperand of the specified value and data type.
// Workaround for builer.constant(type, value) method since it has not been implemented now.
// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value
// BTW, the spec is discussing if the builder.constant(type, value) should be dropped at
// https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision.
//
// This function enforces a mapping between the data_type and the value types:
// - TensorProto_DataType_INT4 <-> int8_t
Expand All @@ -120,96 +123,69 @@ class ModelBuilder {
// - TensorProto_DataType_UINT32 <-> uint32_t
// - TensorProto_DataType_UINT64 <-> uint64_t
template <typename T>
const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_type, T value,
const std::vector<uint32_t>& shape) {
std::string name = "webnn_constant_" + std::to_string(data_type) + "_" + std::to_string(value);
emscripten::val dims = emscripten::val::array();
if (!shape.empty()) {
dims = emscripten::val::array(shape);
std::ostringstream name_stream;
name_stream << name;
for (const auto& dim : shape) {
name_stream << "_" << dim;
}
name = name_stream.str();
const emscripten::val& ModelBuilder::CreateOrGetScalarConstant(const int32_t& data_type, T value) {
std::string name = "webnn_scalar_constant_" + std::to_string(data_type) + "_" + std::to_string(value);
emscripten::val desc = emscripten::val::object();
desc.set("shape", emscripten::val::array());
emscripten::val scalar_buffer = emscripten::val::undefined();
uint16_t value_uint16 = 0;
uint8_t value_uint8 = 0;
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
}

// If the operand does not exist, create it.
if (wnn_operands_.find(name) == wnn_operands_.end()) {
emscripten::val desc = emscripten::val::object();
desc.set("shape", dims);
desc.set("dimensions", dims);
emscripten::val buffer = emscripten::val::undefined();
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
}
auto num_elements = Product(shape);
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
// For WebNN int4 and uint4 tensors are stored in Uint8Array,
// so we need to adjust the number of elements.
num_elements = (num_elements + 1) / 2;
buffer = emscripten::val::global("Uint8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(PackInt8ToUint8AsNibble(value, data_type)));
}
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
value_uint8 = PackInt8ToUint8AsNibble(value, data_type);
scalar_buffer.call<void>("fill", emscripten::val(value_uint8));
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value ? 1 : 0));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
buffer = emscripten::val::global("Uint8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
scalar_buffer = emscripten::val::global("Uint8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
buffer = emscripten::val::global("Int8Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
scalar_buffer = emscripten::val::global("Int8Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
buffer = emscripten::val::global("Uint16Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value)));
}
scalar_buffer = emscripten::val::global("Uint16Array").new_(1);
value_uint16 = PackFloat32ToUint16AsFloat16(value);
scalar_buffer.call<void>("fill", emscripten::val(value_uint16));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
buffer = emscripten::val::global("Float32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
scalar_buffer = emscripten::val::global("Float32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
buffer = emscripten::val::global("Int32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
scalar_buffer = emscripten::val::global("Int32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
buffer = emscripten::val::global("Uint32Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val(value));
}
scalar_buffer = emscripten::val::global("Uint32Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
buffer = emscripten::val::global("BigInt64Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
}
scalar_buffer = emscripten::val::global("BigInt64Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
buffer = emscripten::val::global("BigUint64Array").new_(num_elements);
if (value) {
buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
}
scalar_buffer = emscripten::val::global("BigUint64Array").new_(1);
scalar_buffer.call<void>("fill", emscripten::val::global("BigInt")(value));
break;
default:
break;
}

const emscripten::val constant = wnn_builder_.call<emscripten::val>("constant", desc, buffer);
wnn_operands_.insert(std::make_pair(name, constant));
const emscripten::val scalar_constant = wnn_builder_.call<emscripten::val>("constant", desc, scalar_buffer);
wnn_operands_.insert(std::make_pair(name, scalar_constant));
}

return wnn_operands_.at(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateLogicalOpBuilder("Xor", op_registrations);
}

{ // LRN
CreateLRNOpBuilder("LRN", op_registrations);
}

{ // LSTM
CreateLstmOpBuilder("LSTM", op_registrations);
}
Expand Down

0 comments on commit d05f287

Please sign in to comment.