Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Template algorithms on track container frontend TrackContainer #3193

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of the Acts project.
//
// Copyright (C) 2020 CERN for the benefit of the Acts project
// Copyright (C) 2023-2024 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand All @@ -11,6 +11,7 @@
#include "Acts/EventData/MultiTrajectoryHelpers.hpp"
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/EventData/TrackContainer.hpp"
#include "Acts/EventData/TrackContainerFrontendConcept.hpp"
#include "Acts/Utilities/Delegate.hpp"
#include "Acts/Utilities/Logger.hpp"

Expand Down Expand Up @@ -77,13 +78,11 @@ class GreedyAmbiguityResolution {
/// @param state An empty state object which is expected to be default constructed.
/// @param sourceLinkHash A functor to acquire a hash from a given source link.
/// @param sourceLinkEquality A functor to check equality of two source links.
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
typename source_link_equality_t>
void computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
State& state, source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const;
template <TrackContainerFrontend track_container_t,
typename source_link_hash_t, typename source_link_equality_t>
void computeInitialState(const track_container_t& tracks, State& state,
source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const;

/// Updates the state iteratively by evicting one track after the other until
/// the final state conditions are met.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

namespace Acts {

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
template <TrackContainerFrontend track_container_t, typename source_link_hash_t,
typename source_link_equality_t>
void GreedyAmbiguityResolution::computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
State& state, source_link_hash_t&& sourceLinkHash,
const track_container_t& tracks, State& state,
source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const {
auto measurementIndexMap =
std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include "Acts/Definitions/Units.hpp"
#include "Acts/EventData/TrackContainer.hpp"
#include "Acts/EventData/TrackContainerFrontendConcept.hpp"
#include "Acts/EventData/TrackProxyConcept.hpp"
#include "Acts/Utilities/Delegate.hpp"
#include "Acts/Utilities/Logger.hpp"

Expand Down Expand Up @@ -116,16 +118,12 @@ class ScoreBasedAmbiguityResolution {
/// The optional cuts,weights and score are used to remove tracks that are not
/// good enough, based on some criteria. Users are free to add their own cuts
/// with the help of this struct.
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <TrackProxyConcept track_proxy_t>
struct OptionalCuts {
using OptionalFilter =
std::function<bool(const Acts::TrackProxy<track_container_t, traj_t,
holder_t, ReadOnly>&)>;
using OptionalFilter = std::function<bool(const track_proxy_t&)>;

using OptionalScoreModifier = std::function<void(
const Acts::TrackProxy<track_container_t, traj_t, holder_t, ReadOnly>&,
double&)>;
using OptionalScoreModifier =
std::function<void(const track_proxy_t&, double&)>;
std::vector<OptionalFilter> cuts = {};
std::vector<OptionalScoreModifier> weights = {};

Expand All @@ -146,12 +144,10 @@ class ScoreBasedAmbiguityResolution {
/// @param sourceLinkEquality is the equality function for the source links
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @return a vector of the initial state of the tracks
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
typename source_link_equality_t>
template <TrackContainerFrontend track_container_t,
typename source_link_hash_t, typename source_link_equality_t>
std::vector<std::vector<MeasurementInfo>> computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
source_link_hash_t sourceLinkHash,
const track_container_t& tracks, source_link_hash_t sourceLinkHash,
source_link_equality_t sourceLinkEquality,
std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors) const;

Expand All @@ -161,12 +157,11 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @param optionalCuts is the user defined optional cuts to be applied.
/// @return a vector of scores for each track
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <TrackContainerFrontend track_container_t>
std::vector<double> simpleScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

/// Compute the score of each track based on the ambiguity function.
Expand All @@ -175,12 +170,11 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @param optionalCuts is the user defined optional cuts to be applied.
/// @return a vector of scores for each track
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <TrackContainerFrontend track_container_t>
std::vector<double> ambiguityScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

/// Remove hits that are not good enough for each track and removes tracks
Expand All @@ -205,13 +199,12 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the map of detector id to trackFeatures for each track
/// @param optionalCuts is the optional cuts to be applied
/// @return a vector of IDs of the tracks we want to keep
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <TrackContainerFrontend track_container_t>
std::vector<int> solveAmbiguity(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<MeasurementInfo>>& measurementsPerTrack,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "Acts/AmbiguityResolution/ScoreBasedAmbiguityResolution.hpp"
#include "Acts/Definitions/Units.hpp"
#include "Acts/EventData/TrackContainerFrontendConcept.hpp"
#include "Acts/Utilities/VectorHelpers.hpp"

#include <unordered_map>
Expand All @@ -20,13 +21,11 @@ inline const Logger& ScoreBasedAmbiguityResolution::logger() const {
return *m_logger;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
template <TrackContainerFrontend track_container_t, typename source_link_hash_t,
typename source_link_equality_t>
std::vector<std::vector<ScoreBasedAmbiguityResolution::MeasurementInfo>>
ScoreBasedAmbiguityResolution::computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
source_link_hash_t sourceLinkHash,
const track_container_t& tracks, source_link_hash_t sourceLinkHash,
source_link_equality_t sourceLinkEquality,
std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors) const {
auto MeasurementIndexMap =
Expand Down Expand Up @@ -98,12 +97,11 @@ ScoreBasedAmbiguityResolution::computeInitialState(
return measurementsPerTrack;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>
template <TrackContainerFrontend track_container_t>
std::vector<double> Acts::ScoreBasedAmbiguityResolution::simpleScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
std::vector<double> trackScore;
trackScore.reserve(tracks.size());
Expand Down Expand Up @@ -248,12 +246,11 @@ std::vector<double> Acts::ScoreBasedAmbiguityResolution::simpleScore(
return trackScore;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>
template <TrackContainerFrontend track_container_t>
std::vector<double> Acts::ScoreBasedAmbiguityResolution::ambiguityScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
std::vector<double> trackScore;
trackScore.reserve(tracks.size());
Expand Down Expand Up @@ -425,13 +422,13 @@ std::vector<double> Acts::ScoreBasedAmbiguityResolution::ambiguityScore(

return trackScore;
}
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>

template <TrackContainerFrontend track_container_t>
std::vector<int> Acts::ScoreBasedAmbiguityResolution::solveAmbiguity(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<MeasurementInfo>>& measurementsPerTrack,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
ACTS_INFO("Number of tracks before Ambiguty Resolution: " << tracks.size());
// vector of trackFeaturesVectors. where each trackFeaturesVector contains the
Expand Down
33 changes: 33 additions & 0 deletions Core/include/Acts/EventData/TrackContainerFrontendConcept.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// This file is part of the Acts project.
//
// Copyright (C) 2024 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/EventData/TrackContainerBackendConcept.hpp"
#include "Acts/EventData/Types.hpp"

#include <concepts>

namespace Acts {

template <typename T>
concept TrackContainerFrontend = requires() {
andiwand marked this conversation as resolved.
Show resolved Hide resolved
typename T::ReadOnly;

requires std::same_as<typename T::IndexType, TrackIndexType>;

requires TrackContainerBackend<typename T::TrackContainerBackend>;
requires CommonMultiTrajectoryBackend<typename T::TrackStateContainerBackend>;

typename T::TrackProxy;
typename T::ConstTrackProxy;
typename T::TrackStateProxy;
typename T::ConstTrackStateProxy;
};

} // namespace Acts
13 changes: 4 additions & 9 deletions Core/include/Acts/EventData/TrackHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,17 @@

#include "Acts/EventData/MultiTrajectory.hpp"
#include "Acts/EventData/TrackContainer.hpp"
#include "Acts/EventData/TrackProxyConcept.hpp"

namespace Acts {

/// Helper function to calculate a number of track level quantities and store
/// them on the track itself
/// @note The input track needs to be mutable, so @c ReadOnly=false
/// @tparam track_container_t the track container backend
/// @tparam track_state_container_t the track state container backend
/// @tparam holder_t the holder type for the track container backends
/// @tparam track_proxy_t The type of the track proxy
/// @param track A mutable track proxy to operate on
template <typename track_container_t, typename track_state_container_t,
template <typename> class holder_t>
void calculateTrackQuantities(
Acts::TrackProxy<track_container_t, track_state_container_t, holder_t,
false>
track) {
template <TrackProxyConcept track_proxy_t>
void calculateTrackQuantities(track_proxy_t track) {
track.chi2() = 0;
track.nDoF() = 0;

Expand Down
3 changes: 2 additions & 1 deletion Core/include/Acts/EventData/TrackProxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "Acts/EventData/ParticleHypothesis.hpp"
#include "Acts/EventData/TrackContainerBackendConcept.hpp"
#include "Acts/EventData/TrackParameters.hpp"
#include "Acts/EventData/TrackProxyConcept.hpp"
#include "Acts/EventData/TrackStatePropMask.hpp"
#include "Acts/Utilities/HashedString.hpp"
#include "Acts/Utilities/UnitVectors.hpp"
Expand Down Expand Up @@ -668,7 +669,7 @@ class TrackProxy {
/// @tparam track_proxy_t the other track proxy's type
/// @param other The track proxy
/// @param copyTrackStates Copy the track state sequence from @p other
template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
void copyFrom(const track_proxy_t& other,
bool copyTrackStates = true) requires(!ReadOnly) {
// @TODO: Add constraint on which track proxies are allowed,
Expand Down
30 changes: 30 additions & 0 deletions Core/include/Acts/EventData/TrackProxyConcept.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// This file is part of the Acts project.
//
// Copyright (C) 2024 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/EventData/MultiTrajectoryBackendConcept.hpp"
#include "Acts/EventData/TrackContainerBackendConcept.hpp"
#include "Acts/EventData/Types.hpp"

#include <concepts>

namespace Acts {

template <typename T>
concept TrackProxyConcept = requires() {
andiwand marked this conversation as resolved.
Show resolved Hide resolved
typename T::ReadOnly;
andiwand marked this conversation as resolved.
Show resolved Hide resolved

requires TrackContainerBackend<typename T::Container>;

requires CommonMultiTrajectoryBackend<typename T::Trajectory>;

requires std::same_as<typename T::IndexType, TrackIndexType>;
};

} // namespace Acts
9 changes: 5 additions & 4 deletions Core/include/Acts/TrackFinding/TrackSelector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include "Acts/EventData/TrackProxyConcept.hpp"
#include "Acts/EventData/TrackStateType.hpp"
#include "Acts/Geometry/GeometryHierarchyMap.hpp"
#include "Acts/Geometry/GeometryIdentifier.hpp"
Expand Down Expand Up @@ -37,7 +38,7 @@ class TrackSelector {

boost::container::small_vector<CounterElement, 4> counters;

template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
bool isValidTrack(const track_proxy_t& track) const;

void addCounter(const std::vector<GeometryIdentifier>& identifiers,
Expand Down Expand Up @@ -226,7 +227,7 @@ class TrackSelector {
/// @tparam track_proxy_t is the type of the track proxy
/// @param track is the track proxy
/// @return true if the track is valid
template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
bool isValidTrack(const track_proxy_t& track) const;

/// Get readonly access to the config parameters
Expand Down Expand Up @@ -387,7 +388,7 @@ void TrackSelector::selectTracks(const input_tracks_t& inputTracks,
}
}

template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
bool TrackSelector::isValidTrack(const track_proxy_t& track) const {
auto checkMin = [](auto x, auto min) { return min <= x; };
auto checkMax = [](auto x, auto max) { return x <= max; };
Expand Down Expand Up @@ -474,7 +475,7 @@ inline TrackSelector::TrackSelector(
inline TrackSelector::TrackSelector(const Config& config)
: TrackSelector{EtaBinnedConfig{config}} {}

template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
bool TrackSelector::MeasurementCounter::isValidTrack(
const track_proxy_t& track) const {
// No hit cuts, accept everything
Expand Down
Loading
Loading