Skip to content

Commit

Permalink
DeepTau - Do not call TF inference with empty grid
Browse files Browse the repository at this point in the history
  • Loading branch information
valsdav committed Mar 18, 2024
1 parent 28f76ff commit 6bb9279
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
5 changes: 4 additions & 1 deletion RecoTauTag/HLTProducers/src/L2TauTagNNProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,10 @@ void L2TauNNProducer::fillPatatracks(tensorflow::Tensor& cellGridMatrix,

std::vector<float> L2TauNNProducer::getTauScore(const tensorflow::Tensor& cellGridMatrix) {
std::vector<tensorflow::Tensor> pred_tensor;
tensorflow::run(L2cacheData_->session, {{inputTensorName_, cellGridMatrix}}, {outputTensorName_}, &pred_tensor);
// Check for empty input
if (cellGridMatrix.NumElements() != 0) {
tensorflow::run(L2cacheData_->session, {{inputTensorName_, cellGridMatrix}}, {outputTensorName_}, &pred_tensor);
}
const int nTau = cellGridMatrix.shape().dim_size(0);
std::vector<float> pred_vector(nTau);
for (int tau_idx = 0; tau_idx < nTau; ++tau_idx) {
Expand Down
14 changes: 13 additions & 1 deletion RecoTauTag/RecoTau/plugins/DeepTauId.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ class DeepTauId : public DeepTauIdBase<DeepTauIdWrapper> {
{"outer_all_dropout_4/Identity"},
&pred_vector);
}

return pred_vector.at(0);
}

Expand All @@ -547,8 +548,10 @@ class DeepTauId : public DeepTauIdBase<DeepTauIdWrapper> {
bool is_inner) {
if (debug_level >= 2) {
std::cout << "<DeepTauId::createConvFeatures (is_inner = " << is_inner << ")>:" << std::endl;
std::cout << "number of valid cells = " << grid.num_valid_cells() << std::endl;
}
tensorflow::Tensor& convTensor = *convTensor_.at(is_inner);

eGammaTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
tensorflow::DT_FLOAT,
tensorflow::TensorShape{
Expand Down Expand Up @@ -605,8 +608,17 @@ class DeepTauId : public DeepTauIdBase<DeepTauIdWrapper> {
}
}
}
tensorflow::Tensor predTensor;
//check if at least one input is there to
//avoid calling TF with empty grid #TODO understand why the grid is empty
if (idx != 0) {
predTensor = getPartialPredictions(is_inner);
} else {
if (debug_level >= 2) {
std::cout << " no valid cells found, skipped TF evaluation" << std::endl;
}
}

const auto predTensor = getPartialPredictions(is_inner);
idx = 0;
for (int eta = -grid.maxEtaIndex(); eta <= grid.maxEtaIndex(); ++eta) {
for (int phi = -grid.maxPhiIndex(); phi <= grid.maxPhiIndex(); ++phi) {
Expand Down

0 comments on commit 6bb9279

Please sign in to comment.