Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Template algorithms on track container frontend TrackContainer #3193

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "Acts/EventData/MultiTrajectoryHelpers.hpp"
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/EventData/TrackContainer.hpp"
#include "Acts/EventData/TrackContainerFrontendConcept.hpp"
#include "Acts/Utilities/Delegate.hpp"
#include "Acts/Utilities/Logger.hpp"

Expand Down Expand Up @@ -77,13 +78,11 @@ class GreedyAmbiguityResolution {
/// @param state An empty state object which is expected to be default constructed.
/// @param sourceLinkHash A functor to acquire a hash from a given source link.
/// @param sourceLinkEquality A functor to check equality of two source links.
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
typename source_link_equality_t>
void computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
State& state, source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const;
template <TrackContainerFrontend track_container_t,
typename source_link_hash_t, typename source_link_equality_t>
void computeInitialState(const track_container_t& tracks, State& state,
source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const;

/// Updates the state iteratively by evicting one track after the other until
/// the final state conditions are met.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@

namespace Acts {

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
template <TrackContainerFrontend track_container_t, typename source_link_hash_t,
typename source_link_equality_t>
void GreedyAmbiguityResolution::computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
State& state, source_link_hash_t&& sourceLinkHash,
const track_container_t& tracks, State& state,
source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const {
auto measurementIndexMap =
std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include "Acts/Definitions/Units.hpp"
#include "Acts/EventData/TrackContainer.hpp"
#include "Acts/EventData/TrackContainerFrontendConcept.hpp"
#include "Acts/EventData/TrackProxyConcept.hpp"
#include "Acts/Utilities/Delegate.hpp"
#include "Acts/Utilities/Logger.hpp"

Expand Down Expand Up @@ -116,16 +118,12 @@ class ScoreBasedAmbiguityResolution {
/// The optional cuts,weights and score are used to remove tracks that are not
/// good enough, based on some criteria. Users are free to add their own cuts
/// with the help of this struct.
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <TrackProxyConcept track_proxy_t>
struct OptionalCuts {
using OptionalFilter =
std::function<bool(const Acts::TrackProxy<track_container_t, traj_t,
holder_t, ReadOnly>&)>;
using OptionalFilter = std::function<bool(const track_proxy_t&)>;

using OptionalScoreModifier = std::function<void(
const Acts::TrackProxy<track_container_t, traj_t, holder_t, ReadOnly>&,
double&)>;
using OptionalScoreModifier =
std::function<void(const track_proxy_t&, double&)>;
std::vector<OptionalFilter> cuts = {};
std::vector<OptionalScoreModifier> weights = {};

Expand All @@ -146,12 +144,10 @@ class ScoreBasedAmbiguityResolution {
/// @param sourceLinkEquality is the equality function for the source links
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @return a vector of the initial state of the tracks
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
typename source_link_equality_t>
template <TrackContainerFrontend track_container_t,
typename source_link_hash_t, typename source_link_equality_t>
std::vector<std::vector<MeasurementInfo>> computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
source_link_hash_t sourceLinkHash,
const track_container_t& tracks, source_link_hash_t sourceLinkHash,
source_link_equality_t sourceLinkEquality,
std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors) const;

Expand All @@ -161,12 +157,11 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @param optionalCuts is the user defined optional cuts to be applied.
/// @return a vector of scores for each track
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <TrackContainerFrontend track_container_t>
std::vector<double> simpleScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

/// Compute the score of each track based on the ambiguity function.
Expand All @@ -175,12 +170,11 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @param optionalCuts is the user defined optional cuts to be applied.
/// @return a vector of scores for each track
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <TrackContainerFrontend track_container_t>
std::vector<double> ambiguityScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

/// Remove hits that are not good enough for each track and removes tracks
Expand All @@ -205,13 +199,12 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the map of detector id to trackFeatures for each track
/// @param optionalCuts is the optional cuts to be applied
/// @return a vector of IDs of the tracks we want to keep
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <TrackContainerFrontend track_container_t>
std::vector<int> solveAmbiguity(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<MeasurementInfo>>& measurementsPerTrack,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "Acts/AmbiguityResolution/ScoreBasedAmbiguityResolution.hpp"
#include "Acts/Definitions/Units.hpp"
#include "Acts/EventData/TrackContainerFrontendConcept.hpp"
#include "Acts/Utilities/VectorHelpers.hpp"

#include <unordered_map>
Expand All @@ -20,13 +21,11 @@ inline const Logger& ScoreBasedAmbiguityResolution::logger() const {
return *m_logger;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
template <TrackContainerFrontend track_container_t, typename source_link_hash_t,
typename source_link_equality_t>
std::vector<std::vector<ScoreBasedAmbiguityResolution::MeasurementInfo>>
ScoreBasedAmbiguityResolution::computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
source_link_hash_t sourceLinkHash,
const track_container_t& tracks, source_link_hash_t sourceLinkHash,
source_link_equality_t sourceLinkEquality,
std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors) const {
auto MeasurementIndexMap =
Expand Down Expand Up @@ -98,12 +97,11 @@ ScoreBasedAmbiguityResolution::computeInitialState(
return measurementsPerTrack;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>
template <TrackContainerFrontend track_container_t>
std::vector<double> Acts::ScoreBasedAmbiguityResolution::simpleScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
std::vector<double> trackScore;
trackScore.reserve(tracks.size());
Expand Down Expand Up @@ -248,12 +246,11 @@ std::vector<double> Acts::ScoreBasedAmbiguityResolution::simpleScore(
return trackScore;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>
template <TrackContainerFrontend track_container_t>
std::vector<double> Acts::ScoreBasedAmbiguityResolution::ambiguityScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
std::vector<double> trackScore;
trackScore.reserve(tracks.size());
Expand Down Expand Up @@ -425,13 +422,13 @@ std::vector<double> Acts::ScoreBasedAmbiguityResolution::ambiguityScore(

return trackScore;
}
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>

template <TrackContainerFrontend track_container_t>
std::vector<int> Acts::ScoreBasedAmbiguityResolution::solveAmbiguity(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<MeasurementInfo>>& measurementsPerTrack,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
ACTS_INFO("Number of tracks before Ambiguty Resolution: " << tracks.size());
// vector of trackFeaturesVectors. where each trackFeaturesVector contains the
Expand Down
33 changes: 33 additions & 0 deletions Core/include/Acts/EventData/TrackContainerFrontendConcept.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// This file is part of the Acts project.
//
// Copyright (C) 2024 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/EventData/TrackContainerBackendConcept.hpp"
#include "Acts/EventData/Types.hpp"

#include <concepts>

namespace Acts {

template <typename T>
concept TrackContainerFrontend = requires() {
andiwand marked this conversation as resolved.
Show resolved Hide resolved
{ T::ReadOnly } -> std::same_as<const bool &>;

requires std::same_as<typename T::IndexType, TrackIndexType>;

requires TrackContainerBackend<typename T::TrackContainerBackend>;
requires CommonMultiTrajectoryBackend<typename T::TrackStateContainerBackend>;

typename T::TrackProxy;
typename T::ConstTrackProxy;
typename T::TrackStateProxy;
typename T::ConstTrackStateProxy;
};

} // namespace Acts
3 changes: 2 additions & 1 deletion Core/include/Acts/EventData/TrackProxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "Acts/EventData/ParticleHypothesis.hpp"
#include "Acts/EventData/TrackContainerBackendConcept.hpp"
#include "Acts/EventData/TrackParameters.hpp"
#include "Acts/EventData/TrackProxyConcept.hpp"
#include "Acts/EventData/TrackStatePropMask.hpp"
#include "Acts/Utilities/HashedString.hpp"
#include "Acts/Utilities/UnitVectors.hpp"
Expand Down Expand Up @@ -682,7 +683,7 @@ class TrackProxy {
/// @tparam track_proxy_t the other track proxy's type
/// @param other The track proxy
/// @param copyTrackStates Copy the track state sequence from @p other
template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
void copyFrom(const track_proxy_t& other, bool copyTrackStates = true)
requires(!ReadOnly)
{
Expand Down
30 changes: 30 additions & 0 deletions Core/include/Acts/EventData/TrackProxyConcept.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// This file is part of the Acts project.
//
// Copyright (C) 2024 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/EventData/MultiTrajectoryBackendConcept.hpp"
#include "Acts/EventData/TrackContainerBackendConcept.hpp"
#include "Acts/EventData/Types.hpp"

#include <concepts>

namespace Acts {

template <typename T>
concept TrackProxyConcept = requires() {
andiwand marked this conversation as resolved.
Show resolved Hide resolved
{ T::ReadOnly } -> std::same_as<const bool &>;

requires TrackContainerBackend<typename T::Container>;

requires CommonMultiTrajectoryBackend<typename T::Trajectory>;

requires std::same_as<typename T::IndexType, TrackIndexType>;
};

} // namespace Acts
9 changes: 5 additions & 4 deletions Core/include/Acts/TrackFinding/TrackSelector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include "Acts/EventData/TrackProxyConcept.hpp"
#include "Acts/EventData/TrackStateType.hpp"
#include "Acts/Geometry/GeometryHierarchyMap.hpp"
#include "Acts/Geometry/GeometryIdentifier.hpp"
Expand Down Expand Up @@ -37,7 +38,7 @@ class TrackSelector {

boost::container::small_vector<CounterElement, 4> counters;

template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
bool isValidTrack(const track_proxy_t& track) const;

void addCounter(const std::vector<GeometryIdentifier>& identifiers,
Expand Down Expand Up @@ -227,7 +228,7 @@ class TrackSelector {
/// @tparam track_proxy_t is the type of the track proxy
/// @param track is the track proxy
/// @return true if the track is valid
template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
bool isValidTrack(const track_proxy_t& track) const;

/// Get readonly access to the config parameters
Expand Down Expand Up @@ -389,7 +390,7 @@ void TrackSelector::selectTracks(const input_tracks_t& inputTracks,
}
}

template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
bool TrackSelector::isValidTrack(const track_proxy_t& track) const {
auto checkMin = [](auto x, auto min) { return min <= x; };
auto checkMax = [](auto x, auto max) { return x <= max; };
Expand Down Expand Up @@ -478,7 +479,7 @@ inline TrackSelector::TrackSelector(
inline TrackSelector::TrackSelector(const Config& config)
: TrackSelector{EtaBinnedConfig{config}} {}

template <typename track_proxy_t>
template <TrackProxyConcept track_proxy_t>
bool TrackSelector::MeasurementCounter::isValidTrack(
const track_proxy_t& track) const {
// No hit cuts, accept everything
Expand Down
27 changes: 11 additions & 16 deletions Core/include/Acts/TrackFitting/GaussianSumFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,12 @@ struct GaussianSumFitter {

/// @brief The fit function for the Direct navigator
template <typename source_link_it_t, typename start_parameters_t,
typename track_container_t, template <typename> class holder_t>
TrackContainerFrontend track_container_t>
auto fit(source_link_it_t begin, source_link_it_t end,
const start_parameters_t& sParameters,
const GsfOptions<traj_t>& options,
const std::vector<const Surface*>& sSequence,
TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
const {
track_container_t& trackContainer) const {
// Check if we have the correct navigator
static_assert(
std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
Expand Down Expand Up @@ -148,12 +147,11 @@ struct GaussianSumFitter {

/// @brief The fit function for the standard navigator
template <typename source_link_it_t, typename start_parameters_t,
typename track_container_t, template <typename> class holder_t>
TrackContainerFrontend track_container_t>
auto fit(source_link_it_t begin, source_link_it_t end,
const start_parameters_t& sParameters,
const GsfOptions<traj_t>& options,
TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
const {
track_container_t& trackContainer) const {
// Check if we have the correct navigator
static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);

Expand Down Expand Up @@ -200,16 +198,13 @@ struct GaussianSumFitter {
/// first measurementSurface
template <typename source_link_it_t, typename start_parameters_t,
typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
typename track_container_t, template <typename> class holder_t>
Acts::Result<
typename TrackContainer<track_container_t, traj_t, holder_t>::TrackProxy>
fit_impl(source_link_it_t begin, source_link_it_t end,
const start_parameters_t& sParameters,
const GsfOptions<traj_t>& options,
const fwd_prop_initializer_t& fwdPropInitializer,
const bwd_prop_initializer_t& bwdPropInitializer,
TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
const {
TrackContainerFrontend track_container_t>
Acts::Result<typename track_container_t::TrackProxy> fit_impl(
source_link_it_t begin, source_link_it_t end,
const start_parameters_t& sParameters, const GsfOptions<traj_t>& options,
const fwd_prop_initializer_t& fwdPropInitializer,
const bwd_prop_initializer_t& bwdPropInitializer,
track_container_t& trackContainer) const {
// return or abort utility
auto return_error_or_abort = [&](auto error) {
if (options.abortOnError) {
Expand Down
Loading
Loading