Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an indexed sort for expensive comparisons in MuonShowerInformationFiller #40675

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use an indexed sort for expensive comparison functions
  • Loading branch information
Dan Riley authored and Dan Riley committed Feb 2, 2023
commit 569e19c26ef873a5f37dd9baca9057e37682a717
55 changes: 33 additions & 22 deletions RecoMuon/MuonIdentification/interface/MuonShowerInformationFiller.h
Original file line number Diff line number Diff line change
@@ -108,14 +108,23 @@ class MuonShowerInformationFiller {
GlobalPoint crossingPoint(const GlobalPoint&, const GlobalPoint&, const Disk&) const;
std::vector<const GeomDet*> dtPositionToDets(const GlobalPoint&) const;
std::vector<const GeomDet*> cscPositionToDets(const GlobalPoint&) const;
MuonRecHitContainer findPerpCluster(MuonRecHitContainer& muonRecHits) const;
TransientTrackingRecHit::ConstRecHitContainer findThetaCluster(TransientTrackingRecHit::ConstRecHitContainer&,
MuonRecHitContainer findPerpCluster(const MuonRecHitContainer& muonRecHits) const;
TransientTrackingRecHit::ConstRecHitContainer findThetaCluster(const TransientTrackingRecHit::ConstRecHitContainer&,
const GlobalPoint&) const;
TransientTrackingRecHit::ConstRecHitContainer hitsFromSegments(const GeomDet*,
edm::Handle<DTRecSegment4DCollection>,
edm::Handle<CSCSegmentCollection>) const;
std::vector<const GeomDet*> getCompatibleDets(const reco::Track&) const;

struct MagTransform {
MagTransform(const GlobalPoint& point) : thePoint(point) {}
double operator()(const GlobalPoint& p) const { return (p - thePoint).mag(); }
double operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& hit) const {
return (hit->globalPosition() - thePoint).mag();
}
GlobalPoint thePoint;
};

struct LessMag {
LessMag(const GlobalPoint& point) : thePoint(point) {}
bool operator()(const GlobalPoint& lhs, const GlobalPoint& rhs) const {
@@ -128,22 +137,10 @@ class MuonShowerInformationFiller {
GlobalPoint thePoint;
};

struct LessDPhi {
LessDPhi(const GlobalPoint& point) : thePoint(point) {}
bool operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& lhs,
const MuonTransientTrackingRecHit::MuonRecHitPointer& rhs) const {
return deltaPhi(lhs->globalPosition().barePhi(), thePoint.barePhi()) <
deltaPhi(rhs->globalPosition().barePhi(), thePoint.barePhi());
}
GlobalPoint thePoint;
};

struct AbsLessDPhi {
AbsLessDPhi(const GlobalPoint& point) : thePoint(point) {}
bool operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& lhs,
const MuonTransientTrackingRecHit::MuonRecHitPointer& rhs) const {
return (fabs(deltaPhi(lhs->globalPosition().barePhi(), thePoint.barePhi())) <
fabs(deltaPhi(rhs->globalPosition().barePhi(), thePoint.barePhi())));
struct AbsDThetaTransform {
AbsDThetaTransform(const GlobalPoint& point) : thePoint(point) {}
double operator()(const TransientTrackingRecHit::ConstRecHitPointer& hit) const {
return std::fabs(hit->globalPosition().bareTheta() - thePoint.bareTheta());
}
GlobalPoint thePoint;
};
@@ -158,6 +155,12 @@ class MuonShowerInformationFiller {
GlobalPoint thePoint;
};

struct PhiTransform {
double operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& hit) const {
return hit->globalPosition().barePhi();
}
};

struct LessPhi {
LessPhi() : thePoint(0, 0, 0) {}
bool operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& lhs,
@@ -167,22 +170,30 @@ class MuonShowerInformationFiller {
GlobalPoint thePoint;
};

struct PerpTransform {
double operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& hit) const {
return hit->globalPosition().perp();
}
};

struct LessPerp {
LessPerp() : thePoint(0, 0, 0) {}
bool operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& lhs,
const MuonTransientTrackingRecHit::MuonRecHitPointer& rhs) const {
return (lhs->globalPosition().perp() < rhs->globalPosition().perp());
}
GlobalPoint thePoint;
};

struct AbsMagTransform {
double operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& hit) const {
return hit->globalPosition().mag();
}
};

