Skip to content

Commit

Permalink
[WebNN] Support SkipSimplifiedLayerNormalization op
Browse files Browse the repository at this point in the history
The algorithm of SkipSimplifiedLayerNormalization is quite
similar to the SimplifiedLayerNormalization, only different
is SkipSimplifiedLayerNormalization provides an additional
output used for caculating the sum of the input, skip and
bias (if it exits).

BTW, fix a bug in SimplifiedLayerNormalization, adding bias
if it exits.
  • Loading branch information
Honry committed Dec 19, 2024
1 parent e76bd2f commit c8d9efc
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 23 deletions.
3 changes: 2 additions & 1 deletion js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements ||| Only supports 'reduction' == 'none' |
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND ||| Only supports 'reduction' == 'none' |
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice ||| |
| SimplifiedLayerNormalization | ai.onnx(1+) | pow + reduceMean + add + sqrt + div + mul ||| |
| SimplifiedLayerNormalization | com.microsoft(1+) | pow, reduceMean, add, sqrt, div, mul ||| |
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid ||| |
| Sign | ai.onnx(9-12, 13+) | sign ||| |
| SkipSimplifiedLayerNormalization | com.microsoft(1+) | pow, reduceMean, add, sqrt, div, mul ||| |
| Softplus | ai.onnx(7+) | softplus ||| |
| Softsign | ai.onnx(7+) | softsign ||| |
| Sin | ai.onnx(7+) | sin ||| |
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Softplus", "softplus"},
{"Softsign", "softsign"},
{"Sin", "sin"},
{"SkipSimplifiedLayerNormalization", "layerNormalization"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Split", "split"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
const auto& output_defs = node.OutputDefs();
ORT_RETURN_IF_NOT(input_defs.size() >= 2, op_type, " requires at least two inputs.");

emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
Expand All @@ -45,7 +46,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
options.set("label", node.Name());

std::vector<int64_t> scale_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape");
const size_t scale_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 2 : 1;
ORT_RETURN_IF_NOT(GetShape(*input_defs[scale_input_index], scale_shape, logger), "Cannot get scale shape");
const auto scale_size = scale_shape.size();
// Except LayerNormalization, other normalization ops' scale input should be 1-D.
if (op_type == "LayerNormalization") {
Expand All @@ -55,19 +57,17 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(scale_size == 1, "The scale size should be one.");
}

if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
emscripten::val scale = model_builder.GetOperand(input_defs[scale_input_index]->Name());
options.set("scale", scale);

const size_t bias_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 3 : 2;
emscripten::val bias = emscripten::val::undefined();
if (input_defs.size() > bias_input_index && input_defs[bias_input_index]->Exists()) {
// Bias input exists, and bias's shape should be the same as scale's shape.
std::vector<int64_t> bias_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[2], bias_shape, logger), "Cannot get bias shape");
ORT_RETURN_IF_NOT(GetShape(*input_defs[bias_input_index], bias_shape, logger), "Cannot get bias shape");
ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape.");
}

emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name());
options.set("scale", scale);

if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
emscripten::val bias = model_builder.GetOperand(input_defs[2]->Name());
bias = model_builder.GetOperand(input_defs[bias_input_index]->Name());
options.set("bias", bias);
}

Expand All @@ -76,6 +76,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
options.set("epsilon", epsilon);

emscripten::val output = emscripten::val::undefined();
// SkipSimplifiedLayerNormalization's output: input_skip_bias_sum.
emscripten::val input_skip_bias_sum = emscripten::val::undefined();
if (op_type == "BatchNormalization") {
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
Expand All @@ -85,7 +87,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
}

