Skip to content

Commit

Permalink
Merge pull request cms-sw#40284 from riga/use_global_session_DeepMET
Browse files Browse the repository at this point in the history
Move TF session in DeepMETProducer to global cache
  • Loading branch information
cmsbuild authored Dec 13, 2022
2 parents 25b8359 + 8d81903 commit 39394a2
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 26 deletions.
33 changes: 33 additions & 0 deletions PhysicsTools/TensorFlow/interface/TensorFlow.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,39 @@ namespace tensorflow {
run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
}

// struct that can be used in edm::stream modules for caching a graph and a session instance,
// both made atomic for cases where access is required from multiple threads
struct SessionCache {
std::atomic<GraphDef*> graph;
std::atomic<Session*> session;

// constructor
SessionCache() {}

// initializing constructor, forwarding all arguments to createSession
template <typename... Args>
SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
createSession(graphPath, std::forward<Args>(sessionArgs)...);
}

// destructor
~SessionCache() { closeSession(); }

// create the internal graph representation from graphPath and the session object, forwarding
// all additional arguments to the central tensorflow::createSession
template <typename... Args>
void createSession(const std::string& graphPath, Args&&... sessionArgs) {
graph.store(loadGraphDef(graphPath));
session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
}

// return a pointer to the const session
inline const Session* getSession() const { return session.load(); }

// closes and removes the session as well as the graph, and sets the atomic members to nullptr's
void closeSession();
};

} // namespace tensorflow

#endif // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
15 changes: 15 additions & 0 deletions PhysicsTools/TensorFlow/src/TensorFlow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,19 @@ namespace tensorflow {
run(session, {}, outputNames, outputs, threadPoolName);
}

void SessionCache::closeSession() {
// delete the session if set
Session* s = session.load();
if (s != nullptr) {
tensorflow::closeSession(s);
session.store(nullptr);
}

// delete the graph if set
if (graph.load() != nullptr) {
delete graph.load();
graph.store(nullptr);
}
}

} // namespace tensorflow
6 changes: 6 additions & 0 deletions PhysicsTools/TensorFlow/test/BuildFile.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
<use name="PhysicsTools/TensorFlow"/>
</bin>

<bin name="testTFSessionCache" file="testRunner.cpp,testSessionCache.cc">
<use name="boost_filesystem"/>
<use name="cppunit"/>
<use name="PhysicsTools/TensorFlow"/>
</bin>

