Skip to content

Commit

Permalink
Merge pull request microsoft#18 from NonStatic2014/bohu/executor
Browse files Browse the repository at this point in the history
Executor for prediction and MLValue2TensorProto
  • Loading branch information
NonStatic2014 authored Apr 2, 2019
2 parents a5f97db + 6aed3be commit 5390322
Show file tree
Hide file tree
Showing 14 changed files with 1,687 additions and 15 deletions.
6 changes: 5 additions & 1 deletion cmake/onnxruntime_hosting.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,21 @@ set(re2_src ${REPO_ROOT}/cmake/external/re2)
# Setup source code
file(GLOB_RECURSE onnxruntime_hosting_lib_srcs
"${ONNXRUNTIME_ROOT}/hosting/http/*.cc"
"${ONNXRUNTIME_ROOT}/hosting/environment.cc"
"${ONNXRUNTIME_ROOT}/hosting/executor.cc"
"${ONNXRUNTIME_ROOT}/hosting/converter.cc"
)
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/http/json_handling.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/http/predict_request_handler.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/executor.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/converter.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()

file(GLOB_RECURSE onnxruntime_hosting_srcs
"${ONNXRUNTIME_ROOT}/hosting/main.cc"
"${ONNXRUNTIME_ROOT}/hosting/environment.cc"
)

