Skip to content

Commit

Permalink
Use tf SessionCache in DeepMETProducer.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Dec 10, 2022
1 parent cf78146 commit 8d81903
Showing 1 changed file with 11 additions and 25 deletions.
36 changes: 11 additions & 25 deletions RecoMET/METPUSubtraction/plugins/DeepMETProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,36 @@

using namespace deepmet_helper;

struct DeepMETCache {
std::atomic<tensorflow::GraphDef*> graph_def;
};

class DeepMETProducer : public edm::stream::EDProducer<edm::GlobalCache<DeepMETCache> > {
class DeepMETProducer : public edm::stream::EDProducer<edm::GlobalCache<tensorflow::SessionCache> > {
public:
explicit DeepMETProducer(const edm::ParameterSet&, const DeepMETCache*);
explicit DeepMETProducer(const edm::ParameterSet&, const tensorflow::SessionCache*);
void produce(edm::Event& event, const edm::EventSetup& setup) override;
static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);

// static methods for handling the global cache
static std::unique_ptr<DeepMETCache> initializeGlobalCache(const edm::ParameterSet&);
static void globalEndJob(DeepMETCache*);
static std::unique_ptr<tensorflow::SessionCache> initializeGlobalCache(const edm::ParameterSet&);
static void globalEndJob(tensorflow::SessionCache*){};

private:
const edm::EDGetTokenT<std::vector<pat::PackedCandidate> > pf_token_;
const float norm_;
const bool ignore_leptons_;
const unsigned int max_n_pf_;

tensorflow::Session* session_;
const tensorflow::Session* session_;

tensorflow::Tensor input_;
tensorflow::Tensor input_cat0_;
tensorflow::Tensor input_cat1_;
tensorflow::Tensor input_cat2_;
};

DeepMETProducer::DeepMETProducer(const edm::ParameterSet& cfg, const DeepMETCache* cache)
DeepMETProducer::DeepMETProducer(const edm::ParameterSet& cfg, const tensorflow::SessionCache* cache)
: pf_token_(consumes<std::vector<pat::PackedCandidate> >(cfg.getParameter<edm::InputTag>("pf_src"))),
norm_(cfg.getParameter<double>("norm_factor")),
ignore_leptons_(cfg.getParameter<bool>("ignore_leptons")),
max_n_pf_(cfg.getParameter<unsigned int>("max_n_pf")),
session_(tensorflow::createSession(cache->graph_def)) {
session_(cache->getSession()) {
produces<pat::METCollection>();

const tensorflow::TensorShape shape({1, max_n_pf_, 8});
Expand Down Expand Up @@ -125,22 +121,12 @@ void DeepMETProducer::produce(edm::Event& event, const edm::EventSetup& setup) {
event.put(std::move(pf_mets));
}

std::unique_ptr<DeepMETCache> DeepMETProducer::initializeGlobalCache(const edm::ParameterSet& params) {
// this method is supposed to create, initialize and return a DeepMETCache instance
std::unique_ptr<DeepMETCache> cache = std::make_unique<DeepMETCache>();

// load the graph def and save it
std::string graphPath = params.getParameter<std::string>("graph_path");
if (!graphPath.empty()) {
graphPath = edm::FileInPath(graphPath).fullPath();
cache->graph_def = tensorflow::loadGraphDef(graphPath);
}

return cache;
std::unique_ptr<tensorflow::SessionCache> DeepMETProducer::initializeGlobalCache(const edm::ParameterSet& params) {
// this method is supposed to create, initialize and return a SessionCache instance
std::string graphPath = edm::FileInPath(params.getParameter<std::string>("graph_path")).fullPath();
return std::make_unique<tensorflow::SessionCache>(graphPath);
}

void DeepMETProducer::globalEndJob(DeepMETCache* cache) { delete cache->graph_def; }

void DeepMETProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
edm::ParameterSetDescription desc;
desc.add<edm::InputTag>("pf_src", edm::InputTag("packedPFCandidates"));
Expand Down

0 comments on commit 8d81903

Please sign in to comment.