output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization" || op_type == "SimplifiedLayerNormalization") {
} else if (op_type == "LayerNormalization" ||
op_type == "SimplifiedLayerNormalization" ||
op_type == "SkipSimplifiedLayerNormalization") {
int64_t axis = helper.Get("axis", -1);
axis = HandleNegativeAxis(axis, rank);
std::vector<uint32_t> axes(rank - SafeInt<uint32_t>(axis));
Expand All @@ -94,13 +98,17 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
if (op_type == "LayerNormalization") {
options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("layerNormalization", input, options);
} else { // SimplifiedLayerNormalization
} else { // SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization
/**
WebNN doesn't support SimplifiedLayerNormalization. So decompose it into a series of ops:
X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul
^ ^ ^ ^ ^
| | | | |
Y:2 axis B:epsilon A:X A:scale
WebNN doesn't support SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.
So decompose it into a series of ops:
X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul -> Add (optional)
^ ^ ^ ^ ^ ^
| | | | | |
Y:2 axis B:epsilon A:X A:scale B:bias
If it is SkipSimplifiedLayerNormalization and its output input_skip_bias_sum exists,
input_skip_bias_sum = X + skip + bias (if it exists)
*/

int32_t input_type;
Expand Down Expand Up @@ -137,6 +145,25 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
// Mul
common_options.set("label", node.Name() + "_mul");
output = model_builder.GetBuilder().call<emscripten::val>("mul", scale, div, common_options);

// Add (if bias exits)
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_bias");
output = model_builder.GetBuilder().call<emscripten::val>("add", output, bias, common_options);
}

// SkipSimplifiedLayerNormalization's output input_skip_bias_sum is the sum of input, skip, and bias.
if (op_type == "SkipSimplifiedLayerNormalization" && output_defs.size() > 3 && output_defs[3]->Exists()) {
emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name());
common_options.set("label", node.Name() + "_add_skip");
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>("add", input, skip, common_options);
if (!bias.isUndefined()) {
common_options.set("label", node.Name() + "_add_skip_bias");
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>(
"add", input_skip_bias_sum, bias, common_options);
}
model_builder.AddOperand(output_defs[3]->Name(), std::move(input_skip_bias_sum));
}
}
} else if (op_type == "InstanceNormalization") {
// WebNN spec only supports 4D input for instanceNormalization.
Expand Down Expand Up @@ -188,7 +215,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
model_builder.AddOperand(output_defs[0]->Name(), std::move(output));

Check warning on line 218 in onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc:218: Add #include <utility> for move [build/include_what_you_use] [4]

return Status::OK();
}
Expand All @@ -215,9 +242,19 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
}

const auto& output_defs = node.OutputDefs();
if (output_defs.size() != 1) {
LOGS(logger, VERBOSE) << op_type << " output count must be one.";
return false;
if (op_type == "SkipSimplifiedLayerNormalization") {
for (size_t i = 1; i < output_defs.size(); i++) {
if (output_defs[i]->Exists() && i < 3) {
// Output mean and inv_std_var are used for training mode, which is not supported.
const auto output_name = i == 1 ? "mean" : "inv_std_var";
LOGS(logger, VERBOSE) << "SkipSimplifiedLayerNormalization's output: " << output_name << " is not supported.";
}
}
} else {
if (output_defs.size() != 1) {
LOGS(logger, VERBOSE) << op_type << " output count must be one.";
return false;
}
}

if (op_type == "BatchNormalization" && helper.Get("training_mode", 0)) {
Expand Down Expand Up @@ -277,6 +314,7 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat
"InstanceNormalization",
"LayerNormalization",
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
};

op_registrations.builders.push_back(std::make_unique<NormalizationOpBuilder>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateNormalizationOpBuilder("InstanceNormalization", op_registrations);
CreateNormalizationOpBuilder("LayerNormalization", op_registrations);
CreateNormalizationOpBuilder("SimplifiedLayerNormalization", op_registrations);
CreateNormalizationOpBuilder("SkipSimplifiedLayerNormalization", op_registrations);
}

{ // Pad
Expand Down

0 comments on commit c8d9efc

Please sign in to comment.