Skip to content

Commit

Permalink
feat: Manual Propagator initialization and AnySurfaceReached abor…
Browse files Browse the repository at this point in the history
…ter (acts-project#3208)

This allows to manually initialize the propagation so `propagate(state)` can be called multiple times without destruction of the state in-between. I also added an `AnySurfaceReached` aborter which, in combination with the other change, allows users to jump from surface to surface using the same propagator state without any resets.
  • Loading branch information
andiwand authored and EleniXoch committed May 31, 2024
1 parent 6472b0b commit 0633b98
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 67 deletions.
90 changes: 79 additions & 11 deletions Core/include/Acts/Propagator/Propagator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,49 @@ class Propagator final
propagate(const parameters_t& start, const Surface& target,
const propagator_options_t& options) const;

/// @brief Builds the propagator state object
///
/// This function creates the propagator state object from the initial track
/// parameters and the propagation options.
///
/// @note This will also initialize the state
///
/// @tparam parameters_t Type of initial track parameters to propagate
/// @tparam propagator_options_t Type of the propagator options
/// @tparam path_aborter_t The path aborter type to be added
///
/// @param [in] start Initial track parameters to propagate
/// @param [in] options Propagation options
///
/// @return Propagator state object
template <typename parameters_t, typename propagator_options_t,
typename path_aborter_t = PathLimitReached>
auto makeState(const parameters_t& start,
const propagator_options_t& options) const;

/// @brief Builds the propagator state object
///
/// This function creates the propagator state object from the initial track
/// parameters, the target surface, and the propagation options.
///
/// @note This will also initialize the state
///
/// @tparam parameters_t Type of initial track parameters to propagate
/// @tparam propagator_options_t Type of the propagator options
/// @tparam target_aborter_t The target aborter type to be added
/// @tparam path_aborter_t The path aborter type to be added
///
/// @param [in] start Initial track parameters to propagate
/// @param [in] target Target surface of to propagate to
/// @param [in] options Propagation options
///
/// @return Propagator state object
template <typename parameters_t, typename propagator_options_t,
typename target_aborter_t = SurfaceReached,
typename path_aborter_t = PathLimitReached>
auto makeState(const parameters_t& start, const Surface& target,
const propagator_options_t& options) const;

/// @brief Propagate track parameters
///
/// This function performs the propagation of the track parameters according
Expand All @@ -484,24 +527,46 @@ class Propagator final
template <typename propagator_state_t>
Result<void> propagate(propagator_state_t& state) const;

template <typename parameters_t, typename propagator_options_t,
typename path_aborter_t = PathLimitReached>
auto makeState(const parameters_t& start,
const propagator_options_t& options) const;

template <typename parameters_t, typename propagator_options_t,
typename target_aborter_t = SurfaceReached,
typename path_aborter_t = PathLimitReached>
auto makeState(const parameters_t& start, const Surface& target,
const propagator_options_t& options) const;

/// @brief Builds the propagator result object
///
/// This function creates the propagator result object from the propagator
/// state object. The `result` is passed to pipe a potential error from the
/// propagation call. The `options` are used to determine the type of the
/// result object. The `makeCurvilinear` flag is used to determine if the
/// result should contain curvilinear track parameters.
///
/// @tparam propagator_state_t Type of the propagator state object
/// @tparam propagator_options_t Type of the propagator options
///
/// @param [in] state Propagator state object
/// @param [in] result Result of the propagation
/// @param [in] options Propagation options
/// @param [in] makeCurvilinear Produce curvilinear parameters at the end of the propagation
///
/// @return Propagation result
template <typename propagator_state_t, typename propagator_options_t>
Result<
action_list_t_result_t<StepperCurvilinearTrackParameters,
typename propagator_options_t::action_list_type>>
makeResult(propagator_state_t state, Result<void> result,
const propagator_options_t& options, bool makeCurvilinear) const;

/// @brief Builds the propagator result object
///
/// This function creates the propagator result object from the propagator
/// state object. The `result` is passed to pipe a potential error from the
/// propagation call. The `options` are used to determine the type of the
/// result object.
///
/// @tparam propagator_state_t Type of the propagator state object
/// @tparam propagator_options_t Type of the propagator options
///
/// @param [in] state Propagator state object
/// @param [in] result Result of the propagation
/// @param [in] target Target surface of to propagate to
/// @param [in] options Propagation options
///
/// @return Propagation result
template <typename propagator_state_t, typename propagator_options_t>
Result<
action_list_t_result_t<StepperBoundTrackParameters,
Expand All @@ -516,6 +581,9 @@ class Propagator final
private:
const Logger& logger() const { return *m_logger; }

template <typename propagator_state_t, typename path_aborter_t>
void initialize(propagator_state_t& state) const;

template <typename propagator_state_t, typename result_t>
void moveStateToResult(propagator_state_t& state, result_t& result) const;

Expand Down
24 changes: 14 additions & 10 deletions Core/include/Acts/Propagator/Propagator.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ auto Acts::Propagator<S, N>::propagate(propagator_state_t& state) const

state.stage = PropagatorStage::prePropagation;

// Navigator initialize state call
m_navigator.initialize(state, m_stepper);
// Pre-Stepping call to the action list
state.options.actionList(state, m_stepper, m_navigator, logger());
// assume negative outcome, only set to true later if we actually have
Expand Down Expand Up @@ -178,10 +176,7 @@ auto Acts::Propagator<S, N>::makeState(
"Step method of the Stepper is not compatible with the propagator "
"state");

// Apply the loop protection - it resets the internal path limit
detail::setupLoopProtection(
state, m_stepper, state.options.abortList.template get<path_aborter_t>(),
false, logger());
initialize<StateType, path_aborter_t>(state);

return state;
}
Expand Down Expand Up @@ -222,10 +217,7 @@ auto Acts::Propagator<S, N>::makeState(
"Step method of the Stepper is not compatible with the propagator "
"state");

