Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore TF and MXNet-based inference for DeepJet, DeepDoubleX and DeepAK8 #29172

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@
#include <string>
#include <memory>

// currently ONNXRUNTIME only supports x86 and ARM
#if defined(__arm__) || defined(__aarch64__) || defined(__x86_64__) || defined(__i386__)
#define CMS_USE_ONNXRUNTIME
#endif

#ifdef CMS_USE_ONNXRUNTIME
#include "onnxruntime/core/session/onnxruntime_cxx_api.h"
#else
namespace Ort {
struct SessionOptions {};
} // namespace Ort
#endif

namespace cms::Ort {

Expand Down Expand Up @@ -48,6 +59,7 @@ namespace cms::Ort {
// The 0th dim depends on the batch size, therefore is set to -1
const std::vector<int64_t>& getOutputShape(const std::string& output_name) const;

#ifdef CMS_USE_ONNXRUNTIME
private:
static const ::Ort::Env env_;
std::unique_ptr<::Ort::Session> session_;
Expand All @@ -59,6 +71,7 @@ namespace cms::Ort {
std::vector<std::string> output_node_strings_;
std::vector<const char*> output_node_names_;
std::map<std::string, std::vector<int64_t>> output_node_dims_;
#endif
};

} // namespace cms::Ort
Expand Down
18 changes: 17 additions & 1 deletion PhysicsTools/ONNXRuntime/src/ONNXRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ namespace cms::Ort {

using namespace ::Ort;

#ifdef CMS_USE_ONNXRUNTIME
const Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_WARNING, "");
#endif

ONNXRuntime::ONNXRuntime(const std::string& model_path, const SessionOptions* session_options) {
#ifdef CMS_USE_ONNXRUNTIME
// create session
if (session_options) {
session_.reset(new Session(env_, model_path.c_str(), *session_options));
Expand Down Expand Up @@ -76,6 +79,7 @@ namespace cms::Ort {
// the 0th dim depends on the batch size
output_node_dims_[output_name].at(0) = -1;
}
#endif
}

ONNXRuntime::~ONNXRuntime() {}
Expand All @@ -84,6 +88,7 @@ namespace cms::Ort {
FloatArrays& input_values,
const std::vector<std::string>& output_names,
int64_t batch_size) const {
#ifdef CMS_USE_ONNXRUNTIME
assert(input_names.size() == input_values.size());
assert(batch_size > 0);

Expand Down Expand Up @@ -142,23 +147,34 @@ namespace cms::Ort {
assert(outputs.size() == run_output_node_names.size());

return outputs;
#else
throw cms::Exception("RuntimeError") << "ONNXRuntime does not support the current architecture";
#endif
}

const std::vector<std::string>& ONNXRuntime::getOutputNames() const {
#ifdef CMS_USE_ONNXRUNTIME
if (session_) {
return output_node_strings_;
} else {
throw cms::Exception("RuntimeError") << "Needs to call createSession() first before getting the output names!";
throw cms::Exception("RuntimeError") << "ONNXRuntime session is not initialized!";
}
#else
throw cms::Exception("RuntimeError") << "ONNXRuntime does not support the current architecture";
#endif
}

const std::vector<int64_t>& ONNXRuntime::getOutputShape(const std::string& output_name) const {
#ifdef CMS_USE_ONNXRUNTIME
auto iter = output_node_dims_.find(output_name);
if (iter == output_node_dims_.end()) {
throw cms::Exception("RuntimeError") << "Output name " << output_name << " is invalid!";
} else {
return iter->second;
}
#else
throw cms::Exception("RuntimeError") << "ONNXRuntime does not support the current architecture";
#endif
}

} /* namespace cms::Ort */
5 changes: 5 additions & 0 deletions PhysicsTools/ONNXRuntime/test/testONNXRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
#include "FWCore/ParameterSet/interface/FileInPath.h"
#include "FWCore/Utilities/interface/Exception.h"

#include <chrono>
#include <iostream>
Expand All @@ -27,11 +28,15 @@ void testONNXRuntime::checkAll() {
std::vector<float>(batch_size * 2, 1),
};
FloatArrays outputs;
#ifdef CMS_USE_ONNXRUNTIME
CPPUNIT_ASSERT_NO_THROW(outputs = rt.run({"X"}, input_values, {"Y"}, batch_size));
CPPUNIT_ASSERT(outputs.size() == 1);
CPPUNIT_ASSERT(outputs[0].size() == batch_size);
for (const auto &v : outputs[0]) {
CPPUNIT_ASSERT(v == 3);
}
#else
CPPUNIT_ASSERT_THROW(rt.run({"X"}, input_values, {"Y"}, batch_size), cms::Exception);
#endif
}
}
10 changes: 10 additions & 0 deletions PhysicsTools/PatAlgos/python/tools/jetTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def setupBTagging(process, jetSource, pfCandidates, explicitJTA, pvSource, svSou
process.load("RecoBTag.CTagging.cTagging_EventSetup_cff")
import RecoBTag.Configuration.RecoBTag_cff as btag
import RecoJets.JetProducers.caTopTaggers_cff as toptag
from RecoBTag.ONNXRuntime.SwitchProducerONNX import SwitchProducerONNX

if tightBTagNTkHits:
if not runIVF:
Expand Down Expand Up @@ -720,6 +721,15 @@ def setupBTagging(process, jetSource, pfCandidates, explicitJTA, pvSource, svSou
process,
task
)
elif isinstance(getattr(btag, btagDiscr), SwitchProducerONNX):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this branch is needed only because of cloneAll in the name.
It would be better to avoid a special method case

addToProcessAndTask(
newDiscr,
getattr(btag, btagDiscr).cloneAll(
src = btagPrefix + supportedBtagDiscr[discriminator_name][0][0] + labelName + postfix
),
process,
task
)
else:
raise ValueError('I do not know how to update %s it does not have neither "tagInfos" nor "src" attributes' % btagDiscr)
acceptedBtagDiscriminators.append(discriminator_name)
Expand Down
4 changes: 2 additions & 2 deletions PhysicsTools/TensorFlow/interface/TensorFlow.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ namespace tensorflow {
// return a new session that will contain an already loaded graph def, sessionOptions are predefined
// an error is thrown when graphDef is a nullptr or when the grah has no nodes
// transfers ownership
Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions);
Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions);