struct LessAbsMag {
LessAbsMag() : thePoint(0, 0, 0) {}
bool operator()(const MuonTransientTrackingRecHit::MuonRecHitPointer& lhs,
const MuonTransientTrackingRecHit::MuonRecHitPointer& rhs) const {
return (lhs->globalPosition().mag() < rhs->globalPosition().mag());
}
GlobalPoint thePoint;
};

std::string category_;
80 changes: 61 additions & 19 deletions RecoMuon/MuonIdentification/src/MuonShowerInformationFiller.cc
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
#include <memory>
#include <algorithm>
#include <iostream>
#include <numeric>

// user include files
#include "FWCore/Framework/interface/Event.h"
@@ -62,6 +63,47 @@
using namespace std;
using namespace edm;

namespace {
//
// Sort function optimized for the case where the desired comparison
// is extravagantly expensive. To use, replace
//
// auto LessPhi = [](const auto& lhs, const auto& rhs) {
// return (lhs->globalPosition().barePhi() < rhs->globalPosition().barePhi())
// }
// stable_sort(muonRecHits[stat].begin(), muonRecHits[stat].end(), LessPhi);
//
// with
//
// auto calcPhi = [](const auto& hit) { return hit->globalPosition().barePhi(); };
// muonRecHits[stat] = sort_all_indexed(muonRecHits[stat], std::less(), calcPhi);
//
// This calculates the values to be compared once in O(n) time, instead of
// O(n*log(n)) times in the comparison function.
//
template <typename RandomAccessSequence, typename Predicate, typename Transform>
RandomAccessSequence sort_all_indexed(const RandomAccessSequence& s, Predicate p, Transform t) {
std::vector<size_t> idx(s.size());
std::iota(idx.begin(), idx.end(), 0);

// fill the cache of the values to be sorted
using valueCacheType = std::invoke_result_t<decltype(t), typename RandomAccessSequence::value_type>;
std::vector<valueCacheType> valcache(s.size());
std::transform(s.begin(), s.end(), valcache.begin(), t);

// sort the indices of the value cache
auto idxComp = [&valcache, p](auto i1, auto i2) { return p(valcache[i1], valcache[i2]); };
std::stable_sort(idx.begin(), idx.end(), idxComp);

// fill the sorted output vector
RandomAccessSequence r(s.size());
for (size_t i = 0; i < s.size(); ++i) {
r[i] = s[idx[i]];
}
return r;
}
} // namespace

//
// Constructor
//
@@ -273,18 +315,18 @@ TransientTrackingRecHit::ConstRecHitContainer MuonShowerInformationFiller::hitsF
//