// Apply the loop protection, it resets the internal path limit
detail::setupLoopProtection(
state, m_stepper, state.options.abortList.template get<path_aborter_t>(),
false, logger());
initialize<StateType, path_aborter_t>(state);

return state;
}
Expand Down Expand Up @@ -318,6 +310,18 @@ auto Acts::Propagator<S, N>::makeResult(
return Result<ResultType>::success(std::move(result));
}

template <typename S, typename N>
template <typename propagator_state_t, typename path_aborter_t>
void Acts::Propagator<S, N>::initialize(propagator_state_t& state) const {
// Navigator initialize state call
m_navigator.initialize(state, m_stepper);

// Apply the loop protection - it resets the internal path limit
detail::setupLoopProtection(
state, m_stepper, state.options.abortList.template get<path_aborter_t>(),
false, logger());
}

template <typename S, typename N>
template <typename propagator_state_t, typename result_t>
void Acts::Propagator<S, N>::moveStateToResult(propagator_state_t& state,
Expand Down
30 changes: 28 additions & 2 deletions Core/include/Acts/Propagator/StandardAborters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ struct PathLimitReached {
///
/// @param [in,out] state The propagation state object
/// @param [in] stepper Stepper used for propagation
/// @param [in] navigator Navigator used for propagation
/// @param logger a logger instance
template <typename propagator_state_t, typename stepper_t,
typename navigator_t>
bool operator()(propagator_state_t& state, const stepper_t& stepper,
const navigator_t& /*navigator*/,
const Logger& logger) const {
const navigator_t& navigator, const Logger& logger) const {
(void)navigator;

// Check if the maximum allowed step size has to be updated
double distance =
std::abs(internalLimit) - std::abs(state.stepping.pathAccumulated);
Expand Down Expand Up @@ -181,4 +183,28 @@ struct EndOfWorldReached {
}
};

/// Aborter that checks if the propagation has reached any surface
struct AnySurfaceReached {
template <typename propagator_state_t, typename stepper_t,
typename navigator_t>
bool operator()(propagator_state_t& state, const stepper_t& stepper,
const navigator_t& navigator, const Logger& logger) const {
(void)stepper;
(void)logger;

const Surface* startSurface = navigator.startSurface(state.navigation);
const Surface* targetSurface = navigator.targetSurface(state.navigation);
const Surface* currentSurface = navigator.currentSurface(state.navigation);

// `startSurface` is excluded because we want to reach a new surface
// `targetSurface` is excluded because another aborter should handle it
if (currentSurface != nullptr && currentSurface != startSurface &&
currentSurface != targetSurface) {
return true;
}

return false;
}
};

} // namespace Acts
2 changes: 1 addition & 1 deletion Core/include/Acts/TrackFitting/KalmanFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ class KalmanFitter {
kalmanResult.fittedStates = &trackContainer.trackStateContainer();

// Run the fitter
auto result = m_propagator.template propagate(propagatorState);
auto result = m_propagator.propagate(propagatorState);

if (!result.ok()) {
ACTS_ERROR("Propagation failed: " << result.error());
Expand Down
45 changes: 2 additions & 43 deletions Tests/UnitTests/Core/Navigation/DetectorNavigatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,8 @@ BOOST_AUTO_TEST_CASE(DetectorNavigatorTestsInitialization) {
Acts::Experimental::DetectorNavigator>(
stepper, navigator);

auto state = propagator.makeState(start, options);

BOOST_CHECK_THROW(navigator.initialize(state, stepper),
BOOST_CHECK_THROW(propagator.makeState(start, options),
std::invalid_argument);

navigator.preStep(state, stepper);
auto preStepState = state.navigation;
BOOST_CHECK_EQUAL(preStepState.currentDetector, nullptr);
BOOST_CHECK_EQUAL(preStepState.currentVolume, nullptr);
BOOST_CHECK_EQUAL(preStepState.currentSurface, nullptr);
BOOST_CHECK_EQUAL(preStepState.currentPortal, nullptr);
BOOST_CHECK(preStepState.surfaceCandidates.empty());

navigator.postStep(state, stepper);
auto postStepState = state.navigation;
BOOST_CHECK_EQUAL(postStepState.currentDetector, nullptr);
BOOST_CHECK_EQUAL(postStepState.currentVolume, nullptr);
BOOST_CHECK_EQUAL(postStepState.currentSurface, nullptr);
BOOST_CHECK_EQUAL(postStepState.currentPortal, nullptr);
BOOST_CHECK(postStepState.surfaceCandidates.empty());
}

// Run with geometry but without resolving
Expand Down Expand Up @@ -168,31 +150,8 @@ BOOST_AUTO_TEST_CASE(DetectorNavigatorTestsInitialization) {
Acts::Experimental::DetectorNavigator>(
stepper, navigator);

auto state = propagator.makeState(startEoW, options);

BOOST_CHECK(navigator.endOfWorldReached(state.navigation));

BOOST_CHECK_THROW(navigator.initialize(state, stepper),
BOOST_CHECK_THROW(propagator.makeState(startEoW, options),
std::invalid_argument);
auto initState = state.navigation;
BOOST_CHECK_EQUAL(initState.currentVolume, nullptr);
BOOST_CHECK_EQUAL(initState.currentSurface, nullptr);
BOOST_CHECK_EQUAL(initState.currentPortal, nullptr);
BOOST_CHECK(initState.surfaceCandidates.empty());

navigator.preStep(state, stepper);
auto preStepState = state.navigation;
BOOST_CHECK_EQUAL(preStepState.currentVolume, nullptr);
BOOST_CHECK_EQUAL(preStepState.currentSurface, nullptr);
BOOST_CHECK_EQUAL(preStepState.currentPortal, nullptr);
BOOST_CHECK(preStepState.surfaceCandidates.empty());

navigator.postStep(state, stepper);
auto postStepState = state.navigation;
BOOST_CHECK_EQUAL(postStepState.currentVolume, nullptr);
BOOST_CHECK_EQUAL(postStepState.currentSurface, nullptr);
BOOST_CHECK_EQUAL(postStepState.currentPortal, nullptr);
BOOST_CHECK(postStepState.surfaceCandidates.empty());
}

// Initialize properly
Expand Down

0 comments on commit 0633b98

Please sign in to comment.