Skip to content

Commit

Permalink
Merge pull request #45147 from mmusich/mm_fix_protect_L2TauTagNNProdu…
Browse files Browse the repository at this point in the history
…cerAlpaka_empty_inputs

`L2TauTagNNProducerAlpaka`: do not call TF inference with empty grid
  • Loading branch information
cmsbuild authored Jun 6, 2024
2 parents ebe68aa + 322acb9 commit a91b3f7
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions RecoTauTag/HLTProducers/src/L2TauTagNNProducerAlpaka.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,15 +731,19 @@ void L2TauNNProducerAlpaka::fillPatatracks(tensorflow::Tensor& cellGridMatrix,
}

std::vector<float> L2TauNNProducerAlpaka::getTauScore(const tensorflow::Tensor& cellGridMatrix) {
std::vector<tensorflow::Tensor> pred_tensor;
tensorflow::run(L2cacheData_->session, {{inputTensorName_, cellGridMatrix}}, {outputTensorName_}, &pred_tensor);
const int nTau = cellGridMatrix.shape().dim_size(0);
std::vector<float> pred_vector(nTau);
for (int tau_idx = 0; tau_idx < nTau; ++tau_idx) {
pred_vector[tau_idx] = pred_tensor[0].matrix<float>()(tau_idx, 0);
}
if (nTau == 0) {
return std::vector<float>();
} else {
std::vector<tensorflow::Tensor> pred_tensor;
tensorflow::run(L2cacheData_->session, {{inputTensorName_, cellGridMatrix}}, {outputTensorName_}, &pred_tensor);
std::vector<float> pred_vector(nTau);
for (int tau_idx = 0; tau_idx < nTau; ++tau_idx) {
pred_vector[tau_idx] = pred_tensor[0].matrix<float>()(tau_idx, 0);
}

return pred_vector;
return pred_vector;
}
}

void L2TauNNProducerAlpaka::produce(edm::Event& event, const edm::EventSetup& eventsetup) {
Expand Down

0 comments on commit a91b3f7

Please sign in to comment.