<bin name="testTFThreadPools" file="testRunner.cpp,testThreadPools.cc">
<use name="boost_filesystem"/>
<use name="cppunit"/>
Expand Down
2 changes: 1 addition & 1 deletion PhysicsTools/TensorFlow/test/testBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void testBase::setUp() {

// create the graph
std::string testPath = cmsswPath("/src/PhysicsTools/TensorFlow/test");
std::string cmd = "python3 " + testPath + "/" + pyScript() + " " + dataPath_;
std::string cmd = "python3 -W ignore " + testPath + "/" + pyScript() + " " + dataPath_;
std::array<char, 128> buffer;
std::string result;
std::shared_ptr<FILE> pipe(popen(cmd.c_str(), "r"), pclose);
Expand Down
46 changes: 46 additions & 0 deletions PhysicsTools/TensorFlow/test/testSessionCache.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Tests for interacting with the SessionCache.
*
* Author: Marcel Rieger
*/

#include <stdexcept>
#include <cppunit/extensions/HelperMacros.h>

#include "PhysicsTools/TensorFlow/interface/TensorFlow.h"

#include "testBase.h"

class testSessionCache : public testBase {
CPPUNIT_TEST_SUITE(testSessionCache);
CPPUNIT_TEST(checkAll);
CPPUNIT_TEST_SUITE_END();

public:
std::string pyScript() const override;
void checkAll() override;
};

CPPUNIT_TEST_SUITE_REGISTRATION(testSessionCache);

std::string testSessionCache::pyScript() const { return "createconstantgraph.py"; }

void testSessionCache::checkAll() {
std::string pbFile = dataPath_ + "/constantgraph.pb";

tensorflow::setLogging();

// load the graph and the session
tensorflow::SessionCache cache(pbFile);
CPPUNIT_ASSERT(cache.graph.load() != nullptr);
CPPUNIT_ASSERT(cache.session.load() != nullptr);

// get a const session pointer
const tensorflow::Session* session = cache.getSession();
CPPUNIT_ASSERT(session != nullptr);

// cleanup
cache.closeSession();
CPPUNIT_ASSERT(cache.graph.load() == nullptr);
CPPUNIT_ASSERT(cache.session.load() == nullptr);
}
36 changes: 11 additions & 25 deletions RecoMET/METPUSubtraction/plugins/DeepMETProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,36 @@

using namespace deepmet_helper;

struct DeepMETCache {
std::atomic<tensorflow::GraphDef*> graph_def;
};

class DeepMETProducer : public edm::stream::EDProducer<edm::GlobalCache<DeepMETCache> > {
class DeepMETProducer : public edm::stream::EDProducer<edm::GlobalCache<tensorflow::SessionCache> > {
public:
explicit DeepMETProducer(const edm::ParameterSet&, const DeepMETCache*);
explicit DeepMETProducer(const edm::ParameterSet&, const tensorflow::SessionCache*);
void produce(edm::Event& event, const edm::EventSetup& setup) override;
static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);

// static methods for handling the global cache
static std::unique_ptr<DeepMETCache> initializeGlobalCache(const edm::ParameterSet&);
static void globalEndJob(DeepMETCache*);
static std::unique_ptr<tensorflow::SessionCache> initializeGlobalCache(const edm::ParameterSet&);
static void globalEndJob(tensorflow::SessionCache*){};

private:
const edm::EDGetTokenT<std::vector<pat::PackedCandidate> > pf_token_;
const float norm_;
const bool ignore_leptons_;
const unsigned int max_n_pf_;

tensorflow::Session* session_;
const tensorflow::Session* session_;

tensorflow::Tensor input_;
tensorflow::Tensor input_cat0_;
tensorflow::Tensor input_cat1_;
tensorflow::Tensor input_cat2_;
};

DeepMETProducer::DeepMETProducer(const edm::ParameterSet& cfg, const DeepMETCache* cache)
DeepMETProducer::DeepMETProducer(const edm::ParameterSet& cfg, const tensorflow::SessionCache* cache)
: pf_token_(consumes<std::vector<pat::PackedCandidate> >(cfg.getParameter<edm::InputTag>("pf_src"))),
norm_(cfg.getParameter<double>("norm_factor")),
ignore_leptons_(cfg.getParameter<bool>("ignore_leptons")),
max_n_pf_(cfg.getParameter<unsigned int>("max_n_pf")),
session_(tensorflow::createSession(cache->graph_def)) {
session_(cache->getSession()) {
produces<pat::METCollection>();

const tensorflow::TensorShape shape({1, max_n_pf_, 8});
Expand Down Expand Up @@ -125,22 +121,12 @@ void DeepMETProducer::produce(edm::Event& event, const edm::EventSetup& setup) {
event.put(std::move(pf_mets));
}

std::unique_ptr<DeepMETCache> DeepMETProducer::initializeGlobalCache(const edm::ParameterSet& params) {
// this method is supposed to create, initialize and return a DeepMETCache instance
std::unique_ptr<DeepMETCache> cache = std::make_unique<DeepMETCache>();

// load the graph def and save it
std::string graphPath = params.getParameter<std::string>("graph_path");
if (!graphPath.empty()) {
graphPath = edm::FileInPath(graphPath).fullPath();
cache->graph_def = tensorflow::loadGraphDef(graphPath);
}

return cache;
std::unique_ptr<tensorflow::SessionCache> DeepMETProducer::initializeGlobalCache(const edm::ParameterSet& params) {
// this method is supposed to create, initialize and return a SessionCache instance
std::string graphPath = edm::FileInPath(params.getParameter<std::string>("graph_path")).fullPath();
return std::make_unique<tensorflow::SessionCache>(graphPath);
}

void DeepMETProducer::globalEndJob(DeepMETCache* cache) { delete cache->graph_def; }

void DeepMETProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
edm::ParameterSetDescription desc;
desc.add<edm::InputTag>("pf_src", edm::InputTag("packedPFCandidates"));
Expand Down

0 comments on commit 39394a2

Please sign in to comment.