Skip to content

Commit

Permalink
refactor!: Template on track container frontend TrackContainer
Browse files Browse the repository at this point in the history
  • Loading branch information
andiwand committed Jul 23, 2024
1 parent c921214 commit 3c050df
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 145 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of the Acts project.
//
// Copyright (C) 2020 CERN for the benefit of the Acts project
// Copyright (C) 2023-2024 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand Down Expand Up @@ -77,13 +77,11 @@ class GreedyAmbiguityResolution {
/// @param state An empty state object which is expected to be default constructed.
/// @param sourceLinkHash A functor to acquire a hash from a given source link.
/// @param sourceLinkEquality A functor to check equality of two source links.
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
template <typename track_container_t, typename source_link_hash_t,
typename source_link_equality_t>
void computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
State& state, source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const;
void computeInitialState(const track_container_t& tracks, State& state,
source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const;

/// Updates the state iteratively by evicting one track after the other until
/// the final state conditions are met.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

namespace Acts {

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
template <typename track_container_t, typename source_link_hash_t,
typename source_link_equality_t>
void GreedyAmbiguityResolution::computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
State& state, source_link_hash_t&& sourceLinkHash,
const track_container_t& tracks, State& state,
source_link_hash_t&& sourceLinkHash,
source_link_equality_t&& sourceLinkEquality) const {
auto measurementIndexMap =
std::unordered_map<SourceLink, std::size_t, source_link_hash_t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,12 @@ class ScoreBasedAmbiguityResolution {
/// The optional cuts,weights and score are used to remove tracks that are not
/// good enough, based on some criteria. Users are free to add their own cuts
/// with the help of this struct.
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <typename track_proxy_t>
struct OptionalCuts {
using OptionalFilter =
std::function<bool(const Acts::TrackProxy<track_container_t, traj_t,
holder_t, ReadOnly>&)>;
using OptionalFilter = std::function<bool(const track_proxy_t&)>;

using OptionalScoreModifier = std::function<void(
const Acts::TrackProxy<track_container_t, traj_t, holder_t, ReadOnly>&,
double&)>;
using OptionalScoreModifier =
std::function<void(const track_proxy_t&, double&)>;
std::vector<OptionalFilter> cuts = {};
std::vector<OptionalScoreModifier> weights = {};

Expand All @@ -146,12 +142,10 @@ class ScoreBasedAmbiguityResolution {
/// @param sourceLinkEquality is the equality function for the source links
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @return a vector of the initial state of the tracks
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
template <typename track_container_t, typename source_link_hash_t,
typename source_link_equality_t>
std::vector<std::vector<MeasurementInfo>> computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
source_link_hash_t sourceLinkHash,
const track_container_t& tracks, source_link_hash_t sourceLinkHash,
source_link_equality_t sourceLinkEquality,
std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors) const;

Expand All @@ -161,12 +155,11 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @param optionalCuts is the user defined optional cuts to be applied.
/// @return a vector of scores for each track
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <typename track_container_t>
std::vector<double> simpleScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

/// Compute the score of each track based on the ambiguity function.
Expand All @@ -175,12 +168,11 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the trackFeatures map from detector ID to trackFeatures
/// @param optionalCuts is the user defined optional cuts to be applied.
/// @return a vector of scores for each track
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <typename track_container_t>
std::vector<double> ambiguityScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

/// Remove hits that are not good enough for each track and removes tracks
Expand All @@ -205,13 +197,12 @@ class ScoreBasedAmbiguityResolution {
/// @param trackFeaturesVectors is the map of detector id to trackFeatures for each track
/// @param optionalCuts is the optional cuts to be applied
/// @return a vector of IDs of the tracks we want to keep
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly = true>
template <typename track_container_t>
std::vector<int> solveAmbiguity(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<MeasurementInfo>>& measurementsPerTrack,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts = {}) const;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ inline const Logger& ScoreBasedAmbiguityResolution::logger() const {
return *m_logger;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, typename source_link_hash_t,
template <typename track_container_t, typename source_link_hash_t,
typename source_link_equality_t>
std::vector<std::vector<ScoreBasedAmbiguityResolution::MeasurementInfo>>
ScoreBasedAmbiguityResolution::computeInitialState(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
source_link_hash_t sourceLinkHash,
const track_container_t& tracks, source_link_hash_t sourceLinkHash,
source_link_equality_t sourceLinkEquality,
std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors) const {
auto MeasurementIndexMap =
Expand Down Expand Up @@ -98,12 +96,11 @@ ScoreBasedAmbiguityResolution::computeInitialState(
return measurementsPerTrack;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>
template <typename track_container_t>
std::vector<double> Acts::ScoreBasedAmbiguityResolution::simpleScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
std::vector<double> trackScore;
trackScore.reserve(tracks.size());
Expand Down Expand Up @@ -248,12 +245,11 @@ std::vector<double> Acts::ScoreBasedAmbiguityResolution::simpleScore(
return trackScore;
}

template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>
template <typename track_container_t>
std::vector<double> Acts::ScoreBasedAmbiguityResolution::ambiguityScore(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
std::vector<double> trackScore;
trackScore.reserve(tracks.size());
Expand Down Expand Up @@ -425,13 +421,13 @@ std::vector<double> Acts::ScoreBasedAmbiguityResolution::ambiguityScore(

return trackScore;
}
template <typename track_container_t, typename traj_t,
template <typename> class holder_t, bool ReadOnly>

template <typename track_container_t>
std::vector<int> Acts::ScoreBasedAmbiguityResolution::solveAmbiguity(
const TrackContainer<track_container_t, traj_t, holder_t>& tracks,
const track_container_t& tracks,
const std::vector<std::vector<MeasurementInfo>>& measurementsPerTrack,
const std::vector<std::vector<TrackFeatures>>& trackFeaturesVectors,
const OptionalCuts<track_container_t, traj_t, holder_t, ReadOnly>&
const OptionalCuts<typename track_container_t::ConstTrackProxy>&
optionalCuts) const {
ACTS_INFO("Number of tracks before Ambiguty Resolution: " << tracks.size());
// vector of trackFeaturesVectors. where each trackFeaturesVector contains the
Expand Down
12 changes: 3 additions & 9 deletions Core/include/Acts/EventData/TrackHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,10 @@ namespace Acts {
/// Helper function to calculate a number of track level quantities and store
/// them on the track itself
/// @note The input track needs to be mutable, so @c ReadOnly=false
/// @tparam track_container_t the track container backend
/// @tparam track_state_container_t the track state container backend
/// @tparam holder_t the holder type for the track container backends
/// @tparam track_proxy_t The type of the track proxy
/// @param track A mutable track proxy to operate on
template <typename track_container_t, typename track_state_container_t,
template <typename> class holder_t>
void calculateTrackQuantities(
Acts::TrackProxy<track_container_t, track_state_container_t, holder_t,
false>
track) {
template <typename track_proxy_t>
void calculateTrackQuantities(track_proxy_t track) {
track.chi2() = 0;
track.nDoF() = 0;

Expand Down
27 changes: 11 additions & 16 deletions Core/include/Acts/TrackFitting/GaussianSumFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,12 @@ struct GaussianSumFitter {

/// @brief The fit function for the Direct navigator
template <typename source_link_it_t, typename start_parameters_t,
typename track_container_t, template <typename> class holder_t>
typename track_container_t>
auto fit(source_link_it_t begin, source_link_it_t end,
const start_parameters_t& sParameters,
const GsfOptions<traj_t>& options,
const std::vector<const Surface*>& sSequence,
TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
const {
track_container_t& trackContainer) const {
// Check if we have the correct navigator
static_assert(
std::is_same_v<DirectNavigator, typename propagator_t::Navigator>);
Expand Down Expand Up @@ -148,12 +147,11 @@ struct GaussianSumFitter {

/// @brief The fit function for the standard navigator
template <typename source_link_it_t, typename start_parameters_t,
typename track_container_t, template <typename> class holder_t>
typename track_container_t>
auto fit(source_link_it_t begin, source_link_it_t end,
const start_parameters_t& sParameters,
const GsfOptions<traj_t>& options,
TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
const {
track_container_t& trackContainer) const {
// Check if we have the correct navigator
static_assert(std::is_same_v<Navigator, typename propagator_t::Navigator>);

Expand Down Expand Up @@ -200,16 +198,13 @@ struct GaussianSumFitter {
/// first measurementSurface
template <typename source_link_it_t, typename start_parameters_t,
typename fwd_prop_initializer_t, typename bwd_prop_initializer_t,
typename track_container_t, template <typename> class holder_t>
Acts::Result<
typename TrackContainer<track_container_t, traj_t, holder_t>::TrackProxy>
fit_impl(source_link_it_t begin, source_link_it_t end,
const start_parameters_t& sParameters,
const GsfOptions<traj_t>& options,
const fwd_prop_initializer_t& fwdPropInitializer,
const bwd_prop_initializer_t& bwdPropInitializer,
TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
const {
typename track_container_t>
Acts::Result<typename track_container_t::TrackProxy> fit_impl(
source_link_it_t begin, source_link_it_t end,
const start_parameters_t& sParameters, const GsfOptions<traj_t>& options,
const fwd_prop_initializer_t& fwdPropInitializer,
const bwd_prop_initializer_t& bwdPropInitializer,
track_container_t& trackContainer) const {
// return or abort utility
auto return_error_or_abort = [&](auto error) {
if (options.abortOnError) {
Expand Down
10 changes: 4 additions & 6 deletions Core/include/Acts/TrackFitting/GlobalChiSquareFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,15 +656,13 @@ class Gx2Fitter {
/// @return the output as an output track
template <typename source_link_iterator_t, typename start_parameters_t,
typename parameters_t = BoundTrackParameters,
typename track_container_t, template <typename> class holder_t,
bool _isdn = isDirectNavigator>
typename track_container_t, bool _isdn = isDirectNavigator>
auto fit(source_link_iterator_t it, source_link_iterator_t end,
const start_parameters_t& sParameters,
const Gx2FitterOptions<traj_t>& gx2fOptions,
TrackContainer<track_container_t, traj_t, holder_t>& trackContainer)
const -> std::enable_if_t<
!_isdn, Result<typename TrackContainer<track_container_t, traj_t,
holder_t>::TrackProxy>> {
track_container_t& trackContainer) const
-> std::enable_if_t<!_isdn,
Result<typename track_container_t::TrackProxy>> {
// Preprocess Measurements (SourceLinks -> map)
// To be able to find measurements later, we put them into a map
// We need to copy input SourceLinks anyway, so the map can own them.
Expand Down
Loading

0 comments on commit 3c050df

Please sign in to comment.