diff --git a/src/lib.rs b/src/lib.rs index b7ef130d..a0ac280a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,15 +11,19 @@ //! You will also need to choose a matrix type to use. DiffSol can use the [nalgebra](https://nalgebra.org) `DMatrix` type, the [faer](https://github.com/sarah-ek/faer-rs) `Mat` type, or any other type that implements the //! [Matrix] trait. You can also use the [sundials](https://computation.llnl.gov/projects/sundials) library for the matrix and vector types (see [SundialsMatrix]). //! +//! ## Initial state +//! +//! The solver state is held in [OdeSolverState], and contains a state vector, the gradient of the state vector, the time, and the step size. You can intitialise a new state using [OdeSolverState::new], +//! or create an uninitialised state using [OdeSolverState::new_without_initialise] and intitialise it manually or using the [OdeSolverState::set_consistent] and [OdeSolverState::set_step_size] methods. +//! //! ## The solver //! -//! To solve the problem, you need to choose a solver. DiffSol provides the following solvers: +//! To solve the problem given the initial state, you need to choose a solver. DiffSol provides the following solvers: //! - A Backwards Difference Formulae [Bdf] solver, suitable for stiff problems and singular mass matrices. //! - A Singly Diagonally Implicit Runge-Kutta (SDIRK or ESDIRK) solver [Sdirk]. You can use your own butcher tableau using [Tableau] or use one of the provided ([Tableau::tr_bdf2], [Tableau::esdirk34]). //! - A BDF solver that wraps the IDA solver solver from the sundials library ([SundialsIda], requires the `sundials` feature). //! //! See the [OdeSolverMethod] trait for a more detailed description of the available methods on each solver. Possible workflows are: -//! - Initialise the problem using [OdeSolverState::new] or [OdeSolverState::new_consistent], and then use [OdeSolverMethod::set_problem] to setup the solver with the problem and [OdeSolverState] instance. //! - Use the [OdeSolverMethod::step] method to step the solution forward in time with an internal time step chosen by the solver to meet the error tolerances. //! - Use the [OdeSolverMethod::interpolate] method to interpolate the solution between the last two time steps. //! - Use the [OdeSolverMethod::set_stop_time] method to stop the solver at a specific time (i.e. this will override the internal time step so that the solver stops at the specified time). @@ -141,6 +145,7 @@ pub use ode_solver::sundials::SundialsIda; #[cfg(feature = "diffsl")] pub use ode_solver::diffsl::DiffSlContext; +pub use matrix::default_solver::DefaultSolver; use matrix::{DenseMatrix, Matrix, MatrixCommon, MatrixSparsity, MatrixView, MatrixViewMut}; pub use nonlinear_solver::newton::NewtonNonlinearSolver; use nonlinear_solver::{root::RootFinder, NonLinearSolver}; @@ -196,7 +201,7 @@ mod tests { let t = 0.4; let y = solver.solve(&problem, t).unwrap(); - let state = OdeSolverState::new(&problem); + let state = OdeSolverState::new(&problem, &solver).unwrap(); solver.set_problem(state, &problem); while solver.state().unwrap().t <= t { solver.step().unwrap(); @@ -237,7 +242,7 @@ mod tests { let t = 0.4; let y = solver.solve(&problem, t).unwrap(); - let state = OdeSolverState::new(&problem); + let state = OdeSolverState::new(&problem, &solver).unwrap(); solver.set_problem(state, &problem); while solver.state().unwrap().t <= t { solver.step().unwrap(); diff --git a/src/ode_solver/bdf/mod.rs b/src/ode_solver/bdf.rs similarity index 63% rename from src/ode_solver/bdf/mod.rs rename to src/ode_solver/bdf.rs index 0a53ebb5..82ad668b 100644 --- a/src/ode_solver/bdf/mod.rs +++ b/src/ode_solver/bdf.rs @@ -1,5 +1,5 @@ -use std::ops::AddAssign; use std::rc::Rc; +use std::{ops::AddAssign, ops::MulAssign, panic}; use anyhow::{anyhow, Result}; @@ -12,14 +12,11 @@ use crate::{ op::bdf::BdfCallable, scalar::scale, vector::DefaultDenseMatrix, - DenseMatrix, IndexType, MatrixViewMut, NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, - OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Op, Scalar, - SolverProblem, Vector, VectorRef, VectorView, VectorViewMut, + DenseMatrix, IndexType, MatrixViewMut, NewtonNonlinearSolver, NonLinearSolver, OdeSolverMethod, + OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Op, Scalar, SolverProblem, Vector, + VectorRef, VectorView, VectorViewMut, }; -pub mod faer; -pub mod nalgebra; - use super::equations::OdeEquations; #[derive(Clone, Debug, Serialize)] @@ -81,6 +78,7 @@ pub struct Bdf< state: Option>, tstop: Option, root_finder: Option>, + is_state_modified: bool, } impl Default @@ -102,6 +100,30 @@ where let mut nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); type M = ::M; + + // kappa values for difference orders, taken from Table 1 of [1] + let kappa = [ + Eqn::T::from(0.0), + Eqn::T::from(-0.1850), + Eqn::T::from(-1.0) / Eqn::T::from(9.0), + Eqn::T::from(-0.0823), + Eqn::T::from(-0.0415), + Eqn::T::from(0.0), + ]; + let mut alpha = vec![Eqn::T::zero()]; + let mut gamma = vec![Eqn::T::zero()]; + let mut error_const = vec![Eqn::T::one()]; + + #[allow(clippy::needless_range_loop)] + for i in 1..=Self::MAX_ORDER { + let i_t = Eqn::T::from(i as f64); + let one_over_i = Eqn::T::one() / i_t; + let one_over_i_plus_one = Eqn::T::one() / (i_t + Eqn::T::one()); + gamma.push(gamma[i - 1] + one_over_i); + alpha.push(Eqn::T::one() / ((Eqn::T::one() - kappa[i]) * gamma[i])); + error_const.push(kappa[i] * gamma[i] + one_over_i_plus_one); + } + Self { ode_problem: None, nonlinear_solver, @@ -109,14 +131,15 @@ where n_equal_steps: 0, diff: as Matrix>::zeros(n, Self::MAX_ORDER + 3), //DMatrix::::zeros(n, Self::MAX_ORDER + 3), diff_tmp: as Matrix>::zeros(n, Self::MAX_ORDER + 3), - gamma: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], - alpha: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], - error_const: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], + gamma, + alpha, + error_const, u: as Matrix>::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), statistics: BdfStatistics::default(), state: None, tstop: None, root_finder: None, + is_state_modified: false, } } } @@ -274,6 +297,31 @@ where } Ok(None) } + + fn initialise_to_first_order(&mut self) { + if self.state.as_ref().unwrap().y.len() != self.problem().unwrap().eqn.rhs().nstates() { + panic!("State vector length does not match number of states in problem"); + } + let nstates = self.ode_problem.as_ref().unwrap().eqn.rhs().nstates(); + let state = self.state.as_ref().unwrap(); + self.order = 1usize; + self.n_equal_steps = 0; + if self.diff.nrows() != nstates { + self.diff = M::zeros(nstates, Self::MAX_ORDER + 3); + self.diff_tmp = M::zeros(nstates, Self::MAX_ORDER + 3); + } + self.diff.column_mut(0).copy_from(&state.y); + self.diff.column_mut(1).copy_from(&state.f); + self.diff.column_mut(1).mul_assign(scale(state.h)); + + // setup U + self.u = Self::_compute_r(self.order, Eqn::T::one()); + + // update statistics + self.statistics.initial_step_size = state.h; + + self.is_state_modified = false; + } } impl, Eqn: OdeEquations, Nls> OdeSolverMethod @@ -283,6 +331,10 @@ where for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, { + fn order(&self) -> usize { + self.order + } + fn interpolate(&self, t: Eqn::T) -> Result { //interpolate solution at time values t* where t-h < t* < t // @@ -291,6 +343,14 @@ where // state must be set let state = self.state.as_ref().ok_or(anyhow!("State not set"))?; + if self.is_state_modified { + if t == state.t { + return Ok(state.y.clone()); + } else { + return Err(anyhow::anyhow!("Interpolation time is not within the current step. Step size is zero after calling state_mut()")); + } + } + // check that t is before the current time if t > state.t { return Err(anyhow!("Interpolation time is after current time")); @@ -313,71 +373,17 @@ where fn state(&self) -> Option<&OdeSolverState> { self.state.as_ref() } - - fn take_state(&mut self) -> Option::V>> { + fn take_state(&mut self) -> Option> { Option::take(&mut self.state) } + fn state_mut(&mut self) -> Option<&mut OdeSolverState> { + self.is_state_modified = true; + self.state.as_mut() + } + fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { - let mut state = state; self.ode_problem = Some(problem.clone()); - let nstates = problem.eqn.rhs().nstates(); - self.order = 1usize; - self.n_equal_steps = 0; - self.diff = M::zeros(nstates, Self::MAX_ORDER + 3); - self.diff_tmp = M::zeros(nstates, Self::MAX_ORDER + 3); - self.diff.column_mut(0).copy_from(&state.y); - - // kappa values for difference orders, taken from Table 1 of [1] - let kappa = [ - Eqn::T::from(0.0), - Eqn::T::from(-0.1850), - Eqn::T::from(-1.0) / Eqn::T::from(9.0), - Eqn::T::from(-0.0823), - Eqn::T::from(-0.0415), - Eqn::T::from(0.0), - ]; - self.alpha = vec![Eqn::T::zero()]; - self.gamma = vec![Eqn::T::zero()]; - self.error_const = vec![Eqn::T::one()]; - - #[allow(clippy::needless_range_loop)] - for i in 1..=Self::MAX_ORDER { - let i_t = Eqn::T::from(i as f64); - let one_over_i = Eqn::T::one() / i_t; - let one_over_i_plus_one = Eqn::T::one() / (i_t + Eqn::T::one()); - self.gamma.push(self.gamma[i - 1] + one_over_i); - self.alpha - .push(Eqn::T::one() / ((Eqn::T::one() - kappa[i]) * self.gamma[i])); - self.error_const - .push(kappa[i] * self.gamma[i] + one_over_i_plus_one); - } - - // update initial step size based on function - let mut scale_factor = state.y.abs(); - scale_factor *= scale(problem.rtol); - scale_factor += problem.atol.as_ref(); - - let f0 = problem.eqn.rhs().call(&state.y, state.t); - let hf0 = &f0 * scale(state.h); - let y1 = &state.y + &hf0; - let t1 = state.t + state.h; - let f1 = problem.eqn.rhs().call(&y1, t1); - - // store f1 in diff[1] for use in step size control - self.diff.column_mut(1).copy_from(&hf0); - - let mut df = f1 - f0; - df.component_div_assign(&scale_factor); - let d2 = df.norm(); - - let one_over_order_plus_one = - Eqn::T::one() / (Eqn::T::from(self.order as f64) + Eqn::T::one()); - let mut new_h = state.h * d2.pow(-one_over_order_plus_one); - if new_h > Eqn::T::from(100.0) * state.h { - new_h = Eqn::T::from(100.0) * state.h; - } - state.h = new_h; // setup linear solver for first step let bdf_callable = Rc::new(BdfCallable::new(problem)); @@ -386,13 +392,7 @@ where let nonlinear_problem = SolverProblem::new_from_ode_problem(bdf_callable, problem); self.nonlinear_solver.set_problem(&nonlinear_problem); - // setup U - self.u = Self::_compute_r(self.order, Eqn::T::one()); - - // update statistics - self.statistics.initial_step_size = state.h; - - // store state + // store state and setup root solver self.state = Some(state); if let Some(root_fn) = problem.eqn.root() { let state = self.state.as_ref().unwrap(); @@ -402,6 +402,9 @@ where .unwrap() .init(root_fn.as_ref(), &state.y, state.t); } + + // initialise solver to first order + self.initialise_to_first_order(); } fn step(&mut self) -> Result> { @@ -414,6 +417,10 @@ where return Err(anyhow!("State not set")); } + if self.is_state_modified { + self.initialise_to_first_order(); + } + let (mut y_predict, mut t_new) = self._predict_forward(); // loop until step is accepted @@ -601,3 +608,280 @@ where Ok(()) } } + +#[cfg(test)] +mod test { + use crate::{ + ode_solver::{ + test_models::dydt_y2::dydt_y2_problem, + test_models::exponential_decay::exponential_decay_problem, + test_models::exponential_decay::exponential_decay_problem_with_root, + test_models::exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem, + test_models::gaussian_decay::gaussian_decay_problem, + test_models::robertson::robertson, + test_models::robertson_ode::robertson_ode, + tests::{ + test_interpolate, test_no_set_problem, test_ode_solver, test_state_mut, + test_state_mut_on_problem, + }, + }, + Bdf, FaerLU, NalgebraLU, NewtonNonlinearSolver, OdeEquations, Op, + }; + + use num_traits::abs; + + type M = nalgebra::DMatrix; + #[test] + fn bdf_no_set_problem() { + test_no_set_problem::(Bdf::default()) + } + #[test] + fn bdf_state_mut() { + test_state_mut::(Bdf::default()) + } + #[test] + fn bdf_test_interpolate() { + test_interpolate::(Bdf::default()) + } + + #[test] + fn bdf_test_state_mut_exponential_decay() { + let (p, soln) = exponential_decay_problem::(false); + let s = Bdf::default(); + test_state_mut_on_problem(s, p, soln); + } + + #[test] + fn bdf_test_nalgebra_exponential_decay() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = exponential_decay_problem::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 8 + number_of_steps: 25 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 50 + number_of_nonlinear_solver_fails: 0 + initial_step_size: 0.001189207115002721 + final_step_size: 0.9861196765479318 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 52 + number_of_jac_muls: 2 + number_of_matrix_evals: 1 + "###); + } + + #[test] + fn bdf_test_faer_exponential_decay() { + type M = faer::Mat; + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(FaerLU::default()); + let (problem, soln) = exponential_decay_problem::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 8 + number_of_steps: 25 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 50 + number_of_nonlinear_solver_fails: 0 + initial_step_size: 0.001189207115002721 + final_step_size: 0.9861196765889989 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 52 + number_of_jac_muls: 2 + number_of_matrix_evals: 1 + "###); + } + + #[test] + fn test_bdf_nalgebra_exponential_decay_algebraic() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = exponential_decay_with_algebraic_problem::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 11 + number_of_steps: 17 + number_of_error_test_failures: 4 + number_of_nonlinear_solver_iterations: 42 + number_of_nonlinear_solver_fails: 0 + initial_step_size: 0.00014907855910877986 + final_step_size: 0.2008052778053449 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 46 + number_of_jac_muls: 4 + number_of_matrix_evals: 1 + "###); + } + + #[test] + fn test_bdf_nalgebra_robertson() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = robertson::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 103 + number_of_steps: 352 + number_of_error_test_failures: 3 + number_of_nonlinear_solver_iterations: 1003 + number_of_nonlinear_solver_fails: 21 + initial_step_size: 0.000000005427827356796531 + final_step_size: 5943224095.574959 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 1006 + number_of_jac_muls: 55 + number_of_matrix_evals: 18 + "###); + } + + #[test] + fn test_bdf_nalgebra_robertson_colored() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = robertson::(true); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 103 + number_of_steps: 352 + number_of_error_test_failures: 3 + number_of_nonlinear_solver_iterations: 1003 + number_of_nonlinear_solver_fails: 21 + initial_step_size: 0.000000005427827356796531 + final_step_size: 5943224095.574959 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 1006 + number_of_jac_muls: 58 + number_of_matrix_evals: 18 + "###); + } + + #[test] + fn test_bdf_nalgebra_robertson_ode() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = robertson_ode::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 94 + number_of_steps: 340 + number_of_error_test_failures: 2 + number_of_nonlinear_solver_iterations: 950 + number_of_nonlinear_solver_fails: 15 + initial_step_size: 0.000000004564240566951627 + final_step_size: 6155729544.745563 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 952 + number_of_jac_muls: 48 + number_of_matrix_evals: 16 + "###); + } + + #[test] + fn test_bdf_nalgebra_dydt_y2() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = dydt_y2_problem::(false, 10); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 45 + number_of_steps: 192 + number_of_error_test_failures: 1 + number_of_nonlinear_solver_iterations: 538 + number_of_nonlinear_solver_fails: 3 + initial_step_size: 0.00000028403960645516395 + final_step_size: 1.0749050435964294 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 540 + number_of_jac_muls: 40 + number_of_matrix_evals: 4 + "###); + } + + #[test] + fn test_bdf_nalgebra_dydt_y2_colored() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = dydt_y2_problem::(true, 10); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 45 + number_of_steps: 192 + number_of_error_test_failures: 1 + number_of_nonlinear_solver_iterations: 538 + number_of_nonlinear_solver_fails: 3 + initial_step_size: 0.00000028403960645516395 + final_step_size: 1.0749050435964294 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 540 + number_of_jac_muls: 14 + number_of_matrix_evals: 4 + "###); + } + + #[test] + fn test_bdf_nalgebra_gaussian_decay() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = gaussian_decay_problem::(false, 10); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 16 + number_of_steps: 60 + number_of_error_test_failures: 3 + number_of_nonlinear_solver_iterations: 165 + number_of_nonlinear_solver_fails: 0 + initial_step_size: 0.00009999999999999999 + final_step_size: 0.19565537798887184 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 167 + number_of_jac_muls: 10 + number_of_matrix_evals: 1 + "###); + } + + #[test] + fn test_tstop_bdf() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = exponential_decay_problem::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, true); + } + + #[test] + fn test_root_finder_bdf() { + let mut s = Bdf::default(); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = exponential_decay_problem_with_root::(false); + let y = test_ode_solver(&mut s, rs, &problem, soln, None, false); + assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); + } +} diff --git a/src/ode_solver/bdf/faer.rs b/src/ode_solver/bdf/faer.rs deleted file mode 100644 index 4a6b7ac7..00000000 --- a/src/ode_solver/bdf/faer.rs +++ /dev/null @@ -1,21 +0,0 @@ -#[cfg(test)] -mod test { - use crate::{ - ode_solver::tests::{test_interpolate, test_no_set_problem, test_take_state}, - Bdf, - }; - - type M = faer::Mat; - #[test] - fn bdf_no_set_problem() { - test_no_set_problem::(Bdf::::default()) - } - #[test] - fn bdf_take_state() { - test_take_state::(Bdf::::default()) - } - #[test] - fn bdf_test_interpolate() { - test_interpolate::(Bdf::::default()) - } -} diff --git a/src/ode_solver/bdf/nalgebra.rs b/src/ode_solver/bdf/nalgebra.rs deleted file mode 100644 index ed53a9b1..00000000 --- a/src/ode_solver/bdf/nalgebra.rs +++ /dev/null @@ -1,21 +0,0 @@ -#[cfg(test)] -mod test { - use crate::{ - ode_solver::tests::{test_interpolate, test_no_set_problem, test_take_state}, - Bdf, - }; - - type M = nalgebra::DMatrix; - #[test] - fn bdf_no_set_problem() { - test_no_set_problem::(Bdf::default()) - } - #[test] - fn bdf_take_state() { - test_take_state::(Bdf::default()) - } - #[test] - fn bdf_test_interpolate() { - test_interpolate::(Bdf::default()) - } -} diff --git a/src/ode_solver/builder.rs b/src/ode_solver/builder.rs index 953778bd..bb181e77 100644 --- a/src/ode_solver/builder.rs +++ b/src/ode_solver/builder.rs @@ -53,7 +53,7 @@ impl Default for OdeBuilder { /// /// let mut solver = Bdf::default(); /// let t = 0.4; -/// let state = OdeSolverState::new(&problem); +/// let mut state = OdeSolverState::new(&problem, &solver).unwrap(); /// solver.set_problem(state, &problem); /// while solver.state().unwrap().t <= t { /// solver.step().unwrap(); diff --git a/src/ode_solver/diffsl.rs b/src/ode_solver/diffsl.rs index 94bb6800..e8a25091 100644 --- a/src/ode_solver/diffsl.rs +++ b/src/ode_solver/diffsl.rs @@ -37,7 +37,7 @@ pub type M = nalgebra::DMatrix; /// .build_diffsl(&context).unwrap(); /// let mut solver = Bdf::default(); /// let t = 0.4; -/// let state = OdeSolverState::new(&problem); +/// let state = OdeSolverState::new(&problem, &solver).unwrap(); /// solver.set_problem(state, &problem); /// while solver.state().unwrap().t <= t { /// solver.step().unwrap(); diff --git a/src/ode_solver/equations.rs b/src/ode_solver/equations.rs index 255fde9c..f789887d 100644 --- a/src/ode_solver/equations.rs +++ b/src/ode_solver/equations.rs @@ -126,7 +126,7 @@ pub trait OdeEquations { /// /// let mut solver = Bdf::default(); /// let t = 0.4; -/// let state = OdeSolverState::new(&problem); +/// let state = OdeSolverState::new(&problem, &solver).unwrap(); /// solver.set_problem(state, &problem); /// while solver.state().unwrap().t <= t { /// solver.step().unwrap(); diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index fbe5e4ce..aaca443e 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -1,9 +1,10 @@ use anyhow::Result; -use num_traits::Zero; +use num_traits::{One, Pow, Zero}; use std::rc::Rc; use crate::{ - op::filter::FilterCallable, scalar::Scalar, LinearOp, Matrix, NonLinearSolver, OdeEquations, + matrix::default_solver::DefaultSolver, op::filter::FilterCallable, scalar::Scalar, scale, + LinearOp, Matrix, NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquations, OdeSolverProblem, SolverProblem, Vector, VectorIndex, }; @@ -21,10 +22,14 @@ pub enum OdeSolverStopReason { /// # Example /// /// ``` -/// use diffsol::{ OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeEquations }; +/// use diffsol::{ OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeEquations, DefaultSolver }; /// -/// fn solve_ode(solver: &mut impl OdeSolverMethod, problem: &OdeSolverProblem, t: Eqn::T) -> Eqn::V { -/// let state = OdeSolverState::new(problem); +/// fn solve_ode(solver: &mut impl OdeSolverMethod, problem: &OdeSolverProblem, t: Eqn::T) -> Eqn::V +/// where +/// Eqn: OdeEquations, +/// Eqn::M: DefaultSolver, +/// { +/// let state = OdeSolverState::new(problem, solver).unwrap(); /// solver.set_problem(state, problem); /// while solver.state().unwrap().t <= t { /// solver.step().unwrap(); @@ -37,7 +42,8 @@ pub trait OdeSolverMethod { fn problem(&self) -> Option<&OdeSolverProblem>; /// Set the problem to solve, this performs any initialisation required by the solver. Call this before calling `step` or `solve`. - /// The solver takes ownership of the initial state given by `state`, this is assumed to be consistent with any algebraic constraints. + /// The solver takes ownership of the initial state given by `state`, this is assumed to be consistent with any algebraic constraints, + /// and the time step `h` is assumed to be set appropriately for the problem fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem); /// Step the solution forward by one step, altering the internal state of the solver. @@ -57,6 +63,14 @@ pub trait OdeSolverMethod { /// Get the current state of the solver, if it exists fn state(&self) -> Option<&OdeSolverState>; + /// Get a mutable reference to the current state of the solver, if it exists + /// Note that calling this will cause the next call to `step` to perform some reinitialisation to take into + /// account the mutated state, this could be expensive for multi-step methods. + fn state_mut(&mut self) -> Option<&mut OdeSolverState>; + + /// Get the current order of accuracy of the solver (e.g. explict euler method is first-order) + fn order(&self) -> usize; + /// Take the current state of the solver, if it exists, returning it to the user. This is useful if you want to use this /// state in another solver or problem. Note that this will unset the current problem and solver state, so you will need to call /// `set_problem` again before calling `step` or `solve`. @@ -64,7 +78,8 @@ pub trait OdeSolverMethod { /// Reinitialise the solver state and solve the problem up to time `t` fn solve(&mut self, problem: &OdeSolverProblem, t: Eqn::T) -> Result { - let state = OdeSolverState::new(problem); + let mut state = OdeSolverState::new_without_initialise(problem); + state.set_step_size(problem, self.order()); self.set_problem(state, problem); self.set_stop_time(t)?; loop { @@ -82,7 +97,9 @@ pub trait OdeSolverMethod { t: Eqn::T, root_solver: &mut RS, ) -> Result { - let state = OdeSolverState::new_consistent(problem, root_solver)?; + let mut state = OdeSolverState::new_without_initialise(problem); + state.set_consistent(problem, root_solver)?; + state.set_step_size(problem, self.order()); self.set_problem(state, problem); self.set_stop_time(t)?; loop { @@ -98,51 +115,161 @@ pub trait OdeSolverMethod { #[derive(Clone)] pub struct OdeSolverState { pub y: V, + pub f: V, pub t: V::T, pub h: V::T, } impl OdeSolverState { - /// Create a new solver state from an ODE problem. Note that this does not make the state consistent with any algebraic constraints. - /// If you need to make the state consistent, use [Self::new_consistent]. - pub fn new(ode_problem: &OdeSolverProblem) -> Self + /// Create a new solver state from an ODE problem. + /// This function will make the state consistent with any algebraic constraints using a default nonlinear solver. + /// It will also set the initial step size based on the given solver. + /// If you want to create a state without this default initialisation, use [Self::new_without_initialise] instead. + /// You can then use [Self::set_consistent] and [Self::set_step_size] to set the state up if you need to. + pub fn new(ode_problem: &OdeSolverProblem, solver: &S) -> Result + where + Eqn: OdeEquations, + Eqn::M: DefaultSolver, + S: OdeSolverMethod, + { + let t = ode_problem.t0; + let h = ode_problem.h0; + let y = ode_problem.eqn.init(t); + let f = ode_problem.eqn.rhs().call(&y, t); + + let mut ret = Self { y, t, h, f }; + let ls = ::default_solver(); + let mut root_solver = NewtonNonlinearSolver::new(ls); + ret.set_consistent(ode_problem, &mut root_solver)?; + ret.set_step_size(ode_problem, solver.order()); + Ok(ret) + } + + /// Create a new solver state from an ODE problem, without any initialisation. + /// This is useful if you want to set up the state yourself, or if you want to use a different nonlinear solver to make the state consistent, + /// or if you want to set the step size yourself or based on the exact order of the solver. + pub fn new_without_initialise(ode_problem: &OdeSolverProblem) -> Self where Eqn: OdeEquations, { let t = ode_problem.t0; let h = ode_problem.h0; let y = ode_problem.eqn.init(t); - Self { y, t, h } + let f = ode_problem.eqn.rhs().call(&y, t); + Self { y, t, h, f } } - /// Create a new solver state from an [OdeSolverProblem], making the state consistent with the algebraic constraints using a solver that implements [NonLinearSolver]. - /// If there are no algebraic constraints, please use [Self::new] instead. - pub fn new_consistent( + /// Set the state to be consistent with the algebraic constraints of the problem. + pub fn set_consistent( + &mut self, ode_problem: &OdeSolverProblem, root_solver: &mut S, - ) -> Result + ) -> Result<()> where Eqn: OdeEquations, S: NonLinearSolver> + ?Sized, { - let t = ode_problem.t0; - let h = ode_problem.h0; - let mass_diagonal = ode_problem.eqn.mass().matrix(t).diagonal(); + let mass_diagonal = ode_problem.eqn.mass().matrix(self.t).diagonal(); let indices = mass_diagonal.filter_indices(|x| x == Eqn::T::zero()); - let mut y = ode_problem.eqn.init(t); if indices.len() == 0 { - return Ok(Self { y, t, h }); + return Ok(()); } - let mut y_filtered = y.filter(&indices); + let mut y_filtered = self.y.filter(&indices); let atol = Rc::new(ode_problem.atol.as_ref().filter(&indices)); let rhs = ode_problem.eqn.rhs().clone(); - let f = Rc::new(FilterCallable::new(rhs, &y, indices)); + let f = Rc::new(FilterCallable::new(rhs, &self.y, indices)); let rtol = ode_problem.rtol; let init_problem = SolverProblem::new(f, atol, rtol); root_solver.set_problem(&init_problem); - root_solver.solve_in_place(&mut y_filtered, t)?; + root_solver.solve_in_place(&mut y_filtered, self.t)?; let indices = init_problem.f.indices(); - y.scatter_from(&y_filtered, indices); - Ok(Self { y, t, h }) + self.y.scatter_from(&y_filtered, indices); + Ok(()) + } + + /// compute size of first step based on alg in Hairer, Norsett, Wanner + /// Solving Ordinary Differential Equations I, Nonstiff Problems + /// Section II.4.2 + pub fn set_step_size(&mut self, ode_problem: &OdeSolverProblem, solver_order: usize) + where + Eqn: OdeEquations, + { + let y0 = &self.y; + let t0 = self.t; + let f0 = &self.f; + + let mut scale_factor = y0.abs(); + scale_factor *= scale(ode_problem.rtol); + scale_factor += ode_problem.atol.as_ref(); + + let mut tmp = f0.clone(); + tmp.component_div_assign(&scale_factor); + let d0 = tmp.norm(); + + tmp = y0.clone(); + tmp.component_div_assign(&scale_factor); + let d1 = f0.norm(); + + let h0 = if d0 < Eqn::T::from(1e-5) || d1 < Eqn::T::from(1e-5) { + Eqn::T::from(1e-6) + } else { + Eqn::T::from(0.01) * (d0 / d1) + }; + + let y1 = f0.clone() * scale(h0) + y0; + let t1 = t0 + h0; + let f1 = ode_problem.eqn.rhs().call(&y1, t1); + + let mut df = f1 - f0; + df *= scale(Eqn::T::one() / h0); + df.component_div_assign(&scale_factor); + let d2 = df.norm(); + + let mut max_d = d2; + if max_d < d1 { + max_d = d1; + } + let h1 = if max_d < Eqn::T::from(1e-15) { + let h1 = h0 * Eqn::T::from(1e-3); + if h1 < Eqn::T::from(1e-6) { + Eqn::T::from(1e-6) + } else { + h1 + } + } else { + (Eqn::T::from(0.01) / max_d) + .pow(Eqn::T::one() / Eqn::T::from(1.0 + solver_order as f64)) + }; + + self.h = Eqn::T::from(100.0) * h0; + if self.h > h1 { + self.h = h1; + } + + // update initial step size based on function + //let mut scale_factor = state.y.abs(); + //scale_factor *= scale(problem.rtol); + //scale_factor += problem.atol.as_ref(); + + //let f0 = problem.eqn.rhs().call(&state.y, state.t); + //let hf0 = &f0 * scale(state.h); + //let y1 = &state.y + &hf0; + //let t1 = state.t + state.h; + //let f1 = problem.eqn.rhs().call(&y1, t1); + + //// store f1 in diff[1] for use in step size control + //self.diff.column_mut(1).copy_from(&hf0); + + //let mut df = f1 - f0; + //df.component_div_assign(&scale_factor); + //let d2 = df.norm(); + + //let one_over_order_plus_one = + // Eqn::T::one() / (Eqn::T::from(self.order as f64) + Eqn::T::one()); + //let mut new_h = state.h * d2.pow(-one_over_order_plus_one); + //if new_h > Eqn::T::from(100.0) * state.h { + // new_h = Eqn::T::from(100.0) * state.h; + //} + //state.h = new_h; } } diff --git a/src/ode_solver/mod.rs b/src/ode_solver/mod.rs index 178d5974..a776054e 100644 --- a/src/ode_solver/mod.rs +++ b/src/ode_solver/mod.rs @@ -18,33 +18,22 @@ mod tests { use std::rc::Rc; use self::problem::OdeSolverSolution; - use self::test_models::exponential_decay::exponential_decay_problem_with_root; - use super::test_models::{ - exponential_decay::exponential_decay_problem, - exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem, - robertson::robertson, robertson_ode::robertson_ode, - }; use super::*; - use crate::linear_solver::nalgebra::lu::LU; use crate::matrix::Matrix; - use crate::nonlinear_solver::newton::NewtonNonlinearSolver; use crate::op::filter::FilterCallable; use crate::op::unit::UnitCallable; use crate::op::{NonLinearOp, Op}; use crate::scalar::scale; + use crate::Vector; use crate::{ NonLinearSolver, OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, }; - use crate::{Sdirk, Tableau, Vector}; + use num_traits::One; use num_traits::Zero; - use num_traits::{abs, One}; - use tests::bdf::Bdf; - use tests::test_models::dydt_y2::dydt_y2_problem; - use tests::test_models::gaussian_decay::gaussian_decay_problem; - fn test_ode_solver( + pub fn test_ode_solver( method: &mut impl OdeSolverMethod, mut root_solver: impl NonLinearSolver>, problem: &OdeSolverProblem, @@ -56,7 +45,9 @@ mod tests { M: Matrix, Eqn: OdeEquations, { - let state = OdeSolverState::new_consistent(problem, &mut root_solver).unwrap(); + let mut state = OdeSolverState::new_without_initialise(problem); + state.set_consistent(problem, &mut root_solver).unwrap(); + state.set_step_size(problem, method.order()); method.set_problem(state, problem); let have_root = problem.eqn.as_ref().root().is_some(); for point in solution.solution_points.iter() { @@ -107,411 +98,6 @@ mod tests { method.state().unwrap().y.clone() } - type Mcpu = nalgebra::DMatrix; - - #[test] - fn test_tr_bdf2_nalgebra_exponential_decay() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, LU::default()); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 27 - number_of_steps: 27 - number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 0 - number_of_nonlinear_solver_fails: 0 - initial_step_size: 0.1919383103666485 - final_step_size: 0.37881820951293194 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 110 - number_of_jac_muls: 2 - number_of_matrix_evals: 1 - "###); - } - - #[test] - fn test_esdirk34_nalgebra_exponential_decay() { - let tableau = Tableau::::esdirk34(); - let mut s = Sdirk::new(tableau, LU::default()); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 12 - number_of_steps: 12 - number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 0 - number_of_nonlinear_solver_fails: 0 - initial_step_size: 0.28998214001102113 - final_step_size: 0.9543072149538415 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 74 - number_of_jac_muls: 2 - number_of_matrix_evals: 1 - "###); - } - - #[test] - fn test_bdf_nalgebra_exponential_decay() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 19 - number_of_steps: 31 - number_of_error_test_failures: 8 - number_of_nonlinear_solver_iterations: 78 - number_of_nonlinear_solver_fails: 0 - initial_step_size: 0.011892071150027213 - final_step_size: 0.9795994412020951 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 80 - number_of_jac_muls: 2 - number_of_matrix_evals: 1 - "###); - } - - #[cfg(feature = "sundials")] - #[test] - fn test_sundials_exponential_decay() { - let mut s = crate::SundialsIda::default(); - let rs = NewtonNonlinearSolver::new(crate::SundialsLinearSolver::new_dense()); - let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 18 - number_of_steps: 43 - number_of_error_test_failures: 3 - number_of_nonlinear_solver_iterations: 63 - number_of_nonlinear_solver_fails: 0 - initial_step_size: 0.001 - final_step_size: 0.7770043351266953 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 63 - number_of_jac_muls: 36 - number_of_matrix_evals: 18 - "###); - } - - #[test] - fn test_bdf_nalgebra_exponential_decay_algebraic() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = exponential_decay_with_algebraic_problem::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 17 - number_of_steps: 21 - number_of_error_test_failures: 8 - number_of_nonlinear_solver_iterations: 58 - number_of_nonlinear_solver_fails: 0 - initial_step_size: 0.004450050658086208 - final_step_size: 0.20995860176773154 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 62 - number_of_jac_muls: 4 - number_of_matrix_evals: 1 - "###); - } - - #[test] - fn test_tr_bdf2_nalgebra_robertson() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, LU::default()); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = robertson::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 433 - number_of_steps: 415 - number_of_error_test_failures: 6 - number_of_nonlinear_solver_iterations: 0 - number_of_nonlinear_solver_fails: 12 - initial_step_size: 0.0011378590984747281 - final_step_size: 35000974461.348206 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 3157 - number_of_jac_muls: 40 - number_of_matrix_evals: 13 - "###); - } - - #[test] - fn test_esdirk34_nalgebra_robertson() { - let tableau = Tableau::::esdirk34(); - let mut s = Sdirk::new(tableau, LU::default()); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = robertson::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 370 - number_of_steps: 346 - number_of_error_test_failures: 3 - number_of_nonlinear_solver_iterations: 0 - number_of_nonlinear_solver_fails: 21 - initial_step_size: 0.00619535739618413 - final_step_size: 57384898746.15714 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 3406 - number_of_jac_muls: 58 - number_of_matrix_evals: 19 - "###); - } - - #[test] - fn test_bdf_nalgebra_robertson() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = robertson::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 106 - number_of_steps: 345 - number_of_error_test_failures: 5 - number_of_nonlinear_solver_iterations: 985 - number_of_nonlinear_solver_fails: 22 - initial_step_size: 0.0000045643545698038086 - final_step_size: 5435491162.573224 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 988 - number_of_jac_muls: 55 - number_of_matrix_evals: 18 - "###); - } - - #[cfg(feature = "sundials")] - #[test] - fn test_sundials_robertson() { - let mut s = crate::SundialsIda::default(); - let rs = NewtonNonlinearSolver::new(crate::SundialsLinearSolver::new_dense()); - let (problem, soln) = robertson::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 59 - number_of_steps: 355 - number_of_error_test_failures: 15 - number_of_nonlinear_solver_iterations: 506 - number_of_nonlinear_solver_fails: 5 - initial_step_size: 0.001 - final_step_size: 11535117835.253025 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 507 - number_of_jac_muls: 178 - number_of_matrix_evals: 59 - "###); - } - - #[test] - fn test_bdf_nalgebra_robertson_colored() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = robertson::(true); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 106 - number_of_steps: 345 - number_of_error_test_failures: 5 - number_of_nonlinear_solver_iterations: 985 - number_of_nonlinear_solver_fails: 22 - initial_step_size: 0.0000045643545698038086 - final_step_size: 5435491162.573224 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 988 - number_of_jac_muls: 58 - number_of_matrix_evals: 18 - "###); - } - - #[test] - fn test_tr_bdf2_nalgebra_robertson_ode() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, LU::default()); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = robertson_ode::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 242 - number_of_steps: 230 - number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 0 - number_of_nonlinear_solver_fails: 12 - initial_step_size: 0.0010137172178872197 - final_step_size: 45212162967.124176 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 2368 - number_of_jac_muls: 39 - number_of_matrix_evals: 13 - "###); - } - - #[test] - fn test_bdf_nalgebra_robertson_ode() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = robertson_ode::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 106 - number_of_steps: 345 - number_of_error_test_failures: 5 - number_of_nonlinear_solver_iterations: 981 - number_of_nonlinear_solver_fails: 22 - initial_step_size: 0.0000038381494276795106 - final_step_size: 5636682847.540523 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 983 - number_of_jac_muls: 54 - number_of_matrix_evals: 18 - "###); - } - - #[test] - fn test_bdf_nalgebra_dydt_y2() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = dydt_y2_problem::(false, 10); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 65 - number_of_steps: 205 - number_of_error_test_failures: 10 - number_of_nonlinear_solver_iterations: 593 - number_of_nonlinear_solver_fails: 7 - initial_step_size: 0.0000019982428436469115 - final_step_size: 1.0781694150073 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 595 - number_of_jac_muls: 60 - number_of_matrix_evals: 6 - "###); - } - - #[test] - fn test_bdf_nalgebra_dydt_y2_colored() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = dydt_y2_problem::(true, 10); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 65 - number_of_steps: 205 - number_of_error_test_failures: 10 - number_of_nonlinear_solver_iterations: 593 - number_of_nonlinear_solver_fails: 7 - initial_step_size: 0.0000019982428436469115 - final_step_size: 1.0781694150073 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 595 - number_of_jac_muls: 16 - number_of_matrix_evals: 6 - "###); - } - - #[test] - fn test_bdf_nalgebra_gaussian_decay() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = gaussian_decay_problem::(false, 10); - test_ode_solver(&mut s, rs, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 14 - number_of_steps: 58 - number_of_error_test_failures: 2 - number_of_nonlinear_solver_iterations: 159 - number_of_nonlinear_solver_fails: 0 - initial_step_size: 0.0025148668593658707 - final_step_size: 0.19566316816600493 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 161 - number_of_jac_muls: 10 - number_of_matrix_evals: 1 - "###); - } - - #[test] - fn test_tstop_tr_bdf2() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, LU::default()); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, true); - } - - #[test] - fn test_tstop_bdf() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, rs, &problem, soln, None, true); - } - - #[test] - fn test_root_finder_tr_bdf2() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, LU::default()); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = exponential_decay_problem_with_root::(false); - let y = test_ode_solver(&mut s, rs, &problem, soln, None, false); - assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); - } - - #[test] - fn test_root_finder_bdf() { - let mut s = Bdf::default(); - let rs = NewtonNonlinearSolver::new(LU::default()); - let (problem, soln) = exponential_decay_problem_with_root::(false); - let y = test_ode_solver(&mut s, rs, &problem, soln, None, false); - assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); - } - pub struct TestEqnRhs { _m: std::marker::PhantomData, } @@ -593,7 +179,7 @@ mod tests { M::T::zero(), M::T::one(), ); - let state = OdeSolverState::new(&problem); + let state = OdeSolverState::new_without_initialise(&problem); s.set_problem(state.clone(), &problem); let t0 = M::T::zero(); let t1 = M::T::one(); @@ -609,12 +195,12 @@ mod tests { pub fn test_no_set_problem>>(mut s: Method) { assert!(s.state().is_none()); assert!(s.problem().is_none()); - assert!(s.take_state().is_none()); + assert!(s.state().is_none()); assert!(s.step().is_err()); assert!(s.interpolate(M::T::one()).is_err()); } - pub fn test_take_state>>(mut s: Method) { + pub fn test_state_mut>>(mut s: Method) { let problem = OdeSolverProblem::new( TestEqn::new(), M::T::from(1e-6), @@ -622,13 +208,47 @@ mod tests { M::T::zero(), M::T::one(), ); - let state = OdeSolverState::new(&problem); + let state = OdeSolverState::new_without_initialise(&problem); s.set_problem(state.clone(), &problem); - let state2 = s.take_state().unwrap(); + let state2 = s.state().unwrap(); state2.y.assert_eq_st(&state.y, M::T::from(1e-9)); - assert!(s.take_state().is_none()); - assert!(s.state().is_none()); - assert!(s.step().is_err()); - assert!(s.interpolate(M::T::one()).is_err()); + s.state_mut().unwrap().y[0] = M::T::from(std::f64::consts::PI); + assert_eq!(s.state().unwrap().y[0], M::T::from(std::f64::consts::PI)); + } + + pub fn test_state_mut_on_problem>( + mut s: Method, + problem: OdeSolverProblem, + soln: OdeSolverSolution, + ) { + // solve for a little bit + s.solve(&problem, Eqn::T::from(1.0)).unwrap(); + + // reinit using state_mut + let state = OdeSolverState::new_without_initialise(&problem); + s.state_mut().unwrap().y.copy_from(&state.y); + s.state_mut().unwrap().t = state.t; + + // solve and check against solution + for point in soln.solution_points.iter() { + while s.state().unwrap().t < point.t { + s.step().unwrap(); + } + let soln = s.interpolate(point.t).unwrap(); + + let scale = { + let problem = s.problem().unwrap(); + point.state.abs() * scale(problem.rtol) + problem.atol.as_ref() + }; + let mut error = soln.clone() - &point.state; + error.component_div_assign(&scale); + let error_norm = error.norm() / Eqn::T::from((point.state.len() as f64).sqrt()); + assert!( + error_norm < Eqn::T::from(15.0), + "error_norm: {} at t = {}", + error_norm, + point.t + ); + } } } diff --git a/src/ode_solver/sdirk.rs b/src/ode_solver/sdirk.rs index 4cdd34d4..a63404b5 100644 --- a/src/ode_solver/sdirk.rs +++ b/src/ode_solver/sdirk.rs @@ -1,3 +1,4 @@ +use anyhow::anyhow; use anyhow::Result; use num_traits::abs; use num_traits::One; @@ -8,6 +9,7 @@ use std::rc::Rc; use crate::matrix::MatrixRef; use crate::vector::VectorRef; +use crate::LinearSolver; use crate::NewtonNonlinearSolver; use crate::OdeSolverStopReason; use crate::RootFinder; @@ -17,7 +19,6 @@ use crate::{ DenseMatrix, OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Op, Scalar, Vector, VectorViewMut, }; -use crate::{LinearSolver, NonLinearOp}; use super::bdf::BdfStatistics; @@ -47,11 +48,11 @@ where old_t: Eqn::T, old_y: Eqn::V, old_f: Eqn::V, - f: Eqn::V, a_rows: Vec, statistics: BdfStatistics, root_finder: Option>, tstop: Option, + is_state_mutated: bool, } impl Sdirk @@ -142,7 +143,6 @@ where let old_t = Eqn::T::zero(); let old_y = ::zeros(n); let old_f = ::zeros(n); - let f = ::zeros(n); let statistics = BdfStatistics::default(); Self { tableau, @@ -156,10 +156,10 @@ where old_y, a_rows, old_f, - f, statistics, root_finder: None, tstop: None, + is_state_mutated: false, } } @@ -205,66 +205,15 @@ where self.problem.as_ref() } - fn set_problem( - &mut self, - mut state: OdeSolverState<::V>, - problem: &OdeSolverProblem, - ) { - // update initial step size based on function - let mut scale_factor = state.y.abs(); - scale_factor *= scale(problem.rtol); - scale_factor += problem.atol.as_ref(); - - // compute first step based on alg in Hairer, Norsett, Wanner - // Solving Ordinary Differential Equations I, Nonstiff Problems - // Section II.4.2 - let f0 = problem.eqn.rhs().call(&state.y, state.t); - let hf0 = &f0 * scale(state.h); - - let mut tmp = f0.clone(); - tmp.component_div_assign(&scale_factor); - let d0 = tmp.norm(); - - tmp = state.y.clone(); - tmp.component_div_assign(&scale_factor); - let d1 = f0.norm(); - - let h0 = if d0 < Eqn::T::from(1e-5) || d1 < Eqn::T::from(1e-5) { - Eqn::T::from(1e-6) - } else { - Eqn::T::from(0.01) * (d0 / d1) - }; - - let y1 = &state.y + hf0; - let t1 = state.t + h0; - let f1 = problem.eqn.rhs().call(&y1, t1); - - let mut df = f1 - &f0; - df *= scale(Eqn::T::one() / h0); - df.component_div_assign(&scale_factor); - let d2 = df.norm(); - - let mut max_d = d2; - if max_d < d1 { - max_d = d1; - } - let h1 = if max_d < Eqn::T::from(1e-15) { - let h1 = h0 * Eqn::T::from(1e-3); - if h1 < Eqn::T::from(1e-6) { - Eqn::T::from(1e-6) - } else { - h1 - } - } else { - (Eqn::T::from(0.01) / max_d) - .pow(Eqn::T::one() / Eqn::T::from(1.0 + self.tableau.order() as f64)) - }; + fn order(&self) -> usize { + self.tableau.order() + } - state.h = Eqn::T::from(100.0) * h0; - if state.h > h1 { - state.h = h1; - } + fn take_state(&mut self) -> Option> { + Option::take(&mut self.state) + } + fn set_problem(&mut self, state: OdeSolverState<::V>, problem: &OdeSolverProblem) { // setup linear solver for first step let callable = Rc::new(SdirkCallable::new(problem, self.gamma)); callable.set_h(state.h); @@ -276,8 +225,7 @@ where self.statistics.initial_step_size = state.h; self.diff = M::zeros(state.y.len(), self.tableau.s()); - self.old_f = f0.clone(); - self.f = f0; + self.old_f = state.f.clone(); self.old_t = state.t; self.old_y = state.y.clone(); self.state = Some(state); @@ -294,6 +242,9 @@ where fn step(&mut self) -> Result> { // optionally do the first step + if self.state.is_none() { + return Err(anyhow!("State not set")); + } let state = self.state.as_mut().unwrap(); let n = state.y.len(); let y0 = &state.y; @@ -310,7 +261,7 @@ where // if start == 1, then we need to compute the first stage if start == 1 { let mut hf = self.diff.column_mut(0); - hf.copy_from(&self.f); + hf.copy_from(&state.f); hf *= scale(state.h); } for i in start..self.tableau.s() { @@ -451,7 +402,7 @@ where self.old_f .copy_from_view(&self.diff.column(self.diff.ncols() - 1)); self.old_f.mul_assign(scale(Eqn::T::one() / dt)); - std::mem::swap(&mut self.old_f, &mut self.f); + std::mem::swap(&mut self.old_f, &mut state.f); { let y1 = self.nonlinear_solver.problem().f.get_last_f_eval(); @@ -459,6 +410,8 @@ where std::mem::swap(&mut self.old_y, &mut state.y); } + self.is_state_mutated = false; + // update statistics self.statistics.number_of_linear_solver_setups = self.nonlinear_solver.problem().f.number_of_jac_evals(); @@ -502,7 +455,18 @@ where } fn interpolate(&self, t: ::T) -> anyhow::Result<::V> { - let state = self.state.as_ref().expect("State not set"); + if self.state.is_none() { + return Err(anyhow!("State not set")); + } + let state = self.state.as_ref().unwrap(); + + if self.is_state_mutated { + if t == state.t { + return Ok(state.y.clone()); + } else { + return Err(anyhow::anyhow!("Interpolation time is not within the current step. Step size is zero after calling state_mut()")); + } + } // check that t is within the current step if t > state.t || t < self.old_t { @@ -550,11 +514,251 @@ where } } - fn state(&self) -> Option<&OdeSolverState<::V>> { + fn state(&self) -> Option<&OdeSolverState> { self.state.as_ref() } - fn take_state(&mut self) -> Option::V>> { - Option::take(&mut self.state) + fn state_mut(&mut self) -> Option<&mut OdeSolverState> { + self.is_state_mutated = true; + self.state.as_mut() + } +} + +#[cfg(test)] +mod test { + use crate::{ + ode_solver::{ + test_models::{ + exponential_decay::exponential_decay_problem, + exponential_decay::exponential_decay_problem_with_root, robertson::robertson, + robertson_ode::robertson_ode, + }, + tests::{ + test_interpolate, test_no_set_problem, test_ode_solver, test_state_mut, + test_state_mut_on_problem, + }, + }, + NalgebraLU, NewtonNonlinearSolver, OdeEquations, Op, Sdirk, Tableau, + }; + + use num_traits::abs; + + type M = nalgebra::DMatrix; + #[test] + fn sdirk_no_set_problem() { + let tableau = Tableau::::tr_bdf2(); + test_no_set_problem::(Sdirk::::new(tableau, NalgebraLU::default())); + } + #[test] + fn sdirk_state_mut() { + let tableau = Tableau::::tr_bdf2(); + test_state_mut::(Sdirk::::new(tableau, NalgebraLU::default())); + } + #[test] + fn sdirk_test_interpolate() { + let tableau = Tableau::::tr_bdf2(); + test_interpolate::(Sdirk::::new(tableau, NalgebraLU::default())); + } + + #[test] + fn sdirk_test_state_mut_exponential_decay() { + let (p, soln) = exponential_decay_problem::(false); + let tableau = Tableau::::tr_bdf2(); + let s = Sdirk::::new(tableau, NalgebraLU::default()); + test_state_mut_on_problem(s, p, soln); + } + + #[test] + fn test_tr_bdf2_nalgebra_exponential_decay() { + let tableau = Tableau::::tr_bdf2(); + let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = exponential_decay_problem::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 28 + number_of_steps: 28 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 0 + number_of_nonlinear_solver_fails: 0 + initial_step_size: 0.011224620483093733 + final_step_size: 0.37808462088748845 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 114 + number_of_jac_muls: 2 + number_of_matrix_evals: 1 + "###); + } + + #[test] + fn test_esdirk34_nalgebra_exponential_decay() { + let tableau = Tableau::::esdirk34(); + let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = exponential_decay_problem::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 12 + number_of_steps: 12 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 0 + number_of_nonlinear_solver_fails: 0 + initial_step_size: 0.034484882412482154 + final_step_size: 0.9398383410208245 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 74 + number_of_jac_muls: 2 + number_of_matrix_evals: 1 + "###); + } + + #[cfg(feature = "sundials")] + #[test] + fn test_sundials_exponential_decay() { + let mut s = crate::SundialsIda::default(); + let rs = NewtonNonlinearSolver::new(crate::SundialsLinearSolver::new_dense()); + let (problem, soln) = exponential_decay_problem::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 18 + number_of_steps: 43 + number_of_error_test_failures: 3 + number_of_nonlinear_solver_iterations: 63 + number_of_nonlinear_solver_fails: 0 + initial_step_size: 0.001 + final_step_size: 0.7770043351266953 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 65 + number_of_jac_muls: 36 + number_of_matrix_evals: 18 + "###); + } + + #[test] + fn test_tr_bdf2_nalgebra_robertson() { + let tableau = Tableau::::tr_bdf2(); + let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = robertson::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 427 + number_of_steps: 412 + number_of_error_test_failures: 3 + number_of_nonlinear_solver_iterations: 0 + number_of_nonlinear_solver_fails: 12 + initial_step_size: 0.0000030885218897033307 + final_step_size: 35655827121.9909 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 3063 + number_of_jac_muls: 40 + number_of_matrix_evals: 13 + "###); + } + + #[test] + fn test_esdirk34_nalgebra_robertson() { + let tableau = Tableau::::esdirk34(); + let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = robertson::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 212 + number_of_steps: 193 + number_of_error_test_failures: 1 + number_of_nonlinear_solver_iterations: 0 + number_of_nonlinear_solver_fails: 18 + initial_step_size: 0.00007367379016174295 + final_step_size: 44328923924.83207 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 2478 + number_of_jac_muls: 58 + number_of_matrix_evals: 19 + "###); + } + + #[cfg(feature = "sundials")] + #[test] + fn test_sundials_robertson() { + let mut s = crate::SundialsIda::default(); + let rs = NewtonNonlinearSolver::new(crate::SundialsLinearSolver::new_dense()); + let (problem, soln) = robertson::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 59 + number_of_steps: 355 + number_of_error_test_failures: 15 + number_of_nonlinear_solver_iterations: 506 + number_of_nonlinear_solver_fails: 5 + initial_step_size: 0.001 + final_step_size: 11535117835.253025 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 509 + number_of_jac_muls: 178 + number_of_matrix_evals: 59 + "###); + } + + #[test] + fn test_tr_bdf2_nalgebra_robertson_ode() { + let tableau = Tableau::::tr_bdf2(); + let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = robertson_ode::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, false); + insta::assert_yaml_snapshot!(s.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 248 + number_of_steps: 233 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 0 + number_of_nonlinear_solver_fails: 15 + initial_step_size: 0.0000027515601924872376 + final_step_size: 31858152718.061752 + "###); + insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + --- + number_of_calls: 2398 + number_of_jac_muls: 42 + number_of_matrix_evals: 14 + "###); + } + + #[test] + fn test_tstop_tr_bdf2() { + let tableau = Tableau::::tr_bdf2(); + let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = exponential_decay_problem::(false); + test_ode_solver(&mut s, rs, &problem, soln, None, true); + } + + #[test] + fn test_root_finder_tr_bdf2() { + let tableau = Tableau::::tr_bdf2(); + let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let rs = NewtonNonlinearSolver::new(NalgebraLU::default()); + let (problem, soln) = exponential_decay_problem_with_root::(false); + let y = test_ode_solver(&mut s, rs, &problem, soln, None, false); + assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); } } diff --git a/src/ode_solver/sundials.rs b/src/ode_solver/sundials.rs index 0ad09284..9fa27ded 100644 --- a/src/ode_solver/sundials.rs +++ b/src/ode_solver/sundials.rs @@ -7,12 +7,13 @@ use std::{ }; use sundials_sys::{ realtype, IDACalcIC, IDACreate, IDAFree, IDAGetDky, IDAGetIntegratorStats, - IDAGetNonlinSolvStats, IDAGetReturnFlagName, IDAInit, IDASVtolerances, IDASetId, IDASetJacFn, - IDASetLinearSolver, IDASetStopTime, IDASetUserData, IDASolve, N_Vector, SUNLinSolFree, - SUNLinSolInitialize, SUNLinSol_Dense, SUNLinearSolver, SUNMatrix, IDA_CONSTR_FAIL, - IDA_CONV_FAIL, IDA_ERR_FAIL, IDA_ILL_INPUT, IDA_LINIT_FAIL, IDA_LSETUP_FAIL, IDA_LSOLVE_FAIL, - IDA_MEM_NULL, IDA_ONE_STEP, IDA_REP_RES_ERR, IDA_RES_FAIL, IDA_ROOT_RETURN, IDA_RTFUNC_FAIL, - IDA_SUCCESS, IDA_TOO_MUCH_ACC, IDA_TOO_MUCH_WORK, IDA_TSTOP_RETURN, IDA_YA_YDP_INIT, + IDAGetNonlinSolvStats, IDAGetReturnFlagName, IDAInit, IDAReInit, IDASVtolerances, IDASetId, + IDASetJacFn, IDASetLinearSolver, IDASetStopTime, IDASetUserData, IDASolve, N_Vector, + SUNLinSolFree, SUNLinSolInitialize, SUNLinSol_Dense, SUNLinearSolver, SUNMatrix, + IDA_CONSTR_FAIL, IDA_CONV_FAIL, IDA_ERR_FAIL, IDA_ILL_INPUT, IDA_LINIT_FAIL, IDA_LSETUP_FAIL, + IDA_LSOLVE_FAIL, IDA_MEM_NULL, IDA_ONE_STEP, IDA_REP_RES_ERR, IDA_RES_FAIL, IDA_ROOT_RETURN, + IDA_RTFUNC_FAIL, IDA_SUCCESS, IDA_TOO_MUCH_ACC, IDA_TOO_MUCH_WORK, IDA_TSTOP_RETURN, + IDA_YA_YDP_INIT, }; use crate::{ @@ -134,6 +135,7 @@ where jacobian: SundialsMatrix, statistics: SundialsStatistics, state: Option>, + is_state_modified: bool, } impl SundialsIda @@ -203,6 +205,7 @@ where statistics: SundialsStatistics::new(), jacobian, state: None, + is_state_modified: false, } } @@ -267,10 +270,19 @@ where self.problem.as_ref() } - fn state(&self) -> Option<&OdeSolverState<::V>> { + fn state(&self) -> Option<&OdeSolverState> { self.state.as_ref() } + fn order(&self) -> usize { + 1 + } + + fn state_mut(&mut self) -> Option<&mut OdeSolverState> { + self.is_state_modified = true; + self.state.as_mut() + } + fn take_state(&mut self) -> Option::V>> { Option::take(&mut self.state) } @@ -290,7 +302,7 @@ where .unwrap(); // initialize - self.yp = SundialsVector::zeros(number_of_states); + self.yp = ::zeros(number_of_states); Self::check(unsafe { IDAInit( ida_mem, @@ -335,6 +347,17 @@ where if self.problem.is_none() { return Err(anyhow!("Problem not set")); } + if self.is_state_modified { + // reinit as state has been modified + Self::check(unsafe { + IDAReInit( + self.ida_mem, + state.t, + state.y.sundials_vector(), + self.yp.sundials_vector(), + ) + })? + } let itask = IDA_ONE_STEP; let retval = unsafe { IDASolve( @@ -389,7 +412,7 @@ where #[cfg(test)] mod test { use crate::{ - ode_solver::tests::{test_interpolate, test_no_set_problem, test_take_state}, + ode_solver::tests::{test_interpolate, test_no_set_problem, test_state_mut}, SundialsIda, SundialsMatrix, }; @@ -399,8 +422,8 @@ mod test { test_no_set_problem::(SundialsIda::default()) } #[test] - fn sundials_take_state() { - test_take_state::(SundialsIda::default()) + fn sundials_state_mut() { + test_state_mut::(SundialsIda::default()) } #[test] fn sundials_interpolate() {