Skip to content

Commit

Permalink
propagated array filling back to windowbase
Browse files Browse the repository at this point in the history
  • Loading branch information
jkiesele committed Sep 25, 2020
1 parent 3ca732d commit 9fc4fd5
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 77 deletions.
13 changes: 1 addition & 12 deletions RecoHGCal/GraphReco/interface/InferenceWindow.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#define SRC_RECOHGCAL_GRAPHRECO_INTERFACE_INFERENCEWINDOW_H_

#include "../interface/WindowBase.h"
#include "PhysicsTools/TensorFlow/interface/TensorFlow.h"

class InferenceWindow: public WindowBase {
public:
Expand All @@ -20,18 +19,12 @@ class InferenceWindow: public WindowBase {

~InferenceWindow();

void setupTFInterface(size_t padSize, size_t nFeatures, bool batchedModel,
const std::string& inputTensorName,
const std::string& outputTensorName);

void fillFeatureArrays();

void evaluate(tensorflow::Session* sess);
void evaluate();

void getOutput() const{}//needs output format etc.

void flattenRechitFeatures();

static std::vector<InferenceWindow> createWindows(size_t nSegmentsPhi,
size_t nSegmentsEta, double minEta, double maxEta, double frameWidthEta,
double frameWidthPhi);
Expand All @@ -41,10 +34,6 @@ class InferenceWindow: public WindowBase {
InferenceWindow(){}
//
//Inference
tensorflow::Tensor inputTensor;
tensorflow::NamedTensorList inputTensorList;
tensorflow::Tensor outputTensor;
std::string outputTensorName_;

};

Expand Down
2 changes: 1 addition & 1 deletion RecoHGCal/GraphReco/interface/NTupleWindow.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class NTupleWindow: public WindowBase {
//0 associate all rechits etc

void fillFeatureArrays();

//1
void fillTruthArrays();
//2
Expand Down Expand Up @@ -56,7 +57,6 @@ class NTupleWindow: public WindowBase {


//can be layer clusters or rechits according to mode
std::vector<std::vector<float> > hitFeatures_; //this includes tracks!
std::vector<float> recHitEnergy_;
std::vector<float> recHitEta_;
std::vector<float> recHitPhi_;
Expand Down
7 changes: 6 additions & 1 deletion RecoHGCal/GraphReco/interface/WindowBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class WindowBase {

void clear();

virtual void fillFeatureArrays()=0;
void fillFeatureArrays();


//debug functions
Expand Down Expand Up @@ -189,6 +189,8 @@ class WindowBase {
size_t nSegmentsEta, double minEta, double maxEta, double frameWidthEta,
double frameWidthPhi) ;

const std::vector<std::vector<float> >& getHitFeatures()const { return hitFeatures_;}

private:

mode mode_;
Expand Down Expand Up @@ -217,6 +219,9 @@ class WindowBase {
std::vector<const SimCluster*> badSimClusters_;
std::vector<bool> simClustersInnerWindow_;

//this is the input to the model!
std::vector<std::vector<float> > hitFeatures_; //this includes tracks!

//for one track
void fillTrackFeatures(float*& data, const TrackWithHGCalPos *) const;
//for one rechit
Expand Down
2 changes: 1 addition & 1 deletion RecoHGCal/GraphReco/plugins/WindowInference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void WindowInference::analyze(const edm::Event& event,

// run the evaluation per window
for (auto & window : windows_) {
window.evaluate(session_);
// window.evaluate(session_);
}

// reconstruct showers using all windows and put them into the event
Expand Down
33 changes: 0 additions & 33 deletions RecoHGCal/GraphReco/src/InferenceWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,4 @@ std::vector<InferenceWindow> InferenceWindow::createWindows(size_t nSegmentsPhi,
minEta,maxEta,frameWidthEta,frameWidthPhi);
}

void InferenceWindow::fillFeatureArrays(){
float * data = 0; //inputTensor.data()

if(getMode() == useRechits){
for(const auto& rh:recHits){
fillRecHitFeatures(data,rh);
}
}
else{
for(const auto& lc: layerClusters_){
fillLayerClusterFeatures(data,lc);
}
}
//add tracks LAST!
for(const auto& tr:tracks_){
fillTrackFeatures(data,tr);
}
//do some zero padding if needed

//FIXME

}

void InferenceWindow::setupTFInterface(size_t padSize, size_t nFeatures, bool batchedModel,
const std::string& inputTensorName,
const std::string& outputTensorName) {

}



void InferenceWindow::evaluate(tensorflow::Session* sess) {

}
30 changes: 1 addition & 29 deletions RecoHGCal/GraphReco/src/NTupleWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,35 +311,7 @@ void NTupleWindow::clear(){

void NTupleWindow::fillFeatureArrays(){
//NO CUTS HERE!

hitFeatures_.clear();
if(getMode() == useRechits){
for(const auto& rh:recHits){
std::vector<float> feats(nRechitFeatures_);
auto data = &feats.at(0);
fillRecHitFeatures(data,rh);
hitFeatures_.push_back(feats);
}
}
else{
for(const auto& lc: layerClusters_){
std::vector<float> feats(nLayerClusterFeatures_);
auto data = &feats.at(0);
fillLayerClusterFeatures(data,lc);
hitFeatures_.push_back(feats);
}
}

//return;
//add tracks LAST!
for(const auto& tr:tracks_){
std::vector<float> feats(nTrackFeatures_);
auto data = &feats.at(0);
fillTrackFeatures(data,tr);
hitFeatures_.push_back(feats);
}


WindowBase::fillFeatureArrays();
createDetIDHitAssociation();

}
Expand Down
34 changes: 34 additions & 0 deletions RecoHGCal/GraphReco/src/WindowBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ void WindowBase::clear() {
simClusters_.clear();
badSimClusters_.clear();
ticltracksters_.clear();
hitFeatures_.clear();
}


Expand Down Expand Up @@ -115,6 +116,39 @@ void WindowBase::fillLayerClusterFeatures(float*& data, const reco::CaloCluster
throw std::runtime_error("WindowBase::fillLayerClusterFeatures: LC not supported anymore");
}

void WindowBase::fillFeatureArrays(){
//NO CUTS HERE!

hitFeatures_.clear();
if(getMode() == useRechits){
for(const auto& rh:recHits){
std::vector<float> feats(nRechitFeatures_);
auto data = &feats.at(0);
fillRecHitFeatures(data,rh);
hitFeatures_.push_back(feats);
}
}
else{
for(const auto& lc: layerClusters_){
std::vector<float> feats(nLayerClusterFeatures_);
auto data = &feats.at(0);
fillLayerClusterFeatures(data,lc);
hitFeatures_.push_back(feats);
}
}

//return;
//add tracks LAST!
for(const auto& tr:tracks_){
std::vector<float> feats(nTrackFeatures_);
auto data = &feats.at(0);
fillTrackFeatures(data,tr);
hitFeatures_.push_back(feats);
}



}

WindowBase::particle_type WindowBase::pdgToParticleType(int pdgid)const{

Expand Down

0 comments on commit 9fc4fd5

Please sign in to comment.