-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30123 from hqucms/onnxrt-106x
[10_6_X] ONNXRuntime-based implementation of DeepJet, DeepAK8 and DeepDoubleX
- Loading branch information
Showing
37 changed files
with
1,110 additions
and
718 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
<use name="onnxruntime"/> | ||
<use name="FWCore/Utilities"/> | ||
<export> | ||
<lib name="1"/> | ||
</export> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* ONNXRuntime.h | ||
* | ||
* A convenience wrapper of the ONNXRuntime C++ API. | ||
* Based on https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp. | ||
* | ||
* Created on: Jun 28, 2019 | ||
* Author: hqu | ||
*/ | ||
|
||
#ifndef PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_ | ||
#define PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_ | ||
|
||
#include <vector> | ||
#include <map> | ||
#include <string> | ||
#include <memory> | ||
|
||
#include "onnxruntime/core/session/onnxruntime_cxx_api.h" | ||
|
||
namespace cms::Ort { | ||
|
||
typedef std::vector<std::vector<float>> FloatArrays; | ||
|
||
class ONNXRuntime { | ||
public: | ||
ONNXRuntime(const std::string& model_path, const ::Ort::SessionOptions* session_options = nullptr); | ||
ONNXRuntime(const ONNXRuntime&) = delete; | ||
ONNXRuntime& operator=(const ONNXRuntime&) = delete; | ||
~ONNXRuntime(); | ||
|
||
// Run inference and get outputs | ||
// input_names: list of the names of the input nodes. | ||
// input_values: list of input arrays for each input node. The order of `input_values` must match `input_names`. | ||
// output_names: names of the output nodes to get outputs from. Empty list means all output nodes. | ||
// batch_size: number of samples in the batch. Each array in `input_values` must have a shape layout of (batch_size, ...). | ||
// Returns: a std::vector<std::vector<float>>, with the order matched to `output_names`. | ||
// When `output_names` is empty, will return all outputs ordered as in `getOutputNames()`. | ||
FloatArrays run(const std::vector<std::string>& input_names, | ||
FloatArrays& input_values, | ||
const std::vector<std::string>& output_names = {}, | ||
int64_t batch_size = 1) const; | ||
|
||
// Get a list of names of all the output nodes | ||
const std::vector<std::string>& getOutputNames() const; | ||
|
||
// Get the shape of a output node | ||
// The 0th dim depends on the batch size, therefore is set to -1 | ||
const std::vector<int64_t>& getOutputShape(const std::string& output_name) const; | ||
|
||
private: | ||
static const ::Ort::Env env_; | ||
std::unique_ptr<::Ort::Session> session_; | ||
|
||
std::vector<std::string> input_node_strings_; | ||
std::vector<const char*> input_node_names_; | ||
std::map<std::string, std::vector<int64_t>> input_node_dims_; | ||
|
||
std::vector<std::string> output_node_strings_; | ||
std::vector<const char*> output_node_names_; | ||
std::map<std::string, std::vector<int64_t>> output_node_dims_; | ||
}; | ||
|
||
} // namespace cms::Ort | ||
|
||
#endif /* PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_ */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
/* | ||
* ONNXRuntime.cc | ||
* | ||
* Created on: Jun 28, 2019 | ||
* Author: hqu | ||
*/ | ||
|
||
#include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h" | ||
|
||
#include <cassert> | ||
#include <iostream> | ||
#include <algorithm> | ||
#include <numeric> | ||
#include <functional> | ||
#include "FWCore/Utilities/interface/Exception.h" | ||
#include "FWCore/Utilities/interface/thread_safety_macros.h" | ||
|
||
namespace cms::Ort { | ||
|
||
using namespace ::Ort; | ||
|
||
const Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_ERROR, ""); | ||
|
||
ONNXRuntime::ONNXRuntime(const std::string& model_path, const SessionOptions* session_options) { | ||
// create session | ||
if (session_options) { | ||
session_.reset(new Session(env_, model_path.c_str(), *session_options)); | ||
} else { | ||
SessionOptions sess_opts; | ||
sess_opts.SetIntraOpNumThreads(1); | ||
session_.reset(new Session(env_, model_path.c_str(), sess_opts)); | ||
} | ||
AllocatorWithDefaultOptions allocator; | ||
|
||
// get input names and shapes | ||
size_t num_input_nodes = session_->GetInputCount(); | ||
input_node_strings_.resize(num_input_nodes); | ||
input_node_names_.resize(num_input_nodes); | ||
input_node_dims_.clear(); | ||
|
||
for (size_t i = 0; i < num_input_nodes; i++) { | ||
// get input node names | ||
std::string input_name(session_->GetInputName(i, allocator)); | ||
input_node_strings_[i] = input_name; | ||
input_node_names_[i] = input_node_strings_[i].c_str(); | ||
|
||
// get input shapes | ||
auto type_info = session_->GetInputTypeInfo(i); | ||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); | ||
size_t num_dims = tensor_info.GetDimensionsCount(); | ||
input_node_dims_[input_name].resize(num_dims); | ||
tensor_info.GetDimensions(input_node_dims_[input_name].data(), num_dims); | ||
|
||
// set the batch size to 1 by default | ||
input_node_dims_[input_name].at(0) = 1; | ||
} | ||
|
||
size_t num_output_nodes = session_->GetOutputCount(); | ||
output_node_strings_.resize(num_output_nodes); | ||
output_node_names_.resize(num_output_nodes); | ||
output_node_dims_.clear(); | ||
|
||
for (size_t i = 0; i < num_output_nodes; i++) { | ||
// get output node names | ||
std::string output_name(session_->GetOutputName(i, allocator)); | ||
output_node_strings_[i] = output_name; | ||
output_node_names_[i] = output_node_strings_[i].c_str(); | ||
|
||
// get output node types | ||
auto type_info = session_->GetOutputTypeInfo(i); | ||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); | ||
size_t num_dims = tensor_info.GetDimensionsCount(); | ||
output_node_dims_[output_name].resize(num_dims); | ||
tensor_info.GetDimensions(output_node_dims_[output_name].data(), num_dims); | ||
|
||
// the 0th dim depends on the batch size | ||
output_node_dims_[output_name].at(0) = -1; | ||
} | ||
} | ||
|
||
ONNXRuntime::~ONNXRuntime() {} | ||
|
||
FloatArrays ONNXRuntime::run(const std::vector<std::string>& input_names, | ||
FloatArrays& input_values, | ||
const std::vector<std::string>& output_names, | ||
int64_t batch_size) const { | ||
assert(input_names.size() == input_values.size()); | ||
assert(batch_size > 0); | ||
|
||
// create input tensor objects from data values | ||
std::vector<Value> input_tensors; | ||
auto memory_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); | ||
for (const auto& name : input_node_strings_) { | ||
auto iter = std::find(input_names.begin(), input_names.end(), name); | ||
if (iter == input_names.end()) { | ||
throw cms::Exception("RuntimeError") << "Input " << name << " is not provided!"; | ||
} | ||
auto value = input_values.begin() + (iter - input_names.begin()); | ||
auto input_dims = input_node_dims_.at(name); | ||
input_dims[0] = batch_size; | ||
auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies<int64_t>()); | ||
if (expected_len != (int64_t)value->size()) { | ||
throw cms::Exception("RuntimeError") | ||
<< "Input array " << name << " has a wrong size of " << value->size() << ", expected " << expected_len; | ||
} | ||
auto input_tensor = | ||
Value::CreateTensor<float>(memory_info, value->data(), value->size(), input_dims.data(), input_dims.size()); | ||
assert(input_tensor.IsTensor()); | ||
input_tensors.emplace_back(std::move(input_tensor)); | ||
} | ||
|
||
// set output node names; will get all outputs if `output_names` is not provided | ||
std::vector<const char*> run_output_node_names; | ||
if (output_names.empty()) { | ||
run_output_node_names = output_node_names_; | ||
} else { | ||
for (const auto& name : output_names) { | ||
run_output_node_names.push_back(name.c_str()); | ||
} | ||
} | ||
|
||
// run | ||
auto output_tensors = session_->Run(RunOptions{nullptr}, | ||
input_node_names_.data(), | ||
input_tensors.data(), | ||
input_tensors.size(), | ||
run_output_node_names.data(), | ||
run_output_node_names.size()); | ||
|
||
// convert output to floats | ||
FloatArrays outputs; | ||
for (auto& output_tensor : output_tensors) { | ||
assert(output_tensor.IsTensor()); | ||
|
||
// get output shape | ||
auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo(); | ||
auto length = tensor_info.GetElementCount(); | ||
|
||
auto floatarr = output_tensor.GetTensorMutableData<float>(); | ||
outputs.emplace_back(floatarr, floatarr + length); | ||
} | ||
assert(outputs.size() == run_output_node_names.size()); | ||
|
||
return outputs; | ||
} | ||
|
||
const std::vector<std::string>& ONNXRuntime::getOutputNames() const { | ||
if (session_) { | ||
return output_node_strings_; | ||
} else { | ||
throw cms::Exception("RuntimeError") << "Needs to call createSession() first before getting the output names!"; | ||
} | ||
} | ||
|
||
const std::vector<int64_t>& ONNXRuntime::getOutputShape(const std::string& output_name) const { | ||
auto iter = output_node_dims_.find(output_name); | ||
if (iter == output_node_dims_.end()) { | ||
throw cms::Exception("RuntimeError") << "Output name " << output_name << " is invalid!"; | ||
} else { | ||
return iter->second; | ||
} | ||
} | ||
|
||
} /* namespace cms::Ort */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
<bin name="testONNXRuntime" file="testRunner.cpp, testONNXRuntime.cc"> | ||
<use name="boost_filesystem"/> | ||
<use name="cppunit"/> | ||
<use name="PhysicsTools/ONNXRuntime"/> | ||
<use name="FWCore/ParameterSet"/> | ||
<use name="FWCore/Utilities"/> | ||
</bin> |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include <cppunit/extensions/HelperMacros.h> | ||
|
||
#include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h" | ||
#include "FWCore/ParameterSet/interface/FileInPath.h" | ||
|
||
#include <chrono> | ||
#include <iostream> | ||
|
||
using namespace cms::Ort; | ||
|
||
class testONNXRuntime : public CppUnit::TestFixture { | ||
CPPUNIT_TEST_SUITE(testONNXRuntime); | ||
CPPUNIT_TEST(checkAll); | ||
CPPUNIT_TEST_SUITE_END(); | ||
|
||
public: | ||
void checkAll(); | ||
}; | ||
|
||
CPPUNIT_TEST_SUITE_REGISTRATION(testONNXRuntime); | ||
|
||
void testONNXRuntime::checkAll() { | ||
std::string model_path = edm::FileInPath("PhysicsTools/ONNXRuntime/test/data/model.onnx").fullPath(); | ||
ONNXRuntime rt(model_path); | ||
for (const unsigned batch_size : {1, 2, 4}) { | ||
FloatArrays input_values{ | ||
std::vector<float>(batch_size * 2, 1), | ||
}; | ||
FloatArrays outputs; | ||
CPPUNIT_ASSERT_NO_THROW(outputs = rt.run({"X"}, input_values, {"Y"}, batch_size)); | ||
CPPUNIT_ASSERT(outputs.size() == 1); | ||
CPPUNIT_ASSERT(outputs[0].size() == batch_size); | ||
for (const auto &v : outputs[0]) { | ||
CPPUNIT_ASSERT(v == 3); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#include <Utilities/Testing/interface/CppUnit_testdriver.icpp> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.