Skip to content

Commit

Permalink
Support Int32, Uint32 tensor data type (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry authored May 18, 2023
1 parent ef537ad commit a99bf1a
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 6 deletions.
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn
return op_map.find(op_type) != op_map.end() && wnn_builder_[op_map.find(op_type)->second].as<bool>();
}

constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 4> supported_data_types = {
constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 6> supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
};

bool IsSupportedDataType(int32_t data_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
operand_type = "float32";
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
operand_type = "int32";
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
operand_type = "int64";
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
operand_type = "uint32";
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The Cast node has unsupported 'to' type, name: ",
Expand Down Expand Up @@ -79,7 +85,7 @@ bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
if (!IsSupportedDataType(to_type)) {
LOGS(logger, VERBOSE) << "Invalid cast to type " << to_type
<< " . Current WebNN only support cast to bool, float32, float16 or int64.";
<< " . Current WebNN only support cast to bool, float32, float16, int32, int64 or uint32.";
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,15 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
element_size = sizeof(float);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
element_size = sizeof(int32_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
element_size = sizeof(int64_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
element_size = sizeof(uint32_t);
break;
default:
break;
}
Expand Down
34 changes: 30 additions & 4 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,18 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const float*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const int32_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const int64_t*>(tensor.buffer))};
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint32_t*>(tensor.buffer))};
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The input of graph has unsupported type, name: ",
Expand Down Expand Up @@ -88,11 +95,18 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const float*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const int32_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const int64_t*>(tensor.buffer))};
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint32_t*>(tensor.buffer))};
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The output of graph has unsupported type, name: ",
Expand Down Expand Up @@ -150,9 +164,15 @@ void Model::AllocateInputOutputBuffers() {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
wnn_inputs_.set(input, emscripten::val::global("Float32Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
wnn_inputs_.set(input, emscripten::val::global("Int32Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
wnn_inputs_.set(input, emscripten::val::global("BigInt64Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
wnn_inputs_.set(input, emscripten::val::global("UInt32Array").new_(num_elements));
break;
default:
break;
}
Expand All @@ -172,9 +192,15 @@ void Model::AllocateInputOutputBuffers() {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
wnn_outputs_.set(output, emscripten::val::global("Float32Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
wnn_outputs_.set(output, emscripten::val::global("Int32Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
wnn_outputs_.set(output, emscripten::val::global("BigInt64Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
wnn_outputs_.set(output, emscripten::val::global("UInt32Array").new_(num_elements));
break;
default:
break;
}
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,21 @@ Status ModelBuilder::RegisterInitializers() {
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<float*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
desc.set("type", emscripten::val("int32"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int32_t*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
desc.set("type", emscripten::val("int64"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int64_t*>(unpacked_tensor.data()))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
desc.set("type", emscripten::val("uint32"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint32_t*>(unpacked_tensor.data()))};
break;
default:
break;
}
Expand Down Expand Up @@ -220,9 +230,15 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
desc.set("type", emscripten::val("float32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
desc.set("type", emscripten::val("int32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
desc.set("type", emscripten::val("int64"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
desc.set("type", emscripten::val("uint32"));
break;
default: {
// TODO: support other type.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down Expand Up @@ -295,11 +311,21 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
reinterpret_cast<const float*>(dest))};
desc.set("type", emscripten::val("float32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int32_t),
reinterpret_cast<const int32_t*>(dest))};
desc.set("type", emscripten::val("int32"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int64_t),
reinterpret_cast<const int64_t*>(dest))};
desc.set("type", emscripten::val("int64"));
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint32_t),
reinterpret_cast<const uint32_t*>(dest))};
desc.set("type", emscripten::val("uint32"));
break;
default:
break;
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
output_buffer = output_tensor.GetTensorMutableRawData();
break;
default:
Expand Down

0 comments on commit a99bf1a

Please sign in to comment.