diff --git a/RecoTracker/DisplacedRegionalTracking/plugins/DisplacedRegionSeedingVertexProducer.cc b/RecoTracker/DisplacedRegionalTracking/plugins/DisplacedRegionSeedingVertexProducer.cc index 941cb055d9c8b..d45d904b9ea67 100644 --- a/RecoTracker/DisplacedRegionalTracking/plugins/DisplacedRegionSeedingVertexProducer.cc +++ b/RecoTracker/DisplacedRegionalTracking/plugins/DisplacedRegionSeedingVertexProducer.cc @@ -1,4 +1,3 @@ -#include #include #include #include @@ -46,6 +45,7 @@ class DisplacedRegionSeedingVertexProducer : public edm::global::EDProducer<> { const double nearThreshold_; const double farThreshold_; const double discriminatorCut_; + const unsigned int maxPseudoROIs_; const vector input_names_; const vector output_names_; @@ -63,6 +63,7 @@ DisplacedRegionSeedingVertexProducer::DisplacedRegionSeedingVertexProducer(const nearThreshold_(cfg.getParameter("nearThreshold")), farThreshold_(cfg.getParameter("farThreshold")), discriminatorCut_(cfg.getParameter("discriminatorCut")), + maxPseudoROIs_(cfg.getParameter("maxPseudoROIs")), input_names_(cfg.getParameter >("input_names")), output_names_(cfg.getParameter >("output_names")), beamSpotToken_(consumes(cfg.getParameter("beamSpot"))), @@ -93,16 +94,24 @@ void DisplacedRegionSeedingVertexProducer::produce(edm::StreamID streamID, const auto &trackClusters = event.get(trackClustersToken_); // Initialize distances. - list pseudoROIs; - list distances; + vector pseudoROIs; + pseudoROIs.reserve(std::min(maxPseudoROIs_, trackClusters.size())); + vector distances; const double minTrackClusterRadius = minRadius_ - rParam_; for (unsigned i = 0; i < trackClusters.size(); i++) { const reco::VertexCompositeCandidate &trackCluster = trackClusters[i]; const math::XYZVector x(trackCluster.vertex()); if (minRadius_ < 0.0 || minTrackClusterRadius < 0.0 || (x - bs).rho() > minTrackClusterRadius) pseudoROIs.emplace_back(&trackClusters.at(i), rParam_); + if (pseudoROIs.size() == maxPseudoROIs_) { + edm::LogWarning("DisplacedRegionSeedingVertexProducer") + << "Truncated list of pseudoROIs at " << maxPseudoROIs_ << " out of " << trackClusters.size() + << " possible track clusters."; + break; + } } if (pseudoROIs.size() > 1) { + distances.reserve(pseudoROIs.size() * std::max(1.0, pseudoROIs.size() * 0.05)); DisplacedVertexClusterItr secondToLast = pseudoROIs.end(); secondToLast--; for (DisplacedVertexClusterItr i = pseudoROIs.begin(); i != secondToLast; i++) { @@ -119,11 +128,25 @@ void DisplacedRegionSeedingVertexProducer::produce(edm::StreamID streamID, } } + auto itBegin = distances.begin(); + auto itLast = distances.end(); + // Do clustering. - while (!distances.empty()) { - const auto comp = [](const Distance &a, const Distance &b) { return a.distance2() <= b.distance2(); }; - distances.sort(comp); - DistanceItr dBest = distances.begin(); + while (itBegin != itLast) { + //find the lowest distance. Lots of repeatitive calculations done here + //as from loop iteration to loop iteration only sqrt(distances.size()) distances + //need to be recomputed (those involving best_i + //but this is much better than sorting distances.. + DistanceItr dBest = itBegin; + double distanceBest = dBest->distance2(); + + for (auto i = itBegin; i != itLast; i++) { + if (distanceBest > i->distance2()) { + dBest = i; + distanceBest = i->distance2(); + } + } + if (dBest->distance2() > rParam_ * rParam_) break; @@ -133,11 +156,12 @@ void DisplacedRegionSeedingVertexProducer::produce(edm::StreamID streamID, const auto distancePred = [](const Distance &a) { return (!a.entities().first->valid() || !a.entities().second->valid()); }; - const auto pseudoROIPred = [](const DisplacedVertexCluster &a) { return !a.valid(); }; - distances.remove_if(distancePred); - pseudoROIs.remove_if(pseudoROIPred); + itLast = std::remove_if(itBegin, itLast, distancePred); } + const auto pseudoROIPred = [](const DisplacedVertexCluster &a) { return !a.valid(); }; + auto remove_invalid = std::remove_if(pseudoROIs.begin(), pseudoROIs.end(), pseudoROIPred); + // Remove invalid ROIs. const auto roiPred = [&](const DisplacedVertexCluster &roi) { if (!roi.valid()) @@ -150,7 +174,7 @@ void DisplacedRegionSeedingVertexProducer::produce(edm::StreamID streamID, return true; return false; }; - pseudoROIs.remove_if(roiPred); + auto remove_pred = std::remove_if(pseudoROIs.begin(), remove_invalid, roiPred); auto nearRegionsOfInterest = make_unique >(); auto farRegionsOfInterest = make_unique >(); @@ -158,7 +182,8 @@ void DisplacedRegionSeedingVertexProducer::produce(edm::StreamID streamID, constexpr std::array errorA{{1.0, 0.0, 1.0, 0.0, 0.0, 1.0}}; static const reco::Vertex::Error errorRegion(errorA.begin(), errorA.end(), true, true); - for (const auto &roi : pseudoROIs) { + for (auto it = pseudoROIs.begin(); it != remove_pred; ++it) { + auto const &roi = *it; const auto &x(roi.centerOfMass()); if ((x - bs).rho() < nearThreshold_) nearRegionsOfInterest->emplace_back(reco::Vertex::Point(roi.centerOfMass()), errorRegion); @@ -180,6 +205,7 @@ void DisplacedRegionSeedingVertexProducer::fillDescriptions(edm::ConfigurationDe desc.add("discriminatorCut", -1.0); desc.add >("input_names", {"phi_0", "phi_1"}); desc.add >("output_names", {"model_5/activation_10/Softmax"}); + desc.add("maxPseudoROIs", 10000); desc.addUntracked("nThreads", 1); desc.add( "graph_path", diff --git a/RecoTracker/DisplacedRegionalTracking/plugins/DisplacedVertexCluster.h b/RecoTracker/DisplacedRegionalTracking/plugins/DisplacedVertexCluster.h index c66a0c44d00ae..be35be235ef64 100644 --- a/RecoTracker/DisplacedRegionalTracking/plugins/DisplacedVertexCluster.h +++ b/RecoTracker/DisplacedRegionalTracking/plugins/DisplacedVertexCluster.h @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include "DataFormats/Common/interface/View.h" @@ -11,7 +11,7 @@ #include "DataFormats/Math/interface/Vector3D.h" class DisplacedVertexCluster; -typedef std::list::iterator DisplacedVertexClusterItr; +typedef std::vector::iterator DisplacedVertexClusterItr; class DisplacedVertexCluster { public: @@ -53,7 +53,7 @@ class DisplacedVertexCluster { std::pair entities_; }; - typedef std::list::iterator DistanceItr; + typedef std::vector::iterator DistanceItr; private: bool valid_;