Skip to content

Commit

Permalink
use pass by reference, add consts where possible (steadystateproblem) (
Browse files Browse the repository at this point in the history
…#1745)

* use pass by reference, add consts where possible

* Apply suggestions from code review

Co-authored-by: Daniel Weindl <[email protected]>

* fixup

Co-authored-by: Daniel Weindl <[email protected]>
  • Loading branch information
Fabian Fröhlich and dweindl authored Mar 25, 2022
1 parent 3683d3e commit 82c5d22
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 154 deletions.
30 changes: 15 additions & 15 deletions include/amici/newton_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class NewtonSolver {
*
* @param model pointer to the model object
*/
explicit NewtonSolver(const Model *model);
explicit NewtonSolver(const Model &model);

/**
* @brief Factory method to create a NewtonSolver based on linsolType
Expand All @@ -39,7 +39,7 @@ class NewtonSolver {
* @return solver NewtonSolver according to the specified linsolType
*/
static std::unique_ptr<NewtonSolver>
getSolver(const Solver &simulationSolver, const Model *model);
getSolver(const Solver &simulationSolver, const Model &model);

/**
* @brief Computes the solution of one Newton iteration
Expand All @@ -49,7 +49,7 @@ class NewtonSolver {
* @param model pointer to the model instance
* @param state current simulation state
*/
void getStep(AmiVector &delta, Model *model, const SimulationState &state);
void getStep(AmiVector &delta, Model &model, const SimulationState &state);

/**
* @brief Computes steady state sensitivities
Expand All @@ -58,7 +58,7 @@ class NewtonSolver {
* @param model pointer to the model instance
* @param state current simulation state
*/
void computeNewtonSensis(AmiVectorArray &sx, Model *model,
void computeNewtonSensis(AmiVectorArray &sx, Model &model,
const SimulationState &state);

/**
Expand All @@ -68,7 +68,7 @@ class NewtonSolver {
* @param model pointer to the model instance
* @param state current simulation state
*/
virtual void prepareLinearSystem(Model *model,
virtual void prepareLinearSystem(Model &model,
const SimulationState &state) = 0;

/**
Expand All @@ -78,7 +78,7 @@ class NewtonSolver {
* @param model pointer to the model instance
* @param state current simulation state
*/
virtual void prepareLinearSystemB(Model *model,
virtual void prepareLinearSystemB(Model &model,
const SimulationState &state) = 0;

/**
Expand All @@ -103,7 +103,7 @@ class NewtonSolver {
* @return boolean indicating whether the linear system is singular
* (condition number < 1/machine precision)
*/
virtual bool is_singular(Model *model,
virtual bool is_singular(Model &model,
const SimulationState &state) const = 0;

virtual ~NewtonSolver() = default;
Expand Down Expand Up @@ -133,7 +133,7 @@ class NewtonSolverDense : public NewtonSolver {
*
* @param model model instance that provides problem dimensions
*/
explicit NewtonSolverDense(const Model *model);
explicit NewtonSolverDense(const Model &model);

NewtonSolverDense(const NewtonSolverDense &) = delete;

Expand All @@ -143,15 +143,15 @@ class NewtonSolverDense : public NewtonSolver {

void solveLinearSystem(AmiVector &rhs) override;

void prepareLinearSystem(Model *model,
void prepareLinearSystem(Model &model,
const SimulationState &state) override;

void prepareLinearSystemB(Model *model,
void prepareLinearSystemB(Model &model,
const SimulationState &state) override;

void reinitialize() override;

bool is_singular(Model *model, const SimulationState &state) const override;
bool is_singular(Model &model, const SimulationState &state) const override;

private:
/** temporary storage of Jacobian */
Expand All @@ -174,7 +174,7 @@ class NewtonSolverSparse : public NewtonSolver {
*
* @param model model instance that provides problem dimensions
*/
explicit NewtonSolverSparse(const Model *model);
explicit NewtonSolverSparse(const Model &model);

NewtonSolverSparse(const NewtonSolverSparse &) = delete;

Expand All @@ -184,13 +184,13 @@ class NewtonSolverSparse : public NewtonSolver {

void solveLinearSystem(AmiVector &rhs) override;

void prepareLinearSystem(Model *model,
void prepareLinearSystem(Model &model,
const SimulationState &state) override;

void prepareLinearSystemB(Model *model,
void prepareLinearSystemB(Model &model,
const SimulationState &state) override;

bool is_singular(Model *model, const SimulationState &state) const override;
bool is_singular(Model &model, const SimulationState &state) const override;

void reinitialize() override;

Expand Down
2 changes: 1 addition & 1 deletion include/amici/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class Solver {
*
* @param model pointer to the model instance
*/
void updateAndReinitStatesAndSensitivities(Model *model);
void updateAndReinitStatesAndSensitivities(Model *model) const;

/**
* getRootInfo extracts information which event occurred
Expand Down
40 changes: 20 additions & 20 deletions include/amici/steadystateproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SteadystateProblem {
* @param solver Solver instance
* @param model Model instance
*/
explicit SteadystateProblem(const Solver &solver, Model &model);
explicit SteadystateProblem(const Solver &solver, const Model &model);

/**
* @brief Handles steady state computation in the forward case:
Expand All @@ -39,7 +39,7 @@ class SteadystateProblem {
* @param model pointer to the model object
* @param it integer with the index of the current time step
*/
void workSteadyStateProblem(Solver *solver, Model *model, int it);
void workSteadyStateProblem(const Solver &solver, Model &model, int it);

/**
* Integrates over the adjoint state backward in time by solving a linear
Expand All @@ -49,7 +49,7 @@ class SteadystateProblem {
* @param model pointer to the model object
* @param bwd backward problem
*/
void workSteadyStateBackwardProblem(Solver *solver, Model *model,
void workSteadyStateBackwardProblem(const Solver &solver, Model &model,
const BackwardProblem *bwd);

/**
Expand Down Expand Up @@ -166,22 +166,22 @@ class SteadystateProblem {
* @param model pointer to the model object
* @param it integer with the index of the current time step
*/
void findSteadyState(const Solver *solver, Model *model, int it);
void findSteadyState(const Solver &solver, Model &model, int it);

/**
* @brief Tries to determine the steady state by using Newton's method
* @param model pointer to the model object
* @param newton_retry bool flag indicating whether being relaunched
*/
void findSteadyStateByNewtonsMethod(Model *model, bool newton_retry);
void findSteadyStateByNewtonsMethod(Model &model, bool newton_retry);

/**
* @brief Tries to determine the steady state by using forward simulation
* @param solver pointer to the solver object
* @param model pointer to the model object
* @param it integer with the index of the current time step
*/
void findSteadyStateBySimulation(const Solver *solver, Model *model,
void findSteadyStateBySimulation(const Solver &solver, Model &model,
int it);

/**
Expand All @@ -190,22 +190,22 @@ class SteadystateProblem {
* @param solver pointer to the solver object
* @param model pointer to the model object
*/
void computeSteadyStateQuadrature(const Solver *solver, Model *model);
void computeSteadyStateQuadrature(const Solver &solver, Model &model);

/**
* @brief Computes the quadrature in steady state backward mode by
* solving the linear system defined by the backward Jacobian
* @param model pointer to the model object
*/
void getQuadratureByLinSolve(Model *model);
void getQuadratureByLinSolve(Model &model);

/**
* @brief Computes the quadrature in steady state backward mode by
* numerical integration of xB forward in time
* @param solver pointer to the solver object
* @param model pointer to the model object
*/
void getQuadratureBySimulation(const Solver *solver, Model *model);
void getQuadratureBySimulation(const Solver &solver, Model &model);

/**
* @brief Stores state and throws an exception if equilibration failed
Expand All @@ -230,7 +230,7 @@ class SteadystateProblem {
* @param context SteadyStateContext giving the situation for the flag
* @return flag telling how to process state sensitivities
*/
bool getSensitivityFlag(const Model *model, const Solver *solver, int it,
bool getSensitivityFlag(const Model &model, const Solver &solver, int it,
SteadyStateContext context);

/**
Expand All @@ -254,22 +254,22 @@ class SteadystateProblem {
* @param sensi_method sensitivity method
* @return weighted root mean squared residuals of the RHS
*/
realtype getWrms(Model *model, SensitivityMethod sensi_method);
realtype getWrms(Model &model, SensitivityMethod sensi_method);

/**
* @brief Checks convergence for state sensitivities
* @param model Model instance
* @return weighted root mean squared residuals of the RHS
*/
realtype getWrmsFSA(Model *model);
realtype getWrmsFSA(Model &model);

/**
* @brief Runs the Newton solver iterations and checks for convergence
* to steady state
* @param model pointer to the model object
* @param newton_retry flag indicating if Newton solver is rerun
*/
void applyNewtonsMethod(Model *model, bool newton_retry);
void applyNewtonsMethod(Model &model, bool newton_retry);

/**
* @brief Simulation is launched, if Newton solver or linear system solve
Expand All @@ -278,7 +278,7 @@ class SteadystateProblem {
* @param model pointer to the model object
* @param backward flag indicating adjoint mode (including quadrature)
*/
void runSteadystateSimulation(const Solver *solver, Model *model,
void runSteadystateSimulation(const Solver &solver, Model &model,
bool backward);

/**
Expand All @@ -289,8 +289,8 @@ class SteadystateProblem {
* @param backward flag switching on quadratures computation
* @return solver instance
*/
std::unique_ptr<Solver> createSteadystateSimSolver(const Solver *solver,
Model *model,
std::unique_ptr<Solver> createSteadystateSimSolver(const Solver &solver,
Model &model,
bool forwardSensis,
bool backward) const;

Expand All @@ -300,7 +300,7 @@ class SteadystateProblem {
* @param solver pointer to the solver object
* @param model pointer to the model object
*/
void initializeForwardProblem(int it, const Solver *solver, Model *model);
void initializeForwardProblem(int it, const Solver &solver, Model &model);

/**
* @brief Initialize backward computation
Expand All @@ -309,7 +309,7 @@ class SteadystateProblem {
* @param bwd pointer to backward problem
* @return flag indicating whether backward computation to be carried out
*/
bool initializeBackwardProblem(Solver *solver, Model *model,
bool initializeBackwardProblem(const Solver &solver, Model &model,
const BackwardProblem *bwd);

/**
Expand All @@ -319,15 +319,15 @@ class SteadystateProblem {
* @param yQ vector to be multiplied with dxdotdp
* @param yQB resulting vector after multiplication
*/
void computeQBfromQ(Model *model, const AmiVector &yQ,
void computeQBfromQ(Model &model, const AmiVector &yQ,
AmiVector &yQB) const;

/**
* @brief Ensures state positivity, if requested and repeats convergence
* check, if necessary
* @param model pointer to the model object
*/
bool makePositiveAndCheckConvergence(Model *model);
bool makePositiveAndCheckConvergence(Model &model);

/**
* @brief Updates the damping factor gamma that determines step size
Expand Down
8 changes: 4 additions & 4 deletions src/amici.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ AmiciApplication::runAmiciSimulation(Solver& solver,
);

preeq = std::make_unique<SteadystateProblem>(solver, model);
preeq->workSteadyStateProblem(&solver, &model, -1);
preeq->workSteadyStateProblem(solver, model, -1);
}


Expand All @@ -142,7 +142,7 @@ AmiciApplication::runAmiciSimulation(Solver& solver,

if (fwd->getCurrentTimeIteration() < model.nt()) {
posteq = std::make_unique<SteadystateProblem>(solver, model);
posteq->workSteadyStateProblem(&solver, &model,
posteq->workSteadyStateProblem(solver, model,
fwd->getCurrentTimeIteration());
}

Expand All @@ -151,7 +151,7 @@ AmiciApplication::runAmiciSimulation(Solver& solver,
fwd->getAdjointUpdates(model, *edata);
if (posteq) {
posteq->getAdjointUpdates(model, *edata);
posteq->workSteadyStateBackwardProblem(&solver, &model,
posteq->workSteadyStateBackwardProblem(solver, model,
bwd.get());
}

Expand All @@ -165,7 +165,7 @@ AmiciApplication::runAmiciSimulation(Solver& solver,
if (preeq) {
ConditionContext cc2(&model, edata,
FixedParameterContext::preequilibration);
preeq->workSteadyStateBackwardProblem(&solver, &model,
preeq->workSteadyStateBackwardProblem(solver, model,
bwd.get());
}
}
Expand Down
Loading

0 comments on commit 82c5d22

Please sign in to comment.