# Hosting library
Expand Down
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ if (onnxruntime_BUILD_HOSTING)
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties("${TEST_SRC_DIR}/hosting/json_handling_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties("${TEST_SRC_DIR}/hosting/converter_tests.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()

Expand All @@ -548,7 +549,7 @@ if (onnxruntime_BUILD_HOSTING)
AddTest(
TARGET onnxruntime_hosting_tests
SOURCES ${onnxruntime_test_hosting_src}
LIBS ${onnxruntime_test_hosting_libs} hosting_proto onnxruntime_hosting_lib
LIBS ${onnxruntime_test_hosting_libs} hosting_proto onnxruntime_hosting_lib ${onnxruntime_test_providers_libs}
DEPENDS ${onnxruntime_EXTERNAL_DEPENDENCIES}
)
endif()
Expand Down
258 changes: 258 additions & 0 deletions onnxruntime/hosting/converter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <onnx/onnx_pb.h>
#include "core/common/logging/logging.h"
#include "core/framework/data_types.h"
#include "core/framework/environment.h"
#include "core/framework/framework_common.h"
#include "core/framework/mem_buffer.h"
#include "core/framework/ml_value.h"
#include "core/framework/tensor.h"
#include "core/framework/tensorprotoutils.h"

#include "onnx-ml.pb.h"
#include "predict.pb.h"

#include "converter.h"

namespace onnxruntime {
namespace hosting {

namespace protobufutil = google::protobuf::util;

onnx::TensorProto_DataType MLDataTypeToTensorProtoDataType(const onnxruntime::DataTypeImpl* cpp_type) {
if (cpp_type == onnxruntime::DataTypeImpl::GetType<float>()) {
return onnx::TensorProto_DataType_FLOAT;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint8_t>()) {
return onnx::TensorProto_DataType_UINT8;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int8_t>()) {
return onnx::TensorProto_DataType_INT8;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint16_t>()) {
return onnx::TensorProto_DataType_UINT16;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int16_t>()) {
return onnx::TensorProto_DataType_INT16;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int32_t>()) {
return onnx::TensorProto_DataType_INT32;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<int64_t>()) {
return onnx::TensorProto_DataType_INT64;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<std::string>()) {
return onnx::TensorProto_DataType_STRING;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<bool>()) {
return onnx::TensorProto_DataType_BOOL;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<onnxruntime::MLFloat16>()) {
return onnx::TensorProto_DataType_FLOAT16;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<onnxruntime::BFloat16>()) {
return onnx::TensorProto_DataType_BFLOAT16;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<double>()) {
return onnx::TensorProto_DataType_DOUBLE;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint32_t>()) {
return onnx::TensorProto_DataType_UINT32;
} else if (cpp_type == onnxruntime::DataTypeImpl::GetType<uint64_t>()) {
return onnx::TensorProto_DataType_UINT64;
} else {
return onnx::TensorProto_DataType_UNDEFINED;
}
}

common::Status MLValue2TensorProto(onnxruntime::MLValue& ml_value, bool using_raw_data,
std::shared_ptr<onnxruntime::logging::Logger> logger,
/* out */ onnx::TensorProto& tensor_proto) {
// Tensor in MLValue
const auto& tensor = ml_value.Get<onnxruntime::Tensor>();

// dims field
const onnxruntime::TensorShape& tensor_shape = tensor.Shape();
for (const auto& dim : tensor_shape.GetDims()) {
tensor_proto.add_dims(dim);
}

// data_type field
onnx::TensorProto_DataType data_type = MLDataTypeToTensorProtoDataType(tensor.DataType());
tensor_proto.set_data_type(data_type);

// data_location field: Data is stored in raw_data (if set) otherwise in type-specified field.
if (using_raw_data && data_type != onnx::TensorProto_DataType_STRING) {
tensor_proto.set_data_location(onnx::TensorProto_DataLocation_DEFAULT);
}

// *_data field
// According to onnx_ml.proto, depending on the data_type field,
// exactly one of the *_data fields is used to store the elements of the tensor.
switch (data_type) {
case onnx::TensorProto_DataType_FLOAT: { // Target: raw_data or float_data
const auto* data = tensor.Data<float>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (int i = 0, count = tensor.Shape().Size(); i < count; ++i) {
tensor_proto.add_float_data(data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_INT32: { // Target: raw_data or int32_data
const auto* data = tensor.Data<int32_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (int i = 0, count = tensor.Shape().Size(); i < count; ++i) {
tensor_proto.add_int32_data(data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_UINT8: { // Target: raw_data or int32_data
const auto* data = tensor.Data<uint8_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (int i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_INT8: { // Target: raw_data or int32_data
const auto* data = tensor.Data<int8_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (int i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_UINT16: { // Target: raw_data or int32_data
const auto* data = tensor.Data<uint16_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (int i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_INT16: { // Target: raw_data or int32_data
const auto* data = tensor.Data<int16_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (int i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_BOOL: { // Target: raw_data or int32_data
const auto* data = tensor.Data<bool>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (int i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_FLOAT16: { // Target: raw_data or int32_data
const auto* data = tensor.Data<onnxruntime::MLFloat16>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto i32data = reinterpret_cast<const int32_t*>(data);
for (int i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_BFLOAT16: { // Target: raw_data or int32_data
const auto* data = tensor.Data<onnxruntime::BFloat16>();
std::vector<uint16_t> raw_data;
for (int i = 0; i < tensor.Shape().Size(); ++i) {
raw_data.push_back(data[i].val);
}

if (using_raw_data) {
tensor_proto.set_raw_data(raw_data.data(), raw_data.size() * sizeof(uint16_t));
} else {
auto i32data = reinterpret_cast<const int32_t*>(raw_data.data());
for (int i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(int32_t)); i < count; ++i) {
tensor_proto.add_int32_data(i32data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_STRING: { // Target: string_data
// string could not be written into "raw_data"
const auto* data = tensor.Data<std::string>();
for (int i = 0, count = tensor.Shape().Size(); i < count; ++i) {
tensor_proto.add_string_data(data[i]);
}
break;
}
case onnx::TensorProto_DataType_INT64: { // Target: raw_data or int64_data
const auto* data = tensor.Data<int64_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (int x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
tensor_proto.add_int64_data(data[x]);
}
}
break;
}
case onnx::TensorProto_DataType_UINT32: { // Target: raw_data or uint64_data
const auto* data = tensor.Data<uint32_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
auto u64data = reinterpret_cast<const uint64_t*>(data);
for (int i = 0, count = 1 + ((tensor.Size() - 1) / sizeof(uint64_t)); i < count; ++i) {
tensor_proto.add_uint64_data(u64data[i]);
}
}
break;
}
case onnx::TensorProto_DataType_UINT64: { // Target: raw_data or uint64_data
const auto* data = tensor.Data<uint64_t>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (int x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
tensor_proto.add_uint64_data(data[x]);
}
}
break;
}
case onnx::TensorProto_DataType_DOUBLE: { // Target: raw_data or double_data
auto data = tensor.Data<double>();
if (using_raw_data) {
tensor_proto.set_raw_data(data, tensor.Size());
} else {
for (int x = 0, loop_length = tensor.Shape().Size(); x < loop_length; ++x) {
tensor_proto.add_double_data(data[x]);
}
}
break;
}
default: {
LOGS(*logger, ERROR) << "Unsupported TensorProto DataType: " << data_type;
return common::Status(common::StatusCategory::ONNXRUNTIME,
common::StatusCode::NOT_IMPLEMENTED,
"Unsupported TensorProto DataType: " + std::to_string(data_type));
}
}

return common::Status::OK();
}
} // namespace hosting
} // namespace onnxruntime
32 changes: 32 additions & 0 deletions onnxruntime/hosting/converter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifndef ONNXRUNTIME_HOSTING_CONVERTER_H
#define ONNXRUNTIME_HOSTING_CONVERTER_H

#include <google/protobuf/stubs/status.h>

#include "core/framework/data_types.h"

#include "environment.h"
#include "predict.pb.h"

namespace onnxruntime {
namespace hosting {

onnx::TensorProto_DataType MLDataTypeToTensorProtoDataType(const onnxruntime::DataTypeImpl* cpp_type);

// Convert MLValue to TensorProto. Some fields are ignored:
// * name field: could not get from MLValue
// * doc_string: could not get from MLValue
// * segment field: we do not expect very large tensors in the prediction output
// * external_data field: we do not expect very large tensors in the prediction output
// Note: If any input data is in raw_data field, all outputs tensor data will be put into raw_data field.
common::Status MLValue2TensorProto(onnxruntime::MLValue& ml_value, bool using_raw_data,
std::shared_ptr<onnxruntime::logging::Logger> logger,
/* out */ onnx::TensorProto& tensor_proto);

} // namespace hosting
} // namespace onnxruntime

#endif // ONNXRUNTIME_HOSTING_CONVERTER_H
Loading

0 comments on commit 5390322

Please sign in to comment.