TransientTrackingRecHit::ConstRecHitContainer MuonShowerInformationFiller::findThetaCluster(
TransientTrackingRecHit::ConstRecHitContainer& muonRecHits, const GlobalPoint& refpoint) const {
if (muonRecHits.empty())
return muonRecHits;
const TransientTrackingRecHit::ConstRecHitContainer& muonRecHitsIn, const GlobalPoint& refpoint) const {
if (muonRecHitsIn.empty())
return muonRecHitsIn;

//clustering step by theta
float step = 0.05;
TransientTrackingRecHit::ConstRecHitContainer result;

stable_sort(muonRecHits.begin(), muonRecHits.end(), AbsLessDTheta(refpoint));
auto muonRecHitsTmp = sort_all_indexed(muonRecHitsIn, std::less(), AbsDThetaTransform(refpoint));

for (TransientTrackingRecHit::ConstRecHitContainer::const_iterator ihit = muonRecHits.begin();
ihit != muonRecHits.end() - 1;
for (TransientTrackingRecHit::ConstRecHitContainer::const_iterator ihit = muonRecHitsTmp.begin();
ihit != muonRecHitsTmp.end() - 1;
++ihit) {
if (fabs((*(ihit + 1))->globalPosition().theta() - (*ihit)->globalPosition().theta()) < step) {
result.push_back(*ihit);
@@ -300,24 +342,24 @@ TransientTrackingRecHit::ConstRecHitContainer MuonShowerInformationFiller::findT
//Used to treat overlap region
//
MuonTransientTrackingRecHit::MuonRecHitContainer MuonShowerInformationFiller::findPerpCluster(
MuonTransientTrackingRecHit::MuonRecHitContainer& muonRecHits) const {
if (muonRecHits.empty())
return muonRecHits;
const MuonTransientTrackingRecHit::MuonRecHitContainer& muonRecHitsIn) const {
if (muonRecHitsIn.empty())
return muonRecHitsIn;

stable_sort(muonRecHits.begin(), muonRecHits.end(), LessPerp());
auto muonRecHitsTmp = sort_all_indexed(muonRecHitsIn, std::less(), PerpTransform());

MuonTransientTrackingRecHit::MuonRecHitContainer::const_iterator seedhit =
min_element(muonRecHits.begin(), muonRecHits.end(), LessPerp());
min_element(muonRecHitsTmp.begin(), muonRecHitsTmp.end(), LessPerp());

MuonTransientTrackingRecHit::MuonRecHitContainer::const_iterator ihigh = seedhit;
MuonTransientTrackingRecHit::MuonRecHitContainer::const_iterator ilow = seedhit;

float step = 0.1;
while (ihigh != muonRecHits.end() - 1 &&
while (ihigh != muonRecHitsTmp.end() - 1 &&
(fabs((*(ihigh + 1))->globalPosition().perp() - (*ihigh)->globalPosition().perp()) < step)) {
ihigh++;
}
while (ilow != muonRecHits.begin() &&
while (ilow != muonRecHitsTmp.begin() &&
(fabs((*ilow)->globalPosition().perp() - (*(ilow - 1))->globalPosition().perp()) < step)) {
ilow--;
}
@@ -358,7 +400,7 @@ vector<const GeomDet*> MuonShowerInformationFiller::getCompatibleDets(const reco
allCrossingPoints.push_back(xPoint);
}

stable_sort(allCrossingPoints.begin(), allCrossingPoints.end(), LessMag(innerPos));
allCrossingPoints = sort_all_indexed(allCrossingPoints, std::less(), MagTransform(innerPos));

vector<const GeomDet*> tempDT;

@@ -382,7 +424,7 @@ vector<const GeomDet*> MuonShowerInformationFiller::getCompatibleDets(const reco
(!(xPoint.y() == 0 && xPoint.x() == 0 && xPoint.z() == 0)))
allCrossingPoints.push_back(xPoint);
}
stable_sort(allCrossingPoints.begin(), allCrossingPoints.end(), LessMag(innerPos));
allCrossingPoints = sort_all_indexed(allCrossingPoints, std::less(), MagTransform(innerPos));

vector<const GeomDet*> tempCSC;
for (vector<GlobalPoint>::const_iterator ipos = allCrossingPoints.begin(); ipos != allCrossingPoints.end(); ++ipos) {
@@ -836,7 +878,7 @@ void MuonShowerInformationFiller::fillHitsByStation(const reco::Muon& muon) {

// Cluster seeds by global position phi. Best cluster is chosen to give greatest dphi
// Sort by phi (complexity = NLogN with enough memory, or = NLog^2N for insufficient mem)
stable_sort(muonRecHits[stat].begin(), muonRecHits[stat].end(), LessPhi());
muonRecHits[stat] = sort_all_indexed(muonRecHits[stat], std::less(), PhiTransform());

// Search for gaps (complexity = N)
std::vector<size_t> clUppers;
@@ -906,23 +948,23 @@ void MuonShowerInformationFiller::fillHitsByStation(const reco::Muon& muon) {

//fill showerTs
if (!muonRecHitsPhiBest.empty()) {
muonRecHits[stat] = muonRecHitsPhiBest;
stable_sort(muonRecHits[stat].begin(), muonRecHits[stat].end(), LessAbsMag());
muonRecHits[stat] = sort_all_indexed(muonRecHitsPhiBest, std::less(), AbsMagTransform());
GlobalPoint refpoint = muonRecHits[stat].front()->globalPosition();
theStationShowerTSize.at(stat) = refpoint.mag() * dphimax;
}

//for theta
if (!muonCorrelatedHits.at(stat).empty()) {
float dthetamax = 0;
auto muonCorrelatedHitsTmp{muonCorrelatedHits.at(stat)}; // findThetaCluster() sorts its argument
for (TransientTrackingRecHit::ConstRecHitContainer::const_iterator iseed = muonCorrelatedHits.at(stat).begin();
iseed != muonCorrelatedHits.at(stat).end();
++iseed) {
if (!(*iseed)->isValid())
continue;
GlobalPoint refpoint = (*iseed)->globalPosition(); //starting from the one with smallest value of phi
muonRecHitsThetaTemp.clear();
muonRecHitsThetaTemp = findThetaCluster(muonCorrelatedHits.at(stat), refpoint);
muonRecHitsThetaTemp = findThetaCluster(muonCorrelatedHitsTmp, refpoint);
if (muonRecHitsThetaTemp.size() > 1) {
float dtheta = fabs((float)muonRecHitsThetaTemp.back()->globalPosition().theta() -
(float)muonRecHitsThetaTemp.front()->globalPosition().theta());