// return a new session that will contain an already loaded graph def, threading options are
// inferred from nThreads
// an error is thrown when graphDef is a nullptr or when the grah has no nodes
// transfers ownership
Session* createSession(GraphDef* graphDef, int nThreads = 1);
Session* createSession(const GraphDef* graphDef, int nThreads = 1);

// closes a session, calls its destructor, resets the pointer, and returns true on success
bool closeSession(Session*& session);
Expand Down
4 changes: 2 additions & 2 deletions PhysicsTools/TensorFlow/src/TensorFlow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ namespace tensorflow {
return createSession(metaGraphDef, exportDir, sessionOptions);
}

Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions) {
Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions) {
// check for valid pointer
if (graphDef == nullptr) {
throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
Expand All @@ -185,7 +185,7 @@ namespace tensorflow {
return session;
}

Session* createSession(GraphDef* graphDef, int nThreads) {
Session* createSession(const GraphDef* graphDef, int nThreads) {
// create session options and set thread options
SessionOptions sessionOptions;
setThreading(sessionOptions, nThreads);
Expand Down
48 changes: 48 additions & 0 deletions RecoBTag/FeatureTools/interface/tensor_fillers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef RecoBTag_FeatureTools_tensor_fillers_h
#define RecoBTag_FeatureTools_tensor_fillers_h

#include "DataFormats/BTauReco/interface/DeepFlavourTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepDoubleXTagInfo.h"

namespace btagbtvdeep {

void jet_tensor_filler(float* ptr, const btagbtvdeep::DeepFlavourFeatures& features, unsigned feature_dims);

void jet4vec_tensor_filler(float* ptr, const btagbtvdeep::DeepFlavourFeatures& features, unsigned feature_dims);

void db_tensor_filler(float* ptr, const btagbtvdeep::DeepDoubleXFeatures& features, unsigned feature_dims);

void c_pf_tensor_filler(float* ptr,
std::size_t max_c_pf_n,
const std::vector<btagbtvdeep::ChargedCandidateFeatures>& c_pf_features_vec,
unsigned feature_dims);

void c_pf_reduced_tensor_filler(float* ptr,
std::size_t max_c_pf_n,
const std::vector<btagbtvdeep::ChargedCandidateFeatures>& c_pf_features_vec,
unsigned feature_dims);

void n_pf_tensor_filler(float* ptr,
std::size_t max_n_pf_n,
const std::vector<btagbtvdeep::NeutralCandidateFeatures>& n_pf_features_vec,
unsigned feature_dims);

void sv_tensor_filler(float* ptr,
std::size_t max_sv_n,
const std::vector<btagbtvdeep::SecondaryVertexFeatures>& sv_features_vec,
unsigned feature_dims);

void sv_reduced_tensor_filler(float* ptr,
std::size_t max_sv_n,
const std::vector<btagbtvdeep::SecondaryVertexFeatures>& sv_features_vec,
unsigned feature_dims);

void seed_tensor_filler(float* ptr, const btagbtvdeep::SeedingTrackFeatures& seed_features, unsigned feature_dims);

void neighbourTracks_tensor_filler(float* ptr,
const btagbtvdeep::SeedingTrackFeatures& seed_features,
unsigned feature_dims);

} // namespace btagbtvdeep

#endif
Loading