From 8bcfc60da2895d4a762b4176b5d69c66aafca961 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Sat, 30 Nov 2024 09:31:49 +0000 Subject: [PATCH] refactor: improvements to data ownership and lifetimes (#109) * solvers now created from a problem, solvers have a lifetime linked to the problem * data ownership now much clearer across the structs: * equation struct owned by problem * tolerances owned by problem * parameter vector owned by equation * equation trait refactored to allow data to be shared across op structs (owned by struct implementing the equation trait) * builder pattern extended to generic parameters, rhs, root, mass etc can be set independently and adjoint and sensitivitity equations supported --- benches/ode_solvers.rs | 74 +- book/src/SUMMARY.md | 3 +- book/src/choosing_a_solver.md | 105 +- book/src/initialisation.md | 34 - book/src/lib.rs | 1 + book/src/primer/bouncing_ball.md | 26 +- .../compartmental_models_of_drug_delivery.md | 24 +- book/src/primer/electrical_circuits.md | 10 +- book/src/primer/images/prey-predator2.html | 2 +- book/src/primer/population_dynamics.md | 26 +- book/src/primer/spring_mass_systems.md | 10 +- book/src/solving_the_problem.md | 166 +-- book/src/sparse_problems.md | 1 - book/src/specify/custom/constant_functions.md | 22 +- .../specify/custom/custom_problem_structs.md | 19 +- book/src/specify/custom/linear_functions.md | 21 +- .../specify/custom/non_linear_functions.md | 75 +- book/src/specify/custom/ode_systems.md | 371 ++++++ .../specify/custom/putting_it_all_together.md | 102 -- book/src/specify/diffsl.md | 10 +- book/src/specify/forward_sensitivity.md | 59 +- book/src/specify/mass_matrix.md | 36 +- book/src/specify/ode_equations.md | 39 +- book/src/specify/root_finding.md | 41 +- book/src/specify/sparse_problems.md | 166 +-- book/src/specify/specifying_the_problem.md | 2 +- src/jacobian/mod.rs | 38 +- src/lib.rs | 59 +- src/linear_solver/faer/lu.rs | 9 +- src/linear_solver/faer/sparse_lu.rs | 4 - src/linear_solver/mod.rs | 47 +- src/linear_solver/nalgebra/lu.rs | 4 - src/linear_solver/suitesparse/klu.rs | 16 +- src/linear_solver/sundials.rs | 24 +- src/matrix/sundials.rs | 2 +- src/nonlinear_solver/convergence.rs | 13 +- src/nonlinear_solver/mod.rs | 54 +- src/nonlinear_solver/newton.rs | 31 +- src/nonlinear_solver/root.rs | 9 +- src/ode_solver/adjoint_equations.rs | 259 ++-- src/ode_solver/bdf.rs | 1117 ++++++++-------- src/ode_solver/bdf_state.rs | 14 +- src/ode_solver/builder.rs | 1126 +++++++++-------- src/ode_solver/checkpointing.rs | 81 +- src/ode_solver/diffsl.rs | 53 +- src/ode_solver/equations.rs | 276 ++-- src/ode_solver/jacobian_update.rs | 1 + src/ode_solver/method.rs | 364 +++--- src/ode_solver/mod.rs | 408 ++---- src/ode_solver/problem.rs | 320 ++++- src/ode_solver/sdirk.rs | 947 +++++++------- src/ode_solver/sens_equations.rs | 109 +- src/ode_solver/state.rs | 137 +- src/ode_solver/sundials.rs | 571 --------- src/ode_solver/test_models/dydt_y2.rs | 8 +- .../test_models/exponential_decay.rs | 152 +-- .../exponential_decay_with_algebraic.rs | 192 +-- src/ode_solver/test_models/foodweb.rs | 27 +- src/ode_solver/test_models/gaussian_decay.rs | 10 +- src/ode_solver/test_models/heat2d.rs | 18 +- src/ode_solver/test_models/robertson.rs | 91 +- src/ode_solver/test_models/robertson_ode.rs | 13 +- .../test_models/robertson_ode_with_sens.rs | 7 +- src/op/bdf.rs | 54 +- src/op/closure.rs | 63 +- src/op/closure_no_jac.rs | 44 +- src/op/closure_with_adjoint.rs | 94 +- src/op/closure_with_sens.rs | 80 +- src/op/constant_closure.rs | 45 +- src/op/constant_closure_with_adjoint.rs | 53 +- src/op/constant_closure_with_sens.rs | 47 +- src/op/init.rs | 17 +- src/op/linear_closure.rs | 52 +- src/op/linear_closure_with_adjoint.rs | 67 +- src/op/mod.rs | 64 +- src/op/sdirk.rs | 36 +- src/op/unit.rs | 33 +- src/vector/sundials.rs | 10 +- 78 files changed, 4361 insertions(+), 4454 deletions(-) delete mode 100644 book/src/initialisation.md delete mode 100644 book/src/sparse_problems.md create mode 100644 book/src/specify/custom/ode_systems.md delete mode 100644 book/src/specify/custom/putting_it_all_together.md delete mode 100644 src/ode_solver/sundials.rs diff --git a/benches/ode_solvers.rs b/benches/ode_solvers.rs index 0d192f63..b8cd9e23 100644 --- a/benches/ode_solvers.rs +++ b/benches/ode_solvers.rs @@ -17,9 +17,11 @@ fn criterion_benchmark(c: &mut Criterion) { ($name:ident, $solver:ident, $linear_solver:ident, $model:ident, $model_problem:ident, $matrix:ty) => { c.bench_function(stringify!($name), |b| { b.iter(|| { - let ls = $linear_solver::default(); let (problem, soln) = $model_problem::<$matrix>(false); - benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls); + benchmarks::$solver::<_, $linear_solver<_>>( + &problem, + soln.solution_points.last().unwrap().t, + ); }) }); }; @@ -126,9 +128,8 @@ fn criterion_benchmark(c: &mut Criterion) { ($name:ident, $solver:ident, $linear_solver:ident, $model:ident, $model_problem:ident, $matrix:ty, $($N:expr),+) => { $(c.bench_function(concat!(stringify!($name), "_", $N), |b| { b.iter(|| { - let ls = $linear_solver::default(); let (problem, soln) = $model_problem::<$matrix>(false, $N); - benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls); + benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t); }) });)+ }; @@ -219,12 +220,14 @@ fn criterion_benchmark(c: &mut Criterion) { ($name:ident, $solver:ident, $linear_solver:ident, $matrix:ty) => { #[cfg(feature = "diffsl-llvm")] c.bench_function(stringify!($name), |b| { - use diffsol::diffsl::LlvmModule; use diffsol::ode_solver::test_models::robertson::*; + use diffsol::LlvmModule; b.iter(|| { let (problem, soln) = robertson_diffsl_problem::<$matrix, LlvmModule>(); - let ls = $linear_solver::default(); - benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) + benchmarks::$solver::<_, $linear_solver<_>>( + &problem, + soln.solution_points.last().unwrap().t, + ) }) }); }; @@ -242,8 +245,7 @@ fn criterion_benchmark(c: &mut Criterion) { $(c.bench_function(concat!(stringify!($name), "_", $N), |b| { b.iter(|| { let (problem, soln) = $model_problem::<$matrix, $N>(); - let ls = $linear_solver::default(); - benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) + benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t) }) });)+ }; @@ -334,8 +336,7 @@ fn criterion_benchmark(c: &mut Criterion) { $(c.bench_function(concat!(stringify!($name), "_", $N), |b| { b.iter(|| { let (problem, soln) = $model_problem::<$matrix, $N>(); - let ls = $linear_solver::default(); - benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) + benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t) }) });)+ }; @@ -424,11 +425,10 @@ fn criterion_benchmark(c: &mut Criterion) { $(#[cfg(feature = "diffsl-llvm")] c.bench_function(concat!(stringify!($name), "_", $N), |b| { use diffsol::ode_solver::test_models::heat2d::*; - use diffsol::diffsl::LlvmModule; + use diffsol::LlvmModule; b.iter(|| { let (problem, soln) = heat2d_diffsl_problem::<$matrix, LlvmModule, $N>(); - let ls = $linear_solver::default(); - benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) + benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t) }) });)+ }; @@ -499,11 +499,10 @@ fn criterion_benchmark(c: &mut Criterion) { $(#[cfg(feature = "diffsl-llvm")] c.bench_function(concat!(stringify!($name), "_", $N), |b| { use diffsol::ode_solver::test_models::foodweb::*; - use diffsol::diffsl::LlvmModule; + use diffsol::LlvmModule; b.iter(|| { let (problem, soln) = foodweb_diffsl_problem::<$matrix, LlvmModule, $N>(); - let ls = $linear_solver::default(); - benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) + benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t) }) });)+ @@ -542,56 +541,47 @@ mod benchmarks { use diffsol::vector::VectorRef; use diffsol::LinearSolver; use diffsol::{ - Bdf, DefaultDenseMatrix, DefaultSolver, Matrix, NewtonNonlinearSolver, - OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Sdirk, Tableau, + DefaultDenseMatrix, DefaultSolver, Matrix, OdeEquationsImplicit, OdeSolverMethod, + OdeSolverProblem, }; // bdf - pub fn bdf(problem: &OdeSolverProblem, t: Eqn::T, ls: impl LinearSolver) + pub fn bdf(problem: &OdeSolverProblem, t: Eqn::T) where Eqn: OdeEquationsImplicit, Eqn::M: Matrix + DefaultSolver, Eqn::V: DefaultDenseMatrix, + LS: LinearSolver, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { - let nls = NewtonNonlinearSolver::new(ls); - let mut s = Bdf::<::M, _, _>::new(nls); - let state = OdeSolverState::new(problem, &s).unwrap(); - let _y = s.solve(problem, state, t); + let mut s = problem.bdf::().unwrap(); + let _y = s.solve(t); } - pub fn esdirk34( - problem: &OdeSolverProblem, - t: Eqn::T, - linear_solver: impl LinearSolver, - ) where + pub fn esdirk34(problem: &OdeSolverProblem, t: Eqn::T) + where Eqn: OdeEquationsImplicit, Eqn::M: Matrix + DefaultSolver, Eqn::V: DefaultDenseMatrix, + LS: LinearSolver, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { - let tableau = Tableau::<::M>::esdirk34(); - let mut s = Sdirk::new(tableau, linear_solver); - let state = OdeSolverState::new(problem, &s).unwrap(); - let _y = s.solve(problem, state, t); + let mut s = problem.esdirk34::().unwrap(); + let _y = s.solve(t); } - pub fn tr_bdf2( - problem: &OdeSolverProblem, - t: Eqn::T, - linear_solver: impl LinearSolver, - ) where + pub fn tr_bdf2(problem: &OdeSolverProblem, t: Eqn::T) + where Eqn: OdeEquationsImplicit, Eqn::M: Matrix + DefaultSolver, Eqn::V: DefaultDenseMatrix, + LS: LinearSolver, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { - let tableau = Tableau::<::M>::tr_bdf2(); - let mut s = Sdirk::new(tableau, linear_solver); - let state = OdeSolverState::new(problem, &s).unwrap(); - let _y = s.solve(problem, state, t); + let mut s = problem.tr_bdf2::().unwrap(); + let _y = s.solve(t); } } diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index e5f6cf53..1e2aa173 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -18,10 +18,9 @@ - [Non-linear functions](./specify/custom/non_linear_functions.md) - [Constant functions](./specify/custom/constant_functions.md) - [Linear functions](./specify/custom/linear_functions.md) - - [Putting it all together](./specify/custom/putting_it_all_together.md) + - [ODE systems](./specify/custom/ode_systems.md) - [DiffSL](./specify/diffsl.md) - [Sparse problems](./specify/sparse_problems.md) - [Choosing a solver](./choosing_a_solver.md) -- [Initialisation](./initialisation.md) - [Solving the problem](./solving_the_problem.md) - [Benchmarks](./benchmarks.md) diff --git a/book/src/choosing_a_solver.md b/book/src/choosing_a_solver.md index c3f598a0..e040c833 100644 --- a/book/src/choosing_a_solver.md +++ b/book/src/choosing_a_solver.md @@ -3,75 +3,88 @@ Once you have defined the problem, you need to create a solver to solve the problem. The available solvers are: - [`diffsol::Bdf`](https://docs.rs/diffsol/latest/diffsol/ode_solver/bdf/struct.Bdf.html): A Backwards Difference Formulae solver, suitable for stiff problems and singular mass matrices. - [`diffsol::Sdirk`](https://docs.rs/diffsol/latest/diffsol/ode_solver/sdirk/struct.Sdirk.html) A Singly Diagonally Implicit Runge-Kutta (SDIRK or ESDIRK) solver. You can define your own butcher tableau using [`Tableau`](https://docs.rs/diffsol/latest/diffsol/ode_solver/tableau/struct.Tableau.html) or use one of the pre-defined tableaues. - -Each of these solvers has a number of generic arguments, for example the `Bdf` solver has three generic arguments: -- `M`: The matrix type used to define the problem. -- `Eqn`: The type of the equations struct that defines the problem. -- `Nls`: The type of the non-linear solver used to solve the implicit equations in the solver. +For each solver, you will need to specify the linear solver type to use. The available linear solvers are: +- [`diffsol::NalgebraLU`](https://docs.rs/diffsol/latest/diffsol/linear_solver/nalgebra_lu/struct.NalgebraLU.html): A LU decomposition solver using the [nalgebra](https://nalgebra.org) crate. +- [`diffsol::FaerLU`](https://docs.rs/diffsol/latest/diffsol/linear_solver/faer_lu/struct.FaerLU.html): A LU decomposition solver using the [faer](https://github.com/sarah-ek/faer-rs) crate. +- [`diffsol::FaerSparseLU`](https://docs.rs/diffsol/latest/diffsol/linear_solver/faer_sparse_lu/struct.FaerSparseLU.html): A sparse LU decomposition solver using the `faer` crate. -In normal use cases, Rust can infer these from your code so you don't need to specify these explicitly. The `Bdf` solver implements the `Default` trait so can be easily created using: +Each solver can be created directly, but it generally easier to use the methods on the [`OdeSolverProblem`](https://docs.rs/diffsol/latest/diffsol/ode_solver/problem/struct.OdeSolverProblem.html) struct to create the solver. +For example: ```rust # use diffsol::OdeBuilder; # use nalgebra::DVector; +use diffsol::{OdeSolverState, NalgebraLU, BdfState, Tableau, SdirkState}; # type M = nalgebra::DMatrix; -use diffsol::{Bdf, OdeSolverState, OdeSolverMethod}; +type LS = NalgebraLU; # fn main() { # -# let problem = OdeBuilder::new() +# let problem = OdeBuilder::::new() # .p(vec![1.0, 10.0]) -# .build_ode::( +# .rhs_implicit( # |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); -# } -``` +# ) +# .init(|_p, _t| DVector::from_element(1, 0.1)) +# .build() +# .unwrap(); +// Create a BDF solver with an initial state +let solver = problem.bdf::(); -The `Sdirk` solver requires a tableu to be specified so you can use its `new` method to create a new solver, for example using the `tr_bdf2` tableau: +// Create a non-initialised state and manually set the values before +// creating the solver +let state = BdfState::new_without_initialise(&problem).unwrap(); +// ... set the state values manually +let solver = problem.bdf_solver::(state); -```rust -# use diffsol::{OdeBuilder}; -# use nalgebra::DVector; -# type M = nalgebra::DMatrix; -use diffsol::{Sdirk, Tableau, NalgebraLU, OdeSolverState, OdeSolverMethod}; -# fn main() { -# let problem = OdeBuilder::new() -# .p(vec![1.0, 10.0]) -# .build_ode::( -# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), -# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Sdirk::new(Tableau::::tr_bdf2(), NalgebraLU::default()); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); +// Create a SDIRK solver with a pre-defined tableau +let tableau = Tableau::::tr_bdf2(); +let state = problem.sdirk_state::(&tableau).unwrap(); +let solver = problem.sdirk_solver::(state, tableau); + +// Create a tr_bdf2 or esdirk34 solvers directly (both are SDIRK solvers with different tableaus) +let solver = problem.tr_bdf2::(); +let solver = problem.esdirk34::(); + +// Create a non-initialised state and manually set the values before +// creating the solver +let state = SdirkState::new_without_initialise(&problem).unwrap(); +// ... set the state values manually +let solver = problem.tr_bdf2_solver::(state); # } ``` -You can also use one of the helper functions to create a SDIRK solver with a pre-defined tableau, which will create it with the default linear solver: +# Initialisation + +Each solver has an internal state that holds information like the current state vector, the gradient of the state vector, the current time, and the current step size. When you create a solver using the `bdf` or `sdirk` methods on the `OdeSolverProblem` struct, the solver will be initialised with an initial state based on the initial conditions of the problem as well as satisfying any algebraic constraints. An initial time step will also be chosen based on your provided equations. + +Each solver's state struct implements the [`OdeSolverState`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/trait.OdeSolverState.html) trait, and if you wish to manually create and setup a state, you can use the methods on this trait to do so. + +For example, say that you wish to bypass the initialisation of the state as you already have the algebraic constraints and so don't need to solve for them. You can use the `new_without_initialise` method on the `OdeSolverState` trait to create a new state without initialising it. You can then use the `as_mut` method to get a mutable reference to the state and set the values manually. + +Note that each state struct has a [`as_ref`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/trait.OdeSolverState.html#tymethod.as_ref) and [`as_mut`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/trait.OdeSolverState.html#tymethod.as_mut) methods that return a [`StateRef`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/struct.StateRef.html) or ['StateRefMut`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/struct.StateRefMut.html) struct respectively. These structs provide a solver-independent way to access the state values so you can use the same code with different solvers. ```rust -# use diffsol::{OdeBuilder}; +# use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{Sdirk, Tableau, NalgebraLU, OdeSolverState, OdeSolverMethod}; +use diffsol::{OdeSolverState, NalgebraLU, BdfState}; +type LS = NalgebraLU; + # fn main() { -# let problem = OdeBuilder::new() +# +# let problem = OdeBuilder::::new() # .p(vec![1.0, 10.0]) -# .build_ode::( +# .rhs_implicit( # |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Sdirk::tr_bdf2(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); +# ) +# .init(|_p, _t| DVector::from_element(1, 0.1)) +# .build() +# .unwrap(); +let mut state = BdfState::new_without_initialise(&problem).unwrap(); +state.as_mut().y[0] = 0.1; +let mut solver = problem.bdf_solver::(state); # } -``` - - +``` \ No newline at end of file diff --git a/book/src/initialisation.md b/book/src/initialisation.md deleted file mode 100644 index 1d37f1d4..00000000 --- a/book/src/initialisation.md +++ /dev/null @@ -1,34 +0,0 @@ -# Initialisation - -Before you can solve the problem, you need to generate an intitial state for the solution. DiffSol uses the [`OdeSolverState`](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/struct.OdeSolverState.html) -struct to hold the current state of the solution, this is a struct that contains the state vector, the gradient of the state vector, the time, and the current step size. - -You can create a new state for an ODE problem using the [`OdeSolverState::new`](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/struct.OdeSolverState.html#method.new) method, -which takes as arguments the problem and solver instances. -This method uses the \\(y_0(p, t)\\) closure to generate an intial state vector, and the \\(f(y, p, t)\\) closure to generate the gradient of the state vector. It will also set the time to the initial time -given by the `OdeSolverProblem` struct, and will guess a suitable step size based on the initial state vector and the gradient of the state vector. If you want to set the step size manually or have -more control over the initialisation of the state, you can use the [`OdeSolverState::new_without_initialise`](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/struct.OdeSolverState.html#method.new_without_initialise) method. - -Once the state is created then you can use the state and the problem to initialise the solver in preparation for solving the problem. - -```rust -# use diffsol::OdeBuilder; -# use nalgebra::DVector; -# type M = nalgebra::DMatrix; -use diffsol::{OdeSolverState, OdeSolverMethod, Bdf}; - -# fn main() { -# -# let problem = OdeBuilder::new() -# .p(vec![1.0, 10.0]) -# .build_ode::( -# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), -# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); -# } -``` - diff --git a/book/src/lib.rs b/book/src/lib.rs index e69de29b..8b137891 100644 --- a/book/src/lib.rs +++ b/book/src/lib.rs @@ -0,0 +1 @@ + diff --git a/book/src/primer/bouncing_ball.md b/book/src/primer/bouncing_ball.md index 1c1b242a..7fbdf2ec 100644 --- a/book/src/primer/bouncing_ball.md +++ b/book/src/primer/bouncing_ball.md @@ -61,14 +61,14 @@ In code, the bouncing ball problem can be solved using DiffSol as follows: # fn main() { # use std::fs; use diffsol::{ - DiffSl, CraneliftModule, OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod, - OdeSolverStopReason, + DiffSl, CraneliftModule, OdeBuilder, OdeSolverMethod, OdeSolverStopReason, }; use plotly::{ Plot, Scatter, common::Mode, layout::Layout, layout::Axis }; type M = nalgebra::DMatrix; type CG = CraneliftModule; +type LS = diffsol::NalgebraLU; let eqn = DiffSl::::compile(" g { 9.81 } h { 10.0 } @@ -86,10 +86,8 @@ let eqn = DiffSl::::compile(" ").unwrap(); let e = 0.8; -let problem = OdeBuilder::new().build_from_eqn(eqn).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem).unwrap(); +let problem = OdeBuilder::::new().build_from_eqn(eqn).unwrap(); +let mut solver = problem.bdf::().unwrap(); let mut x = Vec::new(); let mut v = Vec::new(); @@ -97,8 +95,8 @@ let mut t = Vec::new(); let final_time = 10.0; // save the initial state -x.push(solver.state().unwrap().y[0]); -v.push(solver.state().unwrap().y[1]); +x.push(solver.state().y[0]); +v.push(solver.state().y[1]); t.push(0.0); // solve until the final time is reached @@ -117,16 +115,16 @@ loop { y[0] = y[0].max(f64::EPSILON); // set the state to the updated state - solver.state_mut().unwrap().y.copy_from(&y); - solver.state_mut().unwrap().dy[0] = y[1]; - *solver.state_mut().unwrap().t = t; + solver.state_mut().y.copy_from(&y); + solver.state_mut().dy[0] = y[1]; + *solver.state_mut().t = t; }, Ok(OdeSolverStopReason::TstopReached) => break, Err(_) => panic!("unexpected solver error"), } - x.push(solver.state().unwrap().y[0]); - v.push(solver.state().unwrap().y[1]); - t.push(solver.state().unwrap().t); + x.push(solver.state().y[0]); + v.push(solver.state().y[1]); + t.push(solver.state().t); } let mut plot = Plot::new(); let x = Scatter::new(t.clone(), x).mode(Mode::Lines).name("x"); diff --git a/book/src/primer/compartmental_models_of_drug_delivery.md b/book/src/primer/compartmental_models_of_drug_delivery.md index 09f992d2..0507a4c7 100644 --- a/book/src/primer/compartmental_models_of_drug_delivery.md +++ b/book/src/primer/compartmental_models_of_drug_delivery.md @@ -80,14 +80,14 @@ Let's now solve this system of ODEs using DiffSol. # fn main() { # use std::fs; use diffsol::{ - DiffSl, CraneliftModule, OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod, - OdeSolverStopReason, + DiffSl, CraneliftModule, OdeBuilder, OdeSolverMethod, OdeSolverStopReason, }; use plotly::{ Plot, Scatter, common::Mode, layout::Layout, layout::Axis }; type M = nalgebra::DMatrix; type CG = CraneliftModule; +type LS = diffsol::NalgebraLU; let eqn = DiffSl::::compile(" Vc { 1000.0 } Vp1 { 1000.0 } CL { 100.0 } Qp1 { 50.0 } @@ -101,20 +101,18 @@ let eqn = DiffSl::::compile(" } ").unwrap(); -let problem = OdeBuilder::new().build_from_eqn(eqn).unwrap(); -let mut solver = Bdf::default(); +let problem = OdeBuilder::::new().build_from_eqn(eqn).unwrap(); +let mut solver = problem.bdf::().unwrap(); let doses = vec![(0.0, 1000.0), (6.0, 1000.0), (12.0, 1000.0), (18.0, 1000.0)]; -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem).unwrap(); let mut q_c = Vec::new(); let mut q_p1 = Vec::new(); let mut time = Vec::new(); // apply the first dose and save the initial state -solver.state_mut().unwrap().y[0] = doses[0].1; -q_c.push(solver.state().unwrap().y[0]); -q_p1.push(solver.state().unwrap().y[1]); +solver.state_mut().y[0] = doses[0].1; +q_c.push(solver.state().y[0]); +q_p1.push(solver.state().y[1]); time.push(0.0); // solve and apply the remaining doses @@ -122,16 +120,16 @@ for (t, dose) in doses.into_iter().skip(1) { solver.set_stop_time(t).unwrap(); loop { let ret = solver.step(); - q_c.push(solver.state().unwrap().y[0]); - q_p1.push(solver.state().unwrap().y[1]); - time.push(solver.state().unwrap().t); + q_c.push(solver.state().y[0]); + q_p1.push(solver.state().y[1]); + time.push(solver.state().t); match ret { Ok(OdeSolverStopReason::InternalTimestep) => continue, Ok(OdeSolverStopReason::TstopReached) => break, _ => panic!("unexpected solver error"), } } - solver.state_mut().unwrap().y[0] += dose; + solver.state_mut().y[0] += dose; } let mut plot = Plot::new(); let q_c = Scatter::new(time.clone(), q_c).mode(Mode::Lines).name("q_c"); diff --git a/book/src/primer/electrical_circuits.md b/book/src/primer/electrical_circuits.md index ec3162d6..7cf4e8d0 100644 --- a/book/src/primer/electrical_circuits.md +++ b/book/src/primer/electrical_circuits.md @@ -85,13 +85,14 @@ We can solve this system of equations using DiffSol and plot the current and vol # fn main() { # use std::fs; use diffsol::{ - DiffSl, CraneliftModule, OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod + DiffSl, CraneliftModule, OdeBuilder, OdeSolverMethod }; use plotly::{ Plot, Scatter, common::Mode, layout::Layout, layout::Axis }; type M = nalgebra::DMatrix; type CG = CraneliftModule; +type LS = diffsol::NalgebraLU; let eqn = DiffSl::::compile(" R { 100.0 } L { 1.0 } C { 0.001 } V0 { 10 } omega { 100.0 } @@ -124,11 +125,10 @@ let eqn = DiffSl::::compile(" iR, } ").unwrap(); -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .build_from_eqn(eqn).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -let (ys, ts) = solver.solve(&problem, state, 1.0).unwrap(); +let mut solver = problem.bdf::().unwrap(); +let (ys, ts) = solver.solve(1.0).unwrap(); let ir: Vec<_> = ys.row(0).into_iter().copied().collect(); let t: Vec<_> = ts.into_iter().collect(); diff --git a/book/src/primer/images/prey-predator2.html b/book/src/primer/images/prey-predator2.html index 54385fa7..d46d1f16 100644 --- a/book/src/primer/images/prey-predator2.html +++ b/book/src/primer/images/prey-predator2.html @@ -1,4 +1,4 @@
\ No newline at end of file diff --git a/book/src/primer/population_dynamics.md b/book/src/primer/population_dynamics.md index 3637e02a..35c18e8f 100644 --- a/book/src/primer/population_dynamics.md +++ b/book/src/primer/population_dynamics.md @@ -65,12 +65,13 @@ Let's solve this system of ODEs using the DiffSol crate. We will use the [DiffSL # fn main() { # use std::fs; use diffsol::{ - DiffSl, CraneliftModule, OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod + DiffSl, CraneliftModule, OdeBuilder, OdeSolverMethod }; use plotly::{ Plot, Scatter, common::Mode, layout::Layout, layout::Axis }; type M = nalgebra::DMatrix; +type LS = diffsol::NalgebraLU; type CG = CraneliftModule; let eqn = DiffSl::::compile(" @@ -84,11 +85,10 @@ let eqn = DiffSl::::compile(" c * y1 * y2 - d * y2, } ").unwrap(); -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .build_from_eqn(eqn).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -let (ys, ts) = solver.solve(&problem, state, 40.0).unwrap(); +let mut solver = problem.bdf::().unwrap(); +let (ys, ts) = solver.solve(40.0).unwrap(); let prey: Vec<_> = ys.row(0).into_iter().copied().collect(); let predator: Vec<_> = ys.row(1).into_iter().copied().collect(); @@ -125,12 +125,13 @@ so we can solve this system for different values of \\(y_0\\) and plot the phase # fn main() { # use std::fs; use diffsol::{ - DiffSl, CraneliftModule, OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod + DiffSl, CraneliftModule, OdeBuilder, OdeSolverMethod, OdeEquations }; use plotly::{ Plot, Scatter, common::Mode, layout::Layout, layout::Axis }; type M = nalgebra::DMatrix; +type LS = diffsol::NalgebraLU; type CG = CraneliftModule; let eqn = DiffSl::::compile(" @@ -147,15 +148,15 @@ let eqn = DiffSl::::compile(" } ").unwrap(); -let mut problem = OdeBuilder::new().p([1.0]).build_from_eqn(eqn).unwrap(); -let mut solver = Bdf::default(); +let mut problem = OdeBuilder::::new().p([1.0]).build_from_eqn(eqn).unwrap(); let mut plot = Plot::new(); for y0 in (1..6).map(f64::from) { - problem.set_params(nalgebra::DVector::from_element(1, y0)).unwrap(); + let p = nalgebra::DVector::from_element(1, y0); + problem.eqn_mut().set_params(&p); - let state = OdeSolverState::new(&problem, &solver).unwrap(); - let (ys, _ts) = solver.solve(&problem, state, 40.0).unwrap(); + let mut solver = problem.bdf::().unwrap(); + let (ys, _ts) = solver.solve(40.0).unwrap(); let prey: Vec<_> = ys.row(0).into_iter().copied().collect(); let predator: Vec<_> = ys.row(1).into_iter().copied().collect(); @@ -163,9 +164,6 @@ for y0 in (1..6).map(f64::from) { let phase = Scatter::new(prey, predator) .mode(Mode::Lines).name(format!("y0 = {}", y0)); plot.add_trace(phase); - - // release problem and state to set new parameters in the next iteration - solver.take_state().unwrap(); } let layout = Layout::new() diff --git a/book/src/primer/spring_mass_systems.md b/book/src/primer/spring_mass_systems.md index 1eee2fe3..3321acc8 100644 --- a/book/src/primer/spring_mass_systems.md +++ b/book/src/primer/spring_mass_systems.md @@ -29,13 +29,14 @@ We can solve this system of ODEs using DiffSol with the following code: # fn main() { # use std::fs; use diffsol::{ - DiffSl, CraneliftModule, OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod + DiffSl, CraneliftModule, OdeBuilder, OdeSolverMethod }; use plotly::{ Plot, Scatter, common::Mode, layout::Layout, layout::Axis }; type M = nalgebra::DMatrix; type CG = CraneliftModule; +type LS = diffsol::NalgebraLU; let eqn = DiffSl::::compile(" k { 1.0 } m { 1.0 } c { 0.1 } @@ -48,11 +49,10 @@ let eqn = DiffSl::::compile(" -k/m * x - c/m * v, } ").unwrap(); -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .build_from_eqn(eqn).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -let (ys, ts) = solver.solve(&problem, state, 40.0).unwrap(); +let mut solver = problem.bdf::().unwrap(); +let (ys, ts) = solver.solve(40.0).unwrap(); let x: Vec<_> = ys.row(0).into_iter().copied().collect(); let time: Vec<_> = ts.into_iter().collect(); diff --git a/book/src/solving_the_problem.md b/book/src/solving_the_problem.md index 603314ce..3e18bb4a 100644 --- a/book/src/solving_the_problem.md +++ b/book/src/solving_the_problem.md @@ -1,28 +1,88 @@ # Solving the Problem Each solver implements the [`OdeSolverMethod`](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/trait.OdeSolverMethod.html) trait, which provides a number of methods to solve the problem. -The fundamental method to solve the problem is the `step` method on the `OdeSolverMethod` trait, which steps the solution forward in time by a single step, with a step size chosen by the solver -in order to satisfy the error tolerances in the `problem` struct. The `step` method returns a `Result` that contains the new state of the solution if the step was successful, or an error if the step failed. + +## Solving the Problem + +DiffSol has a few high-level solution functions on the `OdeSolverMethod` trait that are the easiest way to solve your equations: +- [`solve`](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/trait.OdeSolverMethod.html#method.solve) - solve the problem from an initial state up to a specified time, returning the solution at all the internal timesteps used by the solver. +- [`solve_dense`](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/trait.OdeSolverMethod.html#method.solve_dense) - solve the problem from an initial state, returning the solution at a `Vec` of times provided by the user. +- ['solve_dense_sensitivities`](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/trait.OdeSolverMethod.html#method.solve_dense_sensitivities) - solve the forward sensitivity problem from an initial state, returning the solution at a `Vec` of times provided by the user. +- ['solve_adjoint'](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/trait.OdeSolverMethod.html#method.solve_adjoint) - solve the adjoint sensitivity problem from an initial state to a final time, returning the integration of the output function over time as well as its gradient with respect to the initial state. + +The following example shows how to solve a simple ODE problem using the `solve` method on the `OdeSolverMethod` trait. + +```rust +# use diffsol::OdeBuilder; +# use nalgebra::DVector; +use diffsol::{OdeSolverMethod, NalgebraLU}; +type M = nalgebra::DMatrix; +type LS = NalgebraLU; + +# fn main() { +# let problem = OdeBuilder::::new() +# .p(vec![1.0, 10.0]) +# .rhs_implicit( +# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), +# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), +# ) +# .init(|_p, _t| DVector::from_element(1, 0.1)) +# .build() +# .unwrap(); +let mut solver = problem.bdf::().unwrap(); +let (ys, ts) = solver.solve(10.0).unwrap(); +# } +``` + +`solve_dense` will solve a problem from an initial state, returning the solution at a `Vec` of times provided by the user. ```rust # use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, OdeSolverState, Bdf}; +use diffsol::{OdeSolverMethod, NalgebraLU}; +type LS = NalgebraLU; + +# fn main() { +# let problem = OdeBuilder::::new() +# .p(vec![1.0, 10.0]) +# .rhs_implicit( +# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), +# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), +# ) +# .init(|_p, _t| DVector::from_element(1, 0.1)) +# .build() +# .unwrap(); +let mut solver = problem.bdf::().unwrap(); +let times = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; +let _soln = solver.solve_dense(×).unwrap(); +# } +``` + +## Stepping the Solution + +The fundamental method to step the solver through a solution is the [`step`](https://docs.rs/diffsol/latest/diffsol/ode_solver/method/trait.OdeSolverMethod.html#tymethod.step) method on the `OdeSolverMethod` trait, which steps the solution forward in time by a single step, with a step size chosen by the solver in order to satisfy the error tolerances in the `problem` struct. The `step` method returns a `Result` that contains the new state of the solution if the step was successful, or an error if the step failed. + +```rust +# use diffsol::OdeBuilder; +# use nalgebra::DVector; +# type M = nalgebra::DMatrix; +use diffsol::{OdeSolverMethod, NalgebraLU}; +type LS = NalgebraLU; # fn main() { # -# let problem = OdeBuilder::new() +# let problem = OdeBuilder::::new() # .p(vec![1.0, 10.0]) -# .build_ode::( +# .rhs_implicit( # |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); -while solver.state().unwrap().t < 10.0 { +# ) +# .init(|_p, _t| DVector::from_element(1, 0.1)) +# .build() +# .unwrap(); +let mut solver = problem.bdf::().unwrap(); +while solver.state().t < 10.0 { if let Err(_) = solver.step() { break; } @@ -39,22 +99,23 @@ until you are beyond \\(t_o\\), and then interpolate the solution back to \\(t_o # use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, OdeSolverState, Bdf}; +use diffsol::{OdeSolverMethod, NalgebraLU}; +type LS = NalgebraLU; # fn main() { # -# let problem = OdeBuilder::new() +# let problem = OdeBuilder::::new() # .p(vec![1.0, 10.0]) -# .build_ode::( +# .rhs_implicit( # |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); +# ) +# .init(|_p, _t| DVector::from_element(1, 0.1)) +# .build() +# .unwrap(); +let mut solver = problem.bdf::().unwrap(); let t_o = 10.0; -while solver.state().unwrap().t < t_o { +while solver.state().t < t_o { solver.step().unwrap(); } let _soln = solver.interpolate(t_o).unwrap(); @@ -70,20 +131,21 @@ Once the solver has stopped at the specified time, you can get the current state # use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, OdeSolverStopReason, OdeSolverState, Bdf}; +use diffsol::{OdeSolverMethod, OdeSolverStopReason, NalgebraLU}; +type LS = NalgebraLU; # fn main() { # -# let problem = OdeBuilder::new() +# let problem = OdeBuilder::::new() # .p(vec![1.0, 10.0]) -# .build_ode::( +# .rhs_implicit( # |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); +# ) +# .init(|_p, _t| DVector::from_element(1, 0.1)) +# .build() +# .unwrap(); +let mut solver = problem.bdf::().unwrap(); solver.set_stop_time(10.0).unwrap(); loop { match solver.step() { @@ -93,53 +155,7 @@ loop { Err(e) => panic!("Solver failed to converge: {}", e), } } -let _soln = &solver.state().unwrap().y; +let _soln = &solver.state().y; # } ``` -DiffSol also has two convenience functions `solve` and `solve_dense` on the `OdeSolverMethod` trait. `solve` solve the problem from an initial state up to a specified time, returning the solution at all the -internal timesteps used by the solver. This function returns a tuple that contains a `Vec` of -the solution at each timestep, and a `Vec` of the times at each timestep. - -```rust -# use diffsol::OdeBuilder; -# use nalgebra::DVector; -# type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, Bdf, OdeSolverState}; - -# fn main() { -# let problem = OdeBuilder::new() -# .p(vec![1.0, 10.0]) -# .build_ode::( -# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), -# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -let (ys, ts) = solver.solve(&problem, state, 10.0).unwrap(); -# } -``` - -`solve_dense` will solve a problem from an initial state, returning the solution at a `Vec` of times provided by the user. This function returns a `Vec`, where `V` is the vector type used to define the problem. - -```rust -# use diffsol::OdeBuilder; -# use nalgebra::DVector; -# type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, Bdf, OdeSolverState}; - -# fn main() { -# let problem = OdeBuilder::new() -# .p(vec![1.0, 10.0]) -# .build_ode::( -# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), -# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# ).unwrap(); -let mut solver = Bdf::default(); -let times = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; -let state = OdeSolverState::new(&problem, &solver).unwrap(); -let _soln = solver.solve_dense(&problem, state, ×).unwrap(); -# } -``` \ No newline at end of file diff --git a/book/src/sparse_problems.md b/book/src/sparse_problems.md deleted file mode 100644 index 3aff5e4b..00000000 --- a/book/src/sparse_problems.md +++ /dev/null @@ -1 +0,0 @@ -# Sparse problems diff --git a/book/src/specify/custom/constant_functions.md b/book/src/specify/custom/constant_functions.md index c72e82df..054e26e2 100644 --- a/book/src/specify/custom/constant_functions.md +++ b/book/src/specify/custom/constant_functions.md @@ -22,6 +22,9 @@ impl Op for MyInit { fn nout(&self) -> usize { 1 } + fn nparams(&self) -> usize { + 0 + } } impl ConstantOp for MyInit { @@ -30,21 +33,4 @@ impl ConstantOp for MyInit { } } # } -``` - -Again, we can use the [`ConstantClosure`](https://docs.rs/diffsol/latest/diffsol/op/constant_closure/struct.ConstantClosure.html) struct to implement the `ConstantOp` trait for us if it's not neccessary to use our own struct. - -```rust -# fn main() { -# use std::rc::Rc; -use diffsol::ConstantClosure; - -# type T = f64; -# type V = nalgebra::DVector; -# type M = nalgebra::DMatrix; -# -let p = Rc::new(V::from_vec(vec![1.0, 10.0])); -let init_fn = |_p: &V, _t: T| V::from_element(1, 0.1); -let init = Rc::new(ConstantClosure::::new(init_fn, p.clone())); -# } -``` +``` \ No newline at end of file diff --git a/book/src/specify/custom/custom_problem_structs.md b/book/src/specify/custom/custom_problem_structs.md index f5b574d0..f2e9b440 100644 --- a/book/src/specify/custom/custom_problem_structs.md +++ b/book/src/specify/custom/custom_problem_structs.md @@ -1,19 +1,20 @@ # Custom Problem Structs While the [`OdeBuilder`](https://docs.rs/diffsol/latest/diffsol/ode_solver/builder/struct.OdeBuilder.html) struct is a convenient way to specify the problem, it may not be suitable in all cases. -Often users will want to provide their own struct that can hold custom data structures and methods for evaluating the right-hand side of the ODE, the jacobian, and other functions. +Often users will want to provide their own structs that can hold custom data structures and methods for evaluating the right-hand side of the ODE, the jacobian, and other functions. ## Traits -To use a custom struct to specify a problem, the primary goal is to implement the [`OdeEquations`](https://docs.rs/diffsol/latest/diffsol/ode_solver/equations/trait.OdeEquations.html) trait. -This trait has a number of associated traits that need to be implemented in order to specify each function, depending on if they are: -- Non-linear functions. In this case the [`NonLinearOp`](https://docs.rs/diffsol/latest/diffsol/op/trait.NonLinearOp.html) trait needs to be implemented. -- Linear functions. In this case the [`LinearOp`](https://docs.rs/diffsol/latest/diffsol/op/trait.LinearOp.html) trait needs to be implemented. -- Constant functions. In this case the [`ConstantOp`](https://docs.rs/diffsol/latest/diffsol/op/trait.ConstantOp.html) trait needs to be implemented. +To create your own structs for the ode system, the final goal is to implement the [`OdeEquations`](https://docs.rs/diffsol/latest/diffsol/ode_solver/equations/trait.OdeEquations.html) trait. +When you have done this, you can use the `build_from_eqn` method on the `OdeBuilder` struct to create the problem. + +For each function in your system of equations, you will need to implement the appropriate trait for each function. +- Non-linear functions (rhs, out, root). In this case the [`NonLinearOp`](https://docs.rs/diffsol/latest/diffsol/op/trait.NonLinearOp.html) trait needs to be implemented. +- Linear functions (mass). In this case the [`LinearOp`](https://docs.rs/diffsol/latest/diffsol/op/trait.LinearOp.html) trait needs to be implemented. +- Constant functions (init). In this case the [`ConstantOp`](https://docs.rs/diffsol/latest/diffsol/op/trait.ConstantOp.html) trait needs to be implemented. Additionally, each function needs to implement the base operation trait [`Op`](https://docs.rs/diffsol/latest/diffsol/op/trait.Op.html). -## OdeSolverEquations struct +Once you have implemented the appropriate traits for your custom struct, you can use the [`OdeBuilder`](https://docs.rs/diffsol/latest/diffsol/ode_solver/builder/struct.OdeBuilder.html) struct to specify the problem. + -The [`OdeSolverEquations`](https://docs.rs/diffsol/latest/diffsol/ode_solver/equations/struct.OdeSolverEquations.html) struct is a convenience struct that already implements the `OdeEquations` trait, and can be used as a base struct for custom problem structs. -It is not neccessary to use this struct, but it can be useful to reduce boilerplate code. The example below will use this struct, but if it does not fit your use case, you can implement the `OdeEquations` trait directly. diff --git a/book/src/specify/custom/linear_functions.md b/book/src/specify/custom/linear_functions.md index bd560fc3..d558b993 100644 --- a/book/src/specify/custom/linear_functions.md +++ b/book/src/specify/custom/linear_functions.md @@ -31,6 +31,9 @@ impl Op for MyMass { fn nout(&self) -> usize { 1 } + fn nparams(&self) -> usize { + 0 + } } impl LinearOp for MyMass { @@ -41,22 +44,4 @@ impl LinearOp for MyMass { # } ``` -Alternatively, we can use the [`LinearClosure`](https://docs.rs/diffsol/latest/diffsol/op/linear_closure/struct.LinearClosure.html) struct to implement the `LinearOp` trait for us. - -```rust -# fn main() { -# use std::rc::Rc; -use diffsol::LinearClosure; - -# type T = f64; -# type V = nalgebra::DVector; -# type M = nalgebra::DMatrix; -# -# let p = Rc::new(V::from_vec(vec![1.0, 10.0])); -let mass_fn = |v: &V, _p: &V, _t: T, beta: T, y: &mut V| { - y[0] = v[0] + beta * y[0]; -}; -let mass = Rc::new(LinearClosure::::new(mass_fn, 1, 1, p.clone())); -# } -``` diff --git a/book/src/specify/custom/non_linear_functions.md b/book/src/specify/custom/non_linear_functions.md index d9086eae..eb000b47 100644 --- a/book/src/specify/custom/non_linear_functions.md +++ b/book/src/specify/custom/non_linear_functions.md @@ -5,40 +5,32 @@ To illustrate how to implement a custom problem struct, we will take the familar \\[\frac{dy}{dt} = r y (1 - y/K),\\] Our goal is to implement a custom struct that can evaluate the rhs function \\(f(y, p, t)\\) and the jacobian multiplied by a vector \\(f'(y, p, t, v)\\). -First we define a struct that, for this simple example, only holds the parameters of interest. For a more complex problem, this struct could hold data structures neccessary to compute the rhs. +First we define an empty struct. For a more complex problem, this struct could hold data structures neccessary to compute the rhs. ```rust # fn main() { -use std::rc::Rc; type T = f64; type V = nalgebra::DVector; -struct MyProblem { - p: Rc, -} +struct MyProblem; # } ``` -We use an `Rc` to hold the parameters because these parameters will need to be shared between the different functions that we will implement. - -Now we will implement the base `Op` trait for our struct. This trait specifies the types of the vectors and matrices that will be used, as well as the number of states and outputs in the rhs function. +Now we will implement the base `Op` trait for our struct. The `Op` trait specifies the types of the vectors and matrices that will be used, as well as the number of states and outputs in the rhs function. ```rust # fn main() { -# use std::rc::Rc; use diffsol::Op; type T = f64; type V = nalgebra::DVector; type M = nalgebra::DMatrix; -# struct MyProblem { -# p: Rc, -# } +# struct MyProblem; # # impl MyProblem { -# fn new(p: Rc) -> Self { -# MyProblem { p } +# fn new() -> Self { +# MyProblem {} # } # } # @@ -52,31 +44,32 @@ impl Op for MyProblem { fn nout(&self) -> usize { 1 } + fn nparams(&self) -> usize { + 0 + } } # } ``` + Next we implement the `NonLinearOp` and `NonLinearOpJacobian` trait for our struct. This trait specifies the functions that will be used to evaluate the rhs function and the jacobian multiplied by a vector. ```rust # fn main() { -# use std::rc::Rc; use diffsol::{ - NonLinearOp, NonLinearOpJacobian, OdeSolverEquations, OdeSolverProblem, - Op, UnitCallable, ConstantClosure + NonLinearOp, NonLinearOpJacobian }; +# use diffsol::Op; # type T = f64; # type V = nalgebra::DVector; # type M = nalgebra::DMatrix; # -# struct MyProblem { -# p: Rc, -# } +# struct MyProblem; # # impl MyProblem { -# fn new(p: Rc) -> Self { -# MyProblem { p } +# fn new() -> Self { +# MyProblem { } # } # } # @@ -90,43 +83,23 @@ use diffsol::{ # fn nout(&self) -> usize { # 1 # } +# fn nparams(&self) -> usize { +# 0 +# } # } -# -impl NonLinearOp for MyProblem { + +impl<'a> NonLinearOp for MyProblem { fn call_inplace(&self, x: &V, _t: T, y: &mut V) { - y[0] = self.p[0] * x[0] * (1.0 - x[0] / self.p[1]); + y[0] = x[0] * (1.0 - x[0]); } } -impl NonLinearOpJacobian for MyProblem { +impl<'a> NonLinearOpJacobian for MyProblem { fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) { - y[0] = self.p[0] * v[0] * (1.0 - 2.0 * x[0] / self.p[1]); + y[0] = v[0] * (1.0 - 2.0 * x[0]); } } # } ``` -There we go, all done! This demonstrates how to implement a custom struct to specify a rhs function. But this is a fair bit of boilerplate code, do we really need to do all this for **every** function we want to implement? - -Thankfully, the answer is no. If we didn't need to use our own struct for this particular function, we can alternativly use -the [`Closure`](https://docs.rs/diffsol/latest/diffsol/op/closure/struct.Closure.html) struct to implement the `NonLinearOp` trait for us. - -```rust -# fn main() { -# use std::rc::Rc; -# type T = f64; -# type V = nalgebra::DVector; -# type M = nalgebra::DMatrix; -# -use diffsol::Closure; - -let rhs_fn = |x: &V, p: &V, _t: T, y: &mut V| { - y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]); -}; -let jac_fn = |x: &V, p: &V, _t: T, v: &V, y: &mut V| { - y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]); -}; -let p = Rc::new(V::from_vec(vec![1.0, 10.0])); -let rhs = Rc::new(Closure::::new(rhs_fn, jac_fn, 1, 1, p.clone())); -# } -``` +There we go, all done! This demonstrates how to implement a custom struct to specify a rhs function. diff --git a/book/src/specify/custom/ode_systems.md b/book/src/specify/custom/ode_systems.md new file mode 100644 index 00000000..e8897a3c --- /dev/null +++ b/book/src/specify/custom/ode_systems.md @@ -0,0 +1,371 @@ +# ODE systems + +So far we've focused on using custom structs to specify individual equations, now we need to put these together to specify the entire system of equations. + +## Implementing the OdeEquations trait + +To specify the entire system of equations, you need to implement the `Op`, [`OdeEquations`](https://docs.rs/diffsol/latest/diffsol/ode_solver/equations/trait.OdeEquations.html) +and [`OdeEquationsRef`](https://docs.rs/diffsol/latest/diffsol/ode_solver/equations/trait.OdeEquationsRef.html) traits for your struct. + +## Getting all your traits in order + +The `OdeEquations` trait requires methods that return objects corresponding to the right-hand side function, mass matrix, root function, initial condition, and output functions. +Therefore, you need to already have structs that implement the `NonLinearOp`, `LinearOp`, and `ConstantOp` traits for these functions. For the purposes of this example, we will assume that +you have already implemented these traits for your structs. + +Often, the structs that implement these traits will have to use data defined in the struct that implements the `OdeEquations` trait. For example, they might wish to have a reference to the same parameter vector `p`. Therefore, you will often need to define lifetimes for these structs to ensure that they can access the data they need. + +Note that these struct will need to be lightweight and should not contain a significant amount of data. The data should be stored in the struct that implements the `OdeEquations` trait. This is because these structs will be created and destroyed many times during the course of the simulation (e.g. every time the right-hand side function is called). + + +```rust +# fn main() { +type T = f64; +type V = nalgebra::DVector; +type M = nalgebra::DMatrix; +struct MyRhs<'a> { p: &'a V } // implements NonLinearOp +struct MyMass<'a> { p: &'a V } // implements LinearOp +struct MyInit<'a> { p: &'a V } // implements ConstantOp +struct MyRoot<'a> { p: &'a V } // implements NonLinearOp +struct MyOut<'a> { p: &'a V } // implements NonLinearOp +# } +``` + +## Implementing the OdeEquations traits + +Lets imagine we have a struct `MyProblem` that we want to use to specify the entire system of equations. We can implement the `Op`, `OdeEquations`, and `OdeEquationsRef` traits for this struct like so: + +```rust +use diffsol::{Op, NonLinearOp, LinearOp, ConstantOp, OdeEquations, OdeEquationsRef}; +# fn main() { +# type T = f64; +# type V = nalgebra::DVector; +# type M = nalgebra::DMatrix; +# struct MyRhs<'a> { p: &'a V } // implements NonLinearOp +# struct MyMass<'a> { p: &'a V } // implements LinearOp +# struct MyInit<'a> { p: &'a V } // implements ConstantOp +# struct MyRoot<'a> { p: &'a V } // implements NonLinearOp +# struct MyOut<'a> { p: &'a V } // implements NonLinearOp +# impl Op for MyRhs<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 2 +# } +# } +# impl NonLinearOp for MyRhs<'_> { +# fn call_inplace(&self, x: &V, _t: T, y: &mut V) { +# y[0] = x[0] * x[0]; +# } +# } +# impl Op for MyMass<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 0 +# } +# } +# impl LinearOp for MyMass<'_> { +# fn gemv_inplace(&self, x: &V, _t: T, beta: T, y: &mut V) { +# y[0] = x[0] * beta; +# } +# } +# impl Op for MyInit<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 0 +# } +# } +# impl ConstantOp for MyInit<'_> { +# fn call_inplace(&self, _t: T, y: &mut V) { +# y[0] = 0.1; +# } +# } +# impl Op for MyRoot<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 0 +# } +# } +# impl NonLinearOp for MyRoot<'_> { +# fn call_inplace(&self, x: &V, _t: T, y: &mut V) { +# y[0] = x[0] - 1.0; +# } +# } +# impl Op for MyOut<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 0 +# } +# } +# impl NonLinearOp for MyOut<'_> { +# fn call_inplace(&self, x: &V, _t: T, y: &mut V) { +# y[0] = x[0]; +# } +# } + +struct MyProblem { + p: V, +} + +impl MyProblem { + fn new() -> Self { + MyProblem { p: V::zeros(2) } + } +} + +impl Op for MyProblem { + type T = T; + type V = V; + type M = M; + fn nstates(&self) -> usize { + 1 + } + fn nout(&self) -> usize { + 1 + } + fn nparams(&self) -> usize { + 2 + } +} + +impl<'a> OdeEquationsRef<'a> for MyProblem { + type Rhs = MyRhs<'a>; + type Mass = MyMass<'a>; + type Init = MyInit<'a>; + type Root = MyRoot<'a>; + type Out = MyOut<'a>; +} + +impl OdeEquations for MyProblem { + fn rhs(&self) -> >::Rhs { + MyRhs { p: &self.p } + } + fn mass(&self) -> Option<>::Mass> { + Some(MyMass { p: &self.p }) + } + fn init(&self) -> >::Init { + MyInit { p: &self.p } + } + fn root(&self) -> Option<>::Root> { + Some(MyRoot { p: &self.p }) + } + fn out(&self) -> Option<>::Out> { + Some(MyOut { p: &self.p }) + } + fn set_params(&mut self, p: &V) { + self.p.copy_from(p); + } +} +# } +``` + +## Creating the problem + +Now that we have our custom `OdeEquations` struct, we can use it in an `OdeBuilder` to create the problem. Hint: click the button below to see the full code, which includes the implementation of the `Op`, `NonLinearOp`, `LinearOp`, and `ConstantOp` traits for the `MyRhs`, `MyMass`, `MyInit`, `MyRoot`, and `MyOut` structs. + +```rust +use diffsol::{Op, NonLinearOp, LinearOp, ConstantOp, OdeEquations, OdeEquationsRef}; +# fn main() { +# type T = f64; +# type V = nalgebra::DVector; +# type M = nalgebra::DMatrix; +# struct MyRhs<'a> { p: &'a V } // implements NonLinearOp +# struct MyMass<'a> { p: &'a V } // implements LinearOp +# struct MyInit<'a> { p: &'a V } // implements ConstantOp +# struct MyRoot<'a> { p: &'a V } // implements NonLinearOp +# struct MyOut<'a> { p: &'a V } // implements NonLinearOp +# impl Op for MyRhs<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 2 +# } +# } +# impl NonLinearOp for MyRhs<'_> { +# fn call_inplace(&self, x: &V, _t: T, y: &mut V) { +# y[0] = x[0] * x[0]; +# } +# } +# impl Op for MyMass<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 0 +# } +# } +# impl LinearOp for MyMass<'_> { +# fn gemv_inplace(&self, x: &V, _t: T, beta: T, y: &mut V) { +# y[0] = x[0] * beta; +# } +# } +# impl Op for MyInit<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 0 +# } +# } +# impl ConstantOp for MyInit<'_> { +# fn call_inplace(&self, _t: T, y: &mut V) { +# y[0] = 0.1; +# } +# } +# impl Op for MyRoot<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 0 +# } +# } +# impl NonLinearOp for MyRoot<'_> { +# fn call_inplace(&self, x: &V, _t: T, y: &mut V) { +# y[0] = x[0] - 1.0; +# } +# } +# impl Op for MyOut<'_> { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 0 +# } +# } +# impl NonLinearOp for MyOut<'_> { +# fn call_inplace(&self, x: &V, _t: T, y: &mut V) { +# y[0] = x[0]; +# } +# } +# +# struct MyProblem { +# p: V, +# } +# +# impl MyProblem { +# fn new() -> Self { +# MyProblem { p: V::zeros(2) } +# } +# } +# +# impl Op for MyProblem { +# type T = T; +# type V = V; +# type M = M; +# fn nstates(&self) -> usize { +# 1 +# } +# fn nout(&self) -> usize { +# 1 +# } +# fn nparams(&self) -> usize { +# 2 +# } +# } +# +# impl<'a> OdeEquationsRef<'a> for MyProblem { +# type Rhs = MyRhs<'a>; +# type Mass = MyMass<'a>; +# type Init = MyInit<'a>; +# type Root = MyRoot<'a>; +# type Out = MyOut<'a>; +# } +# +# impl OdeEquations for MyProblem { +# fn rhs(&self) -> >::Rhs { +# MyRhs { p: &self.p } +# } +# fn mass(&self) -> Option<>::Mass> { +# Some(MyMass { p: &self.p }) +# } +# fn init(&self) -> >::Init { +# MyInit { p: &self.p } +# } +# fn root(&self) -> Option<>::Root> { +# Some(MyRoot { p: &self.p }) +# } +# fn out(&self) -> Option<>::Out> { +# Some(MyOut { p: &self.p }) +# } +# fn set_params(&mut self, p: &V) { +# self.p.copy_from(p); +# } +# } +use diffsol::OdeBuilder; +let problem = OdeBuilder::::new() + .p(vec![1.0, 10.0]) + .build_from_eqn(MyProblem::new()) + .unwrap(); +# } +``` \ No newline at end of file diff --git a/book/src/specify/custom/putting_it_all_together.md b/book/src/specify/custom/putting_it_all_together.md deleted file mode 100644 index 396df126..00000000 --- a/book/src/specify/custom/putting_it_all_together.md +++ /dev/null @@ -1,102 +0,0 @@ -# Putting it all together - -Once you have structs implementing the functions for your system of equations, you can use the [`OdeSolverEquations`](https://docs.rs/diffsol/latest/diffsol/ode_solver/equations/struct.OdeSolverEquations.html) struct -to put it all together. This struct implements the [`OdeEquations`](https://docs.rs/diffsol/latest/diffsol/ode_solver/equations/trait.OdeEquations.html) trait, and can be used to specify the problem to the solver. - -Note that it is optional to use the `OdeSolverEquations` struct, you can implement the `OdeEquations` trait directly if you prefer, but the `OdeSolverEquations` struct can be useful to reduce boilerplate code -and make it easier to specify the problem. - -## Getting all your traits in order - -The `OdeSolverEquations` struct requires arguments corresponding to the right-hand side function, mass matrix, root function, initial condition, and output functions. -For those that you want to provide, you can implement `NonLinearOp`, `LinearOp`, and `ConstantOp` traits for your structs, as described in the previous sections. -However, some of these arguments are optional and can be set to `None` if not needed. To do this, you still need to provide a placeholder type for these arguments, so you can use the -included [`UnitCallable`](https://docs.rs/diffsol/latest/diffsol/op/unit/struct.UnitCallable.html) type for this purpose. For example lets assume that we already have objects implementing -the `NonLinearOp` trait for the right-hand side function, and the `ConstantOp` trait for the initial condition, but we don't have a mass matrix, root function, or output function. -We can specify the missing arguments like so: - -```rust -# fn main() { -# type T = f64; -# type V = nalgebra::DVector; -# type M = nalgebra::DMatrix; -# -use diffsol::UnitCallable; - -let mass: Option> = None; -let root: Option> = None; -let out: Option> = None; -# } -``` - -## Creating the equations - -Now we have variables `rhs` and `init` that are structs implementing the required traits, and `mass`, `root`, and `out` set to `None`. Using these, we can create the `OdeSolverEquations` struct, -and then provide it to the `OdeBuilder` struct to create the problem. - -```rust -# fn main() { -# use std::rc::Rc; -# use diffsol::{NonLinearOp, NonLinearOpJacobian, Op, UnitCallable, ConstantClosure}; -use diffsol::{OdeSolverEquations, OdeBuilder}; - -# type T = f64; -# type V = nalgebra::DVector; -# type M = nalgebra::DMatrix; -# -# struct MyProblem { -# p: Rc, -# } -# -# impl MyProblem { -# fn new(p: Rc) -> Self { -# MyProblem { p } -# } -# } -# -# impl Op for MyProblem { -# type T = T; -# type V = V; -# type M = M; -# fn nstates(&self) -> usize { -# 1 -# } -# fn nout(&self) -> usize { -# 1 -# } -# fn nparams(&self) -> usize { -# 2 -# } -# } -# -# impl NonLinearOp for MyProblem { -# fn call_inplace(&self, x: &V, _t: T, y: &mut V) { -# y[0] = self.p[0] * x[0] * (1.0 - x[0] / self.p[1]); -# } -# } -# impl NonLinearOpJacobian for MyProblem { -# fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) { -# y[0] = self.p[0] * v[0] * (1.0 - 2.0 * x[0] / self.p[1]); -# } -# } -# -# -# let p_slice = [1.0, 10.0]; -# let p = Rc::new(V::from_vec(p_slice.to_vec())); -# let rhs = MyProblem::new(p.clone()); -# -# // use the provided constant closure to define the initial condition -# let init_fn = |_p: &V, _t: T| V::from_element(1, 0.1); -# let init = ConstantClosure::new(init_fn, p.clone()); -# -# // we don't have a mass matrix, root or output functions, so we can set to None -# // we still need to give a placeholder type for these, so we use the diffsol::UnitCallable type -# let mass: Option> = None; -# let root: Option> = None; -# let out: Option> = None; -# -# let p = Rc::new(V::zeros(0)); -let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); -let _problem = OdeBuilder::new().p(p_slice).build_from_eqn(eqn).unwrap(); -# } -``` \ No newline at end of file diff --git a/book/src/specify/diffsl.md b/book/src/specify/diffsl.md index 56cad6ad..f2cd7ea4 100644 --- a/book/src/specify/diffsl.md +++ b/book/src/specify/diffsl.md @@ -35,9 +35,10 @@ Once you have created the `DiffSl` struct you can use it to create a problem usi ```rust # fn main() { # use diffsol::{DiffSl, CraneliftModule}; -use diffsol::{OdeBuilder, Bdf, OdeSolverMethod, OdeSolverState}; +use diffsol::{OdeBuilder, OdeSolverMethod, OdeSolverState}; # type M = nalgebra::DMatrix; # type CG = CraneliftModule; +type LS = diffsol::NalgebraLU; # let eqn = DiffSl::::compile(" @@ -47,13 +48,12 @@ use diffsol::{OdeBuilder, Bdf, OdeSolverMethod, OdeSolverState}; # u { 0.1 } # F { r * u * (1.0 - u / k) } # ").unwrap(); -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .rtol(1e-6) .p([1.0, 10.0]) .build_from_eqn(eqn).unwrap(); -let mut solver = Bdf::default(); +let mut solver = problem.bdf::().unwrap(); let t = 0.4; -let state = OdeSolverState::new(&problem, &solver).unwrap(); -let _soln = solver.solve(&problem, state, t).unwrap(); +let _soln = solver.solve(t).unwrap(); # } ``` diff --git a/book/src/specify/forward_sensitivity.md b/book/src/specify/forward_sensitivity.md index a9e1713f..02cca193 100644 --- a/book/src/specify/forward_sensitivity.md +++ b/book/src/specify/forward_sensitivity.md @@ -22,8 +22,7 @@ We also need the partial derivative of the initial state vector with respect to \\[J_{y_0} v = 0.\\] - -We can then use the `OdeBuilder` struct to specify the sensitivity problem. The `build_ode_with_sens` method is used to create a new problem that includes the sensitivity equations. +We can then use the `OdeBuilder` struct to specify the sensitivity problem. The `rhs_sens_implicit` and `init_sens` methods are used to create a new problem that includes the sensitivity equations. ```rust # fn main() { @@ -31,16 +30,20 @@ use diffsol::OdeBuilder; use nalgebra::DVector; type M = nalgebra::DMatrix; -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .p(vec![1.0, 10.0]) - .build_ode_with_sens::( + .rhs_sens_implicit( |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), |x, p, _t, v, y| y[0] = v[0] * x[0] * (1.0 - x[0] / p[1]) + v[1] * p[0] * x[0] * x[0] / (p[1] * p[1]), + ) + .init_sens( |_p, _t| DVector::from_element(1, 0.1), |_p, _t, _v, y| y[0] = 0.0, - ).unwrap(); + ) + .build() + .unwrap(); # } ``` @@ -54,22 +57,26 @@ Lets imagine we want to solve the sensitivity problem up to a time \\(t_o = 10\\ # use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, OdeSolverState, Bdf}; +use diffsol::{OdeSolverMethod, NalgebraLU}; +type LS = NalgebraLU; -# let problem = OdeBuilder::new() +# let problem = OdeBuilder::::new() # .p(vec![1.0, 10.0]) -# .build_ode_with_sens::( +# .rhs_sens_implicit( # |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |x, p, _t, v, y| y[0] = v[0] * x[0] * (1.0 - x[0] / p[1]) + v[1] * p[0] * x[0] * x[0] / (p[1] * p[1]), +# |x, p, _t, v, y| y[0] = v[0] * x[0] * (1.0 - x[0] / p[1]) +# + v[1] * p[0] * x[0] * x[0] / (p[1] * p[1]), +# ) +# .init_sens( # |_p, _t| DVector::from_element(1, 0.1), # |_p, _t, _v, y| y[0] = 0.0, -# ).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); +# ) +# .build() +# .unwrap(); +let mut solver = problem.bdf::().unwrap(); let t_o = 10.0; -while solver.state().unwrap().t < t_o { +while solver.state().t < t_o { solver.step().unwrap(); } # } @@ -84,26 +91,30 @@ If we need the sensitivity at the current internal time step, we can get this fr # use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -# use diffsol::{OdeSolverMethod, OdeSolverState, Bdf}; +# use diffsol::{OdeSolverMethod, OdeSolverState, NalgebraLU}; +# type LS = NalgebraLU; # -# let problem = OdeBuilder::new() +# let problem = OdeBuilder::::new() # .p(vec![1.0, 10.0]) -# .build_ode_with_sens::( +# .rhs_sens_implicit( # |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), # |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |x, p, _t, v, y| y[0] = v[0] * x[0] * (1.0 - x[0] / p[1]) + v[1] * p[0] * x[0] * x[0] / (p[1] * p[1]), +# |x, p, _t, v, y| y[0] = v[0] * x[0] * (1.0 - x[0] / p[1]) +# + v[1] * p[0] * x[0] * x[0] / (p[1] * p[1]), +# ) +# .init_sens( # |_p, _t| DVector::from_element(1, 0.1), # |_p, _t, _v, y| y[0] = 0.0, -# ).unwrap(); -# let mut solver = Bdf::default(); -# let state = OdeSolverState::new(&problem, &solver).unwrap(); -# solver.set_problem(state, &problem); +# ) +# .build() +# .unwrap(); +# let mut solver = problem.bdf::().unwrap(); # let t_o = 10.0; -# while solver.state().unwrap().t < t_o { +# while solver.state().t < t_o { # solver.step().unwrap(); # } let sens_at_t_o = solver.interpolate_sens(t_o).unwrap(); -let sens_at_internal_step = &solver.state().as_ref().unwrap().s; +let sens_at_internal_step = &solver.state().s; # } ``` diff --git a/book/src/specify/mass_matrix.md b/book/src/specify/mass_matrix.md index e19e2efc..9c4ff65d 100644 --- a/book/src/specify/mass_matrix.md +++ b/book/src/specify/mass_matrix.md @@ -38,25 +38,29 @@ use nalgebra::{DMatrix, DVector}; type M = DMatrix; type V = DVector; -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .t0(0.0) .rtol(1e-6) .atol([1e-6]) .p(vec![1.0, 10.0]) - .build_ode_with_mass::( - |x, p, _t, y| { - y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]); - y[1] = x[0] - x[1]; - }, - |x, p, _t, v , y| { - y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]); - y[1] = v[0] - v[1]; - }, - |v, _p, _t, beta, y| { - y[0] = v[0] + beta * y[0]; - y[1] *= beta; - }, - |_p, _t| V::from_element(2, 0.1), - ).unwrap(); + .rhs_implicit( + |x, p, _t, y| { + y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]); + y[1] = x[0] - x[1]; + }, + |x, p, _t, v , y| { + y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]); + y[1] = v[0] - v[1]; + }, + ) + .mass( + |v, _p, _t, beta, y| { + y[0] = v[0] + beta * y[0]; + y[1] *= beta; + }, + ) + .init(|_p, _t| V::from_element(2, 0.1)) + .build() + .unwrap(); # } ``` diff --git a/book/src/specify/ode_equations.md b/book/src/specify/ode_equations.md index 57e579c4..8e23cee3 100644 --- a/book/src/specify/ode_equations.md +++ b/book/src/specify/ode_equations.md @@ -1,9 +1,8 @@ # ODE equations -The simplest way to create a new ode problem in Rust is to use the [`OdeBuilder`](https://docs.rs/diffsol/latest/diffsol/ode_solver/builder/struct.OdeBuilder.html) struct. -You can set the initial time, initial step size, relative tolerance, absolute tolerance, and parameters, or leave them at their default values. -Then, call one of the `build_*` functions to create a new problem, for example the [`build_ode`](https://docs.rs/diffsol/latest/diffsol/ode_solver/builder/struct.OdeBuilder.html#method.build_ode) -function can be used to create an ODE problem of the form \\(dy/dt = f(t, y, p)\\), where \\(y\\) is the state vector, \\(t\\) is the time, and \\(p\\) are the parameters. +The simplest way to create a new ode problem is to use the [`OdeBuilder`](https://docs.rs/diffsol/latest/diffsol/ode_solver/builder/struct.OdeBuilder.html) struct. +You can use methods to set the equations to be solve, initial time, initial step size, relative & absolute tolerances, and parameters, or leave them at their default values. +Then, call the `build` method to create a new problem. Below is an example of how to create a new ODE problem using the `OdeBuilder` struct. The specific problem we will solve is the logistic equation @@ -29,21 +28,28 @@ use diffsol::OdeBuilder; use nalgebra::DVector; type M = nalgebra::DMatrix; -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .t0(0.0) .rtol(1e-6) .atol([1e-6]) .p(vec![1.0, 10.0]) - .build_ode::( - |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), - |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), - |_p, _t| DVector::from_element(1, 0.1), - ).unwrap(); + .rhs_implicit( + |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), + |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), + ) + .init( + |_p, _t| DVector::from_element(1, 0.1), + ) + .build() + .unwrap(); # } ``` -Each `build_*` method requires the user to specify what matrix type they wish to use to define and solve the model (the other types are inferred from the closure types). -Here we use the `nalgebra::DMatrix` type, which is a dense matrix type from the [nalgebra](https://nalgebra.org) crate. Other options are: +The `rhs_implicit` method is used to specify the \\(f(y, p, t)\\) and \\(f'(y, p, t, v)\\) functions, whereas the `init` method is used to specify the initial state vector \\(y_0(p, t)\\). +We also use the `t0`, `rtol`, `atol`, and `p` methods to set the initial time, relative tolerance, absolute tolerance, and parameters, respectively. + +We have also specified the matrix type `M` to be `nalgebra::DMatrix`, using a generic parameter of the `OdeBuilder` struct. +The `nalgebra::DMatrix` type is a dense matrix type from the [nalgebra](https://nalgebra.org) crate. Other options are: - `faer::Mat` from [faer](https://github.com/sarah-ek/faer-rs), which is a dense matrix type. - `diffsol::SparseColMat`, which is a thin wrapper around `faer::sparse::SparseColMat`, a sparse compressed sparse column matrix type. @@ -52,11 +58,4 @@ You can see in the example above that the `DVector` type is explicitly used to c For these matrix types the associated vector type is: - `nalgebra::DVector` for `nalgebra::DMatrix`. - `faer::Col` for `faer::Mat`. -- `faer::Coll` for `diffsol::SparseColMat`. - -The arguments to the `build_ode` method are the equations that define the problem. -The first closure is the function \\(f(y, p, t)\\) this is implemented as a closure that takes the time `t`, -the parameter vector `p`, the state vector `y`, and a mutable reference that the closure can use to place the result (i.e. the derivative of the state vector \\(f(y, p, t)\\)). -The second closure is similar in structure in defines the jacobian multiplied by a vector \\(v\\) function \\(f'(y, p, t, v)\\). -The third closure returns the initial state vector \\(y_0(p, t)\\), this is done so that diffsol can infer the size of the state vector. - +- `faer::Coll` for `diffsol::SparseColMat`. \ No newline at end of file diff --git a/book/src/specify/root_finding.md b/book/src/specify/root_finding.md index 6a53c0fc..0e77ca4a 100644 --- a/book/src/specify/root_finding.md +++ b/book/src/specify/root_finding.md @@ -20,22 +20,23 @@ use diffsol::OdeBuilder; use nalgebra::DVector; type M = nalgebra::DMatrix; -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .t0(0.0) .rtol(1e-6) .atol([1e-6]) .p(vec![1.0, 10.0]) - .build_ode_with_root::( + .rhs_implicit( |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), - |_p, _t| DVector::from_element(1, 0.1), - |x, _p, _t, y| y[0] = x[0] - 0.5, - 1, - ).unwrap(); + ) + .init(|_p, _t| DVector::from_element(1, 0.1)) + .root(|x, _p, _t, y| y[0] = x[0] - 0.5, 1) + .build() + .unwrap(); # } ``` -here we have added the root finding function \\(r(y, p, t) = y - 0.5\\), and also let DiffSol know that we have one root function by passing `1` as the last argument to the `build_ode_with_root` method. +here we have added the root finding function \\(r(y, p, t) = y - 0.5\\), and also let DiffSol know that we have one root function by passing `1` as the last argument to the `root` method. If we had specified more than one root function, the solver would stop when any of the root functions are zero. ## Detecting roots during the solve @@ -49,20 +50,20 @@ If successful the `step` method returns an [`OdeSolverStopReason`](https://docs. # use diffsol::OdeBuilder; # use nalgebra::DVector; # type M = nalgebra::DMatrix; -use diffsol::{OdeSolverMethod, OdeSolverStopReason, OdeSolverState, Bdf}; +use diffsol::{OdeSolverMethod, OdeSolverStopReason, NalgebraLU}; +type LS = NalgebraLU; -# let problem = OdeBuilder::new() +# let problem = OdeBuilder::::new() # .p(vec![1.0, 10.0]) -# .build_ode_with_root::( -# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), -# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), -# |_p, _t| DVector::from_element(1, 0.1), -# |x, _p, _t, y| y[0] = x[0] - 0.5, -# 1, -# ).unwrap(); -let mut solver = Bdf::default(); -let state = OdeSolverState::new(&problem, &solver).unwrap(); -solver.set_problem(state, &problem); +# .rhs_implicit( +# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]), +# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]), +# ) +# .init(|_p, _t| DVector::from_element(1, 0.1)) +# .root(|x, _p, _t, y| y[0] = x[0] - 0.5, 1) +# .build() +# .unwrap(); +let mut solver = problem.bdf::().unwrap(); let t = loop { match solver.step() { Ok(OdeSolverStopReason::InternalTimestep) => continue, @@ -72,7 +73,7 @@ let t = loop { } }; println!("Root found at t = {}", t); -let _soln = &solver.state().unwrap().y; +let _soln = &solver.state().y; # } ``` diff --git a/book/src/specify/sparse_problems.md b/book/src/specify/sparse_problems.md index df653b4a..1e3a05d5 100644 --- a/book/src/specify/sparse_problems.md +++ b/book/src/specify/sparse_problems.md @@ -11,24 +11,28 @@ use diffsol::OdeBuilder; type M = diffsol::SparseColMat; type V = faer::Col; -let problem = OdeBuilder::new() +let problem = OdeBuilder::::new() .t0(0.0) .rtol(1e-6) .atol([1e-6]) .p(vec![1.0, 10.0]) - .build_ode::( - |x, p, _t, y| { - for i in 0..10 { - y[i] = p[0] * x[i] * (1.0 - x[i] / p[1]); - } - }, - |x, p, _t, v , y| { - for i in 0..10 { - y[i] = p[0] * v[i] * (1.0 - 2.0 * x[i] / p[1]); - } - }, - |_p, _t| V::from_fn(10, |_| 0.1), - ).unwrap(); + .rhs_implicit( + |x, p, _t, y| { + for i in 0..10 { + y[i] = p[0] * x[i] * (1.0 - x[i] / p[1]); + } + }, + |x, p, _t, v , y| { + for i in 0..10 { + y[i] = p[0] * v[i] * (1.0 - 2.0 * x[i] / p[1]); + } + }, + ) + .init( + |_p, _t| V::from_fn(10, |_| 0.1), + ) + .build() + .unwrap(); # } ``` @@ -46,24 +50,28 @@ use diffsol::{OdeEquations, NonLinearOp, NonLinearOpJacobian, Matrix, ConstantOp # type V = faer::Col; # # fn main() { -# let problem = OdeBuilder::new() -# .t0(0.0) -# .rtol(1e-6) -# .atol([1e-6]) -# .p(vec![1.0, 10.0]) -# .build_ode::( -# |x, p, _t, y| { -# for i in 0..10 { -# y[i] = p[0] * x[i] * (1.0 - x[i] / p[1]); -# } -# }, -# |x, p, _t, v , y| { -# for i in 0..10 { -# y[i] = p[0] * v[i] * (1.0 - 2.0 * x[i] / p[1]); -# } -# }, -# |_p, _t| V::from_fn(10, |_| 0.1), -# ).unwrap(); +# let problem = OdeBuilder::::new() +# .t0(0.0) +# .rtol(1e-6) +# .atol([1e-6]) +# .p(vec![1.0, 10.0]) +# .rhs_implicit( +# |x, p, _t, y| { +# for i in 0..10 { +# y[i] = p[0] * x[i] * (1.0 - x[i] / p[1]); +# } +# }, +# |x, p, _t, v , y| { +# for i in 0..10 { +# y[i] = p[0] * v[i] * (1.0 - 2.0 * x[i] / p[1]); +# } +# }, +# ) +# .init( +# |_p, _t| V::from_fn(10, |_| 0.1), +# ) +# .build() +# .unwrap(); let t0 = problem.t0; let y0 = problem.eqn.init().call(t0); let jacobian = problem.eqn.rhs().jacobian(&y0, t0); @@ -89,97 +97,9 @@ which will print the jacobian matrix in triplet format: ``` DiffSol attempts to guess the sparsity pattern of your jacobian matrix by calling the \\(f'(y, p, t, v)\\) function repeatedly with different one-hot vectors \\(v\\) -with a `NaN` value at each index. The output of this function (i.e. which elements are `0` and which are `NaN`) is then used to determine the sparsity pattern of the jacobian matrix. +with a `NaN` value at each hot index. The output of this function (i.e. which elements are `0` and which are `NaN`) is then used to determine the sparsity pattern of the jacobian matrix. Due to the fact that for IEEE 754 floating point numbers, `NaN` is propagated through most operations, this method is able to detect which output elements are dependent on which input elements. However, this method is not foolproof, and it may fail to detect the correct sparsity pattern in some cases, particularly if values of `v` are used in control-flow statements. -If DiffSol does not detect the correct sparsity pattern, you can manually specify the jacobian. To do this, you need -to implement the [`diffsol::NonLinearOp`](https://docs.rs/diffsol/latest/diffsol/op/trait.NonLinearOp.html) trait for the rhs function. -This is described in more detail in the ["Custom Problem Structs"](./custom_problem_structs.md) section, but is illustrated below. - -```rust -# fn main() { -use std::rc::Rc; -use faer::sparse::{SparseColMat, SymbolicSparseColMat}; -use diffsol::{NonLinearOp, NonLinearOpJacobian, OdeSolverEquations, Op, UnitCallable, ConstantClosure, OdeBuilder}; - -type T = f64; -type V = faer::Col; -type M = diffsol::SparseColMat; - -struct MyProblem { - jacobian: SparseColMat, - p: Rc, -} - -impl MyProblem { - fn new(p: Rc) -> Self { - let mut triplets = Vec::new(); - for i in 0..10 { - triplets.push((i, i, 1.0)); - } - let jacobian = SparseColMat::try_new_from_triplets(10, 10, triplets.as_slice()).unwrap(); - MyProblem { p, jacobian } - } -} - -impl Op for MyProblem { - type T = T; - type V = V; - type M = M; - fn nstates(&self) -> usize { - 10 - } - fn nout(&self) -> usize { - 10 - } - fn nparams(&self) -> usize { - 2 - } -} - -impl NonLinearOp for MyProblem { - fn call_inplace(&self, x: &V, _t: T, y: &mut V) { - for i in 0..10 { - y[i] = self.p[0] * x[i] * (1.0 - x[i] / self.p[1]); - } - } - - -} -impl NonLinearOpJacobian for MyProblem { - fn jac_mul_inplace(&self, x: &V, _t: T, v: &V, y: &mut V) { - for i in 0..10 { - y[i] = self.p[0] * v[i] * (1.0 - 2.0 * x[i] / self.p[1]); - } - } - fn jacobian_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::M) { - for i in 0..10 { - let row = y.faer().row_indices()[i]; - y.faer_mut().values_mut()[i] = self.p[0] * (1.0 - 2.0 * x[row] / self.p[1]); - } - } - fn jacobian_sparsity(&self) -> Option> { - Some(self.jacobian.symbolic().to_owned().unwrap()) - } -} - -let p_slice = [1.0, 10.0]; -let p = Rc::new(V::from_fn(p_slice.len(), |i| p_slice[i])); -let rhs = MyProblem::new(p.clone()); - -// use the provided constant closure to define the initial condition -let init_fn = |_p: &V, _t: T| V::from_fn(10, |_| 0.1); -let init = ConstantClosure::new(init_fn, p.clone()); - -// we don't have a mass matrix, root or output functions, so we can set to None -// we still need to give a placeholder type for these, so we use the diffsol::UnitCallable type -let mass: Option> = None; -let root: Option> = None; -let out: Option> = None; - -let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p); -let _problem = OdeBuilder::new().p(p_slice).build_from_eqn(eqn).unwrap(); -# } -``` - +If DiffSol does not detect the correct sparsity pattern, you can manually specify the jacobian. To do this, you need to use a custom struct that implements the `OdeEquations` trait, +This is described in more detail in the ["Custom Problem Structs"](./custom_problem_structs.md) section. \ No newline at end of file diff --git a/book/src/specify/specifying_the_problem.md b/book/src/specify/specifying_the_problem.md index f6f9e04a..053ddd49 100644 --- a/book/src/specify/specifying_the_problem.md +++ b/book/src/specify/specifying_the_problem.md @@ -30,7 +30,7 @@ DiffSol has three main APIs for specifying problems: where the user can implement the functions above on their own structs. This API is more flexible than the `OdeBuilder` API, but is more complex to use. It is useful if you have custom data structures and code that you want to use to evaluate your functions that does not fit within the `OdeBuilder` API. -- The [`DiffSlContext`](https://docs.rs/diffsol/latest/diffsol/ode_solver/diffsl/struct.DiffSlContext.html) struct, where the user can specify the functions above using the [DiffSL](https://martinjrobins.github.io/diffsl/) +- The [`DiffSl`](https://docs.rs/diffsol/latest/diffsol/ode_solver/diffsl/struct.DiffSl.html) struct, where the user can specify the functions above using the [DiffSL](https://martinjrobins.github.io/diffsl/) Domain Specific Language (DSL). This API is behind a feature flag (`diffsl` if you want to use the slower cranelift backend, `diffsl-llvm*` if you want to use the faster LLVM backend), but has the best API if you want to use DiffSL from a higher-level language like Python or R while still having similar performance. diff --git a/src/jacobian/mod.rs b/src/jacobian/mod.rs index ddd690cf..f40a06f8 100644 --- a/src/jacobian/mod.rs +++ b/src/jacobian/mod.rs @@ -219,11 +219,11 @@ impl JacobianColoring { #[cfg(test)] mod tests { - use std::rc::Rc; use crate::jacobian::{find_jacobian_non_zeros, JacobianColoring}; use crate::matrix::Matrix; use crate::op::linear_closure::LinearClosure; + use crate::op::ParameterisedOp; use crate::vector::Vector; use crate::{ jacobian::{coloring::nonzeros2graph, greedy_coloring::color_graph_greedy}, @@ -235,11 +235,17 @@ mod tests { use num_traits::{One, Zero}; use std::ops::MulAssign; + #[allow(clippy::type_complexity)] fn helper_triplets2op_nonlinear<'a, M: Matrix + 'a>( triplets: &'a [(usize, usize, M::T)], + p: &'a M::V, nrows: usize, ncols: usize, - ) -> impl NonLinearOpJacobian + 'a { + ) -> Closure< + M, + impl Fn(&M::V, &M::V, M::T, &mut M::V) + use<'a, M>, + impl Fn(&M::V, &M::V, M::T, &M::V, &mut M::V) + use<'a, M>, + > { let nstates = ncols; let nout = nrows; let f = move |x: &M::V, y: &mut M::V| { @@ -258,19 +264,21 @@ mod tests { }, nstates, nout, - Rc::new(M::V::zeros(0)), + p.len(), ); let y0 = M::V::zeros(nstates); let t0 = M::T::zero(); - ret.calculate_sparsity(&y0, t0); + ret.calculate_sparsity(&y0, t0, p); ret } + #[allow(clippy::type_complexity)] fn helper_triplets2op_linear<'a, M: Matrix + 'a>( triplets: &'a [(usize, usize, M::T)], + p: &'a M::V, nrows: usize, ncols: usize, - ) -> impl LinearOp + 'a { + ) -> LinearClosure> { let nstates = ncols; let nout = nrows; let f = move |x: &M::V, y: &mut M::V| { @@ -285,10 +293,10 @@ mod tests { }, nstates, nout, - Rc::new(M::V::zeros(0)), + p.len(), ); let t0 = M::T::zero(); - ret.calculate_sparsity(t0); + ret.calculate_sparsity(t0, p); ret } @@ -308,8 +316,10 @@ mod tests { (1, 1, M::T::one()), ], ]; + let p = M::V::zeros(0); for triplets in test_triplets { - let op = helper_triplets2op_nonlinear::(triplets.as_slice(), 2, 2); + let op = helper_triplets2op_nonlinear::(triplets.as_slice(), &p, 2, 2); + let op = ParameterisedOp::new(&op, &p); let non_zeros = find_jacobian_non_zeros(&op, &M::V::zeros(2), M::T::zero()); let expect = triplets .iter() @@ -346,8 +356,10 @@ mod tests { ], ]; let expect = vec![vec![1, 1], vec![1, 2], vec![1, 1], vec![1, 2]]; + let p = M::V::zeros(0); for (triplets, expect) in test_triplets.iter().zip(expect) { - let op = helper_triplets2op_nonlinear::(triplets.as_slice(), 2, 2); + let op = helper_triplets2op_nonlinear::(triplets.as_slice(), &p, 2, 2); + let op = ParameterisedOp::new(&op, &p); let non_zeros = find_jacobian_non_zeros(&op, &M::V::zeros(2), M::T::zero()); let ncols = op.nstates(); let graph = nonzeros2graph(non_zeros.as_slice(), ncols); @@ -385,8 +397,10 @@ mod tests { let n = 3; // test nonlinear functions + let p = M::V::zeros(0); for triplets in test_triplets.iter() { - let op = helper_triplets2op_nonlinear::(triplets.as_slice(), n, n); + let op = helper_triplets2op_nonlinear::(triplets.as_slice(), &p, n, n); + let op = ParameterisedOp::new(&op, &p); let y0 = M::V::zeros(n); let t0 = M::T::zero(); let nonzeros = triplets @@ -405,8 +419,10 @@ mod tests { } // test linear functions + let p = M::V::zeros(0); for triplets in test_triplets { - let op = helper_triplets2op_linear::(triplets.as_slice(), n, n); + let op = helper_triplets2op_linear::(triplets.as_slice(), &p, n, n); + let op = ParameterisedOp::new(&op, &p); let t0 = M::T::zero(); let nonzeros = triplets .iter() diff --git a/src/lib.rs b/src/lib.rs index 9bd6642b..60ee77d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,15 +5,17 @@ //! //! ## Solving ODEs //! -//! The simplest way to create a new problem is to use the [OdeBuilder] struct. You can set the initial time, initial step size, relative tolerance, absolute tolerance, and parameters, -//! or leave them at their default values. Then, call one of the `build_*` functions (e.g. [OdeBuilder::build_ode], [OdeBuilder::build_ode_with_mass], [OdeBuilder::build_from_eqn]) to create a [OdeSolverProblem]. +//! The simplest way to create a new problem is to use the [OdeBuilder] struct. You can set many configuration options such as the initial time ([OdeBuilder::t0]), initial step size ([OdeBuilder::h0]), +//! relative tolerance ([OdeBuilder::rtol]), absolute tolerance ([OdeBuilder::atol]), parameters ([OdeBuilder::p]) and equations ([OdeBuilder::rhs_implicit], [OdeBuilder::init], [OdeBuilder::mass] etc.) +//! or leave them at their default values. Then, call the [OdeBuilder::build] function to create a [OdeSolverProblem]. //! //! You will also need to choose a matrix type to use. DiffSol can use the [nalgebra](https://nalgebra.org) `DMatrix` type, the [faer](https://github.com/sarah-ek/faer-rs) `Mat` type, or any other type that implements the //! [Matrix] trait. //! //! ## Initial state //! -//! The solver state is held in [OdeSolverState], and contains a state vector, the gradient of the state vector, the time, and the step size. You can intitialise a new state using [OdeSolverState::new], +//! The solver state is held in [OdeSolverState], and contains a state vector, the gradient of the state vector, the time, and the step size. The [OdeSolverProblem] class has a collection of methods to create and initialise +//! a new state for each solver ([OdeSolverProblem::bdf_state], [OdeSolverProblem::tr_bdf2_state], [OdeSolverProblem::esdirk34_state]). Or you can manually intitialise a new state using [OdeSolverState::new], //! or create an uninitialised state using [OdeSolverState::new_without_initialise] and intitialise it manually or using the [OdeSolverState::set_consistent] and [OdeSolverState::set_step_size] methods. //! //! To view the state within a solver, you can use the [OdeSolverMethod::state] or [OdeSolverMethod::state_mut] methods. These will return references to the state using either the [StateRef] or [StateRefMut] structs @@ -24,6 +26,9 @@ //! - A Backwards Difference Formulae [Bdf] solver, suitable for stiff problems and singular mass matrices. //! - A Singly Diagonally Implicit Runge-Kutta (SDIRK or ESDIRK) solver [Sdirk]. You can use your own butcher tableau using [Tableau] or use one of the provided ([Tableau::tr_bdf2], [Tableau::esdirk34]). //! +//! The easiest way to create a solver is to use one of the provided methods on the [OdeSolverProblem] struct ([OdeSolverProblem::bdf_solver], [OdeSolverProblem::tr_bdf2_solver], [OdeSolverProblem::esdirk34_solver]). +//! These create a new solver from a provided state and problem. Alternatively, you can create both the solver and the state at once using [OdeSolverProblem::bdf], [OdeSolverProblem::tr_bdf2], [OdeSolverProblem::esdirk34]. +//! //! See the [OdeSolverMethod] trait for a more detailed description of the available methods on each solver. Possible workflows are: //! - Use the [OdeSolverMethod::step] method to step the solution forward in time with an internal time step chosen by the solver to meet the error tolerances. //! - Use the [OdeSolverMethod::interpolate] method to interpolate the solution between the last two time steps. @@ -58,21 +63,21 @@ //! //! ## Events / Root finding //! -//! DiffSol provides a simple way to detect user-provided events during the integration of the ODEs. You can use this by providing a closure that has a zero-crossing at the event you want to detect, using the [OdeBuilder::build_ode_with_root] builder, +//! DiffSol provides a simple way to detect user-provided events during the integration of the ODEs. You can use this by providing a closure that has a zero-crossing at the event you want to detect, using the [OdeBuilder::root] method, //! or by providing a [NonLinearOp] that has a zero-crossing at the event you want to detect. To use the root finding feature while integrating with the solver, you can use the return value of [OdeSolverMethod::step] to check if an event has been detected. //! //! ## Forward Sensitivity Analysis //! -//! DiffSol provides a way to compute the forward sensitivity of the solution with respect to the parameters. To use this your equations struct must implement the [OdeEquationsSens] trait. -//! Note that by default the sensitivity equations are included in the error control for the solvers, you can change this by setting tolerances using the [OdeBuilder::sens_atol] and [[OdeBuilder::sens_rtol]] methods. -//! You will also need to use [SensitivitiesOdeSolverMethod::set_problem_with_sensitivities] to set the problem with sensitivities. +//! DiffSol provides a way to compute the forward sensitivity of the solution with respect to the parameters. You can provide the requires equations to the builder using [OdeBuilder::rhs_sens_implicit] and [OdeBuilder::init_sens], +//! or your equations struct must implement the [OdeEquationsSens] trait, +//! Note that by default the sensitivity equations are included in the error control for the solvers, you can change this by setting tolerances using the [OdeBuilder::sens_atol] and [OdeBuilder::sens_rtol] methods. //! -//! To obtain the sensitivity solution via interpolation, you can use the [OdeSolverMethod::interpolate_sens] method. Otherwise the sensitivity vectors are stored in the [OdeSolverState] struct. +//! The easiest way to obtain the sensitivity solution is to use the [OdeSolverMethod::solve_dense_sensitivities] method, which will solve the forward problem and the sensitivity equations simultaneously and return the result. +//! If you are manually stepping the solver, you can use the [OdeSolverMethod::interpolate_sens] method to obtain the sensitivity solution at a given time. Otherwise the sensitivity vectors are stored in the [OdeSolverState] struct. //! //! ## Checkpointing //! -//! You can checkpoint the solver at a set of times using the [OdeSolverMethod::checkpoint] method. This will store the state of the solver at the given times, and subsequently use the [OdeSolverMethod::set_problem] -//! method to restore the solver to the state at the given time. +//! You can checkpoint the solver at the current internal time [OdeSolverMethod::checkpoint] method. //! //! ## Interpolation //! @@ -93,10 +98,16 @@ //! and then used to compute the sensitivities of the output function. Checkpointing is typically used to store the forward solution at a set of times as theses are required //! to solve the adjoint equations. //! -//! To use the adjoint sensitivity method, your equations struct must implement the [OdeEquationsAdjoint] trait. When you compute the forward solution, use checkpointing -//! to store the solution at a set of times. From this you should obtain a `Vec` (that can be the start and end of the solution), and -//! a [HermiteInterpolator] that can be used to interpolate the solution between the last two checkpoints. You can then use the [AdjointOdeSolverMethod::into_adjoint_solver] -//! method to create an adjoint solver from the forward solver, and then use this solver to step the adjoint equations backwards in time. Once the adjoint equations have been solved, +//! To provide the builder with the required equations, you can use the [OdeBuilder::rhs_adjoint_implicit], [OdeBuilder::init_adjoint], and [OdeBuilder::out_adjoint_implicit] methods, +//! or your equations struct must implement the [OdeEquationsAdjoint] trait. +//! +//! The easiest way to obtain the adjoint solution is to use the [OdeSolverMethod::solve_adjoint] method, which will solve the forwards problem, then the adjoint problem and return the result. +//! If you wish to manually do the timestepping, then the best place to start is by looking at the source code for the [OdeSolverMethod::solve_adjoint] method. During the solution of the forwards problem +//! you will need to use checkpointing to store the solution at a set of times. +//! From this you should obtain a `Vec` (that can be the start and end of the solution), and +//! a [HermiteInterpolator] that can be used to interpolate the solution between the last two checkpoints. You can then use the [AdjointOdeSolverMethod::adjoint_equations] and then create +//! an adjoint solver either manually or using the [AdjointOdeSolverMethod::default_adjoint_solver] method. You can then use this solver to step the adjoint equations backwards in time using [OdeSolverMethod::step] as normal. +//! Once the adjoint equations have been solved, //! the sensitivities of the output function will be stored in the [StateRef::sg] field of the adjoint solver state. If your parameters are used to calculate the initial conditions //! of the forward problem, then you will need to use the [AdjointEquations::correct_sg_for_init] method to correct the sensitivities for the initial conditions. //! @@ -158,9 +169,6 @@ pub use vector::sundials::SundialsVector; #[cfg(feature = "sundials")] pub use linear_solver::sundials::SundialsLinearSolver; -#[cfg(feature = "sundials")] -pub use ode_solver::sundials::SundialsIda; - #[cfg(feature = "suitesparse")] pub use linear_solver::suitesparse::klu::KLU; @@ -184,14 +192,13 @@ use ode_solver::jacobian_update::JacobianUpdate; pub use ode_solver::state::{StateRef, StateRefMut}; pub use ode_solver::{ adjoint_equations::AdjointContext, adjoint_equations::AdjointEquations, - adjoint_equations::AdjointInit, adjoint_equations::AdjointRhs, bdf::Bdf, bdf::BdfAdj, - bdf_state::BdfState, builder::OdeBuilder, checkpointing::Checkpointing, - checkpointing::HermiteInterpolator, equations::AugmentedOdeEquations, - equations::AugmentedOdeEquationsImplicit, equations::NoAug, equations::OdeEquations, - equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit, equations::OdeEquationsRef, - equations::OdeEquationsSens, equations::OdeSolverEquations, method::AdjointOdeSolverMethod, - method::OdeSolverMethod, method::OdeSolverStopReason, method::SensitivitiesOdeSolverMethod, - problem::OdeSolverProblem, sdirk::Sdirk, sdirk::SdirkAdj, sdirk_state::SdirkState, + adjoint_equations::AdjointInit, adjoint_equations::AdjointRhs, bdf::Bdf, bdf_state::BdfState, + builder::OdeBuilder, checkpointing::Checkpointing, checkpointing::HermiteInterpolator, + equations::AugmentedOdeEquations, equations::AugmentedOdeEquationsImplicit, equations::NoAug, + equations::OdeEquations, equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit, + equations::OdeEquationsRef, equations::OdeEquationsSens, equations::OdeSolverEquations, + method::AdjointOdeSolverMethod, method::AugmentedOdeSolverMethod, method::OdeSolverMethod, + method::OdeSolverStopReason, problem::OdeSolverProblem, sdirk::Sdirk, sdirk_state::SdirkState, sens_equations::SensEquations, sens_equations::SensInit, sens_equations::SensRhs, state::OdeSolverState, tableau::Tableau, }; @@ -203,7 +210,7 @@ pub use op::nonlinear_op::{ pub use op::{ closure::Closure, closure_with_adjoint::ClosureWithAdjoint, constant_closure::ConstantClosure, constant_closure_with_adjoint::ConstantClosureWithAdjoint, linear_closure::LinearClosure, - unit::UnitCallable, Op, + unit::UnitCallable, BuilderOp, Op, ParameterisedOp, }; use op::{ closure_no_jac::ClosureNoJac, closure_with_sens::ClosureWithSens, diff --git a/src/linear_solver/faer/lu.rs b/src/linear_solver/faer/lu.rs index 8e139c9d..d68ffd54 100644 --- a/src/linear_solver/faer/lu.rs +++ b/src/linear_solver/faer/lu.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{error::LinearSolverError, linear_solver_error}; use crate::{ @@ -49,12 +47,7 @@ impl LinearSolver> for LU { Ok(()) } - fn set_problem, M = Mat>>( - &mut self, - op: &C, - _rtol: T, - _atol: Rc>, - ) { + fn set_problem, M = Mat>>(&mut self, op: &C) { let ncols = op.nstates(); let nrows = op.nout(); let matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity()); diff --git a/src/linear_solver/faer/sparse_lu.rs b/src/linear_solver/faer/sparse_lu.rs index 788933b3..a5876c42 100644 --- a/src/linear_solver/faer/sparse_lu.rs +++ b/src/linear_solver/faer/sparse_lu.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{ error::{DiffsolError, LinearSolverError}, linear_solver::LinearSolver, @@ -67,8 +65,6 @@ impl LinearSolver> for FaerSparseLU { fn set_problem, M = SparseColMat>>( &mut self, op: &C, - _rtol: T, - _atol: Rc>, ) { let ncols = op.nstates(); let nrows = op.nout(); diff --git a/src/linear_solver/mod.rs b/src/linear_solver/mod.rs index b583a0db..ca064d39 100644 --- a/src/linear_solver/mod.rs +++ b/src/linear_solver/mod.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{error::DiffsolError, Matrix, NonLinearOpJacobian}; #[cfg(feature = "nalgebra")] @@ -31,12 +29,7 @@ pub trait LinearSolver: Default { /// Set the problem to be solved, any previous problem is discarded. /// Any internal state of the solver is reset. /// This function will normally set the sparsity pattern of the matrix to be solved. - fn set_problem>( - &mut self, - op: &C, - rtol: M::T, - atol: Rc, - ); + fn set_problem>(&mut self, op: &C); /// Solve the problem `Ax = b` and return the solution `x`. /// panics if [Self::set_linearisation] has not been called previously @@ -62,11 +55,9 @@ impl LinearSolveSolution { #[cfg(test)] pub mod tests { - use std::rc::Rc; - use crate::{ linear_solver::{FaerLU, NalgebraLU}, - op::closure::Closure, + op::{closure::Closure, ParameterisedOp}, scalar::scale, vector::VectorRef, LinearSolver, Matrix, NonLinearOpJacobian, Vector, @@ -77,26 +68,30 @@ pub mod tests { #[allow(clippy::type_complexity)] pub fn linear_problem() -> ( - impl NonLinearOpJacobian, + Closure< + M, + impl Fn(&M::V, &M::V, M::T, &mut M::V), + impl Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + >, M::T, - Rc, + M::V, Vec>, ) { let diagonal = M::V::from_vec(vec![2.0.into(), 2.0.into()]); let jac1 = M::from_diagonal(&diagonal); let jac2 = M::from_diagonal(&diagonal); - let p = Rc::new(M::V::zeros(0)); + let p = M::V::zeros(0); let mut op = Closure::new( // f = J * x move |x, _p, _t, y| jac1.gemv(M::T::one(), x, M::T::zero(), y), move |_x, _p, _t, v, y| jac2.gemv(M::T::one(), v, M::T::zero(), y), 2, 2, - p, + p.len(), ); - op.calculate_sparsity(&M::V::from_element(2, M::T::one()), M::T::zero()); + op.calculate_sparsity(&M::V::from_element(2, M::T::one()), M::T::zero(), &p); let rtol = M::T::from(1e-6); - let atol = Rc::new(M::V::from_vec(vec![1e-6.into(), 1e-6.into()])); + let atol = M::V::from_vec(vec![1e-6.into(), 1e-6.into()]); let solns = vec![LinearSolveSolution::new( M::V::from_vec(vec![2.0.into(), 4.0.into()]), M::V::from_vec(vec![1.0.into(), 2.0.into()]), @@ -104,23 +99,23 @@ pub mod tests { (op, rtol, atol, solns) } - pub fn test_linear_solver( + pub fn test_linear_solver<'a, C>( mut solver: impl LinearSolver, op: C, rtol: C::T, - atol: Rc, + atol: &'a C::V, solns: Vec>, ) where C: NonLinearOpJacobian, - for<'a> &'a C::V: VectorRef, + for<'b> &'b C::V: VectorRef, { - solver.set_problem(&op, rtol, atol.clone()); + solver.set_problem(&op); let x = C::V::zeros(op.nout()); let t = C::T::zero(); solver.set_linearisation(&op, &x, t); for soln in solns { let x = solver.solve(&soln.b).unwrap(); - let tol = { &soln.x * scale(rtol) + atol.as_ref() }; + let tol = { &soln.x * scale(rtol) + atol }; x.assert_eq(&soln.x, &tol); } } @@ -131,13 +126,17 @@ pub mod tests { #[test] fn test_lu_nalgebra() { let (op, rtol, atol, solns) = linear_problem::(); + let p = nalgebra::DVector::zeros(0); + let op = ParameterisedOp::new(&op, &p); let s = NalgebraLU::default(); - test_linear_solver(s, op, rtol, atol, solns); + test_linear_solver(s, op, rtol, &atol, solns); } #[test] fn test_lu_faer() { let (op, rtol, atol, solns) = linear_problem::(); + let p = faer::Col::zeros(0); + let op = ParameterisedOp::new(&op, &p); let s = FaerLU::default(); - test_linear_solver(s, op, rtol, atol, solns); + test_linear_solver(s, op, rtol, &atol, solns); } } diff --git a/src/linear_solver/nalgebra/lu.rs b/src/linear_solver/nalgebra/lu.rs index 1bdd92a3..ec6c7f39 100644 --- a/src/linear_solver/nalgebra/lu.rs +++ b/src/linear_solver/nalgebra/lu.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use nalgebra::{DMatrix, DVector, Dyn}; use crate::{ @@ -55,8 +53,6 @@ impl LinearSolver> for LU { fn set_problem, M = DMatrix>>( &mut self, op: &C, - _rtol: T, - _atol: Rc>, ) { let ncols = op.nstates(); let nrows = op.nout(); diff --git a/src/linear_solver/suitesparse/klu.rs b/src/linear_solver/suitesparse/klu.rs index d52b4c22..17086c0a 100644 --- a/src/linear_solver/suitesparse/klu.rs +++ b/src/linear_solver/suitesparse/klu.rs @@ -1,7 +1,5 @@ -use std::cell::RefCell; -use std::rc::Rc; - use faer::Col; +use std::cell::RefCell; #[cfg(target_pointer_width = "32")] use suitesparse_sys::{ @@ -223,12 +221,7 @@ where Ok(()) } - fn set_problem>( - &mut self, - op: &C, - _rtol: M::T, - _atol: Rc, - ) { + fn set_problem>(&mut self, op: &C) { let ncols = op.nstates(); let nrows = op.nout(); let mut matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity()); @@ -242,6 +235,7 @@ where mod tests { use crate::{ linear_solver::tests::{linear_problem, test_linear_solver}, + op::ParameterisedOp, SparseColMat, }; @@ -250,7 +244,9 @@ mod tests { #[test] fn test_klu() { let (op, rtol, atol, solns) = linear_problem::>(); + let p = faer::Col::zeros(0); + let op = ParameterisedOp::new(&op, &p); let s = KLU::default(); - test_linear_solver(s, op, rtol, atol, solns); + test_linear_solver(s, op, rtol, &atol, solns); } } diff --git a/src/linear_solver/sundials.rs b/src/linear_solver/sundials.rs index 0e1179aa..d437ccb8 100644 --- a/src/linear_solver/sundials.rs +++ b/src/linear_solver/sundials.rs @@ -1,19 +1,31 @@ -use std::rc::Rc; - use crate::sundials_sys::{ - realtype, SUNLinSolFree, SUNLinSolSetup, SUNLinSolSolve, SUNLinSol_Dense, SUNLinearSolver, + realtype, IDAGetReturnFlagName, SUNLinSolFree, SUNLinSolSetup, SUNLinSolSolve, SUNLinSol_Dense, + SUNLinearSolver, }; use crate::{ - error::*, linear_solver_error, ode_solver::sundials::sundials_check, - vector::sundials::SundialsVector, Matrix, NonLinearOpJacobian, SundialsMatrix, + error::*, linear_solver_error, vector::sundials::SundialsVector, Matrix, NonLinearOpJacobian, + SundialsMatrix, }; +use std::ffi::{c_int, CStr}; #[cfg(not(sundials_version_major = "5"))] use crate::vector::sundials::get_suncontext; use super::LinearSolver; +pub fn sundials_check(retval: c_int) -> Result<(), DiffsolError> { + if retval < 0 { + let char_ptr = unsafe { IDAGetReturnFlagName(i64::from(retval)) }; + let c_str = unsafe { CStr::from_ptr(char_ptr) }; + Err(DiffsolError::from(OdeSolverError::SundialsError( + c_str.to_str().unwrap().to_string(), + ))) + } else { + Ok(()) + } +} + pub struct SundialsLinearSolver { linear_solver: Option, is_setup: bool, @@ -48,8 +60,6 @@ impl LinearSolver for SundialsLinearSolver { fn set_problem>( &mut self, op: &C, - _rtol: realtype, - _atol: Rc, ) { let matrix = SundialsMatrix::zeros(op.nstates(), op.nstates()); let y0 = SundialsVector::new_serial(op.nstates()); diff --git a/src/matrix/sundials.rs b/src/matrix/sundials.rs index 6224fe32..d35c5c48 100644 --- a/src/matrix/sundials.rs +++ b/src/matrix/sundials.rs @@ -11,7 +11,7 @@ use crate::sundials_sys::{ }; use crate::{ - error::*, matrix_error, ode_solver::sundials::sundials_check, scalar::scale, + error::*, linear_solver::sundials::sundials_check, matrix_error, scalar::scale, vector::sundials::SundialsVector, IndexType, Scale, SundialsLinearSolver, Vector, }; diff --git a/src/nonlinear_solver/convergence.rs b/src/nonlinear_solver/convergence.rs index 73d846cd..2b43c2fc 100644 --- a/src/nonlinear_solver/convergence.rs +++ b/src/nonlinear_solver/convergence.rs @@ -1,13 +1,12 @@ use nalgebra::ComplexField; use num_traits::{One, Pow}; -use std::rc::Rc; use crate::{scalar::IndexType, Scalar, Vector}; #[derive(Clone)] -pub struct Convergence { - rtol: V::T, - atol: Rc, +pub struct Convergence<'a, V: Vector> { + pub rtol: V::T, + pub atol: &'a V, tol: V::T, max_iter: IndexType, niter: IndexType, @@ -21,7 +20,7 @@ pub enum ConvergenceStatus { MaximumIterations, } -impl Convergence { +impl<'a, V: Vector> Convergence<'a, V> { pub fn max_iter(&self) -> IndexType { self.max_iter } @@ -31,7 +30,7 @@ impl Convergence { pub fn niter(&self) -> IndexType { self.niter } - pub fn new(rtol: V::T, atol: Rc) -> Self { + pub fn new(rtol: V::T, atol: &'a V) -> Self { let minimum_tol = V::T::from(10.0) * V::T::EPSILON / rtol; let maximum_tol = V::T::from(0.03); let mut tol = V::T::from(0.33); @@ -57,7 +56,7 @@ impl Convergence { pub fn check_new_iteration(&mut self, dy: &mut V, y: &V) -> ConvergenceStatus { self.niter += 1; - let norm = dy.squared_norm(y, &self.atol, self.rtol).sqrt(); + let norm = dy.squared_norm(y, self.atol, self.rtol).sqrt(); // if norm is zero then we are done if norm <= V::T::EPSILON { return ConvergenceStatus::Converged; diff --git a/src/nonlinear_solver/mod.rs b/src/nonlinear_solver/mod.rs index feb5e104..1c3a0339 100644 --- a/src/nonlinear_solver/mod.rs +++ b/src/nonlinear_solver/mod.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{error::DiffsolError, Matrix, NonLinearOp, NonLinearOpJacobian}; use convergence::Convergence; @@ -16,17 +14,8 @@ impl NonLinearSolveSolution { /// A solver for the nonlinear problem `F(x) = 0`. pub trait NonLinearSolver: Default { - fn convergence(&self) -> &Convergence; - - fn convergence_mut(&mut self) -> &mut Convergence; - /// Set the problem to be solved, any previous problem is discarded. - fn set_problem>( - &mut self, - op: &C, - rtol: M::T, - atol: Rc, - ); + fn set_problem>(&mut self, op: &C); /// Reset the approximation of the Jacobian matrix. fn reset_jacobian>( @@ -43,9 +32,10 @@ pub trait NonLinearSolver: Default { x: &M::V, t: M::T, error_y: &M::V, + convergence: &mut Convergence<'_, M::V>, ) -> Result { let mut x = x.clone(); - self.solve_in_place(op, &mut x, t, error_y)?; + self.solve_in_place(op, &mut x, t, error_y, convergence)?; Ok(x) } @@ -56,6 +46,7 @@ pub trait NonLinearSolver: Default { x: &mut C::V, t: C::T, error_y: &C::V, + convergence: &mut Convergence<'_, M::V>, ) -> Result<(), DiffsolError>; /// Solve the linearised problem `J * x = b`, where `J` was calculated using [Self::reset_jacobian]. @@ -70,12 +61,12 @@ pub mod root; //tests #[cfg(test)] pub mod tests { - use std::rc::Rc; - use self::newton::NewtonNonlinearSolver; use crate::{ - linear_solver::nalgebra::lu::LU, matrix::MatrixCommon, op::closure::Closure, scale, - DenseMatrix, Vector, + linear_solver::nalgebra::lu::LU, + matrix::MatrixCommon, + op::{closure::Closure, ParameterisedOp}, + scale, DenseMatrix, Vector, }; use super::*; @@ -83,9 +74,13 @@ pub mod tests { #[allow(clippy::type_complexity)] pub fn get_square_problem() -> ( - impl NonLinearOpJacobian, + Closure< + M, + impl Fn(&M::V, &M::V, M::T, &mut M::V), + impl Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + >, M::T, - Rc, + M::V, Vec>, ) where @@ -93,7 +88,7 @@ pub mod tests { { let jac1 = M::from_diagonal(&M::V::from_vec(vec![2.0.into(), 2.0.into()])); let jac2 = jac1.clone(); - let p = Rc::new(M::V::zeros(0)); + let p = M::V::zeros(0); let op = Closure::new( // 0 = J * x * x - 8 move |x: &::V, _p: &::V, _t, y| { @@ -108,10 +103,10 @@ pub mod tests { }, 2, 2, - p, + p.len(), ); let rtol = M::T::from(1e-6); - let atol = Rc::new(M::V::from_vec(vec![1e-6.into(), 1e-6.into()])); + let atol = M::V::from_vec(vec![1e-6.into(), 1e-6.into()]); let solns = vec![NonLinearSolveSolution::new( M::V::from_vec(vec![2.1.into(), 2.1.into()]), M::V::from_vec(vec![2.0.into(), 2.0.into()]), @@ -123,17 +118,20 @@ pub mod tests { mut solver: impl NonLinearSolver, op: C, rtol: C::T, - atol: Rc, + atol: &C::V, solns: Vec>, ) where C: NonLinearOpJacobian, { - solver.set_problem(&op, rtol, atol.clone()); + solver.set_problem(&op); + let mut convergence = Convergence::new(rtol, atol); let t = C::T::zero(); solver.reset_jacobian(&op, &solns[0].x0, t); for soln in solns { - let x = solver.solve(&op, &soln.x0, t, &soln.x0).unwrap(); - let tol = x.clone() * scale(rtol) + atol.as_ref(); + let x = solver + .solve(&op, &soln.x0, t, &soln.x0, &mut convergence) + .unwrap(); + let tol = x.clone() * scale(rtol) + atol; x.assert_eq(&soln.x, &tol); } } @@ -144,7 +142,9 @@ pub mod tests { fn test_newton_cpu_square() { let lu = LU::default(); let (op, rtol, atol, soln) = get_square_problem::(); + let p = nalgebra::DVector::zeros(0); + let op = ParameterisedOp::new(&op, &p); let s = NewtonNonlinearSolver::new(lu); - test_nonlinear_solver(s, op, rtol, atol, soln); + test_nonlinear_solver(s, op, rtol, &atol, soln); } } diff --git a/src/nonlinear_solver/newton.rs b/src/nonlinear_solver/newton.rs index d56168d9..4338cf3e 100644 --- a/src/nonlinear_solver/newton.rs +++ b/src/nonlinear_solver/newton.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{ error::{DiffsolError, NonLinearSolverError}, non_linear_solver_error, Convergence, ConvergenceStatus, LinearSolver, Matrix, NonLinearOp, @@ -37,7 +35,6 @@ pub fn newton_iteration( } pub struct NewtonNonlinearSolver> { - convergence: Option>, linear_solver: Ls, is_jacobian_set: bool, tmp: M::V, @@ -46,7 +43,6 @@ pub struct NewtonNonlinearSolver> { impl> NewtonNonlinearSolver { pub fn new(linear_solver: Ls) -> Self { Self { - convergence: None, linear_solver, is_jacobian_set: false, tmp: M::V::zeros(0), @@ -64,26 +60,8 @@ impl> Default for NewtonNonlinearSolver { } impl> NonLinearSolver for NewtonNonlinearSolver { - fn convergence(&self) -> &Convergence { - self.convergence - .as_ref() - .expect("NewtonNonlinearSolver::convergence() called before set_problem") - } - - fn convergence_mut(&mut self) -> &mut Convergence { - self.convergence - .as_mut() - .expect("NewtonNonlinearSolver::convergence_mut() called before set_problem") - } - - fn set_problem>( - &mut self, - op: &C, - rtol: M::T, - atol: Rc, - ) { - self.linear_solver.set_problem(op, rtol, atol.clone()); - self.convergence = Some(Convergence::new(rtol, atol)); + fn set_problem>(&mut self, op: &C) { + self.linear_solver.set_problem(op); self.is_jacobian_set = false; self.tmp = C::V::zeros(op.nstates()); } @@ -108,10 +86,8 @@ impl> NonLinearSolver for NewtonNonlinearSolve xn: &mut M::V, t: M::T, error_y: &M::V, + convergence: &mut Convergence, ) -> Result<(), DiffsolError> { - if self.convergence.is_none() { - panic!("NewtonNonlinearSolver::solve() called before set_problem"); - } if !self.is_jacobian_set { panic!("NewtonNonlinearSolver::solve_in_place() called before reset_jacobian"); } @@ -120,7 +96,6 @@ impl> NonLinearSolver for NewtonNonlinearSolve } let linear_solver = |x: &mut C::V| self.linear_solver.solve_in_place(x); let fun = |x: &C::V, y: &mut C::V| op.call_inplace(x, t, y); - let convergence = self.convergence.as_mut().unwrap(); newton_iteration(xn, &mut self.tmp, error_y, fun, linear_solver, convergence) } } diff --git a/src/nonlinear_solver/root.rs b/src/nonlinear_solver/root.rs index 4cd04fb4..d0da6162 100644 --- a/src/nonlinear_solver/root.rs +++ b/src/nonlinear_solver/root.rs @@ -8,6 +8,7 @@ use crate::{ use num_traits::{abs, One, Zero}; +#[derive(Clone)] pub struct RootFinder { t0: RefCell, g0: RefCell, @@ -159,23 +160,23 @@ impl RootFinder { #[cfg(test)] mod tests { - use std::rc::Rc; - - use crate::{error::DiffsolError, ClosureNoJac, RootFinder, Vector}; + use crate::{error::DiffsolError, op::ParameterisedOp, ClosureNoJac, RootFinder, Vector}; #[test] fn test_root() { type V = nalgebra::DVector; type M = nalgebra::DMatrix; let interpolate = |t: f64| -> Result { Ok(Vector::from_vec(vec![t])) }; + let p = V::zeros(0); let root_fn = ClosureNoJac::::new( |y: &V, _p: &V, _t: f64, g: &mut V| { g[0] = y[0] - 0.4; }, 1, 1, - Rc::new(V::zeros(0)), + p.len(), ); + let root_fn = ParameterisedOp::new(&root_fn, &p); // check no root let root_finder = RootFinder::new(1); diff --git a/src/ode_solver/adjoint_equations.rs b/src/ode_solver/adjoint_equations.rs index 3320b191..a9d5f966 100644 --- a/src/ode_solver/adjoint_equations.rs +++ b/src/ode_solver/adjoint_equations.rs @@ -12,26 +12,26 @@ use crate::{ OdeSolverProblem, Op, Vector, }; -pub struct AdjointContext +pub struct AdjointContext<'a, Eqn, Method> where Eqn: OdeEquations, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { - checkpointer: Checkpointing, + checkpointer: Checkpointing<'a, Eqn, Method>, x: Eqn::V, index: usize, last_t: Option, col: Eqn::V, } -impl AdjointContext +impl<'a, Eqn, Method> AdjointContext<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { - pub fn new(checkpointer: Checkpointing) -> Self { - let x = ::zeros(checkpointer.problem.eqn.rhs().nstates()); - let mut col = ::zeros(checkpointer.problem.eqn.out().unwrap().nout()); + pub fn new(checkpointer: Checkpointing<'a, Eqn, Method>) -> Self { + let x = ::zeros(checkpointer.problem().eqn.rhs().nstates()); + let mut col = ::zeros(checkpointer.problem().eqn.out().unwrap().nout()); let index = 0; col[0] = Eqn::T::one(); Self { @@ -68,23 +68,23 @@ where } } -pub struct AdjointMass +pub struct AdjointMass<'a, Eqn> where Eqn: OdeEquationsAdjoint, { - eqn: Rc, + eqn: &'a Eqn, } -impl AdjointMass +impl<'a, Eqn> AdjointMass<'a, Eqn> where Eqn: OdeEquationsAdjoint, { - pub fn new(eqn: &Rc) -> Self { - Self { eqn: eqn.clone() } + pub fn new(eqn: &'a Eqn) -> Self { + Self { eqn } } } -impl Op for AdjointMass +impl Op for AdjointMass<'_, Eqn> where Eqn: OdeEquationsAdjoint, { @@ -103,7 +103,7 @@ where } } -impl LinearOp for AdjointMass +impl LinearOp for AdjointMass<'_, Eqn> where Eqn: OdeEquationsAdjoint, { @@ -119,23 +119,23 @@ where } } -pub struct AdjointInit +pub struct AdjointInit<'a, Eqn> where Eqn: OdeEquationsAdjoint, { - eqn: Rc, + eqn: &'a Eqn, } -impl AdjointInit +impl<'a, Eqn> AdjointInit<'a, Eqn> where Eqn: OdeEquationsAdjoint, { - pub fn new(eqn: &Rc) -> Self { - Self { eqn: eqn.clone() } + pub fn new(eqn: &'a Eqn) -> Self { + Self { eqn } } } -impl Op for AdjointInit +impl Op for AdjointInit<'_, Eqn> where Eqn: OdeEquationsAdjoint, { @@ -154,7 +154,7 @@ where } } -impl ConstantOp for AdjointInit +impl ConstantOp for AdjointInit<'_, Eqn> where Eqn: OdeEquationsAdjoint, { @@ -171,31 +171,31 @@ where /// g_x is the partial derivative of the functional g with respect to the state vector. /// /// We need the current state x(t), which is obtained from the checkpointed forward solve at the current time step. -pub struct AdjointRhs +pub struct AdjointRhs<'a, Eqn, Method> where Eqn: OdeEquations, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { - eqn: Rc, - context: Rc>>, + eqn: &'a Eqn, + context: Rc>>, tmp: RefCell, with_out: bool, } -impl AdjointRhs +impl<'a, Eqn, Method> AdjointRhs<'a, Eqn, Method> where Eqn: OdeEquations, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { pub fn new( - eqn: &Rc, - context: Rc>>, + eqn: &'a Eqn, + context: Rc>>, with_out: bool, ) -> Self { let tmp_n = if with_out { eqn.rhs().nstates() } else { 0 }; let tmp = RefCell::new(::zeros(tmp_n)); Self { - eqn: eqn.clone(), + eqn, context, tmp, with_out, @@ -203,10 +203,10 @@ where } } -impl Op for AdjointRhs +impl<'a, Eqn, Method> Op for AdjointRhs<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { type T = Eqn::T; type V = Eqn::V; @@ -223,10 +223,10 @@ where } } -impl NonLinearOp for AdjointRhs +impl<'a, Eqn, Method> NonLinearOp for AdjointRhs<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { /// F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t) fn call_inplace(&self, lambda: &Self::V, t: Self::T, y: &mut Self::V) { @@ -250,10 +250,10 @@ where } } -impl NonLinearOpJacobian for AdjointRhs +impl<'a, Eqn, Method> NonLinearOpJacobian for AdjointRhs<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { // J = -f^T_x(x, t) fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { @@ -281,31 +281,31 @@ where /// g_p is the partial derivative of the functional g with respect to the parameter vector /// /// We need the current state x(t), which is obtained from the checkpointed forward solve at the current time step. -pub struct AdjointOut +pub struct AdjointOut<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { - eqn: Rc, - context: Rc>>, + eqn: &'a Eqn, + context: Rc>>, tmp: RefCell, with_out: bool, } -impl AdjointOut +impl<'a, Eqn, Method> AdjointOut<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { pub fn new( - eqn: &Rc, - context: Rc>>, + eqn: &'a Eqn, + context: Rc>>, with_out: bool, ) -> Self { let tmp_n = if with_out { eqn.rhs().nparams() } else { 0 }; let tmp = RefCell::new(::zeros(tmp_n)); Self { - eqn: eqn.clone(), + eqn, context, tmp, with_out, @@ -313,10 +313,10 @@ where } } -impl Op for AdjointOut +impl<'a, Eqn, Method> Op for AdjointOut<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { type T = Eqn::T; type V = Eqn::V; @@ -333,10 +333,10 @@ where } } -impl NonLinearOp for AdjointOut +impl<'a, Eqn, Method> NonLinearOp for AdjointOut<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { /// F(λ, x, t) = -g_p(x, t) - λ^T f_p(x, t) fn call_inplace(&self, lambda: &Self::V, t: Self::T, y: &mut Self::V) { @@ -357,10 +357,10 @@ where } } -impl NonLinearOpJacobian for AdjointOut +impl<'a, Eqn, Method> NonLinearOpJacobian for AdjointOut<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { // J = -f_p(x, t) fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { @@ -386,66 +386,106 @@ where /// λ(T) = 0 /// g(λ, x, t) = -g_p(x, t) - λ^T f_p(x, t) /// -pub struct AdjointEquations +pub struct AdjointEquations<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { - eqn: Rc, - rhs: AdjointRhs, - out: Option>, - mass: Option>, - context: Rc>>, + eqn: &'a Eqn, + rhs: AdjointRhs<'a, Eqn, Method>, + out: Option>, + mass: Option>, + context: Rc>>, tmp: RefCell, tmp2: RefCell, - init: Rc>, - atol: Option>, + init: AdjointInit<'a, Eqn>, + atol: Option<&'a Eqn::V>, rtol: Option, out_rtol: Option, - out_atol: Option>, + out_atol: Option<&'a Eqn::V>, } -impl AdjointEquations +impl<'a, Eqn, Method> Clone for AdjointEquations<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, +{ + fn clone(&self) -> Self { + let context = Rc::new(RefCell::new(AdjointContext::new( + self.context.borrow().checkpointer.clone(), + ))); + let rhs = AdjointRhs::new(self.eqn, context.clone(), self.out.is_some()); + let init = AdjointInit::new(self.eqn); + let out = if self.out.is_some() { + Some(AdjointOut::new(self.eqn, context.clone(), true)) + } else { + None + }; + let tmp = self.tmp.clone(); + let tmp2 = self.tmp2.clone(); + let atol = self.atol; + let rtol = self.rtol; + let out_atol = self.out_atol; + let out_rtol = self.out_rtol; + let mass = self.eqn.mass().map(|_m| AdjointMass::new(self.eqn)); + Self { + rhs, + init, + mass, + context, + out, + tmp, + tmp2, + eqn: self.eqn, + atol, + rtol, + out_rtol, + out_atol, + } + } +} + +impl<'a, Eqn, Method> AdjointEquations<'a, Eqn, Method> +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod<'a, Eqn>, { pub(crate) fn new( - problem: &OdeSolverProblem, - context: Rc>>, + problem: &'a OdeSolverProblem, + context: Rc>>, with_out: bool, ) -> Self { - let eqn = problem.eqn.clone(); - let rhs = AdjointRhs::new(&eqn, context.clone(), with_out); - let init = Rc::new(AdjointInit::new(&eqn)); + let eqn = &problem.eqn; + let rhs = AdjointRhs::new(eqn, context.clone(), with_out); + let init = AdjointInit::new(eqn); let out = if with_out { - Some(AdjointOut::new(&eqn, context.clone(), with_out)) + Some(AdjointOut::new(eqn, context.clone(), with_out)) } else { None }; let tmp = if with_out { - RefCell::new(::zeros(0)) - } else { RefCell::new(::zeros(eqn.rhs().nparams())) + } else { + RefCell::new(::zeros(0)) }; let tmp2 = if with_out { - RefCell::new(::zeros(0)) - } else { RefCell::new(::zeros(eqn.rhs().nstates())) + } else { + RefCell::new(::zeros(0)) }; let atol = if with_out { - problem.sens_atol.clone() + problem.sens_atol.as_ref() } else { None }; let rtol = if with_out { problem.sens_rtol } else { None }; let out_atol = if with_out { - problem.out_atol.clone() + problem.out_atol.as_ref() } else { None }; let out_rtol = if with_out { problem.out_rtol } else { None }; - let mass = eqn.mass().map(|_m| AdjointMass::new(&eqn)); + let mass = eqn.mass().map(|_m| AdjointMass::new(eqn)); Self { rhs, init, @@ -480,20 +520,20 @@ where } } -impl std::fmt::Debug for AdjointEquations +impl<'a, Eqn, Method> std::fmt::Debug for AdjointEquations<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AdjointEquations").finish() } } -impl Op for AdjointEquations +impl<'a, Eqn, Method> Op for AdjointEquations<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { type T = Eqn::T; type V = Eqn::V; @@ -510,45 +550,47 @@ where } } -impl<'a, Eqn, Method> OdeEquationsRef<'a> for AdjointEquations +impl<'a, 'b, Eqn, Method> OdeEquationsRef<'a> for AdjointEquations<'b, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'b, Eqn>, { - type Rhs = &'a AdjointRhs; - type Mass = &'a AdjointMass; + type Rhs = &'a AdjointRhs<'b, Eqn, Method>; + type Mass = &'a AdjointMass<'b, Eqn>; type Root = >::Root; - type Init = &'a AdjointInit; - type Out = &'a AdjointOut; + type Init = &'a AdjointInit<'b, Eqn>; + type Out = &'a AdjointOut<'b, Eqn, Method>; } -impl OdeEquations for AdjointEquations +impl<'a, Eqn, Method> OdeEquations for AdjointEquations<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { - fn rhs(&self) -> &AdjointRhs { + fn rhs(&self) -> &AdjointRhs<'a, Eqn, Method> { &self.rhs } - fn mass(&self) -> Option<&AdjointMass> { + fn mass(&self) -> Option<&AdjointMass<'a, Eqn>> { self.mass.as_ref() } fn root(&self) -> Option<>::Root> { None } - fn init(&self) -> &AdjointInit { + fn init(&self) -> &AdjointInit<'a, Eqn> { &self.init } - fn out(&self) -> Option<&AdjointOut> { + fn out(&self) -> Option<&AdjointOut<'a, Eqn, Method>> { self.out.as_ref() } + fn set_params(&mut self, p: &Self::V) { + self.eqn.set_params(p); + } } -impl AugmentedOdeEquations> - for AdjointEquations +impl<'a, Eqn, Method> AugmentedOdeEquations for AdjointEquations<'a, Eqn, Method> where Eqn: OdeEquationsAdjoint, - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, { fn include_in_error_control(&self) -> bool { self.atol.is_some() && self.rtol.is_some() @@ -557,11 +599,11 @@ where self.out().is_some() && self.out_atol.is_some() && self.out_rtol.is_some() } - fn atol(&self) -> Option<&Rc> { - self.atol.as_ref() + fn atol(&self) -> Option<&Eqn::V> { + self.atol } - fn out_atol(&self) -> Option<&Rc> { - self.out_atol.as_ref() + fn out_atol(&self) -> Option<&Eqn::V> { + self.out_atol } fn out_rtol(&self) -> Option { self.out_rtol @@ -581,6 +623,10 @@ where fn update_rhs_out_state(&mut self, _y: &Eqn::V, _dy: &Eqn::V, _t: Eqn::T) {} fn update_init_state(&mut self, _t: ::T) {} + + fn integrate_main_eqn(&self) -> bool { + false + } } #[cfg(test)] @@ -593,18 +639,17 @@ mod tests { test_models::exponential_decay::exponential_decay_problem_adjoint, }, AdjointContext, AugmentedOdeEquations, Checkpointing, FaerSparseLU, Matrix, MatrixCommon, - NalgebraLU, NonLinearOp, NonLinearOpJacobian, OdeSolverMethod, Sdirk, SdirkState, - SparseColMat, Tableau, Vector, + NonLinearOp, NonLinearOpJacobian, SdirkState, SparseColMat, Vector, }; type Mcpu = nalgebra::DMatrix; type Vcpu = nalgebra::DVector; + type LS = crate::NalgebraLU; #[test] fn test_rhs_exponential() { // dy/dt = -ay (p = [a]) // a = 0.1 let (problem, _soln) = exponential_decay_problem_adjoint::(); - let mut solver = Sdirk::::new(Tableau::esdirk34(), NalgebraLU::default()); let state = SdirkState { t: 0.0, y: Vcpu::from_vec(vec![1.0, 1.0]), @@ -617,7 +662,7 @@ mod tests { ds: Vec::new(), h: 0.0, }; - solver.set_problem(state.clone(), &problem).unwrap(); + let solver = problem.esdirk34_solver::(state.clone()).unwrap(); let checkpointer = Checkpointing::new(solver, 0, vec![state.clone(), state.clone()], None); let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer))); let adj_eqn = AdjointEquations::new(&problem, context.clone(), false); @@ -675,8 +720,6 @@ mod tests { // dy/dt = -ay (p = [a]) // a = 0.1 let (problem, _soln) = exponential_decay_problem_adjoint::>(); - let mut solver = - Sdirk::, _, _>::new(Tableau::esdirk34(), FaerSparseLU::default()); let state = SdirkState { t: 0.0, y: faer::Col::from_vec(vec![1.0, 1.0]), @@ -689,7 +732,9 @@ mod tests { ds: Vec::new(), h: 0.0, }; - solver.set_problem(state.clone(), &problem).unwrap(); + let solver = problem + .esdirk34_solver::>(state.clone()) + .unwrap(); let checkpointer = Checkpointing::new(solver, 0, vec![state.clone(), state.clone()], None); let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer))); let mut adj_eqn = AdjointEquations::new(&problem, context, true); diff --git a/src/ode_solver/bdf.rs b/src/ode_solver/bdf.rs index 28c5205f..1ca258b3 100644 --- a/src/ode_solver/bdf.rs +++ b/src/ode_solver/bdf.rs @@ -1,11 +1,10 @@ use nalgebra::ComplexField; use std::ops::AddAssign; -use std::rc::Rc; use crate::{ error::{DiffsolError, OdeSolverError}, - AdjointEquations, AugmentedOdeEquationsImplicit, NoAug, OdeEquationsAdjoint, OdeEquationsSens, - SensEquations, StateRef, StateRefMut, + AdjointEquations, AugmentedOdeEquationsImplicit, Convergence, DefaultDenseMatrix, LinearSolver, + NoAug, OdeEquationsAdjoint, OdeEquationsSens, SensEquations, StateRef, StateRefMut, }; use num_traits::{abs, One, Pow, Zero}; @@ -13,21 +12,14 @@ use serde::Serialize; use crate::ode_solver_error; use crate::{ - matrix::{default_solver::DefaultSolver, MatrixRef}, - nonlinear_solver::root::RootFinder, - op::bdf::BdfCallable, - scalar::scale, - vector::DefaultDenseMatrix, + matrix::MatrixRef, nonlinear_solver::root::RootFinder, op::bdf::BdfCallable, scalar::scale, AugmentedOdeEquations, BdfState, DenseMatrix, IndexType, JacobianUpdate, MatrixViewMut, - NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquationsImplicit, OdeSolverMethod, - OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Op, Scalar, Vector, VectorRef, - VectorView, VectorViewMut, + NonLinearOp, NonLinearSolver, OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, + OdeSolverState, OdeSolverStopReason, Op, Scalar, Vector, VectorRef, VectorView, VectorViewMut, }; -use super::jacobian_update::SolverState; -use super::method::{ - AdjointOdeSolverMethod, AugmentedOdeSolverMethod, SensitivitiesOdeSolverMethod, -}; +use super::method::{AdjointOdeSolverMethod, AugmentedOdeSolverMethod}; +use super::{jacobian_update::SolverState, method::SensitivitiesOdeSolverMethod}; #[derive(Clone, Debug, Serialize, Default)] pub struct BdfStatistics { @@ -38,19 +30,56 @@ pub struct BdfStatistics { pub number_of_nonlinear_solver_fails: usize, } -pub type BdfSens = Bdf>; -pub type BdfAdj = - Bdf>, Nls, AdjointEquations>>; -impl SensitivitiesOdeSolverMethod for BdfSens +impl<'a, M, Eqn, Nls, AugEqn> AugmentedOdeSolverMethod<'a, Eqn, AugEqn> + for Bdf<'a, Eqn, Nls, M, AugEqn> +where + Eqn: OdeEquationsImplicit, + AugEqn: AugmentedOdeEquationsImplicit, + M: DenseMatrix, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, + Nls: NonLinearSolver, + Eqn::V: DefaultDenseMatrix, +{ + fn into_state_and_eqn(self) -> (Self::State, Option) { + (self.state, self.s_op.map(|op| op.eqn)) + } +} + +impl<'a, M, Eqn, Nls> SensitivitiesOdeSolverMethod<'a, Eqn> + for Bdf<'a, Eqn, Nls, M, SensEquations<'a, Eqn>> where Eqn: OdeEquationsSens, M: DenseMatrix, for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, + Eqn::V: DefaultDenseMatrix, Nls: NonLinearSolver, { } +impl<'a, M, Eqn, Nls> AdjointOdeSolverMethod<'a, Eqn> for Bdf<'a, Eqn, Nls, M> +where + Eqn: OdeEquationsAdjoint, + M: DenseMatrix, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, + Eqn::V: DefaultDenseMatrix, + Nls: NonLinearSolver + 'a, +{ + type DefaultAdjointSolver = + Bdf<'a, Eqn, Nls, M, AdjointEquations<'a, Eqn, Bdf<'a, Eqn, Nls, M>>>; + fn default_adjoint_solver>( + self, + mut aug_eqn: AdjointEquations<'a, Eqn, Self>, + ) -> Result { + let problem = self.problem(); + let nonlinear_solver = self.nonlinear_solver; + let state = self.state.into_adjoint::(problem, &mut aug_eqn)?; + Bdf::new_augmented(state, problem, aug_eqn, nonlinear_solver) + } +} + // notes quadrature. // ndf formula rearranged to [2]: // (1 - kappa) * gamma_k * (y_{n+1} - y^0_{n+1}) + (\sum_{m=1}^k gamma_m * y^m_n) - h * F(t_{n+1}, y_{n+1}) = 0 (1) @@ -78,14 +107,18 @@ where /// \[2\] Shampine, L. F., & Reichelt, M. W. (1997). The matlab ode suite. SIAM journal on scientific computing, 18(1), 1-22. /// \[3\] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, T., Cournapeau, D., ... & Van Mulbregt, P. (2020). SciPy 1.0: fundamental algorithms for scientific computing in Python. Nature methods, 17(3), 261-272. pub struct Bdf< - M: DenseMatrix, + 'a, Eqn: OdeEquationsImplicit, Nls: NonLinearSolver, + M: DenseMatrix = <::V as DefaultDenseMatrix>::M, AugmentedEqn: AugmentedOdeEquationsImplicit = NoAug, -> { +> where + Eqn::V: DefaultDenseMatrix, +{ nonlinear_solver: Nls, - ode_problem: Option>, - op: Option>, + convergence: Convergence<'a, Eqn::V>, + ode_problem: &'a OdeSolverProblem, + op: Option>, n_equal_steps: usize, y_delta: Eqn::V, g_delta: Eqn::V, @@ -103,59 +136,72 @@ pub struct Bdf< gamma: Vec, error_const2: Vec, statistics: BdfStatistics, - state: Option>, + state: BdfState, tstop: Option, root_finder: Option>, is_state_modified: bool, jacobian_update: JacobianUpdate, } -impl Default - for Bdf< - ::M, - Eqn, - NewtonNonlinearSolver::LS>, - NoAug, - > +impl Clone for Bdf<'_, Eqn, Nls, M, AugmentedEqn> where Eqn: OdeEquationsImplicit, - Eqn::M: DefaultSolver, - Eqn::V: DefaultDenseMatrix, - for<'b> &'b Eqn::V: VectorRef, - for<'b> &'b Eqn::M: MatrixRef, -{ - fn default() -> Self { - let linear_solver = Eqn::M::default_solver(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - Self::new(nonlinear_solver) - } -} - -impl - Bdf< - ::M, - Eqn, - NewtonNonlinearSolver::LS>, - SensEquations, - > -where - Eqn: OdeEquationsSens, - Eqn::M: DefaultSolver, + Nls: NonLinearSolver, + M: DenseMatrix, + AugmentedEqn: AugmentedOdeEquationsImplicit, Eqn::V: DefaultDenseMatrix, - for<'b> &'b Eqn::V: VectorRef, - for<'b> &'b Eqn::M: MatrixRef, { - pub fn with_sensitivities() -> Self { - let linear_solver = Eqn::M::default_solver(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - Self::new(nonlinear_solver) + fn clone(&self) -> Self { + let problem = self.ode_problem; + let mut nonlinear_solver = Nls::default(); + let op = if let Some(op) = self.op.as_ref() { + let op = op.clone_state(&self.ode_problem.eqn); + nonlinear_solver.set_problem(&op); + nonlinear_solver.reset_jacobian(&op, &self.state.y, self.state.t); + Some(op) + } else { + None + }; + let s_op = self.s_op.as_ref().map(|op| { + let op = op.clone_state(op.eqn().clone()); + op + }); + Self { + nonlinear_solver, + ode_problem: problem, + convergence: self.convergence.clone(), + op, + s_op, + n_equal_steps: self.n_equal_steps, + y_delta: self.y_delta.clone(), + g_delta: self.g_delta.clone(), + y_predict: self.y_predict.clone(), + t_predict: self.t_predict, + s_predict: self.s_predict.clone(), + s_deltas: self.s_deltas.clone(), + sg_deltas: self.sg_deltas.clone(), + diff_tmp: self.diff_tmp.clone(), + gdiff_tmp: self.gdiff_tmp.clone(), + sgdiff_tmp: self.sgdiff_tmp.clone(), + u: self.u.clone(), + alpha: self.alpha.clone(), + gamma: self.gamma.clone(), + error_const2: self.error_const2.clone(), + statistics: self.statistics.clone(), + state: self.state.clone(), + tstop: self.tstop, + root_finder: self.root_finder.clone(), + is_state_modified: self.is_state_modified, + jacobian_update: self.jacobian_update.clone(), + } } } -impl Bdf +impl<'a, M, Eqn, Nls, AugmentedEqn> Bdf<'a, Eqn, Nls, M, AugmentedEqn> where AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit, Eqn: OdeEquationsImplicit, + Eqn::V: DefaultDenseMatrix, M: DenseMatrix, for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, @@ -168,9 +214,20 @@ where const MIN_THRESHOLD: f64 = 0.9; const MIN_TIMESTEP: f64 = 1e-32; - pub fn new(nonlinear_solver: Nls) -> Self { - let n = 1; + pub fn new( + problem: &'a OdeSolverProblem, + state: BdfState, + nonlinear_solver: Nls, + ) -> Result { + Self::_new(problem, state, nonlinear_solver, true) + } + fn _new( + problem: &'a OdeSolverProblem, + mut state: BdfState, + mut nonlinear_solver: Nls, + integrate_main_eqn: bool, + ) -> Result { // kappa values for difference orders, taken from Table 1 of [1] let kappa = [ Eqn::T::from(0.0), @@ -196,33 +253,120 @@ where error_const2.push((kappa[i] * gamma[i] + one_over_i_plus_one).powi(2)); } - Self { + state.check_consistent_with_problem(problem)?; + + let mut convergence = Convergence::new(problem.rtol, &problem.atol); + convergence.set_max_iter(Self::NEWTON_MAXITER); + + let op = if integrate_main_eqn { + // setup linear solver for first step + let bdf_callable = BdfCallable::new(&problem.eqn); + bdf_callable.set_c(state.h, alpha[state.order]); + nonlinear_solver.set_problem(&bdf_callable); + nonlinear_solver.reset_jacobian(&bdf_callable, &state.y, state.t); + Some(bdf_callable) + } else { + None + }; + + state.set_problem(problem)?; + + // setup root solver + let mut root_finder = None; + if let Some(root_fn) = problem.eqn.root() { + root_finder = Some(RootFinder::new(root_fn.nout())); + root_finder + .as_ref() + .unwrap() + .init(&root_fn, &state.y, state.t); + } + + // (re)allocate internal state + let nstates = problem.eqn.rhs().nstates(); + let diff_tmp = M::zeros(nstates, BdfState::::MAX_ORDER + 3); + let y_delta = ::zeros(nstates); + let y_predict = ::zeros(nstates); + + let nout = if let Some(out) = problem.eqn.out() { + out.nout() + } else { + 0 + }; + let g_delta = ::zeros(nout); + let gdiff_tmp = M::zeros(nout, BdfState::::MAX_ORDER + 3); + + // init U matrix + let u = Self::_compute_r(state.order, Eqn::T::one()); + let is_state_modified = false; + + Ok(Self { + convergence, s_op: None, - op: None, - ode_problem: None, + op, + ode_problem: problem, nonlinear_solver, n_equal_steps: 0, - diff_tmp: M::zeros(n, max_order + 3), - gdiff_tmp: M::zeros(n, max_order + 3), - sgdiff_tmp: M::zeros(n, max_order + 3), - y_delta: Eqn::V::zeros(n), - y_predict: Eqn::V::zeros(n), + diff_tmp, + gdiff_tmp, + sgdiff_tmp: M::zeros(0, 0), + y_delta, + y_predict, t_predict: Eqn::T::zero(), - s_predict: Eqn::V::zeros(n), + s_predict: Eqn::V::zeros(0), s_deltas: Vec::new(), sg_deltas: Vec::new(), - g_delta: Eqn::V::zeros(n), + g_delta, gamma, alpha, error_const2, - u: M::zeros(max_order + 1, max_order + 1), + u, statistics: BdfStatistics::default(), - state: None, + state, tstop: None, - root_finder: None, - is_state_modified: false, + root_finder, + is_state_modified, jacobian_update: JacobianUpdate::default(), + }) + } + + pub fn new_augmented( + state: BdfState, + problem: &'a OdeSolverProblem, + augmented_eqn: AugmentedEqn, + nonlinear_solver: Nls, + ) -> Result { + state.check_sens_consistent_with_problem(problem, &augmented_eqn)?; + + let mut ret = Self::_new( + problem, + state, + nonlinear_solver, + augmented_eqn.integrate_main_eqn(), + )?; + + ret.state.set_augmented_problem(problem, &augmented_eqn)?; + + // allocate internal state for sensitivities + let naug = augmented_eqn.max_index(); + let nstates = problem.eqn.rhs().nstates(); + + ret.s_op = if augmented_eqn.integrate_main_eqn() { + Some(BdfCallable::new_no_jacobian(augmented_eqn)) + } else { + let bdf_callable = BdfCallable::new(augmented_eqn); + ret.nonlinear_solver.set_problem(&bdf_callable); + ret.nonlinear_solver + .reset_jacobian(&bdf_callable, &ret.state.s[0], ret.state.t); + Some(bdf_callable) + }; + + ret.s_deltas = vec![::zeros(nstates); naug]; + ret.s_predict = ::zeros(nstates); + if let Some(out) = ret.s_op.as_ref().unwrap().eqn().out() { + ret.sg_deltas = vec![::zeros(out.nout()); naug]; + ret.sgdiff_tmp = M::zeros(out.nout(), BdfState::::MAX_ORDER + 3); } + Ok(ret) } pub fn get_statistics(&self) -> &BdfStatistics { @@ -256,19 +400,26 @@ where } fn _jacobian_updates(&mut self, c: Eqn::T, state: SolverState) { - let y = &self.state.as_ref().unwrap().y; - let t = self.state.as_ref().unwrap().t; - //let y = &self.y_predict; - //let t = self.t_predict; if self.jacobian_update.check_rhs_jacobian_update(c, &state) { - self.op.as_mut().unwrap().set_jacobian_is_stale(); - self.nonlinear_solver - .reset_jacobian(self.op.as_ref().unwrap(), y, t); + if let Some(op) = self.op.as_mut() { + op.set_jacobian_is_stale(); + self.nonlinear_solver + .reset_jacobian(op, &self.state.y, self.state.t); + } else if let Some(s_op) = self.s_op.as_mut() { + s_op.set_jacobian_is_stale(); + self.nonlinear_solver + .reset_jacobian(s_op, &self.state.s[0], self.state.t); + } self.jacobian_update.update_rhs_jacobian(); self.jacobian_update.update_jacobian(c); } else if self.jacobian_update.check_jacobian_update(c, &state) { - self.nonlinear_solver - .reset_jacobian(self.op.as_ref().unwrap(), y, t); + if let Some(op) = self.op.as_mut() { + self.nonlinear_solver + .reset_jacobian(op, &self.state.y, self.state.t); + } else if let Some(s_op) = self.s_op.as_mut() { + self.nonlinear_solver + .reset_jacobian(s_op, &self.state.s[0], self.state.t); + } self.jacobian_update.update_jacobian(c); } } @@ -280,36 +431,52 @@ where //- constant c = h / (1-kappa) gamma_k term //- lu factorisation of (M - c * J) used in newton iteration (same equation) - let new_h = factor * self.state.as_ref().unwrap().h; + let new_h = factor * self.state.h; self.n_equal_steps = 0; // update D using equations in section 3.2 of [1] - let order = self.state.as_ref().unwrap().order; + let order = self.state.order; let r = Self::_compute_r(order, factor); let ru = r.mat_mul(&self.u); { - let state = self.state.as_mut().unwrap(); - Self::_update_diff_for_step_size(&ru, &mut state.diff, &mut self.diff_tmp, order); - for diff in state.sdiff.iter_mut() { - Self::_update_diff_for_step_size(&ru, diff, &mut self.diff_tmp, order); + if self.op.is_some() { + Self::_update_diff_for_step_size( + &ru, + &mut self.state.diff, + &mut self.diff_tmp, + order, + ); + if self.ode_problem.integrate_out { + Self::_update_diff_for_step_size( + &ru, + &mut self.state.gdiff, + &mut self.gdiff_tmp, + order, + ); + } } - if self.ode_problem.as_ref().unwrap().integrate_out { - Self::_update_diff_for_step_size(&ru, &mut state.gdiff, &mut self.gdiff_tmp, order); + for diff in self.state.sdiff.iter_mut() { + Self::_update_diff_for_step_size(&ru, diff, &mut self.diff_tmp, order); } - for diff in state.sgdiff.iter_mut() { + + for diff in self.state.sgdiff.iter_mut() { Self::_update_diff_for_step_size(&ru, diff, &mut self.sgdiff_tmp, order); } } - self.op.as_mut().unwrap().set_c(new_h, self.alpha[order]); + if let Some(op) = self.op.as_mut() { + op.set_c(new_h, self.alpha[order]); + } + if let Some(s_op) = self.s_op.as_mut() { + s_op.set_c(new_h, self.alpha[order]); + } - self.state.as_mut().unwrap().h = new_h; + self.state.h = new_h; // if step size too small, then fail - let state = self.state.as_ref().unwrap(); - if state.h.abs() < Eqn::T::from(Self::MIN_TIMESTEP) { + if self.state.h.abs() < Eqn::T::from(Self::MIN_TIMESTEP) { return Err(DiffsolError::from(OdeSolverError::StepSizeTooSmall { - time: state.t.into(), + time: self.state.t.into(), })); } Ok(new_h) @@ -328,8 +495,8 @@ where fn calculate_output_delta(&mut self) { // integrate output function - let state = self.state.as_mut().unwrap(); - let out = self.ode_problem.as_ref().unwrap().eqn.out().unwrap(); + let state = &mut self.state; + let out = self.ode_problem.eqn.out().unwrap(); out.call_inplace(&self.y_predict, self.t_predict, &mut state.dg); self.op.as_ref().unwrap().integrate_out( &state.dg, @@ -342,31 +509,43 @@ where } fn calculate_sens_output_delta(&mut self, i: usize) { - let state = self.state.as_mut().unwrap(); - let op = self.s_op.as_ref().unwrap(); + let state = &mut self.state; + let s_op = self.s_op.as_ref().unwrap(); // integrate sensitivity output equations - let out = op.eqn().out().unwrap(); + let out = s_op.eqn().out().unwrap(); out.call_inplace(&state.s[i], self.t_predict, &mut state.dsg[i]); - self.op.as_ref().unwrap().integrate_out( - &state.dsg[i], - &state.sgdiff[i], - self.gamma.as_slice(), - self.alpha.as_slice(), - state.order, - &mut self.sg_deltas[i], - ); + + if let Some(op) = self.s_op.as_ref() { + op.integrate_out( + &state.dsg[i], + &state.sgdiff[i], + self.gamma.as_slice(), + self.alpha.as_slice(), + state.order, + &mut self.sg_deltas[i], + ); + } else if let Some(s_op) = self.s_op.as_ref() { + s_op.integrate_out( + &state.dsg[i], + &state.sgdiff[i], + self.gamma.as_slice(), + self.alpha.as_slice(), + state.order, + &mut self.sg_deltas[i], + ); + } } fn update_differences_and_integrate_out(&mut self) { - let order = self.state.as_ref().unwrap().order; - let state = self.state.as_mut().unwrap(); + let order = self.state.order; + let state = &mut self.state; // update differences Self::_update_diff(order, &self.y_delta, &mut state.diff); // integrate output function - if self.ode_problem.as_ref().unwrap().integrate_out { + if self.ode_problem.integrate_out { Self::_predict_using_diff(&mut state.g, &state.gdiff, order); state.g.axpy(Eqn::T::one(), &self.g_delta, Eqn::T::one()); @@ -421,17 +600,19 @@ where } fn _predict_forward(&mut self) { - let state = self.state.as_ref().unwrap(); + let state = &self.state; Self::_predict_using_diff(&mut self.y_predict, &state.diff, state.order); // update psi and c (h, D, y0 has changed) - self.op.as_mut().unwrap().set_psi_and_y0( - &state.diff, - self.gamma.as_slice(), - self.alpha.as_slice(), - state.order, - &self.y_predict, - ); + if let Some(op) = self.op.as_mut() { + op.set_psi_and_y0( + &state.diff, + self.gamma.as_slice(), + self.alpha.as_slice(), + state.order, + &self.y_predict, + ); + } // update time let t_new = state.t + state.h; @@ -443,7 +624,7 @@ where tstop: Eqn::T, ) -> Result>, DiffsolError> { // check if the we are at tstop - let state = self.state.as_ref().unwrap(); + let state = &self.state; let troundoff = Eqn::T::from(100.0) * Eqn::T::EPSILON * (abs(state.t) + abs(state.h)); if abs(state.t - tstop) <= troundoff { self.tstop = None; @@ -473,27 +654,15 @@ where fn initialise_to_first_order(&mut self) { self.n_equal_steps = 0; - self.state - .as_mut() - .unwrap() - .initialise_diff_to_first_order(); + self.state.initialise_diff_to_first_order(); - if self.ode_problem.as_ref().unwrap().integrate_out { - self.state - .as_mut() - .unwrap() - .initialise_gdiff_to_first_order(); + if self.ode_problem.integrate_out { + self.state.initialise_gdiff_to_first_order(); } if self.s_op.is_some() { - self.state - .as_mut() - .unwrap() - .initialise_sdiff_to_first_order(); + self.state.initialise_sdiff_to_first_order(); if self.s_op.as_ref().unwrap().eqn().out().is_some() { - self.state - .as_mut() - .unwrap() - .initialise_sgdiff_to_first_order(); + self.state.initialise_sgdiff_to_first_order(); } } @@ -515,9 +684,9 @@ where } fn error_control(&self) -> Eqn::T { - let state = self.state.as_ref().unwrap(); + let state = &self.state; let order = state.order; - let output_in_error_control = self.ode_problem.as_ref().unwrap().output_in_error_control(); + let output_in_error_control = self.ode_problem.output_in_error_control(); let integrate_sens = self.s_op.is_some(); let sens_in_error_control = integrate_sens && self.s_op.as_ref().unwrap().eqn().include_in_error_control(); @@ -531,23 +700,21 @@ where .eqn() .include_out_in_error_control(); - let atol = self.ode_problem.as_ref().unwrap().atol.as_ref(); - let rtol = self.ode_problem.as_ref().unwrap().rtol; - let mut error_norm = - self.y_delta.squared_norm(&state.y, atol, rtol) * self.error_const2[order - 1]; - let mut ncontrib = 1; - if output_in_error_control { - let rtol = self.ode_problem.as_ref().unwrap().out_rtol.unwrap(); - let atol = self - .ode_problem - .as_ref() - .unwrap() - .out_atol - .as_ref() - .unwrap(); + let mut error_norm = M::T::zero(); + let mut ncontrib = 0; + if self.op.is_some() { + let atol = &self.ode_problem.atol; + let rtol = self.ode_problem.rtol; error_norm += - self.g_delta.squared_norm(&state.g, atol, rtol) * self.error_const2[order]; + self.y_delta.squared_norm(&state.y, atol, rtol) * self.error_const2[order - 1]; ncontrib += 1; + if output_in_error_control { + let rtol = self.ode_problem.out_rtol.unwrap(); + let atol = self.ode_problem.out_atol.as_ref().unwrap(); + error_norm += + self.g_delta.squared_norm(&state.g, atol, rtol) * self.error_const2[order]; + ncontrib += 1; + } } if sens_in_error_control { let sens_atol = self.s_op.as_ref().unwrap().eqn().atol().unwrap(); @@ -567,12 +734,15 @@ where } ncontrib += state.sgdiff.len(); } - error_norm / Eqn::T::from(ncontrib as f64) + if ncontrib > 1 { + error_norm /= Eqn::T::from(ncontrib as f64) + } + error_norm } fn predict_error_control(&self, order: usize) -> Eqn::T { - let state = self.state.as_ref().unwrap(); - let output_in_error_control = self.ode_problem.as_ref().unwrap().output_in_error_control(); + let state = &self.state; + let output_in_error_control = self.ode_problem.output_in_error_control(); let integrate_sens = self.s_op.is_some(); let sens_in_error_control = integrate_sens && self.s_op.as_ref().unwrap().eqn().include_in_error_control(); @@ -586,29 +756,27 @@ where .eqn() .include_out_in_error_control(); - let atol = self.ode_problem.as_ref().unwrap().atol.as_ref(); - let rtol = self.ode_problem.as_ref().unwrap().rtol; - let mut error_norm = state - .diff - .column(order + 1) - .squared_norm(&state.y, atol, rtol) - * self.error_const2[order]; - let mut ncontrib = 1; - if output_in_error_control { - let rtol = self.ode_problem.as_ref().unwrap().out_rtol.unwrap(); - let atol = self - .ode_problem - .as_ref() - .unwrap() - .out_atol - .as_ref() - .unwrap(); + let atol = &self.ode_problem.atol; + let rtol = self.ode_problem.rtol; + let mut error_norm = M::T::zero(); + let mut ncontrib = 0; + if self.op.is_some() { error_norm += state - .gdiff + .diff .column(order + 1) - .squared_norm(&state.g, atol, rtol) + .squared_norm(&state.y, atol, rtol) * self.error_const2[order]; ncontrib += 1; + if output_in_error_control { + let rtol = self.ode_problem.out_rtol.unwrap(); + let atol = self.ode_problem.out_atol.as_ref().unwrap(); + error_norm += state + .gdiff + .column(order + 1) + .squared_norm(&state.g, atol, rtol) + * self.error_const2[order]; + ncontrib += 1; + } } if sens_in_error_control { let sens_atol = self.s_op.as_ref().unwrap().eqn().atol().unwrap(); @@ -620,6 +788,7 @@ where sens_rtol, ) * self.error_const2[order]; } + ncontrib += state.sdiff.len(); } if sens_output_in_error_control { let rtol = self.s_op.as_ref().unwrap().eqn().out_rtol().unwrap(); @@ -631,62 +800,66 @@ where .squared_norm(&state.sg[i], atol, rtol) * self.error_const2[order]; } + ncontrib += state.sgdiff.len(); + } + if ncontrib == 0 { + error_norm + } else { + error_norm / Eqn::T::from(ncontrib as f64) } - error_norm / Eqn::T::from(ncontrib as f64) } fn sensitivity_solve(&mut self, t_new: Eqn::T) -> Result<(), DiffsolError> { - let h = self.state.as_ref().unwrap().h; - let order = self.state.as_ref().unwrap().order; - let op = self.s_op.as_mut().unwrap(); + let order = self.state.order; // update for new state - { - let dy_new = self.op.as_ref().unwrap().tmp(); + if let Some(op) = self.op.as_ref() { + let s_op = self.s_op.as_mut().unwrap(); + let dy_new = op.tmp(); let y_new = &self.y_predict; - Rc::get_mut(op.eqn_mut()) - .unwrap() - .update_rhs_out_state(y_new, &dy_new, t_new); - - // construct bdf discretisation of sensitivity equations - op.set_c(h, self.alpha[order]); + s_op.eqn_mut().update_rhs_out_state(y_new, &dy_new, t_new); } // solve for sensitivities equations discretised using BDF - let naug = op.eqn().max_index(); + let naug = self.s_op.as_mut().unwrap().eqn().max_index(); for i in 0..naug { - let op = self.s_op.as_mut().unwrap(); // setup + let s_op = self.s_op.as_mut().unwrap(); { - let state = self.state.as_ref().unwrap(); + let state = &self.state; // predict forward to new step Self::_predict_using_diff(&mut self.s_predict, &state.sdiff[i], order); // setup op - op.set_psi_and_y0( + s_op.set_psi_and_y0( &state.sdiff[i], self.gamma.as_slice(), self.alpha.as_slice(), order, &self.s_predict, ); - Rc::get_mut(op.eqn_mut()).unwrap().set_index(i); + s_op.eqn_mut().set_index(i); } // solve { - let s_new = &mut self.state.as_mut().unwrap().s[i]; + let s_new = &mut self.state.s[i]; s_new.copy_from(&self.s_predict); - self.nonlinear_solver - .solve_in_place(&*op, s_new, t_new, &self.s_predict)?; - self.statistics.number_of_nonlinear_solver_iterations += - self.nonlinear_solver.convergence().niter(); + // todo: should be a separate convergence object? + self.nonlinear_solver.solve_in_place( + &*s_op, + s_new, + t_new, + &self.s_predict, + &mut self.convergence, + )?; + self.statistics.number_of_nonlinear_solver_iterations += self.convergence.niter(); let s_new = &*s_new; self.s_deltas[i].copy_from(s_new); self.s_deltas[i] -= &self.s_predict; } - if op.eqn().out().is_some() && op.eqn().include_out_in_error_control() { + if s_op.eqn().out().is_some() && s_op.eqn().include_out_in_error_control() { self.calculate_sens_output_delta(i); } } @@ -694,11 +867,12 @@ where } } -impl OdeSolverMethod for Bdf +impl<'a, M, Eqn, Nls, AugmentedEqn> OdeSolverMethod<'a, Eqn> for Bdf<'a, Eqn, Nls, M, AugmentedEqn> where Eqn: OdeEquationsImplicit, AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit, M: DenseMatrix, + Eqn::V: DefaultDenseMatrix, Nls: NonLinearSolver, for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, @@ -706,12 +880,32 @@ where type State = BdfState; fn order(&self) -> usize { - self.state.as_ref().map_or(1, |state| state.order) + self.state.order + } + + fn set_state(&mut self, state: Self::State) { + let old_order = self.state.order; + self.state = state; + + if let Some(op) = self.op.as_mut() { + op.set_c(self.state.h, self.alpha[self.state.order]); + } + + // order might have changed + if self.state.order != old_order { + self.u = Self::_compute_r(self.state.order, Eqn::T::one()); + } + + // reinitialise jacobian updates as if a checkpoint was taken + self._jacobian_updates( + self.state.h * self.alpha[self.state.order], + SolverState::Checkpoint, + ); } fn interpolate(&self, t: Eqn::T) -> Result { // state must be set - let state = self.state.as_ref().ok_or(ode_solver_error!(StateNotSet))?; + let state = &self.state; if self.is_state_modified { if t == state.t { return Ok(state.y.clone()); @@ -735,7 +929,7 @@ where fn interpolate_out(&self, t: Eqn::T) -> Result { // state must be set - let state = self.state.as_ref().ok_or(ode_solver_error!(StateNotSet))?; + let state = &self.state; if self.is_state_modified { if t == state.t { return Ok(state.g.clone()); @@ -759,7 +953,7 @@ where fn interpolate_sens(&self, t: ::T) -> Result, DiffsolError> { // state must be set - let state = self.state.as_ref().ok_or(ode_solver_error!(StateNotSet))?; + let state = &self.state; if self.is_state_modified { if t == state.t { return Ok(state.s.clone()); @@ -786,105 +980,35 @@ where Ok(s) } - fn problem(&self) -> Option<&OdeSolverProblem> { - self.ode_problem.as_ref() + fn problem(&self) -> &'a OdeSolverProblem { + self.ode_problem } - fn state(&self) -> Option> { - self.state.as_ref().map(|state| state.as_ref()) + fn state(&self) -> StateRef { + self.state.as_ref() } - fn take_state(&mut self) -> Option> { - self.ode_problem = None; - self.op = None; - self.s_op = None; - Option::take(&mut self.state) + + fn into_state(self) -> BdfState { + self.state } - fn state_mut(&mut self) -> Option> { + fn state_mut(&mut self) -> StateRefMut { self.is_state_modified = true; - self.state.as_mut().map(|state| state.as_mut()) + self.state.as_mut() } - fn checkpoint(&mut self) -> Result { - if self.state.is_none() { - return Err(ode_solver_error!(StateNotSet)); - } + fn checkpoint(&mut self) -> Self::State { self._jacobian_updates( - self.state.as_ref().unwrap().h * self.alpha[self.state.as_ref().unwrap().order], + self.state.h * self.alpha[self.state.order], SolverState::Checkpoint, ); - - Ok(self.state.as_ref().unwrap().clone()) - } - - fn set_problem( - &mut self, - mut state: BdfState, - problem: &OdeSolverProblem, - ) -> Result<(), DiffsolError> { - self.ode_problem = Some(problem.clone()); - - state.check_consistent_with_problem(problem)?; - - // setup linear solver for first step - let bdf_callable = BdfCallable::new(problem); - bdf_callable.set_c(state.h, self.alpha[state.order]); - - self.nonlinear_solver - .set_problem(&bdf_callable, problem.rtol, problem.atol.clone()); - self.nonlinear_solver - .convergence_mut() - .set_max_iter(Self::NEWTON_MAXITER); - self.nonlinear_solver - .reset_jacobian(&bdf_callable, &state.y, state.t); - self.op = Some(bdf_callable); - - // setup root solver - if let Some(root_fn) = problem.eqn.root() { - self.root_finder = Some(RootFinder::new(root_fn.nout())); - self.root_finder - .as_ref() - .unwrap() - .init(&root_fn, &state.y, state.t); - } - - // (re)allocate internal state - let nstates = problem.eqn.rhs().nstates(); - if self.diff_tmp.nrows() != nstates { - self.diff_tmp = M::zeros(nstates, BdfState::::MAX_ORDER + 3); - self.y_delta = ::zeros(nstates); - self.y_predict = ::zeros(nstates); - } - - let nout = if let Some(out) = problem.eqn.out() { - out.nout() - } else { - 0 - }; - if self.g_delta.len() != nout { - self.g_delta = ::zeros(nout); - } - if self.gdiff_tmp.nrows() != nout { - self.gdiff_tmp = M::zeros(nout, BdfState::::MAX_ORDER + 3); - } - - // init U matrix - self.u = Self::_compute_r(state.order, Eqn::T::one()); - self.is_state_modified = false; - - // initialise state and store it - state.set_problem(problem)?; - self.state = Some(state); - Ok(()) + self.state.clone() } fn step(&mut self) -> Result, DiffsolError> { let mut safety: Eqn::T; let mut error_norm: Eqn::T; - if self.state.is_none() { - return Err(ode_solver_error!(StateNotSet)); - } - let problem = self.ode_problem.as_ref().unwrap(); + let problem = self.ode_problem; let integrate_out = problem.integrate_out; let output_in_error_control = problem.output_in_error_control(); let integrate_sens = self.s_op.is_some(); @@ -894,7 +1018,7 @@ where if self.is_state_modified { // reinitalise root finder if needed if let Some(root_fn) = problem.eqn.root() { - let state = self.state.as_ref().unwrap(); + let state = &self.state; self.root_finder .as_ref() .unwrap() @@ -913,37 +1037,42 @@ where // loop until step is accepted loop { - let order = self.state.as_ref().unwrap().order; + let order = self.state.order; self.y_delta.copy_from(&self.y_predict); // solve BDF equation using y0 as starting point - let mut solve_result = self.nonlinear_solver.solve_in_place( - self.op.as_ref().unwrap(), - &mut self.y_delta, - self.t_predict, - &self.y_predict, - ); - // update statistics - self.statistics.number_of_nonlinear_solver_iterations += - self.nonlinear_solver.convergence().niter(); - - // only calculate norm and sensitivities if solve was successful - if solve_result.is_ok() { - // test error is within tolerance - // combine eq 3, 4 and 6 from [1] to obtain error - // Note that error = C_k * h^{k+1} y^{k+1} - // and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} - self.y_delta -= &self.y_predict; - - // deal with output equations - if integrate_out && output_in_error_control { - self.calculate_output_delta(); + let mut solve_result = Ok(()); + if let Some(op) = self.op.as_ref() { + solve_result = self.nonlinear_solver.solve_in_place( + op, + &mut self.y_delta, + self.t_predict, + &self.y_predict, + &mut self.convergence, + ); + // update statistics + self.statistics.number_of_nonlinear_solver_iterations += self.convergence.niter(); + + if solve_result.is_ok() { + // test error is within tolerance + // combine eq 3, 4 and 6 from [1] to obtain error + // Note that error = C_k * h^{k+1} y^{k+1} + // and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} + self.y_delta -= &self.y_predict; + + // deal with output equations + if integrate_out && output_in_error_control { + self.calculate_output_delta(); + } } + } - // sensitivities - if integrate_sens && self.sensitivity_solve(self.t_predict).is_err() { - solve_result = Err(ode_solver_error!(SensitivitySolveFailed)); - } + // only calculate sensitivities if solve was successful + if solve_result.is_ok() + && integrate_sens + && self.sensitivity_solve(self.t_predict).is_err() + { + solve_result = Err(ode_solver_error!(SensitivitySolveFailed)); } // handle case where either nonlinear solve failed @@ -965,7 +1094,7 @@ where } else { // newton iteration did not converge, so update jacobian and try again self._jacobian_updates( - self.state.as_ref().unwrap().h * self.alpha[order], + self.state.h * self.alpha[order], SolverState::FirstConvergenceFail, ); convergence_fail = true; @@ -977,8 +1106,8 @@ where error_norm = self.error_control(); // need to caulate safety even if step is accepted - let maxiter = self.nonlinear_solver.convergence().max_iter() as f64; - let niter = self.nonlinear_solver.convergence().niter() as f64; + let maxiter = self.convergence.max_iter() as f64; + let niter = self.convergence.niter() as f64; safety = Eqn::T::from(0.9 * (2.0 * maxiter + 1.0) / (2.0 * maxiter + niter)); // do the error test @@ -1008,7 +1137,7 @@ where self.update_differences_and_integrate_out(); { - let state = self.state.as_mut().unwrap(); + let state = &mut self.state; state.y.copy_from(&self.y_predict); state.t = self.t_predict; state.dy.copy_from_view(&state.diff.column(1)); @@ -1016,8 +1145,11 @@ where } // update statistics - self.statistics.number_of_linear_solver_setups = - self.op.as_ref().unwrap().number_of_jac_evals(); + if let Some(op) = self.op.as_ref() { + self.statistics.number_of_linear_solver_setups = op.number_of_jac_evals(); + } else if let Some(s_op) = self.s_op.as_ref() { + self.statistics.number_of_linear_solver_setups = s_op.number_of_jac_evals(); + } self.statistics.number_of_steps += 1; self.jacobian_update.step(); @@ -1025,10 +1157,9 @@ where // (see page 83 of [2]) self.n_equal_steps += 1; - if self.n_equal_steps > self.state.as_ref().unwrap().order { + if self.n_equal_steps > self.state.order { let factors = { - let state = self.state.as_mut().unwrap(); - let order = state.order; + let order = self.state.order; // similar to the optimal step size factor we calculated above for the current // order k, we need to calculate the optimal step size factors for orders // k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n @@ -1064,14 +1195,14 @@ where // update order and update the U matrix let order = { - let old_order = self.state.as_ref().unwrap().order; + let old_order = self.state.order; let new_order = match max_index { 0 => old_order - 1, 1 => old_order, 2 => old_order + 1, _ => unreachable!(), }; - self.state.as_mut().unwrap().order = new_order; + self.state.order = new_order; if max_index != 1 { self.u = Self::_compute_r(new_order, Eqn::T::one()); } @@ -1096,12 +1227,12 @@ where } // check for root within accepted step - if let Some(root_fn) = self.problem().as_ref().unwrap().eqn.root() { + if let Some(root_fn) = self.ode_problem.eqn.root() { let ret = self.root_finder.as_ref().unwrap().check_root( &|t: ::T| self.interpolate(t), &root_fn, - &self.state.as_ref().unwrap().y, - self.state.as_ref().unwrap().t, + self.state.as_ref().y, + self.state.as_ref().t, ); if let Some(root) = ret { return Ok(OdeSolverStopReason::RootFound(root)); @@ -1123,7 +1254,7 @@ where if let Some(OdeSolverStopReason::TstopReached) = self.handle_tstop(tstop)? { let error = OdeSolverError::StopTimeBeforeCurrentTime { stop_time: tstop.into(), - state_time: self.state.as_ref().unwrap().t.into(), + state_time: self.state.t.into(), }; self.tstop = None; return Err(DiffsolError::from(error)); @@ -1132,72 +1263,6 @@ where } } -impl AugmentedOdeSolverMethod - for Bdf -where - Eqn: OdeEquationsImplicit, - AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit, - M: DenseMatrix, - Nls: NonLinearSolver, - for<'b> &'b Eqn::V: VectorRef, - for<'b> &'b Eqn::M: MatrixRef, -{ - fn set_augmented_problem( - &mut self, - state: BdfState, - problem: &OdeSolverProblem, - augmented_eqn: AugmentedEqn, - ) -> Result<(), DiffsolError> { - state.check_sens_consistent_with_problem(problem, &augmented_eqn)?; - - self.set_problem(state, problem)?; - - self.state - .as_mut() - .unwrap() - .set_augmented_problem(problem, &augmented_eqn)?; - - // allocate internal state for sensitivities - let naug = augmented_eqn.max_index(); - let nstates = problem.eqn.rhs().nstates(); - let augmented_eqn = Rc::new(augmented_eqn); - self.s_op = Some(BdfCallable::from_sensitivity_eqn(&augmented_eqn)); - - if self.s_deltas.len() != naug || self.s_deltas[0].len() != nstates { - self.s_deltas = vec![::zeros(nstates); naug]; - } - if self.s_predict.len() != nstates { - self.s_predict = ::zeros(nstates); - } - if let Some(out) = self.s_op.as_ref().unwrap().eqn().out() { - if self.sg_deltas.len() != naug || self.sg_deltas[0].len() != out.nout() { - self.sg_deltas = vec![::zeros(out.nout()); naug]; - } - if self.sgdiff_tmp.nrows() != out.nout() { - self.sgdiff_tmp = M::zeros(out.nout(), BdfState::::MAX_ORDER + 3); - } - } - Ok(()) - } -} - -impl AdjointOdeSolverMethod for Bdf -where - Eqn: OdeEquationsAdjoint, - AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit, - M: DenseMatrix, - Nls: NonLinearSolver, - for<'b> &'b Eqn::V: VectorRef, - for<'b> &'b Eqn::M: MatrixRef, -{ - type AdjointSolver = Bdf, Nls, AdjointEquations>; - - fn new_adjoint_solver(&self) -> Self::AdjointSolver { - let adjoint_nls = Nls::default(); - Self::AdjointSolver::new(adjoint_nls) - } -} - #[cfg(test)] mod test { use crate::{ @@ -1222,50 +1287,46 @@ mod test { robertson_ode_with_sens::robertson_ode_with_sens, }, tests::{ - test_checkpointing, test_interpolate, test_no_set_problem, test_ode_solver, - test_ode_solver_adjoint, test_ode_solver_no_sens, test_param_sweep, test_state_mut, - test_state_mut_on_problem, + test_checkpointing, test_interpolate, test_ode_solver, test_ode_solver_adjoint, + test_problem, test_state_mut, test_state_mut_on_problem, }, }, - Bdf, FaerSparseLU, NewtonNonlinearSolver, OdeEquations, Op, SparseColMat, + FaerLU, FaerSparseLU, OdeEquations, OdeSolverMethod, Op, SparseColMat, Vector, }; - use faer::Mat; use num_traits::abs; type M = nalgebra::DMatrix; - #[test] - fn bdf_no_set_problem() { - test_no_set_problem::(Bdf::default()) - } + type LS = crate::NalgebraLU; #[test] fn bdf_state_mut() { - test_state_mut::(Bdf::default()) + test_state_mut(test_problem::().bdf::().unwrap()); } + #[test] fn bdf_test_interpolate() { - test_interpolate::(Bdf::default()) + test_interpolate(test_problem::().bdf::().unwrap()); } #[test] fn bdf_test_state_mut_exponential_decay() { let (p, soln) = exponential_decay_problem::(false); - let s = Bdf::default(); - test_state_mut_on_problem(s, p, soln); + let s = p.bdf_solver::(p.bdf_state::().unwrap()).unwrap(); + test_state_mut_on_problem(s, soln); } #[test] fn bdf_test_nalgebra_negative_exponential_decay() { - let mut s = Bdf::default(); let (problem, soln) = negative_exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[test] fn bdf_test_nalgebra_exponential_decay() { - let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 11 number_of_steps: 47 @@ -1273,7 +1334,7 @@ mod test { number_of_nonlinear_solver_iterations: 82 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 84 number_of_jac_muls: 2 number_of_matrix_evals: 1 @@ -1283,25 +1344,26 @@ mod test { #[test] fn bdf_test_faer_sparse_exponential_decay() { - let linear_solver = FaerSparseLU::default(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = exponential_decay_problem::>(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[test] fn bdf_test_checkpointing() { let (problem, soln) = exponential_decay_problem::(false); - test_checkpointing(Bdf::default(), Bdf::default(), problem, soln); + let solver1 = problem.bdf::().unwrap(); + let solver2 = problem.bdf::().unwrap(); + test_checkpointing(soln, solver1, solver2); } #[test] fn bdf_test_faer_exponential_decay() { type M = faer::Mat; - let mut s = Bdf::default(); + type LS = FaerLU; let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 11 number_of_steps: 47 @@ -1309,7 +1371,7 @@ mod test { number_of_nonlinear_solver_iterations: 82 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 84 number_of_jac_muls: 2 number_of_matrix_evals: 1 @@ -1319,19 +1381,19 @@ mod test { #[test] fn bdf_test_nalgebra_exponential_decay_sens() { - let mut s = Bdf::with_sensitivities(); let (problem, soln) = exponential_decay_problem_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false, true); + let mut s = problem.bdf_sens::().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - number_of_linear_solver_setups: 11 - number_of_steps: 44 - number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 217 + number_of_linear_solver_setups: 13 + number_of_steps: 48 + number_of_error_test_failures: 1 + number_of_nonlinear_solver_iterations: 234 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - number_of_calls: 87 - number_of_jac_muls: 136 + insta::assert_yaml_snapshot!(problem.eqn.statistics(), @r###" + number_of_calls: 89 + number_of_jac_muls: 151 number_of_matrix_evals: 1 number_of_jac_adj_muls: 0 "###); @@ -1339,49 +1401,35 @@ mod test { #[test] fn bdf_test_nalgebra_exponential_decay_adjoint() { - let s = Bdf::default(); let (problem, soln) = exponential_decay_problem_adjoint::(); - let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); + let s = problem.bdf::().unwrap(); + test_ode_solver_adjoint::(s, soln); insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 84 number_of_jac_muls: 6 number_of_matrix_evals: 3 - number_of_jac_adj_muls: 492 - "###); - insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" - number_of_linear_solver_setups: 24 - number_of_steps: 86 - number_of_error_test_failures: 12 - number_of_nonlinear_solver_iterations: 486 - number_of_nonlinear_solver_fails: 0 + number_of_jac_adj_muls: 392 "###); } #[test] fn bdf_test_nalgebra_exponential_decay_algebraic_adjoint() { - let s = Bdf::default(); let (problem, soln) = exponential_decay_with_algebraic_adjoint_problem::(); - let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); + let s = problem.bdf::().unwrap(); + test_ode_solver_adjoint::(s, soln); insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 190 number_of_jac_muls: 24 number_of_matrix_evals: 8 - number_of_jac_adj_muls: 278 - "###); - insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" - number_of_linear_solver_setups: 32 - number_of_steps: 74 - number_of_error_test_failures: 15 - number_of_nonlinear_solver_iterations: 266 - number_of_nonlinear_solver_fails: 0 + number_of_jac_adj_muls: 187 "###); } #[test] fn test_bdf_nalgebra_exponential_decay_algebraic() { - let mut s = Bdf::default(); let (problem, soln) = exponential_decay_with_algebraic_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 20 number_of_steps: 41 @@ -1389,7 +1437,7 @@ mod test { number_of_nonlinear_solver_iterations: 79 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 83 number_of_jac_muls: 6 number_of_matrix_evals: 2 @@ -1399,18 +1447,16 @@ mod test { #[test] fn bdf_test_faer_sparse_exponential_decay_algebraic() { - let linear_solver = FaerSparseLU::default(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = exponential_decay_with_algebraic_problem::>(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[test] fn test_bdf_nalgebra_exponential_decay_algebraic_sens() { - let mut s = Bdf::with_sensitivities(); let (problem, soln) = exponential_decay_with_algebraic_problem_sens::(); - test_ode_solver(&mut s, &problem, soln, None, false, true); + let mut s = problem.bdf_sens::().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 18 number_of_steps: 43 @@ -1418,7 +1464,7 @@ mod test { number_of_nonlinear_solver_iterations: 155 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 71 number_of_jac_muls: 100 number_of_matrix_evals: 3 @@ -1428,9 +1474,9 @@ mod test { #[test] fn test_bdf_nalgebra_robertson() { - let mut s = Bdf::default(); let (problem, soln) = robertson::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 77 number_of_steps: 316 @@ -1438,7 +1484,7 @@ mod test { number_of_nonlinear_solver_iterations: 722 number_of_nonlinear_solver_fails: 19 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 725 number_of_jac_muls: 60 number_of_matrix_evals: 20 @@ -1448,21 +1494,17 @@ mod test { #[test] fn bdf_test_faer_sparse_robertson() { - let linear_solver = FaerSparseLU::default(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = robertson::>(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[cfg(feature = "suitesparse")] #[test] fn bdf_test_faer_sparse_ku_robertson() { - let linear_solver = crate::KLU::default(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = robertson::>(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[cfg(feature = "diffsl-llvm")] @@ -1471,16 +1513,16 @@ mod test { use diffsl::LlvmModule; use crate::ode_solver::test_models::robertson; - let mut s = Bdf::default(); let (problem, soln) = robertson::robertson_diffsl_problem::(); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[test] fn test_bdf_nalgebra_robertson_sens() { - let mut s = Bdf::with_sensitivities(); let (problem, soln) = robertson_sens::(); - test_ode_solver(&mut s, &problem, soln, None, false, true); + let mut s = problem.bdf_sens::().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 160 number_of_steps: 410 @@ -1488,7 +1530,7 @@ mod test { number_of_nonlinear_solver_iterations: 3107 number_of_nonlinear_solver_fails: 81 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 996 number_of_jac_muls: 2495 number_of_matrix_evals: 71 @@ -1498,9 +1540,9 @@ mod test { #[test] fn test_bdf_nalgebra_robertson_colored() { - let mut s = Bdf::default(); let (problem, soln) = robertson::(true); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 77 number_of_steps: 316 @@ -1508,7 +1550,7 @@ mod test { number_of_nonlinear_solver_iterations: 722 number_of_nonlinear_solver_fails: 19 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 725 number_of_jac_muls: 63 number_of_matrix_evals: 20 @@ -1518,9 +1560,9 @@ mod test { #[test] fn test_bdf_nalgebra_robertson_ode() { - let mut s = Bdf::default(); let (problem, soln) = robertson_ode::(false, 3); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 86 number_of_steps: 416 @@ -1528,7 +1570,7 @@ mod test { number_of_nonlinear_solver_iterations: 911 number_of_nonlinear_solver_fails: 15 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 913 number_of_jac_muls: 162 number_of_matrix_evals: 18 @@ -1538,29 +1580,29 @@ mod test { #[test] fn test_bdf_nalgebra_robertson_ode_sens() { - let mut s = Bdf::with_sensitivities(); let (problem, soln) = robertson_ode_with_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false, true); + let mut s = problem.bdf_sens::().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - number_of_linear_solver_setups: 112 - number_of_steps: 467 - number_of_error_test_failures: 2 - number_of_nonlinear_solver_iterations: 3472 - number_of_nonlinear_solver_fails: 49 + number_of_linear_solver_setups: 152 + number_of_steps: 512 + number_of_error_test_failures: 5 + number_of_nonlinear_solver_iterations: 3779 + number_of_nonlinear_solver_fails: 70 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - number_of_calls: 1041 - number_of_jac_muls: 2672 - number_of_matrix_evals: 45 + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 1157 + number_of_jac_muls: 2930 + number_of_matrix_evals: 54 number_of_jac_adj_muls: 0 "###); } #[test] fn test_bdf_nalgebra_dydt_y2() { - let mut s = Bdf::default(); let (problem, soln) = dydt_y2_problem::(false, 10); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 27 number_of_steps: 161 @@ -1568,7 +1610,7 @@ mod test { number_of_nonlinear_solver_iterations: 355 number_of_nonlinear_solver_fails: 3 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 357 number_of_jac_muls: 50 number_of_matrix_evals: 5 @@ -1578,9 +1620,9 @@ mod test { #[test] fn test_bdf_nalgebra_dydt_y2_colored() { - let mut s = Bdf::default(); let (problem, soln) = dydt_y2_problem::(true, 10); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 27 number_of_steps: 161 @@ -1588,7 +1630,7 @@ mod test { number_of_nonlinear_solver_iterations: 355 number_of_nonlinear_solver_fails: 3 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 357 number_of_jac_muls: 15 number_of_matrix_evals: 5 @@ -1598,9 +1640,9 @@ mod test { #[test] fn test_bdf_nalgebra_gaussian_decay() { - let mut s = Bdf::default(); let (problem, soln) = gaussian_decay_problem::(false, 10); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 14 number_of_steps: 66 @@ -1608,7 +1650,7 @@ mod test { number_of_nonlinear_solver_iterations: 130 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 132 number_of_jac_muls: 20 number_of_matrix_evals: 2 @@ -1618,11 +1660,9 @@ mod test { #[test] fn test_bdf_faer_sparse_heat2d() { - let linear_solver = FaerSparseLU::default(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = head2d_problem::, 10>(); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 21 number_of_steps: 167 @@ -1630,7 +1670,7 @@ mod test { number_of_nonlinear_solver_iterations: 330 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 333 number_of_jac_muls: 128 number_of_matrix_evals: 4 @@ -1644,20 +1684,16 @@ mod test { use diffsl::LlvmModule; use crate::ode_solver::test_models::heat2d; - let linear_solver = FaerSparseLU::default(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = heat2d::heat2d_diffsl_problem::, LlvmModule, 10>(); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[test] fn test_bdf_faer_sparse_foodweb() { - let linear_solver = FaerSparseLU::default(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = foodweb_problem::, 10>(); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 45 number_of_steps: 161 @@ -1673,49 +1709,62 @@ mod test { use diffsl::LlvmModule; use crate::ode_solver::test_models::foodweb; - let linear_solver = FaerSparseLU::default(); - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = foodweb::foodweb_diffsl_problem::, LlvmModule, 10>(); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[test] fn test_tstop_bdf() { - let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, true); + let mut s = problem.bdf::().unwrap(); + test_ode_solver(&mut s, soln, None, true, false); } #[test] fn test_root_finder_bdf() { - let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem_with_root::(false); - let y = test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.bdf::().unwrap(); + let y = test_ode_solver(&mut s, soln, None, false, false); assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); } #[test] fn test_param_sweep_bdf() { - let s = Bdf::default(); - let (problem, _soln) = exponential_decay_problem::(false); + let (mut problem, _soln) = exponential_decay_problem::(false); let mut ps = Vec::new(); for y0 in (1..10).map(f64::from) { ps.push(nalgebra::DVector::::from_vec(vec![0.1, y0])); } - test_param_sweep(s, problem, ps); + + let mut old_soln: Option> = None; + for p in ps { + problem.eqn_mut().set_params(&p); + let mut s = problem.bdf::().unwrap(); + let (ys, _ts) = s.solve(10.0).unwrap(); + // check that the new solution is different from the old one + if let Some(old_soln) = &mut old_soln { + let new_soln = ys.column(ys.ncols() - 1).into_owned(); + let error = new_soln - &*old_soln; + let diff = error + .squared_norm(old_soln, &problem.atol, problem.rtol) + .sqrt(); + assert!(diff > 1.0e-6, "diff: {}", diff); + } + old_soln = Some(ys.column(ys.ncols() - 1).into_owned()); + } } #[cfg(feature = "diffsl")] #[test] fn test_ball_bounce_bdf() { + use crate::ode_solver::tests::test_ball_bounce_problem; type M = nalgebra::DMatrix; type LS = crate::NalgebraLU; - type Nls = crate::NewtonNonlinearSolver; - type Eqn = crate::DiffSl; - let s = Bdf::::default(); - let (x, v, t) = crate::ode_solver::tests::test_ball_bounce(s); + let (x, v, t) = crate::ode_solver::tests::test_ball_bounce( + test_ball_bounce_problem::().bdf::().unwrap(), + ); let expected_x = [ 0.003751514915514589, diff --git a/src/ode_solver/bdf_state.rs b/src/ode_solver/bdf_state.rs index 454bf76f..9cc642a5 100644 --- a/src/ode_solver/bdf_state.rs +++ b/src/ode_solver/bdf_state.rs @@ -2,15 +2,19 @@ use crate::{ error::{DiffsolError, OdeSolverError}, ode_solver_error, scalar::IndexType, - scale, AugmentedOdeEquations, DenseMatrix, OdeEquations, OdeSolverProblem, OdeSolverState, Op, - StateRef, StateRefMut, Vector, VectorViewMut, + scale, AugmentedOdeEquations, DefaultDenseMatrix, DenseMatrix, OdeEquations, OdeSolverProblem, + OdeSolverState, Op, StateRef, StateRefMut, Vector, VectorViewMut, }; use std::ops::MulAssign; use super::state::StateCommon; #[derive(Clone)] -pub struct BdfState> { +pub struct BdfState::M> +where + V: Vector + DefaultDenseMatrix, + M: DenseMatrix, +{ pub(crate) order: usize, pub(crate) diff: M, pub(crate) sdiff: Vec, @@ -34,7 +38,7 @@ pub struct BdfState> { impl BdfState where - V: Vector, + V: Vector + DefaultDenseMatrix, M: DenseMatrix, { pub(crate) const MAX_ORDER: IndexType = 5; @@ -83,7 +87,7 @@ where impl OdeSolverState for BdfState where - V: Vector, + V: Vector + DefaultDenseMatrix, M: DenseMatrix, { fn set_problem( diff --git a/src/ode_solver/builder.rs b/src/ode_solver/builder.rs index d637e466..512082b3 100644 --- a/src/ode_solver/builder.rs +++ b/src/ode_solver/builder.rs @@ -1,29 +1,43 @@ +use nalgebra::DMatrix; + use crate::{ error::{DiffsolError, OdeSolverError}, ode_solver_error, - vector::DefaultDenseMatrix, - Closure, ClosureNoJac, ClosureWithSens, ConstantClosure, ConstantClosureWithSens, - LinearClosure, Matrix, OdeEquations, OdeSolverProblem, Op, UnitCallable, Vector, + op::{linear_closure_with_adjoint::LinearClosureWithAdjoint, BuilderOp}, + Closure, ClosureNoJac, ClosureWithAdjoint, ClosureWithSens, ConstantClosure, + ConstantClosureWithAdjoint, ConstantClosureWithSens, ConstantOp, LinearClosure, LinearOp, + Matrix, NonLinearOp, OdeEquations, OdeSolverProblem, Op, ParameterisedOp, UnitCallable, Vector, }; -use std::rc::Rc; use super::equations::OdeSolverEquations; /// Builder for ODE problems. Use methods to set parameters and then call one of the build methods when done. -pub struct OdeBuilder { - t0: f64, - h0: f64, - rtol: f64, - atol: Vec, - sens_atol: Option>, - sens_rtol: Option, - out_rtol: Option, - out_atol: Option>, - param_rtol: Option, - param_atol: Option>, - p: Vec, +pub struct OdeBuilder< + M: Matrix = DMatrix, + Rhs = UnitCallable, + Init = UnitCallable, + Mass = UnitCallable, + Root = UnitCallable, + Out = UnitCallable, +> { + t0: M::T, + h0: M::T, + rtol: M::T, + atol: Vec, + sens_atol: Option>, + sens_rtol: Option, + out_rtol: Option, + out_atol: Option>, + param_rtol: Option, + param_atol: Option>, + p: Vec, use_coloring: bool, integrate_out: bool, + rhs: Option, + init: Option, + mass: Option, + root: Option, + out: Option, } impl Default for OdeBuilder { @@ -37,13 +51,14 @@ impl Default for OdeBuilder { /// # Example /// /// ```rust -/// use diffsol::{OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod}; +/// use diffsol::{OdeBuilder, NalgebraLU, Bdf, OdeSolverState, OdeSolverMethod}; /// type M = nalgebra::DMatrix; +/// type LS = NalgebraLU; /// -/// let problem = OdeBuilder::new() +/// let problem = OdeBuilder::::new() /// .rtol(1e-6) /// .p([0.1]) -/// .build_ode::( +/// .rhs_implicit( /// // dy/dt = -ay /// |x, p, t, y| { /// y[0] = -p[0] * x[0]; @@ -52,23 +67,28 @@ impl Default for OdeBuilder { /// |x, p, t, v, y| { /// y[0] = -p[0] * v[0]; /// }, +/// ) +/// .init( /// // y(0) = 1 /// |p, t| { /// nalgebra::DVector::from_vec(vec![1.0]) /// }, -/// ).unwrap(); +/// ) +/// .build() +/// .unwrap(); /// -/// let mut solver = Bdf::default(); +/// let mut solver = problem.bdf::().unwrap(); /// let t = 0.4; -/// let mut state = OdeSolverState::new(&problem, &solver).unwrap(); -/// solver.set_problem(state, &problem); -/// while solver.state().unwrap().t <= t { +/// while solver.state().t <= t { /// solver.step().unwrap(); /// } /// let y = solver.interpolate(t); /// ``` /// -impl OdeBuilder { +impl OdeBuilder +where + M: Matrix, +{ /// Create a new builder with default parameters: /// - t0 = 0.0 /// - h0 = 1.0 @@ -78,11 +98,16 @@ impl OdeBuilder { /// - use_coloring = false /// - constant_mass = false pub fn new() -> Self { - let default_atol = vec![1e-6]; - let default_rtol = 1e-6; + let default_atol = vec![M::T::from(1e-6)]; + let default_rtol = 1e-6.into(); Self { - t0: 0.0, - h0: 1.0, + rhs: None, + init: None, + mass: None, + root: None, + out: None, + t0: 0.0.into(), + h0: 1.0.into(), rtol: default_rtol, atol: default_atol.clone(), p: vec![], @@ -97,51 +122,508 @@ impl OdeBuilder { } } + /// Set the right-hand side of the ODE. + /// + /// # Arguments + /// + /// - `rhs`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the right-hand side of the ODE. + /// - `rhs_jac`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the Jacobian of the right-hand side with the vector v. + pub fn rhs_implicit( + self, + rhs: F, + rhs_jac: G, + ) -> OdeBuilder, Init, Mass, Root, Out> + where + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + { + let nstates = 0; + OdeBuilder::, Init, Mass, Root, Out> { + rhs: Some(Closure::new(rhs, rhs_jac, nstates, nstates, nstates)), + init: self.init, + mass: self.mass, + root: self.root, + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + /// Set the right-hand side of the ODE for forward sensitivity analysis. + /// + /// # Arguments + /// - `rhs`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the right-hand side of the ODE. + /// - `rhs_jac`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the Jacobian of the right-hand side with the vector v. + /// - `rhs_sens`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the partial derivative of the rhs wrt the parameters, with the vector v. + pub fn rhs_sens_implicit( + self, + rhs: F, + rhs_jac: G, + rhs_sens: H, + ) -> OdeBuilder, Init, Mass, Root, Out> + where + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + { + let nstates = 0; + OdeBuilder::, Init, Mass, Root, Out> { + rhs: Some(ClosureWithSens::new( + rhs, rhs_jac, rhs_sens, nstates, nstates, nstates, + )), + init: self.init, + mass: self.mass, + root: self.root, + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + #[allow(clippy::type_complexity)] + pub fn rhs_adjoint_implicit( + self, + rhs: F, + rhs_jac: G, + rhs_adjoint: H, + rhs_sens_adjoint: I, + ) -> OdeBuilder, Init, Mass, Root, Out> + where + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + { + let nstates = 0; + OdeBuilder::, Init, Mass, Root, Out> { + rhs: Some(ClosureWithAdjoint::new( + rhs, + rhs_jac, + rhs_adjoint, + rhs_sens_adjoint, + nstates, + nstates, + nstates, + )), + init: self.init, + mass: self.mass, + root: self.root, + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + /// Set the initial condition of the ODE. + /// + /// # Arguments + /// - `init`: Function of type Fn(p: &V, t: S) -> V that computes the initial state. + pub fn init(self, init: F) -> OdeBuilder, Mass, Root, Out> + where + F: Fn(&M::V, M::T) -> M::V, + { + let nstates = 0; + OdeBuilder::, Mass, Root, Out> { + rhs: self.rhs, + init: Some(ConstantClosure::new(init, nstates, nstates)), + mass: self.mass, + root: self.root, + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + /// Set the initial condition of the ODE for forward sensitivity analysis. + /// + /// # Arguments + /// - `init`: Function of type Fn(p: &V, t: S) -> V that computes the initial state. + /// - `init_sens`: Function of type Fn(p: &V, t: S, y: &mut V) that computes the multiplication of the partial derivative of the initial state wrt the parameters, with the vector v. + pub fn init_sens( + self, + init: F, + init_sens: G, + ) -> OdeBuilder, Mass, Root, Out> + where + F: Fn(&M::V, M::T) -> M::V, + G: Fn(&M::V, M::T, &M::V, &mut M::V), + { + let nstates = 0; + OdeBuilder::, Mass, Root, Out> { + rhs: self.rhs, + init: Some(ConstantClosureWithSens::new( + init, init_sens, nstates, nstates, + )), + mass: self.mass, + root: self.root, + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + /// Set the initial condition of the ODE for adjoint sensitivity analysis. + /// + /// # Arguments + /// - `init`: Function of type Fn(p: &V, t: S) -> V that computes the initial state. + /// - `init_sens_adjoint`: Function of type Fn(p: &V, t: S, y: &V, y_adj: &mut V) that computes the multiplication of the partial derivative of the initial state wrt the parameters, with the vector v. + /// + pub fn init_adjoint( + self, + init: F, + init_sens_adjoint: G, + ) -> OdeBuilder, Mass, Root, Out> + where + F: Fn(&M::V, M::T) -> M::V, + G: Fn(&M::V, M::T, &M::V, &mut M::V), + { + let nstates = 0; + OdeBuilder::, Mass, Root, Out> { + rhs: self.rhs, + init: Some(ConstantClosureWithAdjoint::new( + init, + init_sens_adjoint, + nstates, + nstates, + )), + mass: self.mass, + root: self.root, + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + /// Set the mass matrix of the ODE. + /// + /// # Arguments + /// - `mass`: Function of type Fn(v: &V, p: &V, t: S, beta: S, y: &mut V) that computes a gemv multiplication of the mass matrix with the vector v (i.e. y = M * v + beta * y). + pub fn mass(self, mass: F) -> OdeBuilder, Root, Out> + where + F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), + { + let nstates = 0; + OdeBuilder::, Root, Out> { + rhs: self.rhs, + init: self.init, + mass: Some(LinearClosure::new(mass, nstates, nstates, nstates)), + root: self.root, + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + /// Set the mass matrix of the ODE for adjoint sensitivity analysis. + /// + /// # Arguments + /// + /// - `mass`: Function of type Fn(v: &V, p: &V, t: S, beta: S, y: &mut V) that computes a gemv multiplication of the mass matrix with + /// the vector v (i.e. y = M * v + beta * y). + /// - `mass_adjoint`: Function of type Fn(v: &V, p: &V, t: S, beta: S, y: &mut V) that computes a gemv multiplication of the transpose of the mass matrix with + /// the vector v (i.e. y = M^T * v + beta * y). + pub fn mass_adjoint( + self, + mass: F, + mass_adjoint: G, + ) -> OdeBuilder, Root, Out> + where + F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), + { + let nstates = 0; + OdeBuilder::, Root, Out> { + rhs: self.rhs, + init: self.init, + mass: Some(LinearClosureWithAdjoint::new( + mass, + mass_adjoint, + nstates, + nstates, + nstates, + )), + root: self.root, + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + /// Set a root equation for the ODE. + /// + /// # Arguments + /// - `root`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the root function. + /// - `nroots`: Number of roots (i.e. number of elements in the `y` arg in `root`), an event is triggered when any of the roots changes sign. + pub fn root( + self, + root: F, + nroots: usize, + ) -> OdeBuilder, Out> + where + F: Fn(&M::V, &M::V, M::T, &mut M::V), + { + let nstates = 0; + OdeBuilder::, Out> { + rhs: self.rhs, + init: self.init, + mass: self.mass, + root: Some(ClosureNoJac::new(root, nstates, nroots, nroots)), + out: self.out, + + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + pub fn out_implicit( + self, + out: F, + out_jac: G, + nout: usize, + ) -> OdeBuilder> + where + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + { + let nstates = 0; + OdeBuilder::> { + rhs: self.rhs, + init: self.init, + mass: self.mass, + root: self.root, + out: Some(Closure::new(out, out_jac, nstates, nout, nstates)), + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + + #[allow(clippy::type_complexity)] + pub fn out_adjoint_implicit( + self, + out: F, + out_jac: G, + out_adjoint: H, + out_sens_adjoint: I, + nout: usize, + ) -> OdeBuilder> + where + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + { + let nstates = 0; + OdeBuilder::> { + rhs: self.rhs, + init: self.init, + mass: self.mass, + root: self.root, + out: Some(ClosureWithAdjoint::new( + out, + out_jac, + out_adjoint, + out_sens_adjoint, + nstates, + nout, + nstates, + )), + t0: self.t0, + h0: self.h0, + rtol: self.rtol, + atol: self.atol, + sens_atol: self.sens_atol, + sens_rtol: self.sens_rtol, + out_rtol: self.out_rtol, + out_atol: self.out_atol, + param_rtol: self.param_rtol, + param_atol: self.param_atol, + p: self.p, + use_coloring: self.use_coloring, + integrate_out: self.integrate_out, + } + } + /// Set the initial time. pub fn t0(mut self, t0: f64) -> Self { - self.t0 = t0; + self.t0 = t0.into(); self } - pub fn sens_rtol(mut self, sens_rtol: Option) -> Self { - self.sens_rtol = sens_rtol; + pub fn sens_rtol(mut self, sens_rtol: f64) -> Self { + self.sens_rtol = Some(sens_rtol.into()); self } - pub fn sens_atol(mut self, sens_atol: Option) -> Self + pub fn sens_atol(mut self, sens_atol: V) -> Self where V: IntoIterator, - f64: From, + M::T: From, { - self.sens_atol = sens_atol.map(|atol| atol.into_iter().map(|x| f64::from(x)).collect()); + self.sens_atol = Some(sens_atol.into_iter().map(|x| M::T::from(x)).collect()); + self + } + + pub fn turn_off_sensitivities_error_control(mut self) -> Self { + self.sens_atol = None; + self.sens_rtol = None; + self + } + + pub fn turn_off_output_error_control(mut self) -> Self { + self.out_atol = None; + self.out_rtol = None; self } - pub fn out_rtol(mut self, out_rtol: Option) -> Self { - self.out_rtol = out_rtol; + pub fn turn_off_param_error_control(mut self) -> Self { + self.param_atol = None; + self.param_rtol = None; self } - pub fn out_atol(mut self, out_atol: Option) -> Self + pub fn out_rtol(mut self, out_rtol: f64) -> Self { + self.out_rtol = Some(out_rtol.into()); + self + } + + pub fn out_atol(mut self, out_atol: V) -> Self where V: IntoIterator, - f64: From, + M::T: From, { - self.out_atol = out_atol.map(|atol| atol.into_iter().map(|x| f64::from(x)).collect()); + self.out_atol = Some(out_atol.into_iter().map(|x| M::T::from(x)).collect()); self } - pub fn param_rtol(mut self, param_rtol: Option) -> Self { - self.param_rtol = param_rtol; + pub fn param_rtol(mut self, param_rtol: f64) -> Self { + self.param_rtol = Some(param_rtol.into()); self } - pub fn param_atol(mut self, param_atol: Option) -> Self + pub fn param_atol(mut self, param_atol: V) -> Self where V: IntoIterator, - f64: From, + M::T: From, { - self.param_atol = param_atol.map(|atol| atol.into_iter().map(|x| f64::from(x)).collect()); + self.param_atol = Some(param_atol.into_iter().map(|x| M::T::from(x)).collect()); self } @@ -154,13 +636,13 @@ impl OdeBuilder { /// Set the initial step size. pub fn h0(mut self, h0: f64) -> Self { - self.h0 = h0; + self.h0 = h0.into(); self } /// Set the relative tolerance. pub fn rtol(mut self, rtol: f64) -> Self { - self.rtol = rtol; + self.rtol = rtol.into(); self } @@ -168,9 +650,9 @@ impl OdeBuilder { pub fn atol(mut self, atol: V) -> Self where V: IntoIterator, - f64: From, + M::T: From, { - self.atol = atol.into_iter().map(|x| f64::from(x)).collect(); + self.atol = atol.into_iter().map(|x| M::T::from(x)).collect(); self } @@ -178,9 +660,9 @@ impl OdeBuilder { pub fn p(mut self, p: V) -> Self where V: IntoIterator, - f64: From, + M::T: From, { - self.p = p.into_iter().map(|x| f64::from(x)).collect(); + self.p = p.into_iter().map(|x| M::T::from(x)).collect(); self } @@ -194,9 +676,9 @@ impl OdeBuilder { self } - fn build_atol(atol: Vec, nstates: usize, ty: &str) -> Result { + fn build_atol(atol: Vec, nstates: usize, ty: &str) -> Result { if atol.len() == 1 { - Ok(V::from_element(nstates, V::T::from(atol[0]))) + Ok(M::V::from_element(nstates, atol[0])) } else if atol.len() != nstates { Err(ode_solver_error!( BuilderError, @@ -208,24 +690,24 @@ impl OdeBuilder { ) )) } else { - let mut v = V::zeros(nstates); + let mut v = M::V::zeros(nstates); for (i, &a) in atol.iter().enumerate() { - v[i] = V::T::from(a); + v[i] = a; } Ok(v) } } #[allow(clippy::type_complexity)] - fn build_atols( - atol: Vec, - sens_atol: Option>, - out_atol: Option>, - param_atol: Option>, + fn build_atols( + atol: Vec, + sens_atol: Option>, + out_atol: Option>, + param_atol: Option>, nstates: usize, nout: Option, nparam: usize, - ) -> Result<(Rc, Option>, Option>, Option>), DiffsolError> { + ) -> Result<(M::V, Option, Option, Option), DiffsolError> { let atol = Self::build_atol(atol, nstates, "states")?; let out_atol = match out_atol { Some(out_atol) => Some(Self::build_atol(out_atol, nout.unwrap_or(0), "output")?), @@ -239,491 +721,111 @@ impl OdeBuilder { Some(sens_atol) => Some(Self::build_atol(sens_atol, nstates, "sensitivity")?), None => None, }; - Ok(( - Rc::new(atol), - sens_atol.map(Rc::new), - out_atol.map(Rc::new), - param_atol.map(Rc::new), - )) + Ok((atol, sens_atol, out_atol, param_atol)) } - fn build_p(p: Vec) -> V { - let mut v = V::zeros(p.len()); + fn build_p(p: Vec) -> M::V { + let mut v = M::V::zeros(p.len()); for (i, &p) in p.iter().enumerate() { - v[i] = V::T::from(p); + v[i] = p; } v } - /// Build an ODE problem with a mass matrix. - /// - /// # Arguments - /// - /// - `rhs`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the right-hand side of the ODE. - /// - `rhs_jac`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the Jacobian of the right-hand side with the vector v. - /// - `mass`: Function of type Fn(v: &V, p: &V, t: S, beta: S, y: &mut V) that computes a gemv multiplication of the mass matrix with the vector v (i.e. y = M * v + beta * y). - /// - `init`: Function of type Fn(p: &V, t: S) -> V that computes the initial state. - /// - /// # Generic Arguments - /// - /// - `M`: Type that implements the `Matrix` trait. Often this must be provided explicitly (i.e. `type M = DMatrix; builder.build_ode::`). - /// - /// # Example - /// - /// ``` - /// use diffsol::OdeBuilder; - /// use nalgebra::DVector; - /// type M = nalgebra::DMatrix; - /// - /// // dy/dt = y - /// // 0 = z - y - /// // y(0) = 0.1 - /// // z(0) = 0.1 - /// let problem = OdeBuilder::new() - /// .build_ode_with_mass::( - /// |x, _p, _t, y| { - /// y[0] = x[0]; - /// y[1] = x[1] - x[0]; - /// }, - /// |x, _p, _t, v, y| { - /// y[0] = v[0]; - /// y[1] = v[1] - v[0]; - /// }, - /// |v, _p, _t, beta, y| { - /// y[0] = v[0] + beta * y[0]; - /// y[1] = beta * y[1]; - /// }, - /// |p, _t| DVector::from_element(2, 0.1), - /// ); - /// ``` #[allow(clippy::type_complexity)] - pub fn build_ode_with_mass( + pub fn build( self, - rhs: F, - rhs_jac: G, - mass: H, - init: I, - ) -> Result< - OdeSolverProblem< - OdeSolverEquations, ConstantClosure, LinearClosure>, - >, - DiffsolError, - > + ) -> Result>, DiffsolError> where M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V), - G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - H: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), - I: Fn(&M::V, M::T) -> M::V, + Rhs: BuilderOp, + Init: BuilderOp, + Mass: BuilderOp, + Root: BuilderOp, + Out: BuilderOp, + for<'a> ParameterisedOp<'a, Rhs>: NonLinearOp, + for<'a> ParameterisedOp<'a, Init>: ConstantOp, + for<'a> ParameterisedOp<'a, Mass>: LinearOp, + for<'a> ParameterisedOp<'a, Root>: NonLinearOp, + for<'a> ParameterisedOp<'a, Out>: NonLinearOp, { - let p = Rc::new(Self::build_p(self.p)); - let t0 = M::T::from(self.t0); - let y0 = init(&p, t0); - let nstates = y0.len(); - let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone()); - let mut mass = LinearClosure::new(mass, nstates, nstates, p.clone()); - let init = ConstantClosure::new(init, p.clone()); - if self.use_coloring || M::is_sparse() { - rhs.calculate_sparsity(&y0, t0); - mass.calculate_sparsity(t0); - } - let mass = Some(mass); + let p = Self::build_p(self.p); let nparams = p.len(); - let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( - self.atol, - self.sens_atol, - self.out_atol, - self.param_atol, - nstates, - None, - nparams, - )?; - let eqn = OdeSolverEquations::new(rhs, mass, None, init, None, p); - OdeSolverProblem::new( - Rc::new(eqn), - M::T::from(self.rtol), - atol, - self.sens_rtol.map(M::T::from), - sens_atol, - self.out_rtol.map(M::T::from), - out_atol, - self.param_rtol.map(M::T::from), - param_atol, - M::T::from(self.t0), - M::T::from(self.h0), - self.integrate_out, - ) - } + let mut rhs = self + .rhs + .ok_or(ode_solver_error!(BuilderError, "Missing right-hand side"))?; + let mut init = self + .init + .ok_or(ode_solver_error!(BuilderError, "Missing initial state"))?; + let mut mass = self.mass; + let mut root = self.root; + let mut out = self.out; - #[allow(clippy::type_complexity)] - #[allow(clippy::too_many_arguments)] - pub fn build_ode_with_mass_and_out( - self, - rhs: F, - rhs_jac: G, - mass: H, - init: I, - out: J, - out_jac: K, - nout: usize, - ) -> Result< - OdeSolverProblem< - OdeSolverEquations< - M, - Closure, - ConstantClosure, - LinearClosure, - UnitCallable, - Closure, - >, - >, - DiffsolError, - > - where - M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V), - G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - H: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), - I: Fn(&M::V, M::T) -> M::V, - J: Fn(&M::V, &M::V, M::T, &mut M::V), - K: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - { - let p = Rc::new(Self::build_p(self.p)); - let t0 = M::T::from(self.t0); - let y0 = init(&p, t0); + let init_op = ParameterisedOp::new(&init, &p); + let y0 = init_op.call(self.t0); let nstates = y0.len(); - let nparams = p.len(); - let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone()); - let out = Closure::new(out, out_jac, nstates, nout, p.clone()); - let mut mass = LinearClosure::new(mass, nstates, nstates, p.clone()); - let init = ConstantClosure::new(init, p.clone()); - if self.use_coloring || M::is_sparse() { - rhs.calculate_sparsity(&y0, t0); - mass.calculate_sparsity(t0); + + rhs.set_nstates(nstates); + rhs.set_nout(nstates); + rhs.set_nparams(nparams); + + init.set_nout(nstates); + init.set_nparams(nparams); + + if let Some(ref mut mass) = mass { + mass.set_nstates(nstates); + mass.set_nparams(nparams); + mass.set_nout(nstates); } - let mass = Some(mass); - let out = Some(out); - let eqn = OdeSolverEquations::new(rhs, mass, None, init, out, p); - let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( - self.atol, - self.sens_atol, - self.out_atol, - self.param_atol, - nstates, - Some(nout), - nparams, - )?; - OdeSolverProblem::new( - Rc::new(eqn), - M::T::from(self.rtol), - atol, - self.sens_rtol.map(M::T::from), - sens_atol, - self.out_rtol.map(M::T::from), - out_atol, - self.param_rtol.map(M::T::from), - param_atol, - M::T::from(self.t0), - M::T::from(self.h0), - self.integrate_out, - ) - } - /// Build an ODE problem with a mass matrix that is the identity matrix. - /// - /// # Arguments - /// - /// - `rhs`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the right-hand side of the ODE. - /// - `rhs_jac`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the Jacobian of the right-hand side with the vector v. - /// - `init`: Function of type Fn(p: &V, t: S) -> V that computes the initial state. - /// - /// # Generic Arguments - /// - /// - `M`: Type that implements the `Matrix` trait. Often this must be provided explicitly (i.e. `type M = DMatrix; builder.build_ode::`). - /// - /// # Example - /// - /// - /// - /// ``` - /// use diffsol::OdeBuilder; - /// use nalgebra::DVector; - /// type M = nalgebra::DMatrix; - /// - /// - /// // dy/dt = y - /// // y(0) = 0.1 - /// let problem = OdeBuilder::new() - /// .build_ode::( - /// |x, _p, _t, y| y[0] = x[0], - /// |x, _p, _t, v , y| y[0] = v[0], - /// |p, _t| DVector::from_element(1, 0.1), - /// ); - /// ``` - #[allow(clippy::type_complexity)] - pub fn build_ode( - self, - rhs: F, - rhs_jac: G, - init: I, - ) -> Result< - OdeSolverProblem, ConstantClosure>>, - DiffsolError, - > - where - M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V), - G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - I: Fn(&M::V, M::T) -> M::V, - { - let p = Rc::new(Self::build_p(self.p)); - let t0 = M::T::from(self.t0); - let y0 = init(&p, t0); - let nstates = y0.len(); - let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone()); - let init = ConstantClosure::new(init, p.clone()); - if self.use_coloring || M::is_sparse() { - rhs.calculate_sparsity(&y0, t0); + if let Some(ref mut root) = root { + root.set_nstates(nstates); + root.set_nparams(nparams); } - let nparams = p.len(); - let eqn = OdeSolverEquations::new(rhs, None, None, init, None, p); - let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( - self.atol, - self.sens_atol, - self.out_atol, - self.param_atol, - nstates, - None, - nparams, - )?; - OdeSolverProblem::new( - Rc::new(eqn), - M::T::from(self.rtol), - atol, - self.sens_rtol.map(M::T::from), - sens_atol, - self.out_rtol.map(M::T::from), - out_atol, - self.param_rtol.map(M::T::from), - param_atol, - M::T::from(self.t0), - M::T::from(self.h0), - self.integrate_out, - ) - } - /// Build an ODE problem with a mass matrix that is the identity matrix and sensitivities. - /// - /// # Arguments - /// - /// - `rhs`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the right-hand side of the ODE. - /// - `rhs_jac`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the Jacobian of the right-hand side with the vector v. - /// - `rhs_sens`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the partial derivative of the rhs wrt the parameters, with the vector v. - /// - `init`: Function of type Fn(p: &V, t: S) -> V that computes the initial state. - /// - `init_sens`: Function of type Fn(p: &V, t: S, y: &mut V) that computes the multiplication of the partial derivative of the initial state wrt the parameters, with the vector v. - /// - /// # Example - /// - /// ``` - /// use diffsol::OdeBuilder; - /// use nalgebra::DVector; - /// type M = nalgebra::DMatrix; - /// - /// - /// // dy/dt = a y - /// // y(0) = 0.1 - /// let problem = OdeBuilder::new() - /// .build_ode_with_sens::( - /// |x, p, _t, y| y[0] = p[0] * x[0], - /// |x, p, _t, v, y| y[0] = p[0] * v[0], - /// |x, p, _t, v, y| y[0] = v[0] * x[0], - /// |p, _t| DVector::from_element(1, 0.1), - /// |p, t, v, y| y.fill(0.0), - /// ); - /// ``` - #[allow(clippy::type_complexity)] - pub fn build_ode_with_sens( - self, - rhs: F, - rhs_jac: G, - rhs_sens: J, - init: I, - init_sens: K, - ) -> Result< - OdeSolverProblem< - OdeSolverEquations, ConstantClosureWithSens>, - >, - DiffsolError, - > - where - M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V), - G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - I: Fn(&M::V, M::T) -> M::V, - J: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - K: Fn(&M::V, M::T, &M::V, &mut M::V), - { - let p = Rc::new(Self::build_p(self.p)); - let t0 = M::T::from(self.t0); - let y0 = init(&p, t0); - let nstates = y0.len(); - let init = ConstantClosureWithSens::new(init, init_sens, nstates, nstates, p.clone()); - let mut rhs = ClosureWithSens::new(rhs, rhs_jac, rhs_sens, nstates, nstates, p.clone()); - if self.use_coloring || M::is_sparse() { - rhs.calculate_jacobian_sparsity(&y0, t0); - rhs.calculate_sens_sparsity(&y0, t0); + if let Some(ref mut out) = out { + out.set_nstates(nstates); + out.set_nparams(nparams); } - let nparams = p.len(); - let eqn = OdeSolverEquations::new(rhs, None, None, init, None, p); - let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( - self.atol, - self.sens_atol, - self.out_atol, - self.param_atol, - nstates, - None, - nparams, - )?; - OdeSolverProblem::new( - Rc::new(eqn), - M::T::from(self.rtol), - atol, - self.sens_rtol.map(M::T::from), - sens_atol, - self.out_rtol.map(M::T::from), - out_atol, - self.param_rtol.map(M::T::from), - param_atol, - M::T::from(self.t0), - M::T::from(self.h0), - self.integrate_out, - ) - } - /// Build an ODE problem with an event. - /// - /// # Arguments - /// - /// - `rhs`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the right-hand side of the ODE. - /// - `rhs_jac`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the Jacobian of the right-hand side with the vector v. - /// - `init`: Function of type Fn(p: &V, t: S) -> V that computes the initial state. - /// - `root`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the root function. - /// - `nroots`: Number of roots (i.e. number of elements in the `y` arg in `root`), an event is triggered when any of the roots changes sign. - /// - /// # Generic Arguments - /// - /// - `M`: Type that implements the `Matrix` trait. Often this must be provided explicitly (i.e. `type M = DMatrix; builder.build_ode::`). - /// - /// # Example - /// - /// - /// - /// ``` - /// use diffsol::OdeBuilder; - /// use nalgebra::DVector; - /// type M = nalgebra::DMatrix; - /// - /// - /// // dy/dt = y - /// // y(0) = 0.1 - /// // event at y = 0.5 - /// let problem = OdeBuilder::new() - /// .build_ode_with_root::( - /// |x, _p, _t, y| y[0] = x[0], - /// |x, _p, _t, v , y| y[0] = v[0], - /// |p, _t| DVector::from_element(1, 0.1), - /// |x, _p, _t, y| y[0] = x[0] - 0.5, - /// 1, - /// ); - /// ``` - #[allow(clippy::type_complexity)] - pub fn build_ode_with_root( - self, - rhs: F, - rhs_jac: G, - init: I, - root: H, - nroots: usize, - ) -> Result< - OdeSolverProblem< - OdeSolverEquations< - M, - Closure, - ConstantClosure, - UnitCallable, - ClosureNoJac, - >, - >, - DiffsolError, - > - where - M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V), - G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - H: Fn(&M::V, &M::V, M::T, &mut M::V), - I: Fn(&M::V, M::T) -> M::V, - { - let p = Rc::new(Self::build_p(self.p)); - let t0 = M::T::from(self.t0); - let y0 = init(&p, t0); - let nstates = y0.len(); - let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone()); - let root = ClosureNoJac::new(root, nstates, nroots, p.clone()); - let init = ConstantClosure::new(init, p.clone()); if self.use_coloring || M::is_sparse() { - rhs.calculate_sparsity(&y0, t0); + rhs.calculate_sparsity(&y0, self.t0, &p); + if let Some(ref mut mass) = mass { + mass.calculate_sparsity(&y0, self.t0, &p); + } } - let nparams = p.len(); - let eqn = OdeSolverEquations::new(rhs, None, Some(root), init, None, p); + let nout = out.as_ref().map(|out| out.nout()); + let eqn = OdeSolverEquations::new(rhs, init, mass, root, out, p); + let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( self.atol, self.sens_atol, self.out_atol, self.param_atol, nstates, - None, + nout, nparams, )?; OdeSolverProblem::new( - Rc::new(eqn), - M::T::from(self.rtol), + eqn, + self.rtol, atol, - self.sens_rtol.map(M::T::from), + self.sens_rtol, sens_atol, - self.out_rtol.map(M::T::from), + self.out_rtol, out_atol, - self.param_rtol.map(M::T::from), + self.param_rtol, param_atol, - M::T::from(self.t0), - M::T::from(self.h0), + self.t0, + self.h0, self.integrate_out, ) } - /// Build an ODE problem using the default dense matrix (see [Self::build_ode]). - #[allow(clippy::type_complexity)] - pub fn build_ode_dense( - self, - rhs: F, - rhs_jac: G, - init: I, - ) -> Result< - OdeSolverProblem, ConstantClosure>>, - DiffsolError, - > - where - V: Vector + DefaultDenseMatrix, - F: Fn(&V, &V, V::T, &mut V), - G: Fn(&V, &V, V::T, &V, &mut V), - I: Fn(&V, V::T) -> V, - { - self.build_ode(rhs, rhs_jac, init) - } - /// Build an ODE problem from a set of equations pub fn build_from_eqn(self, mut eqn: Eqn) -> Result, DiffsolError> where - Eqn: OdeEquations, + Eqn: OdeEquations, { let nparams = eqn.rhs().nparams(); let nstates = eqn.rhs().nstates(); @@ -748,20 +850,20 @@ impl OdeBuilder { )); } - let p = Rc::new(Self::build_p(self.p)); - eqn.set_params(p); + let p = Self::build_p(self.p); + eqn.set_params(&p); OdeSolverProblem::new( - Rc::new(eqn), - Eqn::T::from(self.rtol), + eqn, + self.rtol, atol, - self.sens_rtol.map(Eqn::T::from), + self.sens_rtol, sens_atol, - self.out_rtol.map(Eqn::T::from), + self.out_rtol, out_atol, - self.param_rtol.map(Eqn::T::from), + self.param_rtol, param_atol, - Eqn::T::from(self.t0), - Eqn::T::from(self.h0), + self.t0, + self.h0, self.integrate_out, ) } diff --git a/src/ode_solver/checkpointing.rs b/src/ode_solver/checkpointing.rs index 1c3cc467..40b083c5 100644 --- a/src/ode_solver/checkpointing.rs +++ b/src/ode_solver/checkpointing.rs @@ -6,6 +6,7 @@ use crate::{ }; use num_traits::One; +#[derive(Clone)] pub struct HermiteInterpolator where V: Vector, @@ -35,16 +36,15 @@ where pub fn new(ys: Vec, ydots: Vec, ts: Vec) -> Self { HermiteInterpolator { ys, ydots, ts } } - pub fn reset( + pub fn reset<'a, Eqn, Method, State>( &mut self, - problem: &OdeSolverProblem, solver: &mut Method, state0: &State, state1: &State, ) -> Result<(), DiffsolError> where - Eqn: OdeEquations, - Method: OdeSolverMethod, + Eqn: OdeEquations + 'a, + Method: OdeSolverMethod<'a, Eqn, State = State>, State: OdeSolverState, { let state0_ref = state0.as_ref(); @@ -56,12 +56,12 @@ where self.ydots.push(state0_ref.dy.clone()); self.ts.push(state0_ref.t); - solver.set_problem(state0.clone(), problem)?; - while solver.state().unwrap().t < state1_ref.t { + solver.set_state(state0.clone()); + while solver.state().t < state1_ref.t { solver.step()?; - self.ys.push(solver.state().unwrap().y.clone()); - self.ydots.push(solver.state().unwrap().dy.clone()); - self.ts.push(solver.state().unwrap().t); + self.ys.push(solver.state().y.clone()); + self.ydots.push(solver.state().dy.clone()); + self.ts.push(solver.state().t); } Ok(()) } @@ -106,21 +106,35 @@ where } } -pub struct Checkpointing +pub struct Checkpointing<'a, Eqn, Method> where - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, Eqn: OdeEquations, { checkpoints: Vec, segment: RefCell>, previous_segment: RefCell>>, solver: RefCell, - pub(crate) problem: OdeSolverProblem, } -impl Checkpointing +impl<'a, Eqn, Method> Clone for Checkpointing<'a, Eqn, Method> where - Method: OdeSolverMethod, + Method: OdeSolverMethod<'a, Eqn>, + Eqn: OdeEquations, +{ + fn clone(&self) -> Self { + Checkpointing { + checkpoints: self.checkpoints.clone(), + segment: RefCell::new(self.segment.borrow().clone()), + previous_segment: RefCell::new(self.previous_segment.borrow().clone()), + solver: RefCell::new(self.solver.borrow().clone()), + } + } +} + +impl<'a, Eqn, Method> Checkpointing<'a, Eqn, Method> +where + Method: OdeSolverMethod<'a, Eqn>, Eqn: OdeEquations, { pub fn new( @@ -129,21 +143,16 @@ where checkpoints: Vec, segment: Option>, ) -> Self { - if solver.problem().is_none() { - panic!("Solver must have a problem set"); - } if checkpoints.len() < 2 { panic!("Checkpoints must have at least 2 elements"); } if start_idx >= checkpoints.len() - 1 { panic!("start_idx must be less than checkpoints.len() - 1"); } - let problem = solver.problem().unwrap().clone(); let segment = segment.unwrap_or_else(|| { let mut segment = HermiteInterpolator::default(); segment .reset( - &problem, &mut solver, &checkpoints[start_idx], &checkpoints[start_idx + 1], @@ -159,10 +168,13 @@ where segment, previous_segment, solver, - problem, } } + pub fn problem(&self) -> &'a OdeSolverProblem { + self.solver.borrow().problem() + } + pub fn interpolate(&self, t: Eqn::T, y: &mut Eqn::V) -> Result<(), DiffsolError> { { let segment = self.segment.borrow(); @@ -202,7 +214,6 @@ where let mut previous_segment = self.previous_segment.borrow_mut(); let mut segment = self.segment.borrow_mut(); previous_segment.as_mut().unwrap().reset( - &self.problem, &mut *solver, &self.checkpoints[idx], &self.checkpoints[idx + 1], @@ -218,39 +229,39 @@ mod tests { use nalgebra::{DMatrix, DVector}; use crate::{ - ode_solver::test_models::robertson::robertson, Bdf, BdfState, OdeEquations, - OdeSolverMethod, OdeSolverState, Op, Vector, + ode_solver::test_models::robertson::robertson, NalgebraLU, OdeEquations, OdeSolverMethod, + Op, Vector, }; use super::{Checkpointing, HermiteInterpolator}; #[test] fn test_checkpointing() { - let mut solver = Bdf::default(); - let (problem, soln) = robertson::>(false); + type M = DMatrix; + type LS = NalgebraLU; + let (problem, soln) = robertson::(false); let t_final = soln.solution_points.last().unwrap().t; let n_steps = 30; - let state0: BdfState<_, _> = OdeSolverState::new(&problem, &solver).unwrap(); - solver.set_problem(state0.clone(), &problem).unwrap(); - let mut checkpoints = vec![state0]; + let mut solver = problem.bdf::().unwrap(); + let mut checkpoints = vec![solver.checkpoint()]; let mut i = 0; let mut ys = Vec::new(); let mut ts = Vec::new(); let mut ydots = Vec::new(); - while solver.state().unwrap().t < t_final { - ts.push(solver.state().unwrap().t); - ys.push(solver.state().unwrap().y.clone()); - ydots.push(solver.state().unwrap().dy.clone()); + while solver.state().t < t_final { + ts.push(solver.state().t); + ys.push(solver.state().y.clone()); + ydots.push(solver.state().dy.clone()); solver.step().unwrap(); i += 1; - if i % n_steps == 0 && solver.state().unwrap().t < t_final { - checkpoints.push(solver.checkpoint().unwrap()); + if i % n_steps == 0 && solver.state().t < t_final { + checkpoints.push(solver.checkpoint()); ts.clear(); ys.clear(); ydots.clear(); } } - checkpoints.push(solver.checkpoint().unwrap()); + checkpoints.push(solver.checkpoint()); let segment = HermiteInterpolator::new(ys, ydots, ts); let checkpointer = Checkpointing::new(solver, checkpoints.len() - 2, checkpoints, Some(segment)); diff --git a/src/ode_solver/diffsl.rs b/src/ode_solver/diffsl.rs index 0d42b0e7..1123dd4d 100644 --- a/src/ode_solver/diffsl.rs +++ b/src/ode_solver/diffsl.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use diffsl::{execution::module::CodegenModule, Compiler}; @@ -28,7 +28,7 @@ pub struct DiffSlContext, CG: CodegenModule> { impl, CG: CodegenModule> DiffSlContext { /// Create a new context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/). - /// The input parameters are not initialized and must be set using the [Op::set_params] function before solving the ODE. + /// The input parameters are not initialized and must be set using the [OdeEquations::set_params] function before solving the ODE. pub fn new(text: &str) -> Result { let compiler = Compiler::from_discrete_str(text).map_err(|e| DiffsolError::Other(e.to_string()))?; @@ -343,19 +343,6 @@ impl, CG: CodegenModule> Op for DiffSl { fn nparams(&self) -> usize { self.context.nparams } - fn set_params(&mut self, p: Rc) { - // set the parameters in data - self.context - .compiler - .set_inputs(p.as_slice(), self.context.data.borrow_mut().as_mut_slice()); - - // set_u0 will calculate all the constants in the equations based on the params - let mut dummy = M::V::zeros(self.context.nstates); - self.context.compiler.set_u0( - dummy.as_mut_slice(), - self.context.data.borrow_mut().as_mut_slice(), - ); - } } impl<'a, M: Matrix, CG: CodegenModule> OdeEquationsRef<'a> for DiffSl { @@ -386,18 +373,30 @@ impl, CG: CodegenModule> OdeEquations for DiffSl { fn out(&self) -> Option> { Some(DiffSlOut(self)) } + + fn set_params(&mut self, p: &Self::V) { + // set the parameters in data + self.context + .compiler + .set_inputs(p.as_slice(), self.context.data.borrow_mut().as_mut_slice()); + + // set_u0 will calculate all the constants in the equations based on the params + let mut dummy = M::V::zeros(self.context.nstates); + self.context.compiler.set_u0( + dummy.as_mut_slice(), + self.context.data.borrow_mut().as_mut_slice(), + ); + } } #[cfg(test)] mod tests { - use std::rc::Rc; - use diffsl::{execution::module::CodegenModule, CraneliftModule}; use nalgebra::DVector; use crate::{ - Bdf, ConstantOp, LinearOp, NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations, - OdeSolverMethod, OdeSolverState, Op, Vector, + ConstantOp, LinearOp, NalgebraLU, NonLinearOp, NonLinearOpJacobian, OdeBuilder, + OdeEquations, OdeSolverMethod, Vector, }; use super::{DiffSl, DiffSlContext}; @@ -445,7 +444,7 @@ mod tests { let context = DiffSlContext::, CG>::new(text).unwrap(); let p = DVector::from_vec(vec![r, k]); let mut eqn = DiffSl::from_context(context); - eqn.set_params(Rc::new(p)); + eqn.set_params(&p); // test that the initial values look ok let y0 = 0.1; @@ -466,11 +465,13 @@ mod tests { mass_y.assert_eq_st(&mass_y_expect, 1e-10); // solver a bit and check the state and output - let problem = OdeBuilder::new().p([r, k]).build_from_eqn(eqn).unwrap(); - let mut solver = Bdf::default(); + let problem = OdeBuilder::>::new() + .p([r, k]) + .build_from_eqn(eqn) + .unwrap(); + let mut solver = problem.bdf::>().unwrap(); let t = 1.0; - let state = OdeSolverState::new(&problem, &solver).unwrap(); - let (ys, ts) = solver.solve(&problem, state, t).unwrap(); + let (ys, ts) = solver.solve(t).unwrap(); for (i, t) in ts.iter().enumerate() { let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0); let z_expect = 2.0 * y_expect; @@ -480,8 +481,8 @@ mod tests { // do it again with some explicit t_evals let t_evals = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0]; - let state = OdeSolverState::new(&problem, &solver).unwrap(); - let ys = solver.solve_dense(&problem, state, &t_evals).unwrap(); + let mut solver = problem.bdf::>().unwrap(); + let ys = solver.solve_dense(&t_evals).unwrap(); for (i, t) in t_evals.iter().enumerate() { let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0); let z_expect = 2.0 * y_expect; diff --git a/src/ode_solver/equations.rs b/src/ode_solver/equations.rs index 53b3d4cd..55f85da9 100644 --- a/src/ode_solver/equations.rs +++ b/src/ode_solver/equations.rs @@ -1,9 +1,7 @@ -use std::rc::Rc; - use crate::{ - op::{constant_op::ConstantOpSensAdjoint, linear_op::LinearOpTranspose}, + op::{constant_op::ConstantOpSensAdjoint, linear_op::LinearOpTranspose, ParameterisedOp}, ConstantOp, ConstantOpSens, LinearOp, Matrix, NonLinearOp, NonLinearOpAdjoint, - NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Op, UnitCallable, + NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Op, Vector, }; use serde::Serialize; @@ -35,7 +33,7 @@ impl OdeEquationsStatistics { } pub trait AugmentedOdeEquations: - OdeEquations + OdeEquations + Clone { fn update_rhs_out_state(&mut self, y: &Eqn::V, dy: &Eqn::V, t: Eqn::T); fn update_init_state(&mut self, t: Eqn::T); @@ -44,9 +42,10 @@ pub trait AugmentedOdeEquations: fn include_in_error_control(&self) -> bool; fn include_out_in_error_control(&self) -> bool; fn rtol(&self) -> Option; - fn atol(&self) -> Option<&Rc>; + fn atol(&self) -> Option<&Eqn::V>; fn out_rtol(&self) -> Option; - fn out_atol(&self) -> Option<&Rc>; + fn out_atol(&self) -> Option<&Eqn::V>; + fn integrate_main_eqn(&self) -> bool; } pub trait AugmentedOdeEquationsImplicit: @@ -65,6 +64,14 @@ pub struct NoAug { _phantom: std::marker::PhantomData, } +impl Clone for NoAug { + fn clone(&self) -> Self { + Self { + _phantom: std::marker::PhantomData, + } + } +} + impl Op for NoAug where Eqn: OdeEquations, @@ -85,10 +92,6 @@ where fn statistics(&self) -> crate::op::OpStatistics { panic!("This should never be called") } - - fn set_params(&mut self, _p: Rc) { - panic!("This should never be called") - } } impl<'a, Eqn: OdeEquations> OdeEquationsRef<'a> for NoAug { @@ -119,6 +122,10 @@ impl OdeEquations for NoAug { fn init(&self) -> >::Init { panic!("This should never be called") } + + fn set_params(&mut self, _p: &Self::V) { + panic!("This should never be called") + } } impl AugmentedOdeEquations for NoAug { @@ -131,13 +138,13 @@ impl AugmentedOdeEquations for NoAug { fn set_index(&mut self, _index: usize) { panic!("This should never be called") } - fn atol(&self) -> Option<&Rc<::V>> { + fn atol(&self) -> Option<&::V> { panic!("This should never be called") } fn include_out_in_error_control(&self) -> bool { panic!("This should never be called") } - fn out_atol(&self) -> Option<&Rc<::V>> { + fn out_atol(&self) -> Option<&::V> { panic!("This should never be called") } fn out_rtol(&self) -> Option<::T> { @@ -152,6 +159,9 @@ impl AugmentedOdeEquations for NoAug { fn include_in_error_control(&self) -> bool { panic!("This should never be called") } + fn integrate_main_eqn(&self) -> bool { + panic!("This should never be called") + } } /// this is the reference trait that defines the ODE equations of the form, this is used to define the ODE equations for a given lifetime. @@ -178,6 +188,14 @@ pub trait OdeEquationsRef<'a, ImplicitBounds: Sealed = Bounds<&'a Self>>: Op { type Out: NonLinearOp; } +impl<'a, T: OdeEquationsRef<'a>> OdeEquationsRef<'a> for &T { + type Mass = >::Mass; + type Rhs = >::Rhs; + type Root = >::Root; + type Init = >::Init; + type Out = >::Out; +} + // seal the trait so that users must use the provided default type for ImplicitBounds mod sealed { pub trait Sealed: Sized {} @@ -220,6 +238,69 @@ pub trait OdeEquations: for<'a> OdeEquationsRef<'a> { /// returns the initial condition, i.e. `y(t)`, where `t` is the initial time fn init(&self) -> >::Init; + + /// sets the current parameters of the equations + fn set_params(&mut self, p: &Self::V); +} + +//impl<'a, T: OdeEquations> OdeEquationsRef<'a> for &'a mut T { +// type Mass = >::Mass; +// type Rhs = >::Rhs; +// type Root = >::Root; +// type Init = >::Init; +// type Out = >::Out; +//} +// +//impl OdeEquations for &'_ mut T { +// fn rhs(&self) -> >::Rhs { +// (*self).rhs() +// } +// +// fn mass(&self) -> Option<>::Mass> { +// (*self).mass() +// } +// +// fn root(&self) -> Option<>::Root> { +// (*self).root() +// } +// +// fn out(&self) -> Option<>::Out> { +// (*self).out() +// } +// +// fn init(&self) -> >::Init { +// (*self).init() +// } +// +// fn set_params(&mut self, p: &Self::V) { +// (*self).set_params(p) +// } +//} + +impl OdeEquations for &'_ T { + fn rhs(&self) -> >::Rhs { + (*self).rhs() + } + + fn mass(&self) -> Option<>::Mass> { + (*self).mass() + } + + fn root(&self) -> Option<>::Root> { + (*self).root() + } + + fn out(&self) -> Option<>::Out> { + (*self).out() + } + + fn init(&self) -> >::Init { + (*self).init() + } + + fn set_params(&mut self, _p: &Self::V) { + unimplemented!() + } } pub trait OdeEquationsImplicit: @@ -282,76 +363,8 @@ impl OdeEquationsAdjoint for T where /// which define a nonlinear operator or function `F` that maps an input vector `x` to an output vector `y`, (i.e. `y = F(x)`). /// Once you have implemented this trait, you can then pass an instance of your struct to the `rhs` argument of the [Self::new] method. /// Once you have created an instance of [OdeSolverEquations], you can then use [crate::OdeBuilder::build_from_eqn] to create a problem. -/// -/// For example: -/// -/// ```rust -/// use std::rc::Rc; -/// use diffsol::{Bdf, OdeSolverState, OdeSolverMethod, NonLinearOp, NonLinearOpJacobian, OdeSolverEquations, OdeSolverProblem, Op, UnitCallable, ConstantClosure, OdeBuilder}; -/// type M = nalgebra::DMatrix; -/// type V = nalgebra::DVector; -/// -/// struct MyProblem; -/// impl Op for MyProblem { -/// type V = V; -/// type T = f64; -/// type M = M; -/// fn nstates(&self) -> usize { -/// 1 -/// } -/// fn nout(&self) -> usize { -/// 1 -/// } -/// } -/// -/// // implement rhs equations for the problem -/// impl NonLinearOp for MyProblem { -/// fn call_inplace(&self, x: &V, _t: f64, y: &mut V) { -/// y[0] = -0.1 * x[0]; -/// } -/// } -/// impl NonLinearOpJacobian for MyProblem { -/// fn jac_mul_inplace(&self, x: &V, _t: f64, v: &V, y: &mut V) { -/// y[0] = -0.1 * v[0]; -/// } -/// } -/// -/// -/// let rhs = MyProblem; -/// -/// // use the provided constant closure to define the initial condition -/// let init_fn = |p: &V, _t: f64| V::from_vec(vec![1.0]); -/// let init = ConstantClosure::new(init_fn, Rc::new(V::from_vec(vec![]))); -/// -/// // we don't have a mass matrix, root or output functions, so we can set to None -/// // we still need to give a placeholder type for these, so we use the diffsol::UnitCallable type -/// let mass: Option> = None; -/// let root: Option> = None; -/// let out: Option> = None; -/// -/// let p = Rc::new(V::from_vec(vec![])); -/// let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p); -/// -/// let problem = OdeBuilder::new().build_from_eqn(eqn).unwrap(); -/// -/// let mut solver = Bdf::default(); -/// let t = 0.4; -/// let state = OdeSolverState::new(&problem, &solver).unwrap(); -/// solver.set_problem(state, &problem); -/// while solver.state().unwrap().t <= t { -/// solver.step().unwrap(); -/// } -/// let y = solver.interpolate(t); -/// ``` -/// -pub struct OdeSolverEquations< - M, - Rhs, - Init, - Mass = UnitCallable, - Root = UnitCallable, - Out = UnitCallable, -> where +pub struct OdeSolverEquations +where M: Matrix, { rhs: Rhs, @@ -359,7 +372,7 @@ pub struct OdeSolverEquations< root: Option, init: Init, out: Option, - p: Rc, + p: M::V, } impl OdeSolverEquations @@ -369,11 +382,11 @@ where #[allow(clippy::too_many_arguments)] pub fn new( rhs: Rhs, + init: Init, mass: Option, root: Option, - init: Init, out: Option, - p: Rc, + p: M::V, ) -> Self { Self { rhs, @@ -384,6 +397,12 @@ where p, } } + fn params_mut(&mut self) -> &mut M::V { + &mut self.p + } + fn params(&self) -> &M::V { + &self.p + } } impl Op for OdeSolverEquations @@ -410,65 +429,68 @@ where fn statistics(&self) -> crate::op::OpStatistics { self.rhs.statistics() } - fn set_params(&mut self, p: Rc) { - self.rhs.set_params(p.clone()); - self.init.set_params(p.clone()); - if let Some(mass) = self.mass.as_mut() { - mass.set_params(p.clone()); - } - if let Some(root) = self.root.as_mut() { - root.set_params(p.clone()); - } - - if let Some(out) = self.out.as_mut() { - out.set_params(p.clone()); - } - self.p = p; - } } impl<'a, M, Rhs, Init, Mass, Root, Out> OdeEquationsRef<'a> for OdeSolverEquations where M: Matrix, - Rhs: NonLinearOp, - Mass: LinearOp, - Root: NonLinearOp, - Init: ConstantOp, - Out: NonLinearOp, + Rhs: Op, + Init: Op, + Mass: Op, + Root: Op, + Out: Op, + ParameterisedOp<'a, Rhs>: NonLinearOp, + ParameterisedOp<'a, Init>: ConstantOp, + ParameterisedOp<'a, Mass>: LinearOp, + ParameterisedOp<'a, Root>: NonLinearOp, + ParameterisedOp<'a, Out>: NonLinearOp, { - type Rhs = &'a Rhs; - type Mass = &'a Mass; - type Root = &'a Root; - type Init = &'a Init; - type Out = &'a Out; + type Rhs = ParameterisedOp<'a, Rhs>; + type Mass = ParameterisedOp<'a, Mass>; + type Root = ParameterisedOp<'a, Root>; + type Init = ParameterisedOp<'a, Init>; + type Out = ParameterisedOp<'a, Out>; } impl OdeEquations for OdeSolverEquations where M: Matrix, - Rhs: NonLinearOp, - Mass: LinearOp, - Root: NonLinearOp, - Init: ConstantOp, - Out: NonLinearOp, + Rhs: Op, + Init: Op, + Mass: Op, + Root: Op, + Out: Op, + for<'a> ParameterisedOp<'a, Rhs>: NonLinearOp, + for<'a> ParameterisedOp<'a, Init>: ConstantOp, + for<'a> ParameterisedOp<'a, Mass>: LinearOp, + for<'a> ParameterisedOp<'a, Root>: NonLinearOp, + for<'a> ParameterisedOp<'a, Out>: NonLinearOp, { - fn rhs(&self) -> &Rhs { - &self.rhs + fn rhs(&self) -> ParameterisedOp<'_, Rhs> { + ParameterisedOp::new(&self.rhs, self.params()) } - fn mass(&self) -> Option<&Mass> { - self.mass.as_ref() + fn mass(&self) -> Option> { + self.mass + .as_ref() + .map(|mass| ParameterisedOp::new(mass, self.params())) } - fn root(&self) -> Option<&Root> { - self.root.as_ref() + fn root(&self) -> Option> { + self.root + .as_ref() + .map(|root| ParameterisedOp::new(root, self.params())) } - fn init(&self) -> &Init { - &self.init + fn init(&self) -> ParameterisedOp<'_, Init> { + ParameterisedOp::new(&self.init, self.params()) } - - fn out(&self) -> Option<&Out> { - self.out.as_ref() + fn out(&self) -> Option> { + self.out + .as_ref() + .map(|out| ParameterisedOp::new(out, self.params())) + } + fn set_params(&mut self, p: &Self::V) { + self.params_mut().copy_from(p); } } diff --git a/src/ode_solver/jacobian_update.rs b/src/ode_solver/jacobian_update.rs index 24e8c0cd..beeed30a 100644 --- a/src/ode_solver/jacobian_update.rs +++ b/src/ode_solver/jacobian_update.rs @@ -8,6 +8,7 @@ pub enum SolverState { Checkpoint, } +#[derive(Clone)] pub struct JacobianUpdate { steps_since_jacobian_eval: usize, steps_since_rhs_jacobian_eval: usize, diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index 51176845..0e44ad55 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -6,14 +6,12 @@ use crate::{ matrix::default_solver::DefaultSolver, ode_solver_error, scalar::Scalar, - AdjointContext, AdjointEquations, Checkpointing, DefaultDenseMatrix, DenseMatrix, Matrix, - NewtonNonlinearSolver, NonLinearOp, OdeEquations, OdeEquationsAdjoint, OdeEquationsSens, - OdeSolverProblem, OdeSolverState, Op, SensEquations, StateRef, StateRefMut, Vector, - VectorViewMut, + AdjointContext, AdjointEquations, AugmentedOdeEquations, Checkpointing, DefaultDenseMatrix, + DenseMatrix, HermiteInterpolator, LinearSolver, Matrix, NonLinearOp, OdeEquations, + OdeEquationsAdjoint, OdeEquationsSens, OdeSolverProblem, OdeSolverState, Op, SensEquations, + StateRef, StateRefMut, Vector, VectorViewMut, }; -use super::checkpointing::HermiteInterpolator; - #[derive(Debug, PartialEq)] pub enum OdeSolverStopReason { InternalTimestep, @@ -32,54 +30,47 @@ pub enum OdeSolverStopReason { /// ``` /// use diffsol::{ OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeEquationsImplicit, DefaultSolver }; /// -/// fn solve_ode(solver: &mut impl OdeSolverMethod, problem: &OdeSolverProblem, t: Eqn::T) -> Eqn::V +/// fn solve_ode<'a, Eqn>(solver: &mut impl OdeSolverMethod<'a, Eqn>, t: Eqn::T) -> Eqn::V /// where -/// Eqn: OdeEquationsImplicit, +/// Eqn: OdeEquationsImplicit + 'a, /// Eqn::M: DefaultSolver, /// { -/// let state = OdeSolverState::new(problem, solver).unwrap(); -/// solver.set_problem(state, problem); -/// while solver.state().unwrap().t <= t { +/// while solver.state().t <= t { /// solver.step().unwrap(); /// } /// solver.interpolate(t).unwrap() /// } /// ``` -pub trait OdeSolverMethod +pub trait OdeSolverMethod<'a, Eqn: OdeEquations>: Clone where Self: Sized, + Eqn: 'a, { type State: OdeSolverState; - /// Get the current problem if it has been set - fn problem(&self) -> Option<&OdeSolverProblem>; - - /// Set the problem to solve, this performs any initialisation required by the solver. Call this before calling `step` or `solve`. - /// The solver takes ownership of the initial state given by `state`, this is assumed to be consistent with any algebraic constraints, - /// and the time step `h` is assumed to be set appropriately for the problem - fn set_problem( - &mut self, - state: Self::State, - problem: &OdeSolverProblem, - ) -> Result<(), DiffsolError>; + /// Get the current problem + fn problem(&self) -> &'a OdeSolverProblem; /// Take a checkpoint of the current state of the solver, returning it to the user. This is useful if you want to use this /// state in another solver or problem but want to keep this solver active. If you don't need to use this solver again, you can use `take_state` instead. /// Note that this will force a reinitialisation of the internal Jacobian for the solver, if it has one. - fn checkpoint(&mut self) -> Result; + fn checkpoint(&mut self) -> Self::State; + + /// Replace the current state of the solver with a new state. + fn set_state(&mut self, state: Self::State); /// Take the current state of the solver, if it exists, returning it to the user. This is useful if you want to use this /// state in another solver or problem. Note that this will unset the current problem and solver state, so you will need to call /// `set_problem` again before calling `step` or `solve`. - fn take_state(&mut self) -> Option; + fn into_state(self) -> Self::State; - /// Get the current state of the solver, if it exists - fn state(&self) -> Option>; + /// Get the current state of the solver + fn state(&self) -> StateRef; - /// Get a mutable reference to the current state of the solver, if it exists + /// Get a mutable reference to the current state of the solver /// Note that calling this will cause the next call to `step` to perform some reinitialisation to take into /// account the mutated state, this could be expensive for multi-step methods. - fn state_mut(&mut self) -> Option>; + fn state_mut(&mut self) -> StateRefMut; /// Step the solution forward by one step, altering the internal state of the solver. /// The return value is a `Result` containing the reason for stopping the solver, possible reasons are: @@ -110,23 +101,26 @@ where #[allow(clippy::type_complexity)] fn solve( &mut self, - problem: &OdeSolverProblem, - state: Self::State, final_time: Eqn::T, ) -> Result<(::M, Vec), DiffsolError> where - Eqn::M: DefaultSolver, Eqn::V: DefaultDenseMatrix, Self: Sized, { - self.set_problem(state, problem)?; let mut ret_t = Vec::new(); let mut ret_y = Vec::new(); - let mut write_out = |t: Eqn::T, y: &Eqn::V, g: &Eqn::V| { + fn write_out( + p: &OdeSolverProblem, + ret_y: &mut Vec, + ret_t: &mut Vec, + t: Eqn::T, + y: &Eqn::V, + g: &Eqn::V, + ) { ret_t.push(t); - match problem.eqn.out() { + match p.eqn.out() { Some(out) => { - if problem.integrate_out { + if p.integrate_out { ret_y.push(g.clone()); } else { ret_y.push(out.call(y, t)); @@ -134,28 +128,37 @@ where } None => ret_y.push(y.clone()), } - }; + } // do the main loop write_out( - self.state().unwrap().t, - self.state().unwrap().y, - self.state().unwrap().g, + self.problem(), + &mut ret_y, + &mut ret_t, + self.state().t, + self.state().y, + self.state().g, ); self.set_stop_time(final_time)?; while self.step()? != OdeSolverStopReason::TstopReached { write_out( - self.state().unwrap().t, - self.state().unwrap().y, - self.state().unwrap().g, + self.problem(), + &mut ret_y, + &mut ret_t, + self.state().t, + self.state().y, + self.state().g, ); } // store the final step write_out( - self.state().unwrap().t, - self.state().unwrap().y, - self.state().unwrap().g, + self.problem(), + &mut ret_y, + &mut ret_t, + self.state().t, + self.state().y, + self.state().g, ); let ntimes = ret_t.len(); let nrows = ret_y[0].len(); @@ -171,54 +174,60 @@ where /// After the solver has finished, the internal state of the solver is at time `t_eval[t_eval.len()-1]`. fn solve_dense( &mut self, - problem: &OdeSolverProblem, - state: Self::State, t_eval: &[Eqn::T], ) -> Result<::M, DiffsolError> where - Eqn::M: DefaultSolver, Eqn::V: DefaultDenseMatrix, Self: Sized, { - self.set_problem(state, problem)?; - let nrows = if problem.eqn.out().is_some() { - problem.eqn.out().unwrap().nout() + let nrows = if self.problem().eqn.out().is_some() { + self.problem().eqn.out().unwrap().nout() } else { - problem.eqn.rhs().nstates() + self.problem().eqn.rhs().nstates() }; let mut ret = <::M as Matrix>::zeros(nrows, t_eval.len()); // check t_eval is increasing and all values are greater than or equal to the current time - let t0 = self.state().unwrap().t; + let t0 = self.state().t; if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) { return Err(ode_solver_error!(InvalidTEval)); } - let mut write_out = |i: usize, y: Option<&Eqn::V>, g: Option<&Eqn::V>| { - let mut y_out = ret.column_mut(i); + fn write_out( + p: &OdeSolverProblem, + y_out: &mut ::M, + t_eval: &[Eqn::T], + i: usize, + y: Option<&Eqn::V>, + g: Option<&Eqn::V>, + ) where + Eqn::V: DefaultDenseMatrix, + { + let mut y_out = y_out.column_mut(i); if let Some(g) = g { y_out.copy_from(g); } else if let Some(y) = y { - match problem.eqn.out() { + match p.eqn.out() { Some(out) => y_out.copy_from(&out.call(y, t_eval[i])), None => y_out.copy_from(y), } } - }; + } // do loop self.set_stop_time(t_eval[t_eval.len() - 1])?; let mut step_reason = OdeSolverStopReason::InternalTimestep; for (i, t) in t_eval.iter().take(t_eval.len() - 1).enumerate() { - while self.state().unwrap().t < *t { + while self.state().t < *t { step_reason = self.step()?; } - if problem.integrate_out { + if self.problem().integrate_out { let g = self.interpolate_out(*t)?; - write_out(i, None, Some(&g)); + + write_out(self.problem(), &mut ret, t_eval, i, None, Some(&g)); } else { let y = self.interpolate(*t)?; - write_out(i, Some(&y), None); + write_out(self.problem(), &mut ret, t_eval, i, Some(&y), None); } } @@ -226,10 +235,24 @@ where while step_reason != OdeSolverStopReason::TstopReached { step_reason = self.step()?; } - if problem.integrate_out { - write_out(t_eval.len() - 1, None, Some(self.state().unwrap().g)); + if self.problem().integrate_out { + write_out( + self.problem(), + &mut ret, + t_eval, + t_eval.len() - 1, + None, + Some(self.state().g), + ); } else { - write_out(t_eval.len() - 1, Some(self.state().unwrap().y), None); + write_out( + self.problem(), + &mut ret, + t_eval, + t_eval.len() - 1, + Some(self.state().y), + None, + ); } Ok(ret) } @@ -241,79 +264,75 @@ where /// the ith element is the sensitivities of the ith element of `g` with respect to the /// parameters. #[allow(clippy::type_complexity)] - fn solve_adjoint( + fn solve_adjoint>( mut self, - problem: &OdeSolverProblem, - state: Self::State, final_time: Eqn::T, max_steps_between_checkpoints: Option, ) -> Result<(Eqn::V, Vec), DiffsolError> where - Self: AdjointOdeSolverMethod, + Self: AdjointOdeSolverMethod<'a, Eqn>, Eqn: OdeEquationsAdjoint, Eqn::M: DefaultSolver, Eqn::V: DefaultDenseMatrix, Self: Sized, { - if problem.eqn.out().is_none() { + if self.problem().eqn.out().is_none() { return Err(ode_solver_error!( Other, "Cannot solve adjoint without output function" )); } - if !problem.integrate_out { + if !self.problem().integrate_out { return Err(ode_solver_error!( Other, "Cannot solve adjoint without integrating out" )); } let max_steps_between_checkpoints = max_steps_between_checkpoints.unwrap_or(500); - self.set_problem(state, problem)?; - let t0 = self.state().unwrap().t; + let t0 = self.state().t; let mut ts = vec![t0]; - let mut ys = vec![self.state().unwrap().y.clone()]; - let mut ydots = vec![self.state().unwrap().dy.clone()]; + let mut ys = vec![self.state().y.clone()]; + let mut ydots = vec![self.state().dy.clone()]; // do the main forward solve, saving checkpoints self.set_stop_time(final_time)?; let mut nsteps = 0; - let mut checkpoints = vec![self.checkpoint().unwrap()]; + let mut checkpoints = vec![self.checkpoint()]; while self.step()? != OdeSolverStopReason::TstopReached { - ts.push(self.state().unwrap().t); - ys.push(self.state().unwrap().y.clone()); - ydots.push(self.state().unwrap().dy.clone()); + ts.push(self.state().t); + ys.push(self.state().y.clone()); + ydots.push(self.state().dy.clone()); nsteps += 1; if nsteps > max_steps_between_checkpoints { - checkpoints.push(self.checkpoint().unwrap()); + checkpoints.push(self.checkpoint()); nsteps = 0; ts.clear(); ys.clear(); ydots.clear(); } } - ts.push(self.state().unwrap().t); - ys.push(self.state().unwrap().y.clone()); - ydots.push(self.state().unwrap().dy.clone()); - checkpoints.push(self.checkpoint().unwrap()); + ts.push(self.state().t); + ys.push(self.state().y.clone()); + ydots.push(self.state().dy.clone()); + checkpoints.push(self.checkpoint()); // save integrateed out function - let g = self.state().unwrap().g.clone(); + let g = self.state().g.clone(); // construct the adjoint solver let last_segment = HermiteInterpolator::new(ys, ydots, ts); - let mut adjoint_solver = self.into_adjoint_solver(checkpoints, last_segment)?; + let adjoint_aug_eqn = self.adjoint_equations(checkpoints, last_segment)?; + let mut adjoint_solver = self.default_adjoint_solver::(adjoint_aug_eqn)?; // solve the adjoint problem adjoint_solver.set_stop_time(t0).unwrap(); while adjoint_solver.step()? != OdeSolverStopReason::TstopReached {} // correct the adjoint solution for the initial conditions - let adjoint_problem = adjoint_solver.problem().unwrap().clone(); - let mut state = adjoint_solver.take_state().unwrap(); + let (mut state, aug_eqn) = adjoint_solver.into_state_and_eqn(); + let aug_eqn = aug_eqn.unwrap(); let state_mut = state.as_mut(); - adjoint_problem - .eqn - .correct_sg_for_init(t0, state_mut.s, state_mut.sg); + aug_eqn.correct_sg_for_init(t0, state_mut.s, state_mut.sg); // return the solution Ok((g, state_mut.sg.to_owned())) @@ -326,8 +345,6 @@ where #[allow(clippy::type_complexity)] fn solve_dense_sensitivities( &mut self, - problem: &OdeSolverProblem, - state: Self::State, t_eval: &[Eqn::T], ) -> Result< ( @@ -337,29 +354,28 @@ where DiffsolError, > where - Self: SensitivitiesOdeSolverMethod, + Self: SensitivitiesOdeSolverMethod<'a, Eqn>, Eqn: OdeEquationsSens, Eqn::M: DefaultSolver, Eqn::V: DefaultDenseMatrix, Self: Sized, { - if problem.integrate_out { + if self.problem().integrate_out { return Err(ode_solver_error!( Other, "Cannot integrate out when solving for sensitivities" )); } - self.set_problem_with_sensitivities(state, problem)?; - let nrows = problem.eqn.rhs().nstates(); + let nrows = self.problem().eqn.rhs().nstates(); let mut ret = <::M as Matrix>::zeros(nrows, t_eval.len()); let mut ret_sens = vec![ <::M as Matrix>::zeros(nrows, t_eval.len()); - problem.eqn.rhs().nparams() + self.problem().eqn.rhs().nparams() ]; // check t_eval is increasing and all values are greater than or equal to the current time - let t0 = self.state().unwrap().t; + let t0 = self.state().t; if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) { return Err(ode_solver_error!(InvalidTEval)); } @@ -368,7 +384,7 @@ where self.set_stop_time(t_eval[t_eval.len() - 1])?; let mut step_reason = OdeSolverStopReason::InternalTimestep; for (i, t) in t_eval.iter().take(t_eval.len() - 1).enumerate() { - while self.state().unwrap().t < *t { + while self.state().t < *t { step_reason = self.step()?; } let y = self.interpolate(*t)?; @@ -383,9 +399,9 @@ where while step_reason != OdeSolverStopReason::TstopReached { step_reason = self.step()?; } - let y = self.state().unwrap().y; + let y = self.state().y; ret.column_mut(t_eval.len() - 1).copy_from(y); - let s = self.state().unwrap().s; + let s = self.state().s; for (j, s_j) in s.iter().enumerate() { ret_sens[j].column_mut(t_eval.len() - 1).copy_from(s_j); } @@ -393,97 +409,63 @@ where } } -pub trait AugmentedOdeSolverMethod: OdeSolverMethod +pub trait AugmentedOdeSolverMethod<'a, Eqn, AugmentedEqn>: OdeSolverMethod<'a, Eqn> where - Eqn: OdeEquations, + Eqn: OdeEquations + 'a, + AugmentedEqn: AugmentedOdeEquations, { - fn set_augmented_problem( - &mut self, - state: Self::State, - ode_problem: &OdeSolverProblem, - augmented_eqn: AugmentedEqn, - ) -> Result<(), DiffsolError>; + fn into_state_and_eqn(self) -> (Self::State, Option); } -pub trait SensitivitiesOdeSolverMethod: - AugmentedOdeSolverMethod> +pub trait SensitivitiesOdeSolverMethod<'a, Eqn>: + AugmentedOdeSolverMethod<'a, Eqn, SensEquations<'a, Eqn>> where - Eqn: OdeEquationsSens, + Eqn: OdeEquationsSens + 'a, { - fn set_problem_with_sensitivities( - &mut self, - state: Self::State, - problem: &OdeSolverProblem, - ) -> Result<(), DiffsolError> { - let augmented_eqn = SensEquations::new(problem); - self.set_augmented_problem(state, problem, augmented_eqn) - } } -pub trait AdjointOdeSolverMethod: OdeSolverMethod +pub trait AdjointOdeSolverMethod<'a, Eqn>: OdeSolverMethod<'a, Eqn> where - Eqn: OdeEquationsAdjoint, + Eqn: OdeEquationsAdjoint + 'a, + Self: 'a, { - type AdjointSolver: AugmentedOdeSolverMethod< - AdjointEquations, - AdjointEquations, + type DefaultAdjointSolver: AugmentedOdeSolverMethod< + 'a, + Eqn, + AdjointEquations<'a, Eqn, Self>, State = Self::State, >; - fn new_adjoint_solver(&self) -> Self::AdjointSolver; - - fn into_adjoint_solver( + fn default_adjoint_solver>( self, + aug_eqn: AdjointEquations<'a, Eqn, Self>, + ) -> Result; + + fn adjoint_equations( + &self, checkpoints: Vec, last_segment: HermiteInterpolator, - ) -> Result + ) -> Result, DiffsolError> where Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, { - // create the adjoint solver - let mut adjoint_solver = self.new_adjoint_solver(); - - let problem = self - .problem() - .ok_or(ode_solver_error!(ProblemNotSet))? - .clone(); - let t = self.state().unwrap().t; - let h = self.state().unwrap().h; + let problem = self.problem(); + let checkpointer_solver = self.clone(); // construct checkpointing - let checkpointer = - Checkpointing::new(self, checkpoints.len() - 2, checkpoints, Some(last_segment)); + let checkpointer = Checkpointing::new( + checkpointer_solver, + checkpoints.len() - 2, + checkpoints, + Some(last_segment), + ); // construct adjoint equations and problem let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer))); - let new_eqn = AdjointEquations::new(&problem, context.clone(), false); - let mut new_augmented_eqn = AdjointEquations::new(&problem, context, true); - let adj_problem = OdeSolverProblem { - eqn: Rc::new(new_eqn), - rtol: problem.rtol, - atol: problem.atol, - t0: t, - h0: -h, - integrate_out: false, - sens_rtol: None, - sens_atol: None, - out_rtol: None, - out_atol: None, - param_rtol: None, - param_atol: None, - }; + let new_augmented_eqn = AdjointEquations::new(problem, context, true); - // initialise adjoint state - let mut state = - Self::State::new_without_initialise_augmented(&adj_problem, &mut new_augmented_eqn)?; - let mut init_nls = - NewtonNonlinearSolver::::LS>::default(); - let new_augmented_eqn = - state.set_consistent_augmented(&adj_problem, new_augmented_eqn, &mut init_nls)?; - - // set the adjoint problem - adjoint_solver.set_augmented_problem(state, &adj_problem, new_augmented_eqn)?; - Ok(adjoint_solver) + Ok(new_augmented_eqn) } } @@ -494,31 +476,30 @@ mod test { exponential_decay_problem, exponential_decay_problem_adjoint, exponential_decay_problem_sens, }, - scale, Bdf, OdeSolverMethod, OdeSolverState, Vector, + scale, NalgebraLU, OdeSolverMethod, Vector, }; #[test] fn test_solve() { - let mut s = Bdf::default(); let (problem, _soln) = exponential_decay_problem::>(false); + let mut s = problem.bdf::>().unwrap(); let k = 0.1; let y0 = nalgebra::DVector::from_vec(vec![1.0, 1.0]); let expect = |t: f64| &y0 * scale(f64::exp(-k * t)); - let state = OdeSolverState::new(&problem, &s).unwrap(); - let (y, t) = s.solve(&problem, state, 10.0).unwrap(); + let (y, t) = s.solve(10.0).unwrap(); assert!((t[0] - 0.0).abs() < 1e-10); assert!((t[t.len() - 1] - 10.0).abs() < 1e-10); for (i, t_i) in t.iter().enumerate() { let y_i = y.column(i).into_owned(); - y_i.assert_eq_norm(&expect(*t_i), problem.atol.as_ref(), problem.rtol, 15.0); + y_i.assert_eq_norm(&expect(*t_i), &problem.atol, problem.rtol, 15.0); } } #[test] fn test_solve_integrate_out() { - let mut s = Bdf::default(); let (problem, _soln) = exponential_decay_problem_adjoint::>(); + let mut s = problem.bdf::>().unwrap(); let k = 0.1; let y0 = nalgebra::DVector::from_vec(vec![1.0, 1.0]); @@ -530,55 +511,49 @@ mod test { 3.0 * g[0] + 4.0 * g[1], ]) }; - let state = OdeSolverState::new(&problem, &s).unwrap(); - let (y, t) = s.solve(&problem, state, 10.0).unwrap(); + let (y, t) = s.solve(10.0).unwrap(); for (i, t_i) in t.iter().enumerate() { let y_i = y.column(i).into_owned(); - y_i.assert_eq_norm(&expect(*t_i), problem.atol.as_ref(), problem.rtol, 15.0); + y_i.assert_eq_norm(&expect(*t_i), &problem.atol, problem.rtol, 15.0); } } #[test] fn test_dense_solve() { - let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem::>(false); + let mut s = problem.bdf::>().unwrap(); - let state = OdeSolverState::new(&problem, &s).unwrap(); let t_eval = soln.solution_points.iter().map(|p| p.t).collect::>(); - let y = s.solve_dense(&problem, state, t_eval.as_slice()).unwrap(); + let y = s.solve_dense(t_eval.as_slice()).unwrap(); for (i, soln_pt) in soln.solution_points.iter().enumerate() { let y_i = y.column(i).into_owned(); - y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0); + y_i.assert_eq_norm(&soln_pt.state, &problem.atol, problem.rtol, 15.0); } } #[test] fn test_dense_solve_integrate_out() { - let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem_adjoint::>(); + let mut s = problem.bdf::>().unwrap(); - let state = OdeSolverState::new(&problem, &s).unwrap(); let t_eval = soln.solution_points.iter().map(|p| p.t).collect::>(); - let y = s.solve_dense(&problem, state, t_eval.as_slice()).unwrap(); + let y = s.solve_dense(t_eval.as_slice()).unwrap(); for (i, soln_pt) in soln.solution_points.iter().enumerate() { let y_i = y.column(i).into_owned(); - y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0); + y_i.assert_eq_norm(&soln_pt.state, &problem.atol, problem.rtol, 15.0); } } #[test] fn test_dense_solve_sensitivities() { - let mut s = Bdf::with_sensitivities(); let (problem, soln) = exponential_decay_problem_sens::>(false); + let mut s = problem.bdf_sens::>().unwrap(); - let state = OdeSolverState::new_with_sensitivities(&problem, &s).unwrap(); let t_eval = soln.solution_points.iter().map(|p| p.t).collect::>(); - let (y, sens) = s - .solve_dense_sensitivities(&problem, state, t_eval.as_slice()) - .unwrap(); + let (y, sens) = s.solve_dense_sensitivities(t_eval.as_slice()).unwrap(); for (i, soln_pt) in soln.solution_points.iter().enumerate() { let y_i = y.column(i).into_owned(); - y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0); + y_i.assert_eq_norm(&soln_pt.state, &problem.atol, problem.rtol, 15.0); } for (j, soln_pts) in soln.sens_solution_points.unwrap().iter().enumerate() { for (i, soln_pt) in soln_pts.iter().enumerate() { @@ -595,12 +570,13 @@ mod test { #[test] fn test_solve_adjoint() { - let s = Bdf::default(); let (problem, soln) = exponential_decay_problem_adjoint::>(); + let s = problem.bdf::>().unwrap(); - let state = OdeSolverState::new(&problem, &s).unwrap(); let final_time = soln.solution_points[soln.solution_points.len() - 1].t; - let (g, gs_adj) = s.solve_adjoint(&problem, state, final_time, None).unwrap(); + let (g, gs_adj) = s + .solve_adjoint::>(final_time, None) + .unwrap(); g.assert_eq_norm( &soln.solution_points[soln.solution_points.len() - 1].state, problem.out_atol.as_ref().unwrap(), diff --git a/src/ode_solver/mod.rs b/src/ode_solver/mod.rs index 9265b165..f5e59a3b 100644 --- a/src/ode_solver/mod.rs +++ b/src/ode_solver/mod.rs @@ -17,34 +17,31 @@ pub mod test_models; #[cfg(feature = "diffsl")] pub mod diffsl; -#[cfg(feature = "sundials")] -pub mod sundials; - #[cfg(test)] mod tests { use std::rc::Rc; use self::problem::OdeSolverSolution; use checkpointing::HermiteInterpolator; - use method::{AdjointOdeSolverMethod, SensitivitiesOdeSolverMethod}; use nalgebra::ComplexField; use super::*; use crate::matrix::Matrix; use crate::op::unit::UnitCallable; + use crate::op::ParameterisedOp; + use crate::{ + op::OpStatistics, AdjointOdeSolverMethod, AugmentedOdeSolverMethod, CraneliftModule, + NonLinearOpJacobian, OdeBuilder, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, + OdeEquationsRef, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, + }; use crate::{ - op::OpStatistics, CraneliftModule, DenseMatrix, DiffSl, MatrixCommon, NonLinearOpJacobian, - OdeBuilder, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsRef, - OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, - VectorView, + ConstantOp, DefaultDenseMatrix, DefaultSolver, LinearSolver, NonLinearOp, Op, Vector, }; - use crate::{ConstantOp, DefaultDenseMatrix, DefaultSolver, NonLinearOp, Op, Vector}; use num_traits::One; use num_traits::Zero; - pub fn test_ode_solver( - method: &mut impl SensitivitiesOdeSolverMethod, - problem: &OdeSolverProblem, + pub fn test_ode_solver<'a, M, Eqn, Method>( + method: &mut Method, solution: OdeSolverSolution, override_tol: Option, use_tstop: bool, @@ -52,19 +49,11 @@ mod tests { ) -> Eqn::V where M: Matrix, - Eqn: OdeEquationsSens, + Eqn: OdeEquations + 'a, Eqn::M: DefaultSolver, + Method: OdeSolverMethod<'a, Eqn>, { - if solve_for_sensitivities { - let state = OdeSolverState::new_with_sensitivities(problem, method).unwrap(); - method - .set_problem_with_sensitivities(state, problem) - .unwrap(); - } else { - let state = OdeSolverState::new(problem, method).unwrap(); - method.set_problem(state, problem).unwrap(); - } - let have_root = problem.eqn.as_ref().root().is_some(); + let have_root = method.problem().eqn.root().is_some(); for (i, point) in solution.solution_points.iter().enumerate() { let (soln, sens_soln) = if use_tstop { match method.set_stop_time(point.t) { @@ -72,24 +61,18 @@ mod tests { match method.step() { Ok(OdeSolverStopReason::RootFound(_)) => { assert!(have_root); - return method.state().unwrap().y.clone(); + return method.state().y.clone(); } Ok(OdeSolverStopReason::TstopReached) => { - break ( - method.state().unwrap().y.clone(), - method.state().unwrap().s.to_vec(), - ); + break (method.state().y.clone(), method.state().s.to_vec()); } _ => (), } }, - Err(_) => ( - method.state().unwrap().y.clone(), - method.state().unwrap().s.to_vec(), - ), + Err(_) => (method.state().y.clone(), method.state().s.to_vec()), } } else { - while method.state().unwrap().t.abs() < point.t.abs() { + while method.state().t.abs() < point.t.abs() { if let OdeSolverStopReason::RootFound(t) = method.step().unwrap() { assert!(have_root); return method.interpolate(t).unwrap(); @@ -99,7 +82,7 @@ mod tests { let sens_soln = method.interpolate_sens(point.t).unwrap(); (soln, sens_soln) }; - let soln = if let Some(out) = problem.eqn.out() { + let soln = if let Some(out) = method.problem().eqn.out() { out.call(&soln, point.t) } else { soln @@ -112,11 +95,11 @@ mod tests { if let Some(override_tol) = override_tol { soln.assert_eq_st(&point.state, override_tol); } else { - let (rtol, atol) = if problem.eqn.out().is_some() { + let (rtol, atol) = if method.problem().eqn.out().is_some() { // problem rtol and atol is on the state, so just use solution tolerance here (solution.rtol, &solution.atol) } else { - (problem.rtol, problem.atol.as_ref()) + (method.problem().rtol, &method.problem().atol) }; let error = soln.clone() - &point.state; let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt(); @@ -147,115 +130,38 @@ mod tests { } } } - method.state().unwrap().y.clone() + method.state().y.clone() } - pub fn test_ode_solver_no_sens( - method: &mut impl OdeSolverMethod, - problem: &OdeSolverProblem, - solution: OdeSolverSolution, - override_tol: Option, - use_tstop: bool, - ) -> Eqn::V - where - M: Matrix, - Eqn: OdeEquationsImplicit, - Eqn::M: DefaultSolver, - { - let state = OdeSolverState::new(problem, method).unwrap(); - method.set_problem(state, problem).unwrap(); - let have_root = problem.eqn.as_ref().root().is_some(); - for point in solution.solution_points.iter() { - let soln = if use_tstop { - match method.set_stop_time(point.t) { - Ok(_) => loop { - match method.step() { - Ok(OdeSolverStopReason::RootFound(_)) => { - assert!(have_root); - return method.state().unwrap().y.clone(); - } - Ok(OdeSolverStopReason::TstopReached) => { - break method.state().unwrap().y.clone(); - } - _ => (), - } - }, - Err(_) => method.state().unwrap().y.clone(), - } - } else { - while method.state().unwrap().t.abs() < point.t.abs() { - if let OdeSolverStopReason::RootFound(t) = method.step().unwrap() { - assert!(have_root); - return method.interpolate(t).unwrap(); - } - } - method.interpolate(point.t).unwrap() - }; - let soln = if let Some(out) = problem.eqn.out() { - out.call(&soln, point.t) - } else { - soln - }; - assert_eq!( - soln.len(), - point.state.len(), - "soln.len() != point.state.len()" - ); - if let Some(override_tol) = override_tol { - soln.assert_eq_st(&point.state, override_tol); - } else { - let (rtol, atol) = if problem.eqn.out().is_some() { - // problem rtol and atol is on the state, so just use solution tolerance here - (solution.rtol, &solution.atol) - } else { - (problem.rtol, problem.atol.as_ref()) - }; - let error = soln.clone() - &point.state; - let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt(); - assert!( - error_norm < M::T::from(15.0), - "error_norm: {} at t = {}. soln: {:?}, expected: {:?}", - error_norm, - point.t, - soln, - point.state - ); - } - } - method.state().unwrap().y.clone() - } - - pub fn test_ode_solver_adjoint( + pub fn test_ode_solver_adjoint<'a, 'b, LS, M, Eqn, Method>( mut method: Method, - problem: &OdeSolverProblem, solution: OdeSolverSolution, - ) -> Method::AdjointSolver - where + ) where M: Matrix, - Method: AdjointOdeSolverMethod, - Eqn: OdeEquationsAdjoint, + Method: AdjointOdeSolverMethod<'a, Eqn>, + Eqn: OdeEquationsAdjoint + 'a, + Eqn::V: DefaultDenseMatrix, Eqn::M: DefaultSolver, + LS: LinearSolver, { - let state = OdeSolverState::new(problem, &method).unwrap(); - method.set_problem(state, problem).unwrap(); let t0 = solution.solution_points.first().unwrap().t; let t1 = solution.solution_points.last().unwrap().t; method.set_stop_time(t1).unwrap(); let mut nsteps = 0; let (rtol, atol) = (solution.rtol, &solution.atol); - let mut checkpoints = vec![method.checkpoint().unwrap()]; + let mut checkpoints = vec![method.checkpoint()]; let mut ts = Vec::new(); let mut ys = Vec::new(); let mut ydots = Vec::new(); for point in solution.solution_points.iter() { - while method.state().unwrap().t.abs() < point.t.abs() { - ts.push(method.state().unwrap().t); - ys.push(method.state().unwrap().y.clone()); - ydots.push(method.state().unwrap().dy.clone()); + while method.state().t.abs() < point.t.abs() { + ts.push(method.state().t); + ys.push(method.state().y.clone()); + ydots.push(method.state().dy.clone()); method.step().unwrap(); nsteps += 1; - if nsteps > 50 && method.state().unwrap().t.abs() < t1.abs() { - checkpoints.push(method.checkpoint().unwrap()); + if nsteps > 50 && method.state().t.abs() < t1.abs() { + checkpoints.push(method.checkpoint()); nsteps = 0; ts.clear(); ys.clear(); @@ -275,35 +181,31 @@ mod tests { point.state ); } - ts.push(method.state().unwrap().t); - ys.push(method.state().unwrap().y.clone()); - ydots.push(method.state().unwrap().dy.clone()); - checkpoints.push(method.checkpoint().unwrap()); + ts.push(method.state().t); + ys.push(method.state().y.clone()); + ydots.push(method.state().dy.clone()); + checkpoints.push(method.checkpoint()); let last_segment = HermiteInterpolator::new(ys, ydots, ts); + + let problem = method.problem(); + let adjoint_aug_eqn = method.adjoint_equations(checkpoints, last_segment).unwrap(); let mut adjoint_solver = method - .into_adjoint_solver(checkpoints, last_segment) + .default_adjoint_solver::(adjoint_aug_eqn) .unwrap(); - let y_expect = M::V::from_element(problem.eqn.rhs().nstates(), M::T::zero()); - adjoint_solver - .state() - .unwrap() - .y - .assert_eq_st(&y_expect, M::T::from(1e-9)); + let g_expect = M::V::from_element(problem.eqn.rhs().nparams(), M::T::zero()); - for i in 0..problem.eqn.out().unwrap().nout() { - adjoint_solver.state().unwrap().sg[i].assert_eq_st(&g_expect, M::T::from(1e-9)); + for sgi in adjoint_solver.state().sg.iter() { + sgi.assert_eq_st(&g_expect, M::T::from(1e-9)); } adjoint_solver.set_stop_time(t0).unwrap(); - while adjoint_solver.state().unwrap().t.abs() > t0 { + while adjoint_solver.state().t.abs() > t0 { adjoint_solver.step().unwrap(); } - let adjoint_problem = adjoint_solver.problem().unwrap().clone(); - let mut state = adjoint_solver.take_state().unwrap(); + let (mut state, aug_eqn) = adjoint_solver.into_state_and_eqn(); + let aug_eqn = aug_eqn.unwrap(); let state_mut = state.as_mut(); - adjoint_problem - .eqn - .correct_sg_for_init(t0, state_mut.s, state_mut.sg); + aug_eqn.correct_sg_for_init(t0, state_mut.s, state_mut.sg); let points = solution .sens_solution_points @@ -324,7 +226,6 @@ mod tests { point.state ); } - adjoint_solver } pub struct TestEqnInit { @@ -407,7 +308,6 @@ mod tests { type T = M::T; type V = M::V; type M = M; - fn set_params(&mut self, _p: Rc) {} fn nout(&self) -> usize { 1 } @@ -424,10 +324,10 @@ mod tests { impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn { type Rhs = &'a TestEqnRhs; - type Mass = &'a UnitCallable; - type Root = &'a UnitCallable; + type Mass = ParameterisedOp<'a, UnitCallable>; + type Root = ParameterisedOp<'a, UnitCallable>; type Init = &'a TestEqnInit; - type Out = &'a UnitCallable; + type Out = ParameterisedOp<'a, UnitCallable>; } impl OdeEquations for TestEqn { @@ -435,11 +335,11 @@ mod tests { &self.rhs } - fn mass(&self) -> Option<&UnitCallable> { + fn mass(&self) -> Option<>::Mass> { None } - fn root(&self) -> Option<&UnitCallable> { + fn root(&self) -> Option<>::Root> { None } @@ -447,16 +347,19 @@ mod tests { &self.init } - fn out(&self) -> Option<&UnitCallable> { + fn out(&self) -> Option<>::Out> { None } + fn set_params(&mut self, _p: &Self::V) { + unimplemented!() + } } - pub fn test_interpolate>>(mut s: Method) { - let problem = OdeSolverProblem::new( - Rc::new(TestEqn::new()), + pub fn test_problem() -> OdeSolverProblem> { + OdeSolverProblem::new( + TestEqn::new(), M::T::from(1e-6), - Rc::new(M::V::from_element(1, M::T::from(1e-6))), + M::V::from_element(1, M::T::from(1e-6)), None, None, None, @@ -467,64 +370,34 @@ mod tests { M::T::one(), false, ) - .unwrap(); - let state = Method::State::new_without_initialise(&problem).unwrap(); - s.set_problem(state.clone(), &problem).unwrap(); - let t0 = M::T::zero(); - let t1 = M::T::one(); + .unwrap() + } + + pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn>>(mut s: Method) { + let state = s.checkpoint(); + let t0 = state.as_ref().t; + let t1 = t0 + M::T::from(1e6); s.interpolate(t0) .unwrap() .assert_eq_st(state.as_ref().y, M::T::from(1e-9)); assert!(s.interpolate(t1).is_err()); s.step().unwrap(); - assert!(s.interpolate(s.state().unwrap().t).is_ok()); - assert!(s.interpolate(s.state().unwrap().t + t1).is_err()); - } - - pub fn test_no_set_problem>>(mut s: Method) { - assert!(s.state().is_none()); - assert!(s.problem().is_none()); - assert!(s.state().is_none()); - assert!(s.step().is_err()); - assert!(s.interpolate(M::T::one()).is_err()); + assert!(s.interpolate(s.state().t).is_ok()); + assert!(s.interpolate(s.state().t + t1).is_err()); } - pub fn test_state_mut>>(mut s: Method) { - let problem = OdeSolverProblem::new( - Rc::new(TestEqn::new()), - M::T::from(1e-6), - Rc::new(M::V::from_element(1, M::T::from(1e-6))), - None, - None, - None, - None, - None, - None, - M::T::zero(), - M::T::one(), - false, - ) - .unwrap(); - let state = Method::State::new_without_initialise(&problem).unwrap(); - s.set_problem(state.clone(), &problem).unwrap(); - let state2 = s.state().unwrap(); + pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn>>(mut s: Method) { + let state = s.checkpoint(); + let state2 = s.state(); state2.y.assert_eq_st(state.as_ref().y, M::T::from(1e-9)); - s.state_mut().unwrap().y[0] = M::T::from(std::f64::consts::PI); - assert_eq!( - s.state_mut().unwrap().y[0], - M::T::from(std::f64::consts::PI) - ); + s.state_mut().y[0] = M::T::from(std::f64::consts::PI); + assert_eq!(s.state_mut().y[0], M::T::from(std::f64::consts::PI)); } #[cfg(feature = "diffsl")] - pub fn test_ball_bounce(mut solver: Method) -> (Vec, Vec, Vec) - where - M: Matrix, - M: DefaultSolver, - M::V: DefaultDenseMatrix, - Method: OdeSolverMethod>, - { - let eqn = DiffSl::compile( + pub fn test_ball_bounce_problem>( + ) -> OdeSolverProblem> { + let eqn = crate::DiffSl::compile( " g { 9.81 } h { 10.0 } u_i { @@ -541,11 +414,18 @@ mod tests { ", ) .unwrap(); + OdeBuilder::::new().build_from_eqn(eqn).unwrap() + } + #[cfg(feature = "diffsl")] + pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec, Vec, Vec) + where + M: Matrix, + M: DefaultSolver, + M::V: DefaultDenseMatrix, + Method: OdeSolverMethod<'a, crate::DiffSl>, + { let e = 0.8; - let problem = OdeBuilder::new().build_from_eqn(eqn).unwrap(); - let state = OdeSolverState::new(&problem, &solver).unwrap(); - solver.set_problem(state, &problem).unwrap(); let final_time = 2.5; @@ -565,9 +445,9 @@ mod tests { y[0] = y[0].max(f64::EPSILON); // set the state to the updated state - solver.state_mut().unwrap().y.copy_from(&y); - solver.state_mut().unwrap().dy[0] = y[1]; - *solver.state_mut().unwrap().t = t; + solver.state_mut().y.copy_from(&y); + solver.state_mut().dy[0] = y[1]; + *solver.state_mut().t = t; break; } @@ -581,9 +461,9 @@ mod tests { let mut t = vec![]; for _ in 0..3 { let ret = solver.step(); - x.push(solver.state().unwrap().y[0]); - v.push(solver.state().unwrap().y[1]); - t.push(solver.state().unwrap().t); + x.push(solver.state().y[0]); + v.push(solver.state().y[1]); + t.push(solver.state().t); match ret { Ok(OdeSolverStopReason::InternalTimestep) => (), Ok(OdeSolverStopReason::RootFound(_)) => { @@ -596,110 +476,86 @@ mod tests { (x, v, t) } - pub fn test_checkpointing( + pub fn test_checkpointing<'a, M, Method, Eqn>( + soln: OdeSolverSolution, mut solver1: Method, mut solver2: Method, - problem: OdeSolverProblem, - soln: OdeSolverSolution, ) where M: Matrix + DefaultSolver, - Method: OdeSolverMethod, - Problem: OdeEquationsImplicit, + Method: OdeSolverMethod<'a, Eqn>, + Eqn: OdeEquationsImplicit + 'a, { - let state = OdeSolverState::new(&problem, &solver1).unwrap(); - solver1.set_problem(state, &problem).unwrap(); let half_i = soln.solution_points.len() / 2; let half_t = soln.solution_points[half_i].t; - while solver1.state().unwrap().t <= half_t { + while solver1.state().t <= half_t { solver1.step().unwrap(); } - let checkpoint = solver1.checkpoint().unwrap(); - solver2.set_problem(checkpoint, &problem).unwrap(); + let checkpoint = solver1.checkpoint(); + solver2.set_state(checkpoint); // carry on solving with both solvers, they should produce about the same results (probably might diverge a bit, but should always match the solution) for point in soln.solution_points.iter().skip(half_i + 1) { - while solver2.state().unwrap().t < point.t { + while solver2.state().t < point.t { solver1.step().unwrap(); solver2.step().unwrap(); - let time_error = (solver1.state().unwrap().t - solver2.state().unwrap().t).abs() - / (solver1.state().unwrap().t.abs() * problem.rtol + problem.atol[0]); + let time_error = (solver1.state().t - solver2.state().t).abs() + / (solver1.state().t.abs() * solver1.problem().rtol + + solver1.problem().atol[0]); assert!( time_error < M::T::from(20.0), "time_error: {} at t = {}", time_error, - solver1.state().unwrap().t + solver1.state().t ); - solver1.state().unwrap().y.assert_eq_norm( - solver2.state().unwrap().y, - &problem.atol, - problem.rtol, + solver1.state().y.assert_eq_norm( + solver2.state().y, + &solver1.problem().atol, + solver1.problem().rtol, M::T::from(20.0), ); } let soln = solver1.interpolate(point.t).unwrap(); - soln.assert_eq_norm(&point.state, &problem.atol, problem.rtol, M::T::from(15.0)); + soln.assert_eq_norm( + &point.state, + &solver1.problem().atol, + solver1.problem().rtol, + M::T::from(15.0), + ); let soln = solver2.interpolate(point.t).unwrap(); - soln.assert_eq_norm(&point.state, &problem.atol, problem.rtol, M::T::from(15.0)); - } - } - - pub fn test_param_sweep( - mut s: Method, - mut problem: OdeSolverProblem, - ps: Vec, - ) where - Method: OdeSolverMethod, - Eqn: OdeEquationsImplicit, - Eqn::M: DefaultSolver, - Eqn::V: DefaultDenseMatrix, - { - let mut old_soln = None; - for p in ps { - problem.set_params(p).unwrap(); - let state = OdeSolverState::new(&problem, &s).unwrap(); - let (ys, _ts) = s.solve(&problem, state, Eqn::T::from(10.0)).unwrap(); - // check that the new solution is different from the old one - if let Some(old_soln) = &mut old_soln { - let new_soln = ys.column(ys.ncols() - 1).into_owned(); - let diff = (new_soln - &*old_soln) - .squared_norm(old_soln, &problem.atol, problem.rtol) - .sqrt(); - assert!(diff > Eqn::T::from(1.0e-6), "diff: {}", diff); - } - old_soln = Some(ys.column(ys.ncols() - 1).into_owned()); - s.take_state().unwrap(); - assert!(s.problem().is_none()); + soln.assert_eq_norm( + &point.state, + &solver1.problem().atol, + solver1.problem().rtol, + M::T::from(15.0), + ); } } - pub fn test_state_mut_on_problem( + pub fn test_state_mut_on_problem<'a, Eqn, Method>( mut s: Method, - problem: OdeSolverProblem, soln: OdeSolverSolution, ) where - Eqn: OdeEquationsImplicit, - Method: OdeSolverMethod, - Eqn::M: DefaultSolver, + Eqn: OdeEquationsImplicit + 'a, + Method: OdeSolverMethod<'a, Eqn>, Eqn::V: DefaultDenseMatrix, { - // solve for a little bit - let state = OdeSolverState::new(&problem, &s).unwrap(); - s.solve(&problem, state, Eqn::T::from(1.0)).unwrap(); + // save state and solve for a little bit + let state = s.checkpoint(); + s.solve(Eqn::T::from(1.0)).unwrap(); // reinit using state_mut - let state = Method::State::new_without_initialise(&problem).unwrap(); - s.state_mut().unwrap().y.copy_from(state.as_ref().y); - *s.state_mut().unwrap().t = state.as_ref().t; + s.state_mut().y.copy_from(state.as_ref().y); + *s.state_mut().t = state.as_ref().t; // solve and check against solution for point in soln.solution_points.iter() { - while s.state().unwrap().t < point.t { + while s.state().t < point.t { s.step().unwrap(); } let soln = s.interpolate(point.t).unwrap(); let error = soln.clone() - &point.state; let error_norm = error - .squared_norm(&error, &problem.atol, problem.rtol) + .squared_norm(&error, &s.problem().atol, s.problem().rtol) .sqrt(); assert!( error_norm < Eqn::T::from(17.0), diff --git a/src/ode_solver/problem.rs b/src/ode_solver/problem.rs index 45d2128c..f2f74e21 100644 --- a/src/ode_solver/problem.rs +++ b/src/ode_solver/problem.rs @@ -1,48 +1,106 @@ -use std::rc::Rc; - use crate::{ - error::{DiffsolError, OdeSolverError}, - ode_solver_error, - vector::Vector, - OdeEquations, + error::DiffsolError, vector::Vector, AugmentedOdeEquationsImplicit, Bdf, BdfState, + DefaultDenseMatrix, DenseMatrix, LinearSolver, MatrixRef, NewtonNonlinearSolver, OdeEquations, + OdeEquationsImplicit, OdeEquationsSens, OdeSolverState, Sdirk, SdirkState, SensEquations, + Tableau, VectorRef, }; -pub struct OdeSolverProblem { - pub eqn: Rc, +pub struct OdeSolverProblem +where + Eqn: OdeEquations, +{ + pub eqn: Eqn, pub rtol: Eqn::T, - pub atol: Rc, + pub atol: Eqn::V, pub t0: Eqn::T, pub h0: Eqn::T, pub integrate_out: bool, pub sens_rtol: Option, - pub sens_atol: Option>, + pub sens_atol: Option, pub out_rtol: Option, - pub out_atol: Option>, + pub out_atol: Option, pub param_rtol: Option, - pub param_atol: Option>, + pub param_atol: Option, } -// impl clone -impl Clone for OdeSolverProblem { - fn clone(&self) -> Self { - Self { - eqn: self.eqn.clone(), - rtol: self.rtol, - atol: self.atol.clone(), - t0: self.t0, - h0: self.h0, - integrate_out: self.integrate_out, - out_atol: self.out_atol.clone(), - out_rtol: self.out_rtol, - param_atol: self.param_atol.clone(), - param_rtol: self.param_rtol, - sens_atol: self.sens_atol.clone(), - sens_rtol: self.sens_rtol, +macro_rules! sdirk_solver_from_tableau { + ($state:ident, $state_sens:ident, $method:ident, $method_sens:ident, $method_solver:ident, $method_solver_sens:ident, $tableau:ident) => { + pub fn $state>( + &self, + ) -> Result, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + self.sdirk_state::(&Tableau::<::M>::$tableau()) } - } + + pub fn $state_sens>( + &self, + ) -> Result, DiffsolError> + where + Eqn: OdeEquationsSens, + { + self.sdirk_state_sens::(&Tableau::<::M>::$tableau()) + } + + pub fn $method_solver>( + &self, + state: SdirkState, + ) -> Result, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + self.sdirk_solver( + state, + Tableau::<::M>::$tableau(), + ) + } + + pub fn $method_solver_sens>( + &self, + state: SdirkState, + ) -> Result< + Sdirk<'_, Eqn, LS, ::M, SensEquations>, + DiffsolError, + > + where + Eqn: OdeEquationsSens, + { + self.sdirk_solver_sens( + state, + Tableau::<::M>::$tableau(), + ) + } + + pub fn $method>( + &self, + ) -> Result, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + let state = self.$state::()?; + self.$method_solver::(state) + } + + pub fn $method_sens>( + &self, + ) -> Result< + Sdirk<'_, Eqn, LS, ::M, SensEquations>, + DiffsolError, + > + where + Eqn: OdeEquationsSens, + { + let state = self.$state_sens::()?; + self.$method_solver_sens::(state) + } + }; } -impl OdeSolverProblem { +impl OdeSolverProblem +where + Eqn: OdeEquations, +{ pub fn default_rtol() -> Eqn::T { Eqn::T::from(1e-6) } @@ -57,15 +115,15 @@ impl OdeSolverProblem { } #[allow(clippy::too_many_arguments)] pub(crate) fn new( - eqn: Rc, + eqn: Eqn, rtol: Eqn::T, - atol: Rc, + atol: Eqn::V, sens_rtol: Option, - sens_atol: Option>, + sens_atol: Option, out_rtol: Option, - out_atol: Option>, + out_atol: Option, param_rtol: Option, - param_atol: Option>, + param_atol: Option, t0: Eqn::T, h0: Eqn::T, integrate_out: bool, @@ -86,12 +144,194 @@ impl OdeSolverProblem { }) } - pub fn set_params(&mut self, p: Eqn::V) -> Result<(), DiffsolError> { - let eqn = - Rc::get_mut(&mut self.eqn).ok_or(ode_solver_error!(FailedToGetMutableReference))?; - eqn.set_params(Rc::new(p)); - Ok(()) + pub fn eqn(&self) -> &Eqn { + &self.eqn + } + pub fn eqn_mut(&mut self) -> &mut Eqn { + &mut self.eqn + } +} + +impl OdeSolverProblem +where + Eqn: OdeEquations, + Eqn::V: DefaultDenseMatrix, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + pub fn bdf_state>(&self) -> Result, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + BdfState::new::(self, 1) + } + + pub fn bdf_state_sens>(&self) -> Result, DiffsolError> + where + Eqn: OdeEquationsSens, + { + BdfState::new_with_sensitivities::(self, 1) + } + + pub fn bdf_solver>( + &self, + state: BdfState, + ) -> Result>, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + let newton_solver = NewtonNonlinearSolver::new(LS::default()); + Bdf::new(self, state, newton_solver) + } + + pub fn bdf>( + &self, + ) -> Result>, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + let state = self.bdf_state::()?; + self.bdf_solver(state) + } + + #[allow(clippy::type_complexity)] + pub(crate) fn bdf_solver_aug< + LS: LinearSolver, + Aug: AugmentedOdeEquationsImplicit, + >( + &self, + state: BdfState, + aug_eqn: Aug, + ) -> Result< + Bdf<'_, Eqn, NewtonNonlinearSolver, ::M, Aug>, + DiffsolError, + > + where + Eqn: OdeEquationsImplicit, + { + let newton_solver = NewtonNonlinearSolver::new(LS::default()); + Bdf::new_augmented(state, self, aug_eqn, newton_solver) + } + + #[allow(clippy::type_complexity)] + pub fn bdf_solver_sens>( + &self, + state: BdfState, + ) -> Result< + Bdf< + '_, + Eqn, + NewtonNonlinearSolver, + ::M, + SensEquations, + >, + DiffsolError, + > + where + Eqn: OdeEquationsSens, + { + let sens_eqn = SensEquations::new(self); + self.bdf_solver_aug(state, sens_eqn) + } + + #[allow(clippy::type_complexity)] + pub fn bdf_sens>( + &self, + ) -> Result< + Bdf< + '_, + Eqn, + NewtonNonlinearSolver, + ::M, + SensEquations, + >, + DiffsolError, + > + where + Eqn: OdeEquationsSens, + { + let state = self.bdf_state_sens::()?; + self.bdf_solver_sens(state) + } + + pub fn sdirk_state, DM: DenseMatrix>( + &self, + tableau: &Tableau, + ) -> Result, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + SdirkState::new::(self, tableau.order()) + } + + pub fn sdirk_state_sens, DM: DenseMatrix>( + &self, + tableau: &Tableau, + ) -> Result, DiffsolError> + where + Eqn: OdeEquationsSens, + { + SdirkState::new_with_sensitivities::(self, tableau.order()) + } + + pub fn sdirk_solver, DM: DenseMatrix>( + &self, + state: SdirkState, + tableau: Tableau, + ) -> Result, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + let linear_solver = LS::default(); + Sdirk::new(self, state, tableau, linear_solver) } + + pub(crate) fn sdirk_solver_aug< + LS: LinearSolver, + DM: DenseMatrix, + Aug: AugmentedOdeEquationsImplicit, + >( + &self, + state: SdirkState, + tableau: Tableau, + aug_eqn: Aug, + ) -> Result, DiffsolError> + where + Eqn: OdeEquationsImplicit, + { + Sdirk::new_augmented(self, state, tableau, LS::default(), aug_eqn) + } + + pub fn sdirk_solver_sens, DM: DenseMatrix>( + &self, + state: SdirkState, + tableau: Tableau, + ) -> Result>, DiffsolError> + where + Eqn: OdeEquationsSens, + { + let sens_eqn = SensEquations::new(self); + self.sdirk_solver_aug::(state, tableau, sens_eqn) + } + + sdirk_solver_from_tableau!( + tr_bdf2_state, + tr_bdf2_state_sens, + tr_bdf2, + tr_bdf2_sens, + tr_bdf2_solver, + tr_bdf2_solver_sens, + tr_bdf2 + ); + sdirk_solver_from_tableau!( + esdirk34_state, + esdirk34_state_sens, + esdirk34, + esdirk34_sens, + esdirk34_solver, + esdirk34_solver_sens, + esdirk34 + ); } #[derive(Debug, Clone)] diff --git a/src/ode_solver/sdirk.rs b/src/ode_solver/sdirk.rs index 072ddac7..4cc102b9 100644 --- a/src/ode_solver/sdirk.rs +++ b/src/ode_solver/sdirk.rs @@ -1,10 +1,3 @@ -use num_traits::abs; -use num_traits::One; -use num_traits::Pow; -use num_traits::Zero; -use std::ops::MulAssign; -use std::rc::Rc; - use crate::error::DiffsolError; use crate::error::OdeSolverError; use crate::matrix::MatrixRef; @@ -12,42 +5,68 @@ use crate::ode_solver_error; use crate::vector::VectorRef; use crate::AdjointEquations; use crate::DefaultDenseMatrix; -use crate::DefaultSolver; use crate::LinearSolver; use crate::NewtonNonlinearSolver; use crate::NoAug; use crate::OdeSolverStopReason; use crate::RootFinder; use crate::SdirkState; -use crate::SensEquations; use crate::Tableau; use crate::{ nonlinear_solver::NonLinearSolver, op::sdirk::SdirkCallable, scale, AdjointOdeSolverMethod, - AugmentedOdeEquations, AugmentedOdeEquationsImplicit, DenseMatrix, JacobianUpdate, NonLinearOp, - OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, + AugmentedOdeEquations, AugmentedOdeEquationsImplicit, Convergence, DenseMatrix, JacobianUpdate, + NonLinearOp, OdeEquationsAdjoint, OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Op, Scalar, StateRef, StateRefMut, Vector, VectorViewMut, }; +use num_traits::abs; +use num_traits::One; +use num_traits::Pow; +use num_traits::Zero; +use std::ops::MulAssign; use super::bdf::BdfStatistics; use super::jacobian_update::SolverState; use super::method::AugmentedOdeSolverMethod; -use super::method::SensitivitiesOdeSolverMethod; -// make a few convenience type aliases -pub type SdirkAdj = Sdirk< - M, - AdjointEquations>, - LS, - AdjointEquations>, ->; -impl SensitivitiesOdeSolverMethod for Sdirk> +impl<'a, M, Eqn, LS, AugEqn> AugmentedOdeSolverMethod<'a, Eqn, AugEqn> + for Sdirk<'a, Eqn, LS, M, AugEqn> where + Eqn: OdeEquationsImplicit, + AugEqn: AugmentedOdeEquationsImplicit, M: DenseMatrix, LS: LinearSolver, - Eqn: OdeEquationsSens, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, + Eqn::V: DefaultDenseMatrix, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + fn into_state_and_eqn(self) -> (Self::State, Option) { + (self.state, self.s_op.map(|op| op.eqn)) + } +} + +impl<'a, M, Eqn, LS> AdjointOdeSolverMethod<'a, Eqn> for Sdirk<'a, Eqn, LS, M> +where + Eqn: OdeEquationsAdjoint, + M: DenseMatrix, + LS: LinearSolver + 'a, + Eqn::V: DefaultDenseMatrix, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, { + type DefaultAdjointSolver = + Sdirk<'a, Eqn, LS, M, AdjointEquations<'a, Eqn, Sdirk<'a, Eqn, LS, M>>>; + + fn default_adjoint_solver>( + self, + mut aug_eqn: AdjointEquations<'a, Eqn, Self>, + ) -> Result { + let problem = self.problem(); + let tableau = self.tableau; + let state = self + .state + .into_adjoint::(problem, &mut aug_eqn)?; + Sdirk::new_augmented(self.problem, state, tableau, LS::default(), aug_eqn) + } } /// A singly diagonally implicit Runge-Kutta method. Can optionally have an explicit first stage for ESDIRK methods. @@ -59,20 +78,25 @@ where /// - The upper triangular part of the `a` matrix must be zero (i.e. not fully implicit). /// - The diagonal of the `a` matrix must be the same non-zero value for all rows (i.e. an SDIRK method), except for the first row which can be zero for ESDIRK methods. /// - The last row of the `a` matrix must be the same as the `b` vector, and the last element of the `c` vector must be 1 (i.e. a stiffly accurate method) -pub struct Sdirk> -where +pub struct Sdirk< + 'a, + Eqn, + LS, + M = <::V as DefaultDenseMatrix>::M, + AugmentedEqn = NoAug, +> where M: DenseMatrix, LS: LinearSolver, Eqn: OdeEquationsImplicit, + Eqn::V: DefaultDenseMatrix, AugmentedEqn: AugmentedOdeEquations, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, { tableau: Tableau, - problem: Option>, + problem: &'a OdeSolverProblem, nonlinear_solver: NewtonNonlinearSolver, - op: Option>, - state: Option>, + convergence: Convergence<'a, Eqn::V>, + op: Option>, + state: SdirkState, diff: M, sdiff: Vec, sgdiff: Vec, @@ -94,90 +118,92 @@ where jacobian_update: JacobianUpdate, } -impl Sdirk<::M, Eqn, ::LS, NoAug> -where - Eqn: OdeEquationsImplicit, - Eqn::M: DefaultSolver, - Eqn::V: DefaultDenseMatrix, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, -{ - pub fn tr_bdf2() -> Self { - let tableau = Tableau::<::M>::tr_bdf2(); - let linear_solver = Eqn::M::default_solver(); - Self::new(tableau, linear_solver) - } - pub fn esdirk34() -> Self { - let tableau = Tableau::<::M>::esdirk34(); - let linear_solver = Eqn::M::default_solver(); - Self::new(tableau, linear_solver) - } -} - -impl - Sdirk<::M, Eqn, ::LS, SensEquations> -where - Eqn: OdeEquationsSens, - Eqn::M: DefaultSolver, - Eqn::V: DefaultDenseMatrix, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, -{ - pub fn tr_bdf2_with_sensitivities() -> Self { - let tableau = Tableau::<::M>::tr_bdf2(); - let linear_solver = Eqn::M::default_solver(); - Self::new_common(tableau, linear_solver) - } - pub fn esdirk34_with_sensitivities() -> Self { - let tableau = Tableau::<::M>::esdirk34(); - let linear_solver = Eqn::M::default_solver(); - Self::new_common(tableau, linear_solver) - } -} - -impl Sdirk> +impl Clone for Sdirk<'_, Eqn, LS, M, AugmentedEqn> where - LS: LinearSolver, M: DenseMatrix, - Eqn: OdeEquationsImplicit, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, -{ - pub fn new(tableau: Tableau, linear_solver: LS) -> Self { - Self::new_common(tableau, linear_solver) - } -} - -impl Sdirk> -where LS: LinearSolver, - M: DenseMatrix, - Eqn: OdeEquationsSens, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, + Eqn: OdeEquationsImplicit, + Eqn::V: DefaultDenseMatrix, + AugmentedEqn: AugmentedOdeEquationsImplicit, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, { - pub fn new_with_sensitivities(tableau: Tableau, linear_solver: LS) -> Self { - Self::new_common(tableau, linear_solver) + fn clone(&self) -> Self { + let problem = self.problem; + let mut nonlinear_solver = NewtonNonlinearSolver::new(LS::default()); + let op = if let Some(op) = &self.op { + let op = op.clone_state(&problem.eqn); + nonlinear_solver.set_problem(&op); + nonlinear_solver.reset_jacobian(&op, &self.state.y, self.state.t); + Some(op) + } else { + None + }; + let s_op = self.s_op.as_ref().map(|op| { + let op = op.clone_state(op.eqn().clone()); + op + }); + Self { + tableau: self.tableau.clone(), + problem: self.problem, + convergence: self.convergence.clone(), + nonlinear_solver, + op, + state: self.state.clone(), + diff: self.diff.clone(), + sdiff: self.sdiff.clone(), + sgdiff: self.sgdiff.clone(), + gdiff: self.gdiff.clone(), + old_g: self.old_g.clone(), + gamma: self.gamma, + is_sdirk: self.is_sdirk, + s_op, + old_t: self.old_t, + old_y: self.old_y.clone(), + old_y_sens: self.old_y_sens.clone(), + old_f: self.old_f.clone(), + old_f_sens: self.old_f_sens.clone(), + a_rows: self.a_rows.clone(), + statistics: self.statistics.clone(), + root_finder: self.root_finder.clone(), + tstop: self.tstop, + is_state_mutated: self.is_state_mutated, + jacobian_update: self.jacobian_update.clone(), + } } } -impl Sdirk +impl<'a, M, Eqn, LS, AugmentedEqn> Sdirk<'a, Eqn, LS, M, AugmentedEqn> where LS: LinearSolver, M: DenseMatrix, Eqn: OdeEquationsImplicit, + Eqn::V: DefaultDenseMatrix, AugmentedEqn: AugmentedOdeEquationsImplicit, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, { const NEWTON_MAXITER: usize = 10; const MIN_FACTOR: f64 = 0.2; const MAX_FACTOR: f64 = 10.0; const MIN_TIMESTEP: f64 = 1e-13; - fn new_common(tableau: Tableau, linear_solver: LS) -> Self { - let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); + pub fn new( + problem: &'a OdeSolverProblem, + state: SdirkState, + tableau: Tableau, + linear_solver: LS, + ) -> Result { + Self::_new(problem, state, tableau, linear_solver, true) + } + fn _new( + problem: &'a OdeSolverProblem, + mut state: SdirkState, + tableau: Tableau, + linear_solver: LS, + integrate_main_eqn: bool, + ) -> Result { // check that the upper triangular part of a is zero let s = tableau.s(); for i in 0..s { @@ -242,30 +268,73 @@ where ); } - let n = 1; - let old_t = Eqn::T::zero(); - let old_y = ::zeros(n); - let old_g = ::zeros(n); - let old_f = ::zeros(n); + // setup linear solver for first step + let mut jacobian_update = JacobianUpdate::default(); + jacobian_update.update_jacobian(state.h); + jacobian_update.update_rhs_jacobian(); + + let mut nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); + + // set max iterations for nonlinear solver + let mut convergence = Convergence::new(problem.rtol, &problem.atol); + convergence.set_max_iter(Self::NEWTON_MAXITER); + + let op = if integrate_main_eqn { + let callable = SdirkCallable::new(&problem.eqn, gamma); + callable.set_h(state.h); + nonlinear_solver.set_problem(&callable); + nonlinear_solver.reset_jacobian(&callable, &state.y, state.t); + Some(callable) + } else { + None + }; + + // update statistics let statistics = BdfStatistics::default(); - let old_f_sens = Vec::new(); - let old_y_sens = Vec::new(); - let diff = M::zeros(n, s); - let sdiff = Vec::new(); - let sgdiff = Vec::new(); - let gdiff = M::zeros(n, s); - Self { - old_y_sens, - old_f_sens, + + state.check_consistent_with_problem(problem)?; + + let nstates = state.y.len(); + let order = tableau.s(); + let diff = M::zeros(nstates, order); + let gdiff_rows = if problem.integrate_out { + problem.eqn.out().unwrap().nout() + } else { + 0 + }; + let gdiff = M::zeros(gdiff_rows, order); + + let old_f = state.dy.clone(); + let old_t = state.t; + let old_y = state.y.clone(); + let old_g = if problem.integrate_out { + state.g.clone() + } else { + ::zeros(0) + }; + + state.set_problem(problem)?; + let root_finder = if let Some(root_fn) = problem.eqn.root() { + let root_finder = RootFinder::new(root_fn.nout()); + root_finder.init(&root_fn, &state.y, state.t); + Some(root_finder) + } else { + None + }; + + Ok(Self { + old_y_sens: vec![], + old_f_sens: vec![], old_g, + convergence, diff, - sdiff, - sgdiff, + sdiff: vec![], + sgdiff: vec![], tableau, nonlinear_solver, - op: None, - state: None, - problem: None, + op, + state, + problem, s_op: None, gdiff, gamma, @@ -275,11 +344,48 @@ where a_rows, old_f, statistics, - root_finder: None, + root_finder, tstop: None, is_state_mutated: false, - jacobian_update: JacobianUpdate::default(), + jacobian_update, + }) + } + + pub fn new_augmented( + problem: &'a OdeSolverProblem, + state: SdirkState, + tableau: Tableau, + linear_solver: LS, + augmented_eqn: AugmentedEqn, + ) -> Result { + state.check_sens_consistent_with_problem(problem, &augmented_eqn)?; + let mut ret = Self::_new( + problem, + state, + tableau, + linear_solver, + augmented_eqn.integrate_main_eqn(), + )?; + let naug = augmented_eqn.max_index(); + let nstates = augmented_eqn.rhs().nstates(); + let order = ret.tableau.s(); + ret.sdiff = vec![M::zeros(nstates, order); naug]; + ret.old_f_sens = vec![::zeros(nstates); naug]; + ret.old_y_sens = ret.state.s.clone(); + if let Some(out) = augmented_eqn.out() { + ret.sgdiff = vec![M::zeros(out.nout(), order); naug]; } + + ret.s_op = if augmented_eqn.integrate_main_eqn() { + Some(SdirkCallable::new_no_jacobian(augmented_eqn, ret.gamma)) + } else { + let callable = SdirkCallable::new(augmented_eqn, ret.gamma); + ret.nonlinear_solver.set_problem(&callable); + ret.nonlinear_solver + .reset_jacobian(&callable, &ret.state.s[0], ret.state.t); + Some(callable) + }; + Ok(ret) } pub fn get_statistics(&self) -> &BdfStatistics { @@ -290,7 +396,7 @@ where &mut self, tstop: Eqn::T, ) -> Result>, DiffsolError> { - let state = self.state.as_mut().unwrap(); + let state = &mut self.state; // check if the we are at tstop let troundoff = Eqn::T::from(100.0) * Eqn::T::EPSILON * (abs(state.t) + abs(state.h)); @@ -314,7 +420,12 @@ where { let factor = (tstop - state.t) / state.h; state.h *= factor; - self.op.as_mut().unwrap().set_h(state.h); + if let Some(op) = self.op.as_mut() { + op.set_h(state.h); + } + if let Some(s_op) = self.s_op.as_mut() { + s_op.set_h(state.h); + } } Ok(None) } @@ -334,38 +445,34 @@ where } fn solve_for_sensitivities(&mut self, i: usize, t: Eqn::T) -> Result<(), DiffsolError> { - let h = self.state.as_ref().unwrap().h; + let h = self.state.h; // update for new state { let op = self.s_op.as_mut().unwrap(); - Rc::get_mut(op.eqn_mut()) - .unwrap() + op.eqn_mut() .update_rhs_out_state(&self.old_y, &self.old_f, t); - - // construct bdf discretisation of sensitivity equations - op.set_h(h); } // solve for sensitivities equations discretised using sdirk equation for j in 0..self.sdiff.len() { - let s0 = &self.state.as_ref().unwrap().s[j]; + let s0 = &self.state.s[j]; let op = self.s_op.as_mut().unwrap(); op.set_phi(&self.sdiff[j].columns(0, i), s0, &self.a_rows[i]); - Rc::get_mut(op.eqn_mut()).unwrap().set_index(j); + op.eqn_mut().set_index(j); let ds = &mut self.old_f_sens[j]; Self::predict_stage(i, &self.sdiff[j], ds, &self.tableau); // solve let op = self.s_op.as_ref().unwrap(); - self.nonlinear_solver.solve_in_place(op, ds, t, s0)?; + self.nonlinear_solver + .solve_in_place(op, ds, t, s0, &mut self.convergence)?; self.old_y_sens[j].copy_from(&op.get_last_f_eval()); - self.statistics.number_of_nonlinear_solver_iterations += - self.nonlinear_solver.convergence().niter(); + self.statistics.number_of_nonlinear_solver_iterations += self.convergence.niter(); // calculate sdg and store in sgdiff if let Some(out) = self.s_op.as_ref().unwrap().eqn().out() { - let dsg = &mut self.state.as_mut().unwrap().dsg[j]; + let dsg = &mut self.state.dsg[j]; out.call_inplace(&self.old_y_sens[j], t, dsg); self.sgdiff[j].column_mut(i).axpy(h, dsg, Eqn::T::zero()); } @@ -408,149 +515,102 @@ where fn _jacobian_updates(&mut self, h: Eqn::T, state: SolverState) { if self.jacobian_update.check_rhs_jacobian_update(h, &state) { - self.op.as_mut().unwrap().set_jacobian_is_stale(); - self.nonlinear_solver.reset_jacobian( - self.op.as_ref().unwrap(), - &self.old_f, - self.state.as_ref().unwrap().t, - ); + if let Some(op) = self.op.as_mut() { + op.set_jacobian_is_stale(); + self.nonlinear_solver + .reset_jacobian(op, &self.old_f, self.state.t); + } else if let Some(s_op) = self.s_op.as_mut() { + s_op.set_jacobian_is_stale(); + self.nonlinear_solver + .reset_jacobian(s_op, &self.old_f_sens[0], self.state.t); + } self.jacobian_update.update_rhs_jacobian(); self.jacobian_update.update_jacobian(h); } else if self.jacobian_update.check_jacobian_update(h, &state) { - self.nonlinear_solver.reset_jacobian( - self.op.as_ref().unwrap(), - &self.old_f, - self.state.as_ref().unwrap().t, - ); + if let Some(op) = self.op.as_ref() { + self.nonlinear_solver + .reset_jacobian(op, &self.old_f, self.state.t); + } else if let Some(s_op) = self.s_op.as_ref() { + self.nonlinear_solver + .reset_jacobian(s_op, &self.old_f_sens[0], self.state.t); + } self.jacobian_update.update_jacobian(h); } } fn _update_step_size(&mut self, factor: Eqn::T) -> Result { - let new_h = self.state.as_ref().unwrap().h * factor; + let new_h = self.state.h * factor; // if step size too small, then fail if abs(new_h) < Eqn::T::from(Self::MIN_TIMESTEP) { return Err(DiffsolError::from(OdeSolverError::StepSizeTooSmall { - time: self.state.as_ref().unwrap().t.into(), + time: self.state.t.into(), })); } // update h for new step size - self.op.as_mut().unwrap().set_h(new_h); + if let Some(op) = self.op.as_mut() { + op.set_h(new_h); + } + if let Some(s_op) = self.s_op.as_mut() { + s_op.set_h(new_h); + } // update state - self.state.as_mut().unwrap().h = new_h; + self.state.h = new_h; Ok(new_h) } } -impl OdeSolverMethod for Sdirk +impl<'a, M, Eqn, AugmentedEqn, LS> OdeSolverMethod<'a, Eqn> for Sdirk<'a, Eqn, LS, M, AugmentedEqn> where LS: LinearSolver, M: DenseMatrix, Eqn: OdeEquationsImplicit, + Eqn::V: DefaultDenseMatrix, AugmentedEqn: AugmentedOdeEquationsImplicit, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, { type State = SdirkState; - fn problem(&self) -> Option<&OdeSolverProblem> { - self.problem.as_ref() + fn problem(&self) -> &'a OdeSolverProblem { + self.problem } fn order(&self) -> usize { self.tableau.order() } - fn take_state(&mut self) -> Option> { - self.problem = None; - self.op = None; - self.s_op = None; - Option::take(&mut self.state) - } + fn set_state(&mut self, state: Self::State) { + self.state = state; - fn checkpoint(&mut self) -> Result { - if self.state.is_none() { - return Err(ode_solver_error!(StateNotSet)); + // update the op with the new state + if let Some(op) = self.op.as_mut() { + op.set_h(self.state.h); } - self._jacobian_updates(self.state.as_ref().unwrap().h, SolverState::Checkpoint); - Ok(self.state.as_ref().unwrap().clone()) - } - - fn set_problem( - &mut self, - mut state: SdirkState, - problem: &OdeSolverProblem, - ) -> Result<(), DiffsolError> { - // setup linear solver for first step - let callable = SdirkCallable::new(problem, self.gamma); - callable.set_h(state.h); - self.jacobian_update.update_jacobian(state.h); - self.jacobian_update.update_rhs_jacobian(); - self.nonlinear_solver - .set_problem(&callable, problem.rtol, problem.atol.clone()); - - // set max iterations for nonlinear solver - self.nonlinear_solver - .convergence_mut() - .set_max_iter(Self::NEWTON_MAXITER); - self.nonlinear_solver - .reset_jacobian(&callable, &state.y, state.t); - self.op = Some(callable); - // update statistics - self.statistics = BdfStatistics::default(); - - state.check_consistent_with_problem(problem)?; - - let nstates = state.y.len(); - let order = self.tableau.s(); - if self.diff.nrows() != nstates || self.diff.ncols() != order { - self.diff = M::zeros(nstates, order); - } - let gdiff_rows = if problem.integrate_out { - problem.eqn.out().unwrap().nout() - } else { - 0 - }; - if self.gdiff.nrows() != gdiff_rows || self.gdiff.ncols() != order { - self.gdiff = M::zeros(gdiff_rows, order); - } + // reinitialise jacobian updates as if a checkpoint was taken + self._jacobian_updates(self.state.h, SolverState::Checkpoint); + } - self.old_f = state.dy.clone(); - self.old_t = state.t; - self.old_y = state.y.clone(); - if problem.integrate_out { - self.old_g = state.g.clone(); - } + fn into_state(self) -> SdirkState { + self.state + } - state.set_problem(problem)?; - self.state = Some(state); - self.problem = Some(problem.clone()); - if let Some(root_fn) = problem.eqn.root() { - let state = self.state.as_ref().unwrap(); - self.root_finder = Some(RootFinder::new(root_fn.nout())); - self.root_finder - .as_ref() - .unwrap() - .init(&root_fn, &state.y, state.t); - } - Ok(()) + fn checkpoint(&mut self) -> Self::State { + self._jacobian_updates(self.state.h, SolverState::Checkpoint); + self.state.clone() } fn step(&mut self) -> Result, DiffsolError> { - if self.state.is_none() { - return Err(ode_solver_error!(StateNotSet)); - } - let n = self.state.as_ref().unwrap().y.len(); + let n = self.state.y.len(); if self.is_state_mutated { // reinitalise root finder if needed - if let Some(root_fn) = self.problem.as_ref().unwrap().eqn.root() { - let state = self.state.as_ref().unwrap(); + if let Some(root_fn) = self.problem.eqn.root() { + let state = &self.state; self.root_finder .as_ref() .unwrap() @@ -570,9 +630,9 @@ where // dont' reset jacobian for the first attempt at the step let mut error = ::zeros(n); - let out_error_control = self.problem().as_ref().unwrap().output_in_error_control(); + let out_error_control = self.problem().output_in_error_control(); let mut out_error = if out_error_control { - ::zeros(self.problem().as_ref().unwrap().eqn.out().unwrap().nout()) + ::zeros(self.problem().eqn.out().unwrap().nout()) } else { ::zeros(0) }; @@ -600,13 +660,13 @@ where // loop until step is accepted 'step: loop { - let t0 = self.state.as_ref().unwrap().t; - let h = self.state.as_ref().unwrap().h; + let t0 = self.state.t; + let h = self.state.h; // if start == 1, then we need to compute the first stage // from the last stage of the previous step if start == 1 { { - let state = self.state.as_ref().unwrap(); + let state = &self.state; let mut hf = self.diff.column_mut(0); hf.copy_from(&state.dy); hf *= scale(h); @@ -614,20 +674,12 @@ where // sensitivities too if self.s_op.is_some() { - for (diff, dy) in self - .sdiff - .iter_mut() - .zip(self.state.as_ref().unwrap().ds.iter()) - { + for (diff, dy) in self.sdiff.iter_mut().zip(self.state.ds.iter()) { let mut hf = diff.column_mut(0); hf.copy_from(dy); hf *= scale(h); } - for (diff, dg) in self - .sgdiff - .iter_mut() - .zip(self.state.as_ref().unwrap().dsg.iter()) - { + for (diff, dg) in self.sgdiff.iter_mut().zip(self.state.dsg.iter()) { let mut hf = diff.column_mut(0); hf.copy_from(dg); hf *= scale(h); @@ -635,8 +687,8 @@ where } // output function - if self.problem.as_ref().unwrap().integrate_out { - let state = self.state.as_ref().unwrap(); + if self.problem.integrate_out { + let state = &self.state; let mut hf = self.gdiff.column_mut(0); hf.copy_from(&state.dg); hf *= scale(h); @@ -645,28 +697,29 @@ where for i in start..self.tableau.s() { let t = t0 + self.tableau.c()[i] * h; - self.op.as_mut().unwrap().set_phi( - &self.diff.columns(0, i), - &self.state.as_ref().unwrap().y, - &self.a_rows[i], - ); - - Self::predict_stage(i, &self.diff, &mut self.old_f, &self.tableau); - let mut solve_result = self.nonlinear_solver.solve_in_place( - self.op.as_ref().unwrap(), - &mut self.old_f, - t, - &self.state.as_ref().unwrap().y, - ); - self.statistics.number_of_nonlinear_solver_iterations += - self.nonlinear_solver.convergence().niter(); + // main equation + let mut solve_result = Ok(()); + if let Some(op) = self.op.as_mut() { + op.set_phi(&self.diff.columns(0, i), &self.state.y, &self.a_rows[i]); + Self::predict_stage(i, &self.diff, &mut self.old_f, &self.tableau); + solve_result = self.nonlinear_solver.solve_in_place( + op, + &mut self.old_f, + t, + &self.state.y, + &mut self.convergence, + ); + self.statistics.number_of_nonlinear_solver_iterations += + self.convergence.niter(); + } // only calculate sensitivities if the solve succeeded if solve_result.is_ok() { + if let Some(op) = self.op.as_ref() { + self.old_y.copy_from(&op.get_last_f_eval()); + } // old_y now has the new y soln and old_f has the new dy soln - self.old_y - .copy_from(&self.op.as_ref().unwrap().get_last_f_eval()); if self.s_op.is_some() { solve_result = self.solve_for_sensitivities(i, t); } @@ -688,18 +741,18 @@ where continue 'step; }; - // update diff with solved dy - self.diff.column_mut(i).copy_from(&self.old_f); - - // calculate dg and store in gdiff - if self.problem.as_ref().unwrap().integrate_out { - let out = self.problem.as_ref().unwrap().eqn.out().unwrap(); - out.call_inplace(&self.old_y, t, &mut self.state.as_mut().unwrap().dg); - self.gdiff.column_mut(i).axpy( - h, - &self.state.as_mut().unwrap().dg, - Eqn::T::zero(), - ); + if self.op.is_some() { + // update diff with solved dy + self.diff.column_mut(i).copy_from(&self.old_f); + + // calculate dg and store in gdiff + if self.problem.integrate_out { + let out = self.problem.eqn.out().unwrap(); + out.call_inplace(&self.old_y, t, &mut self.state.dg); + self.gdiff + .column_mut(i) + .axpy(h, &self.state.dg, Eqn::T::zero()); + } } if self.s_op.is_some() { @@ -708,29 +761,33 @@ where } } } + let mut ncontributions = 0; + let mut error_norm = Eqn::T::zero(); // successfully solved for all stages, now compute error - self.diff - .gemv(Eqn::T::one(), self.tableau.d(), Eqn::T::zero(), &mut error); - - // compute error norm - let atol = self.problem().as_ref().unwrap().atol.as_ref(); - let rtol = self.problem().as_ref().unwrap().rtol; - let mut error_norm = error.squared_norm(&self.old_y, atol, rtol); - let mut ncontributions = 1; - - // output errors - if out_error_control { - self.gdiff.gemv( - Eqn::T::one(), - self.tableau.d(), - Eqn::T::zero(), - &mut out_error, - ); - let atol = self.problem().as_ref().unwrap().out_atol.as_ref().unwrap(); - let rtol = self.problem().as_ref().unwrap().out_rtol.unwrap(); - let out_error_norm = out_error.squared_norm(&self.old_g, atol, rtol); - error_norm += out_error_norm; + if self.op.is_some() { + self.diff + .gemv(Eqn::T::one(), self.tableau.d(), Eqn::T::zero(), &mut error); + + // compute error norm + let atol = &self.problem().atol; + let rtol = self.problem().rtol; + error_norm += error.squared_norm(&self.old_y, atol, rtol); ncontributions += 1; + + // output errors + if out_error_control { + self.gdiff.gemv( + Eqn::T::one(), + self.tableau.d(), + Eqn::T::zero(), + &mut out_error, + ); + let atol = self.problem().out_atol.as_ref().unwrap(); + let rtol = self.problem().out_rtol.unwrap(); + let out_error_norm = out_error.squared_norm(&self.old_g, atol, rtol); + error_norm += out_error_norm; + ncontributions += 1; + } } // sensitivity errors @@ -761,20 +818,19 @@ where Eqn::T::zero(), &mut sens_out_error, ); - let sens_error_norm = sens_out_error.squared_norm( - &self.state.as_ref().unwrap().sg[i], - atol, - rtol, - ); + let sens_error_norm = + sens_out_error.squared_norm(&self.state.sg[i], atol, rtol); error_norm += sens_error_norm; ncontributions += 1; } } - error_norm /= Eqn::T::from(ncontributions as f64); + if ncontributions > 1 { + error_norm /= Eqn::T::from(ncontributions as f64); + } // adjust step size based on error - let maxiter = self.nonlinear_solver.convergence().max_iter() as f64; - let niter = self.nonlinear_solver.convergence().niter() as f64; + let maxiter = self.convergence.max_iter() as f64; + let niter = self.convergence.niter() as f64; let safety = Eqn::T::from(0.9 * (2.0 * maxiter + 1.0) / (2.0 * maxiter + niter)); let order = self.tableau.order() as f64; factor = safety * error_norm.pow(Eqn::T::from(-0.5 / (order + 1.0))); @@ -797,7 +853,7 @@ where // take the step { - let state = self.state.as_mut().unwrap(); + let state = &mut self.state; self.old_t = state.t; state.t += state.h; @@ -825,7 +881,7 @@ where } // integrate output function - if self.problem.as_ref().unwrap().integrate_out { + if self.problem.integrate_out { self.old_g.copy_from(&state.g); self.gdiff .gemv(Eqn::T::one(), self.tableau.b(), Eqn::T::one(), &mut state.g); @@ -837,18 +893,21 @@ where self._jacobian_updates(new_h, SolverState::StepSuccess); // update statistics - self.statistics.number_of_linear_solver_setups = - self.op.as_ref().unwrap().number_of_jac_evals(); + if let Some(op) = self.op.as_ref() { + self.statistics.number_of_linear_solver_setups = op.number_of_jac_evals(); + } else if let Some(s_op) = self.s_op.as_ref() { + self.statistics.number_of_linear_solver_setups = s_op.number_of_jac_evals(); + } self.statistics.number_of_steps += 1; self.jacobian_update.step(); // check for root within accepted step - if let Some(root_fn) = self.problem.as_ref().unwrap().eqn.root() { + if let Some(root_fn) = self.problem.eqn.root() { let ret = self.root_finder.as_ref().unwrap().check_root( &|t| self.interpolate(t), &root_fn, - &self.state.as_ref().unwrap().y, - self.state.as_ref().unwrap().t, + &self.state.y, + self.state.t, ); if let Some(root) = ret { return Ok(OdeSolverStopReason::RootFound(root)); @@ -871,7 +930,7 @@ where if let Some(OdeSolverStopReason::TstopReached) = self.handle_tstop(tstop)? { let error = OdeSolverError::StopTimeBeforeCurrentTime { stop_time: tstop.into(), - state_time: self.state.as_ref().unwrap().t.into(), + state_time: self.state.t.into(), }; self.tstop = None; return Err(DiffsolError::from(error)); @@ -880,10 +939,7 @@ where } fn interpolate_sens(&self, t: ::T) -> Result::V>, DiffsolError> { - if self.state.is_none() { - return Err(ode_solver_error!(StateNotSet)); - } - let state = self.state.as_ref().unwrap(); + let state = &self.state; if self.is_state_mutated { if t == state.t { @@ -929,10 +985,7 @@ where } fn interpolate(&self, t: ::T) -> Result<::V, DiffsolError> { - if self.state.is_none() { - return Err(ode_solver_error!(StateNotSet)); - } - let state = self.state.as_ref().unwrap(); + let state = &self.state; if self.is_state_mutated { if t == state.t { @@ -968,10 +1021,7 @@ where } fn interpolate_out(&self, t: ::T) -> Result<::V, DiffsolError> { - if self.state.is_none() { - return Err(ode_solver_error!(StateNotSet)); - } - let state = self.state.as_ref().unwrap(); + let state = &self.state; if self.is_state_mutated { if t == state.t { @@ -1006,74 +1056,13 @@ where } } - fn state(&self) -> Option> { - self.state.as_ref().map(|s| s.as_ref()) + fn state(&self) -> StateRef { + self.state.as_ref() } - fn state_mut(&mut self) -> Option> { + fn state_mut(&mut self) -> StateRefMut { self.is_state_mutated = true; - self.state.as_mut().map(|s| s.as_mut()) - } -} - -impl AugmentedOdeSolverMethod - for Sdirk -where - LS: LinearSolver, - M: DenseMatrix, - Eqn: OdeEquationsImplicit, - AugmentedEqn: AugmentedOdeEquationsImplicit, - for<'a> &'a Eqn::V: VectorRef, - for<'a> &'a Eqn::M: MatrixRef, -{ - fn set_augmented_problem( - &mut self, - state: Self::State, - ode_problem: &OdeSolverProblem, - augmented_eqn: AugmentedEqn, - ) -> Result<(), DiffsolError> { - state.check_sens_consistent_with_problem(ode_problem, &augmented_eqn)?; - self.set_problem(state, ode_problem)?; - let naug = augmented_eqn.max_index(); - let nstates = augmented_eqn.rhs().nstates(); - let order = self.tableau.s(); - if self.sdiff.len() != naug - || self.sdiff[0].nrows() != nstates - || self.sdiff[0].ncols() != order - { - self.sdiff = vec![M::zeros(nstates, order); naug]; - self.old_f_sens = vec![::zeros(nstates); naug]; - self.old_y_sens = self.state.as_ref().unwrap().s.clone(); - } - if let Some(out) = augmented_eqn.out() { - if self.sgdiff.len() != naug - || self.sgdiff[0].nrows() != out.nout() - || self.sgdiff[0].ncols() != order - { - self.sgdiff = vec![M::zeros(out.nout(), order); naug]; - } - } - let augmented_eqn = Rc::new(augmented_eqn); - self.s_op = Some(SdirkCallable::from_eqn(augmented_eqn, self.gamma)); - Ok(()) - } -} - -impl AdjointOdeSolverMethod for Sdirk -where - Eqn: OdeEquationsAdjoint, - AugmentedEqn: AugmentedOdeEquations + OdeEquationsAdjoint, - M: DenseMatrix, - LS: LinearSolver, - for<'b> &'b Eqn::V: VectorRef, - for<'b> &'b Eqn::M: MatrixRef, -{ - type AdjointSolver = Sdirk, LS, AdjointEquations>; - - fn new_adjoint_solver(&self) -> Self::AdjointSolver { - let tableau = self.tableau.clone(); - let linear_solver = LS::default(); - Self::AdjointSolver::new_common(tableau, linear_solver) + self.state.as_mut() } } @@ -1093,57 +1082,54 @@ mod test { robertson_ode::robertson_ode, }, tests::{ - test_checkpointing, test_interpolate, test_no_set_problem, test_ode_solver, - test_ode_solver_adjoint, test_ode_solver_no_sens, test_param_sweep, test_state_mut, - test_state_mut_on_problem, + test_checkpointing, test_interpolate, test_ode_solver, test_ode_solver_adjoint, + test_problem, test_state_mut, test_state_mut_on_problem, }, }, - OdeEquations, Op, Sdirk, SparseColMat, + FaerSparseLU, NalgebraLU, OdeEquations, OdeSolverMethod, Op, SparseColMat, Vector, }; use num_traits::abs; type M = nalgebra::DMatrix; - #[test] - fn sdirk_no_set_problem() { - test_no_set_problem::(Sdirk::tr_bdf2()); - } + type LS = NalgebraLU; + #[test] fn sdirk_state_mut() { - test_state_mut::(Sdirk::tr_bdf2()); + test_state_mut(test_problem::().tr_bdf2::().unwrap()); } #[test] fn sdirk_test_interpolate() { - test_interpolate::(Sdirk::tr_bdf2()); + test_interpolate(test_problem::().tr_bdf2::().unwrap()); } #[test] fn sdirk_test_checkpointing() { - let s1 = Sdirk::tr_bdf2(); - let s2 = Sdirk::tr_bdf2(); let (problem, soln) = exponential_decay_problem::(false); - test_checkpointing(s1, s2, problem, soln); + let s1 = problem.tr_bdf2::().unwrap(); + let s2 = problem.tr_bdf2::().unwrap(); + test_checkpointing(soln, s1, s2); } #[test] fn sdirk_test_state_mut_exponential_decay() { let (p, soln) = exponential_decay_problem::(false); - let s = Sdirk::tr_bdf2(); - test_state_mut_on_problem(s, p, soln); + let s = p.tr_bdf2::().unwrap(); + test_state_mut_on_problem(s, soln); } #[test] fn sdirk_test_nalgebra_negative_exponential_decay() { - let mut s = Sdirk::esdirk34(); let (problem, soln) = negative_exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.esdirk34::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[test] fn test_tr_bdf2_nalgebra_exponential_decay() { - let mut s = Sdirk::tr_bdf2(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.tr_bdf2::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 4 number_of_steps: 29 @@ -1151,7 +1137,7 @@ mod test { number_of_nonlinear_solver_iterations: 116 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 118 number_of_jac_muls: 2 number_of_matrix_evals: 1 @@ -1161,19 +1147,19 @@ mod test { #[test] fn test_tr_bdf2_nalgebra_exponential_decay_sens() { - let mut s = Sdirk::tr_bdf2_with_sensitivities(); let (problem, soln) = exponential_decay_problem_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false, true); + let mut s = problem.tr_bdf2_sens::().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - number_of_linear_solver_setups: 7 - number_of_steps: 52 - number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 520 + number_of_linear_solver_setups: 8 + number_of_steps: 53 + number_of_error_test_failures: 1 + number_of_nonlinear_solver_iterations: 540 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - number_of_calls: 210 - number_of_jac_muls: 318 + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 218 + number_of_jac_muls: 330 number_of_matrix_evals: 2 number_of_jac_adj_muls: 0 "###); @@ -1181,9 +1167,9 @@ mod test { #[test] fn test_esdirk34_nalgebra_exponential_decay() { - let mut s = Sdirk::esdirk34(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.esdirk34::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 3 number_of_steps: 13 @@ -1191,7 +1177,7 @@ mod test { number_of_nonlinear_solver_iterations: 84 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 86 number_of_jac_muls: 2 number_of_matrix_evals: 1 @@ -1201,19 +1187,19 @@ mod test { #[test] fn test_esdirk34_nalgebra_exponential_decay_sens() { - let mut s = Sdirk::esdirk34_with_sensitivities(); let (problem, soln) = exponential_decay_problem_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false, true); + let mut s = problem.esdirk34_sens::().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - number_of_linear_solver_setups: 5 + number_of_linear_solver_setups: 6 number_of_steps: 20 - number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 317 + number_of_error_test_failures: 1 + number_of_nonlinear_solver_iterations: 332 number_of_nonlinear_solver_fails: 0 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - number_of_calls: 122 - number_of_jac_muls: 201 + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 128 + number_of_jac_muls: 210 number_of_matrix_evals: 1 number_of_jac_adj_muls: 0 "###); @@ -1221,49 +1207,35 @@ mod test { #[test] fn sdirk_test_esdirk34_exponential_decay_adjoint() { - let s = Sdirk::esdirk34(); let (problem, soln) = exponential_decay_problem_adjoint::(); - let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); + let s = problem.esdirk34::().unwrap(); + test_ode_solver_adjoint::(s, soln); insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 196 number_of_jac_muls: 6 number_of_matrix_evals: 3 - number_of_jac_adj_muls: 599 - "###); - insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" - number_of_linear_solver_setups: 18 - number_of_steps: 29 - number_of_error_test_failures: 10 - number_of_nonlinear_solver_iterations: 595 - number_of_nonlinear_solver_fails: 0 + number_of_jac_adj_muls: 474 "###); } #[test] fn sdirk_test_esdirk34_exponential_decay_algebraic_adjoint() { - let s = Sdirk::esdirk34(); let (problem, soln) = exponential_decay_with_algebraic_adjoint_problem::(); - let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); + let s = problem.esdirk34::().unwrap(); + test_ode_solver_adjoint::(s, soln); insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 171 number_of_jac_muls: 12 number_of_matrix_evals: 4 - number_of_jac_adj_muls: 287 - "###); - insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" - number_of_linear_solver_setups: 18 - number_of_steps: 20 - number_of_error_test_failures: 11 - number_of_nonlinear_solver_iterations: 278 - number_of_nonlinear_solver_fails: 0 + number_of_jac_adj_muls: 191 "###); } #[test] fn test_tr_bdf2_nalgebra_robertson() { - let mut s = Sdirk::tr_bdf2(); let (problem, soln) = robertson::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.tr_bdf2::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 97 number_of_steps: 232 @@ -1271,7 +1243,7 @@ mod test { number_of_nonlinear_solver_iterations: 1921 number_of_nonlinear_solver_fails: 18 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 1924 number_of_jac_muls: 36 number_of_matrix_evals: 12 @@ -1281,29 +1253,29 @@ mod test { #[test] fn test_tr_bdf2_nalgebra_robertson_sens() { - let mut s = Sdirk::tr_bdf2_with_sensitivities(); let (problem, soln) = robertson_sens::(); - test_ode_solver(&mut s, &problem, soln, None, false, true); + let mut s = problem.tr_bdf2_sens::().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - number_of_linear_solver_setups: 112 - number_of_steps: 216 + number_of_linear_solver_setups: 109 + number_of_steps: 215 number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 4529 - number_of_nonlinear_solver_fails: 37 + number_of_nonlinear_solver_iterations: 4544 + number_of_nonlinear_solver_fails: 36 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - number_of_calls: 1420 - number_of_jac_muls: 3277 - number_of_matrix_evals: 27 + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + number_of_calls: 1443 + number_of_jac_muls: 3268 + number_of_matrix_evals: 28 number_of_jac_adj_muls: 0 "###); } #[test] fn test_esdirk34_nalgebra_robertson() { - let mut s = Sdirk::esdirk34(); let (problem, soln) = robertson::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.esdirk34::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 100 number_of_steps: 141 @@ -1311,7 +1283,7 @@ mod test { number_of_nonlinear_solver_iterations: 1793 number_of_nonlinear_solver_fails: 24 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 1796 number_of_jac_muls: 54 number_of_matrix_evals: 18 @@ -1321,9 +1293,9 @@ mod test { #[test] fn test_esdirk34_nalgebra_robertson_sens() { - let mut s = Sdirk::esdirk34_with_sensitivities(); let (problem, soln) = robertson_sens::(); - test_ode_solver(&mut s, &problem, soln, None, false, true); + let mut s = problem.esdirk34_sens::().unwrap(); + test_ode_solver(&mut s, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 114 number_of_steps: 131 @@ -1331,7 +1303,7 @@ mod test { number_of_nonlinear_solver_iterations: 4442 number_of_nonlinear_solver_fails: 44 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 1492 number_of_jac_muls: 3136 number_of_matrix_evals: 33 @@ -1341,9 +1313,9 @@ mod test { #[test] fn test_tr_bdf2_nalgebra_robertson_ode() { - let mut s = Sdirk::tr_bdf2(); let (problem, soln) = robertson_ode::(false, 1); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.tr_bdf2::().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" number_of_linear_solver_setups: 113 number_of_steps: 304 @@ -1351,7 +1323,7 @@ mod test { number_of_nonlinear_solver_iterations: 2601 number_of_nonlinear_solver_fails: 15 "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" number_of_calls: 2603 number_of_jac_muls: 39 number_of_matrix_evals: 13 @@ -1361,35 +1333,50 @@ mod test { #[test] fn test_tr_bdf2_faer_sparse_heat2d() { - let mut s = Sdirk::tr_bdf2(); let (problem, soln) = head2d_problem::, 10>(); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.tr_bdf2::>().unwrap(); + test_ode_solver(&mut s, soln, None, false, false); } #[test] fn test_tstop_tr_bdf2() { - let mut s = Sdirk::tr_bdf2(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, true); + let mut s = problem.tr_bdf2::().unwrap(); + test_ode_solver(&mut s, soln, None, true, false); } #[test] fn test_root_finder_tr_bdf2() { - let mut s = Sdirk::tr_bdf2(); let (problem, soln) = exponential_decay_problem_with_root::(false); - let y = test_ode_solver_no_sens(&mut s, &problem, soln, None, false); + let mut s = problem.tr_bdf2::().unwrap(); + let y = test_ode_solver(&mut s, soln, None, false, false); assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); } #[test] fn test_param_sweep_tr_bdf2() { - let s = Sdirk::tr_bdf2(); - let (problem, _soln) = exponential_decay_problem::(false); + let (mut problem, _soln) = exponential_decay_problem::(false); let mut ps = Vec::new(); for y0 in (1..10).map(f64::from) { ps.push(nalgebra::DVector::::from_vec(vec![0.1, y0])); } - test_param_sweep(s, problem, ps); + + let mut old_soln: Option> = None; + for p in ps { + problem.eqn_mut().set_params(&p); + let mut s = problem.tr_bdf2::().unwrap(); + let (ys, _ts) = s.solve(10.0).unwrap(); + // check that the new solution is different from the old one + if let Some(old_soln) = &mut old_soln { + let new_soln = ys.column(ys.ncols() - 1).into_owned(); + let error = new_soln - &*old_soln; + let diff = error + .squared_norm(old_soln, &problem.atol, problem.rtol) + .sqrt(); + assert!(diff > 1.0e-6, "diff: {}", diff); + } + old_soln = Some(ys.column(ys.ncols() - 1).into_owned()); + } } #[cfg(feature = "diffsl")] @@ -1397,9 +1384,11 @@ mod test { fn test_ball_bounce_tr_bdf2() { type M = nalgebra::DMatrix; type LS = crate::NalgebraLU; - type Eqn = crate::DiffSl; - let s = Sdirk::::tr_bdf2(); - let (x, v, t) = crate::ode_solver::tests::test_ball_bounce(s); + let (x, v, t) = crate::ode_solver::tests::test_ball_bounce( + crate::ode_solver::tests::test_ball_bounce_problem::() + .tr_bdf2::() + .unwrap(), + ); let expected_x = [6.375884661615263]; let expected_v = [0.6878538646461059]; let expected_t = [2.5]; diff --git a/src/ode_solver/sens_equations.rs b/src/ode_solver/sens_equations.rs index ea7ced62..8229e575 100644 --- a/src/ode_solver/sens_equations.rs +++ b/src/ode_solver/sens_equations.rs @@ -1,5 +1,5 @@ use num_traits::Zero; -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use crate::{ op::nonlinear_op::NonLinearOpJacobian, AugmentedOdeEquations, ConstantOp, ConstantOpSens, @@ -7,26 +7,26 @@ use crate::{ OdeSolverProblem, Op, Vector, }; -pub struct SensInit +pub struct SensInit<'a, Eqn> where Eqn: OdeEquationsSens, { - eqn: Rc, + eqn: &'a Eqn, init_sens: Eqn::M, index: usize, } -impl SensInit +impl<'a, Eqn> SensInit<'a, Eqn> where Eqn: OdeEquationsSens, { - pub fn new(eqn: &Rc) -> Self { + pub fn new(eqn: &'a Eqn) -> Self { let nstates = eqn.rhs().nstates(); let nparams = eqn.rhs().nparams(); let init_sens = Eqn::M::new_from_sparsity(nstates, nparams, eqn.init().sens_sparsity()); let index = 0; Self { - eqn: eqn.clone(), + eqn, init_sens, index, } @@ -39,7 +39,7 @@ where } } -impl Op for SensInit +impl Op for SensInit<'_, Eqn> where Eqn: OdeEquationsSens, { @@ -58,7 +58,7 @@ where } } -impl ConstantOp for SensInit +impl ConstantOp for SensInit<'_, Eqn> where Eqn: OdeEquationsSens, { @@ -81,24 +81,24 @@ where /// Strategy is to pre-compute S = f_p from the state at given time step and store it in a matrix using [Self::update_state]. /// Then the ith column of function F(s, t) is evaluated as J * s_i + S_i, where s_i is the ith column of the sensitivity matrix /// and S_i is the ith column of the matrix S. The column to evaluate is set using [Self::set_param_index]. -pub struct SensRhs +pub struct SensRhs<'a, Eqn> where Eqn: OdeEquations, { - eqn: Rc, + eqn: &'a Eqn, sens: RefCell, y: RefCell, index: RefCell, } -impl SensRhs +impl<'a, Eqn> SensRhs<'a, Eqn> where Eqn: OdeEquationsSens, { - pub fn new(eqn: &Rc, allocate: bool) -> Self { + pub fn new(eqn: &'a Eqn, allocate: bool) -> Self { if !allocate { return Self { - eqn: eqn.clone(), + eqn, sens: RefCell::new(::zeros(0, 0)), y: RefCell::new(::zeros(0)), index: RefCell::new(0), @@ -114,7 +114,7 @@ where let y = RefCell::new(::zeros(nstates)); let index = RefCell::new(0); Self { - eqn: eqn.clone(), + eqn, sens: RefCell::new(rhs_sens), y, index, @@ -133,7 +133,7 @@ where } } -impl Op for SensRhs +impl Op for SensRhs<'_, Eqn> where Eqn: OdeEquationsSens, { @@ -152,7 +152,7 @@ where } } -impl NonLinearOp for SensRhs +impl NonLinearOp for SensRhs<'_, Eqn> where Eqn: OdeEquationsSens, { @@ -166,7 +166,7 @@ where } } -impl NonLinearOpJacobian for SensRhs +impl NonLinearOpJacobian for SensRhs<'_, Eqn> where Eqn: OdeEquationsSens, { @@ -194,18 +194,33 @@ where /// f_p is the partial derivative of the right-hand side with respect to the parameters /// dy(0)/dp is the partial derivative of the state at the initial time wrt the parameters /// -pub struct SensEquations +pub struct SensEquations<'a, Eqn> where Eqn: OdeEquationsSens, { - eqn: Rc, - rhs: Rc>, - init: Rc>, - atol: Option>, + eqn: &'a Eqn, + rhs: SensRhs<'a, Eqn>, + init: SensInit<'a, Eqn>, + atol: Option<&'a Eqn::V>, rtol: Option, } -impl std::fmt::Debug for SensEquations +impl Clone for SensEquations<'_, Eqn> +where + Eqn: OdeEquationsSens, +{ + fn clone(&self) -> Self { + Self { + eqn: self.eqn, + rhs: SensRhs::new(self.eqn, false), + init: SensInit::new(self.eqn), + rtol: self.rtol, + atol: self.atol, + } + } +} + +impl std::fmt::Debug for SensEquations<'_, Eqn> where Eqn: OdeEquationsSens, { @@ -214,27 +229,27 @@ where } } -impl SensEquations +impl<'a, Eqn> SensEquations<'a, Eqn> where Eqn: OdeEquationsSens, { - pub(crate) fn new(problem: &OdeSolverProblem) -> Self { + pub(crate) fn new(problem: &'a OdeSolverProblem) -> Self { let eqn = &problem.eqn; let rtol = problem.sens_rtol; - let atol = problem.sens_atol.clone(); - let rhs = Rc::new(SensRhs::new(eqn, true)); - let init = Rc::new(SensInit::new(eqn)); + let atol = problem.sens_atol.as_ref(); + let rhs = SensRhs::new(eqn, true); + let init = SensInit::new(eqn); Self { rhs, init, - eqn: eqn.clone(), + eqn, rtol, atol, } } } -impl Op for SensEquations +impl Op for SensEquations<'_, Eqn> where Eqn: OdeEquationsSens, { @@ -253,22 +268,22 @@ where } } -impl<'a, Eqn> OdeEquationsRef<'a> for SensEquations +impl<'a, 'b, Eqn> OdeEquationsRef<'a> for SensEquations<'b, Eqn> where Eqn: OdeEquationsSens, { - type Rhs = &'a SensRhs; + type Rhs = &'a SensRhs<'b, Eqn>; type Mass = >::Mass; type Root = >::Root; - type Init = &'a SensInit; + type Init = &'a SensInit<'b, Eqn>; type Out = >::Out; } -impl OdeEquations for SensEquations +impl<'a, Eqn> OdeEquations for SensEquations<'a, Eqn> where Eqn: OdeEquationsSens, { - fn rhs(&self) -> &SensRhs { + fn rhs(&self) -> &SensRhs<'a, Eqn> { &self.rhs } fn mass(&self) -> Option<>::Mass> { @@ -277,15 +292,18 @@ where fn root(&self) -> Option<>::Root> { None } - fn init(&self) -> &SensInit { + fn init(&self) -> &SensInit<'a, Eqn> { &self.init } fn out(&self) -> Option<>::Out> { None } + fn set_params(&mut self, p: &Self::V) { + self.eqn.set_params(p); + } } -impl AugmentedOdeEquations for SensEquations { +impl AugmentedOdeEquations for SensEquations<'_, Eqn> { fn include_in_error_control(&self) -> bool { self.rtol.is_some() && self.atol.is_some() } @@ -295,10 +313,10 @@ impl AugmentedOdeEquations for SensEquations { fn rtol(&self) -> Option { self.rtol } - fn atol(&self) -> Option<&Rc> { - self.atol.as_ref() + fn atol(&self) -> Option<&Eqn::V> { + self.atol } - fn out_atol(&self) -> Option<&Rc> { + fn out_atol(&self) -> Option<&Eqn::V> { None } fn out_rtol(&self) -> Option { @@ -309,14 +327,17 @@ impl AugmentedOdeEquations for SensEquations { self.nparams() } fn update_rhs_out_state(&mut self, y: &Eqn::V, dy: &Eqn::V, t: Eqn::T) { - Rc::get_mut(&mut self.rhs).unwrap().update_state(y, dy, t); + self.rhs.update_state(y, dy, t); } fn update_init_state(&mut self, t: Eqn::T) { - Rc::get_mut(&mut self.init).unwrap().update_state(t); + self.init.update_state(t); } fn set_index(&mut self, index: usize) { - Rc::get_mut(&mut self.rhs).unwrap().set_param_index(index); - Rc::get_mut(&mut self.init).unwrap().set_param_index(index); + self.rhs.set_param_index(index); + self.init.set_param_index(index); + } + fn integrate_main_eqn(&self) -> bool { + true } } diff --git a/src/ode_solver/state.rs b/src/ode_solver/state.rs index 8801439f..b7c4b660 100644 --- a/src/ode_solver/state.rs +++ b/src/ode_solver/state.rs @@ -1,17 +1,14 @@ use nalgebra::ComplexField; use num_traits::{One, Pow, Zero}; -use std::rc::Rc; use crate::{ error::{DiffsolError, OdeSolverError}, - nonlinear_solver::NonLinearSolver, + nonlinear_solver::{convergence::Convergence, NonLinearSolver}, ode_solver_error, scale, AugmentedOdeEquations, AugmentedOdeEquationsImplicit, ConstantOp, - DefaultSolver, InitOp, NewtonNonlinearSolver, NonLinearOp, OdeEquations, OdeEquationsImplicit, - OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, Op, SensEquations, Vector, + InitOp, LinearSolver, NewtonNonlinearSolver, NonLinearOp, OdeEquations, OdeEquationsImplicit, + OdeEquationsSens, OdeSolverProblem, Op, SensEquations, Vector, }; -use super::method::SensitivitiesOdeSolverMethod; - /// A state holding those variables that are common to all ODE solver states, /// can be used to create a new state for a specific solver. pub struct StateCommon { @@ -144,54 +141,100 @@ pub trait OdeSolverState: Clone + Sized { /// It will also set the initial step size based on the given solver. /// If you want to create a state without this default initialisation, use [Self::new_without_initialise] instead. /// You can then use [Self::set_consistent] and [Self::set_step_size] to set the state up if you need to. - fn new(ode_problem: &OdeSolverProblem, solver: &S) -> Result + fn new( + ode_problem: &OdeSolverProblem, + solver_order: usize, + ) -> Result where Eqn: OdeEquationsImplicit, - Eqn::M: DefaultSolver, - S: OdeSolverMethod, + LS: LinearSolver, { let mut ret = Self::new_without_initialise(ode_problem)?; - let mut root_solver = - NewtonNonlinearSolver::new(::default_solver()); + let mut root_solver = NewtonNonlinearSolver::new(LS::default()); ret.set_consistent(ode_problem, &mut root_solver)?; - ret.set_step_size(ode_problem, solver.order()); + ret.set_step_size(ode_problem, solver_order); Ok(ret) } - fn new_with_sensitivities( + fn new_with_sensitivities( ode_problem: &OdeSolverProblem, - solver: &S, + solver_order: usize, ) -> Result where Eqn: OdeEquationsSens, - Eqn::M: DefaultSolver, - S: SensitivitiesOdeSolverMethod, + LS: LinearSolver, { - let augmented_eqn = SensEquations::new(ode_problem); - Self::new_with_augmented(ode_problem, augmented_eqn, solver).map(|(state, _)| state) + let mut augmented_eqn = SensEquations::new(ode_problem); + Self::new_with_augmented::(ode_problem, &mut augmented_eqn, solver_order) } - fn new_with_augmented( + fn new_with_augmented( ode_problem: &OdeSolverProblem, - mut augmented_eqn: AugmentedEqn, - solver: &S, - ) -> Result<(Self, AugmentedEqn), DiffsolError> + augmented_eqn: &mut AugmentedEqn, + solver_order: usize, + ) -> Result where Eqn: OdeEquationsImplicit, AugmentedEqn: AugmentedOdeEquationsImplicit + std::fmt::Debug, - Eqn::M: DefaultSolver, - S: OdeSolverMethod, + LS: LinearSolver, { - let mut ret = Self::new_without_initialise_augmented(ode_problem, &mut augmented_eqn)?; - let mut root_solver = - NewtonNonlinearSolver::new(::default_solver()); + let mut ret = Self::new_without_initialise_augmented(ode_problem, augmented_eqn)?; + let mut root_solver = NewtonNonlinearSolver::new(LS::default()); ret.set_consistent(ode_problem, &mut root_solver)?; - let mut root_solver_sens = - NewtonNonlinearSolver::new(::default_solver()); - let augmented_eqn = - ret.set_consistent_augmented(ode_problem, augmented_eqn, &mut root_solver_sens)?; - ret.set_step_size(ode_problem, solver.order()); - Ok((ret, augmented_eqn)) + let mut root_solver_sens = NewtonNonlinearSolver::new(LS::default()); + ret.set_consistent_augmented(ode_problem, augmented_eqn, &mut root_solver_sens)?; + ret.set_step_size(ode_problem, solver_order); + Ok(ret) + } + + fn into_adjoint( + self, + ode_problem: &OdeSolverProblem, + augmented_eqn: &mut AugmentedEqn, + ) -> Result + where + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquationsImplicit + std::fmt::Debug, + LS: LinearSolver, + { + let mut state = self.into_common(); + state.h = -state.h; + let naug = augmented_eqn.max_index(); + let mut s = Vec::with_capacity(naug); + let mut ds = Vec::with_capacity(naug); + let nstates = augmented_eqn.rhs().nstates(); + for i in 0..naug { + augmented_eqn.set_index(i); + let si = augmented_eqn.init().call(state.t); + let dsi = V::zeros(nstates); + s.push(si); + ds.push(dsi); + } + state.s = s; + state.ds = ds; + let (dsg, sg) = if augmented_eqn.out().is_some() { + let mut sg = Vec::with_capacity(naug); + let mut dsg = Vec::with_capacity(naug); + for i in 0..naug { + augmented_eqn.set_index(i); + let out = augmented_eqn + .out() + .ok_or(ode_solver_error!(StateProblemMismatch))?; + let dsgi = out.call(&state.s[i], state.t); + let sgi = V::zeros(out.nout()); + sg.push(sgi); + dsg.push(dsgi); + } + (dsg, sg) + } else { + (vec![], vec![]) + }; + state.sg = sg; + state.dsg = dsg; + let mut state = Self::new_from_common(state); + let mut root_solver_sens = NewtonNonlinearSolver::new(LS::default()); + state.set_consistent_augmented(ode_problem, augmented_eqn, &mut root_solver_sens)?; + Ok(state) } /// Create a new solver state from an ODE problem, without any initialisation apart from setting the initial time state vector y, @@ -298,13 +341,14 @@ pub trait OdeSolverState: Clone + Sized { } let f = InitOp::new(&ode_problem.eqn, ode_problem.t0, state.y); let rtol = ode_problem.rtol; - let atol = ode_problem.atol.clone(); - root_solver.set_problem(&f, rtol, atol); + let atol = &ode_problem.atol; + root_solver.set_problem(&f); let mut y_tmp = state.dy.clone(); y_tmp.copy_from_indices(state.y, &f.algebraic_indices); let yerr = y_tmp.clone(); root_solver.reset_jacobian(&f, &y_tmp, *state.t); - root_solver.solve_in_place(&f, &mut y_tmp, *state.t, &yerr)?; + let mut convergence = Convergence::new(rtol, atol); + root_solver.solve_in_place(&f, &mut y_tmp, *state.t, &yerr, &mut convergence)?; f.scatter_soln(&y_tmp, state.y, state.dy); Ok(()) } @@ -315,9 +359,9 @@ pub trait OdeSolverState: Clone + Sized { fn set_consistent_augmented( &mut self, ode_problem: &OdeSolverProblem, - mut augmented_eqn: AugmentedEqn, + augmented_eqn: &mut AugmentedEqn, root_solver: &mut S, - ) -> Result + ) -> Result<(), DiffsolError> where Eqn: OdeEquationsImplicit, AugmentedEqn: AugmentedOdeEquationsImplicit + std::fmt::Debug, @@ -334,24 +378,23 @@ pub trait OdeSolverState: Clone + Sized { } if ode_problem.eqn.mass().is_none() { - return Ok(augmented_eqn); + return Ok(()); } - let mut augmented_eqn_rc = Rc::new(augmented_eqn); - + let mut convergence = Convergence::new(ode_problem.rtol, &ode_problem.atol); for i in 0..naug { - Rc::get_mut(&mut augmented_eqn_rc).unwrap().set_index(i); - let f = InitOp::new(&augmented_eqn_rc, ode_problem.t0, &state.s[i]); - root_solver.set_problem(&f, ode_problem.rtol, ode_problem.atol.clone()); + augmented_eqn.set_index(i); + let f = InitOp::new(augmented_eqn, *state.t, &state.s[i]); + root_solver.set_problem(&f); let mut y = state.ds[i].clone(); y.copy_from_indices(state.y, &f.algebraic_indices); let yerr = y.clone(); root_solver.reset_jacobian(&f, &y, *state.t); - root_solver.solve_in_place(&f, &mut y, *state.t, &yerr)?; + root_solver.solve_in_place(&f, &mut y, *state.t, &yerr, &mut convergence)?; f.scatter_soln(&y, &mut state.s[i], &mut state.ds[i]); } - Ok(Rc::try_unwrap(augmented_eqn_rc).unwrap()) + Ok(()) } /// compute size of first step based on alg in Hairer, Norsett, Wanner @@ -371,7 +414,7 @@ pub trait OdeSolverState: Clone + Sized { let f0 = state.dy; let rtol = ode_problem.rtol; - let atol = ode_problem.atol.as_ref(); + let atol = &ode_problem.atol; let d0 = y0.squared_norm(y0, atol, rtol).sqrt(); let d1 = f0.squared_norm(y0, atol, rtol).sqrt(); diff --git a/src/ode_solver/sundials.rs b/src/ode_solver/sundials.rs deleted file mode 100644 index b4eea999..00000000 --- a/src/ode_solver/sundials.rs +++ /dev/null @@ -1,571 +0,0 @@ -use crate::{ - sundials_sys::{ - realtype, IDACalcIC, IDACreate, IDAFree, IDAGetDky, IDAGetIntegratorStats, - IDAGetNonlinSolvStats, IDAGetReturnFlagName, IDAInit, IDAReInit, IDASVtolerances, IDASetId, - IDASetJacFn, IDASetLinearSolver, IDASetStopTime, IDASetUserData, IDASolve, N_Vector, - SUNLinSolFree, SUNLinSolInitialize, SUNLinSol_Dense, SUNLinearSolver, SUNMatrix, - IDA_CONSTR_FAIL, IDA_CONV_FAIL, IDA_ERR_FAIL, IDA_ILL_INPUT, IDA_LINIT_FAIL, - IDA_LSETUP_FAIL, IDA_LSOLVE_FAIL, IDA_MEM_NULL, IDA_ONE_STEP, IDA_REP_RES_ERR, - IDA_RES_FAIL, IDA_ROOT_RETURN, IDA_RTFUNC_FAIL, IDA_SUCCESS, IDA_TOO_MUCH_ACC, - IDA_TOO_MUCH_WORK, IDA_TSTOP_RETURN, IDA_YA_YDP_INIT, - }, - SdirkState, StateRef, StateRefMut, -}; -use num_traits::Zero; -use serde::Serialize; -use std::{ - ffi::{c_int, c_long, c_void, CStr}, - rc::Rc, -}; - -use crate::{ - error::*, ode_solver_error, scale, LinearOp, Matrix, NonLinearOp, NonLinearOpJacobian, - OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, - Op, SundialsMatrix, SundialsVector, Vector, -}; - -#[cfg(not(sundials_version_major = "5"))] -use crate::vector::sundials::get_suncontext; - -pub fn sundials_check(retval: c_int) -> Result<(), DiffsolError> { - if retval < 0 { - let char_ptr = unsafe { IDAGetReturnFlagName(i64::from(retval)) }; - let c_str = unsafe { CStr::from_ptr(char_ptr) }; - Err(DiffsolError::from(OdeSolverError::SundialsError( - c_str.to_str().unwrap().to_string(), - ))) - } else { - Ok(()) - } -} - -#[derive(Clone, Debug, Serialize)] -pub struct SundialsStatistics { - pub number_of_linear_solver_setups: usize, - pub number_of_steps: usize, - pub number_of_error_test_failures: usize, - pub number_of_nonlinear_solver_iterations: usize, - pub number_of_nonlinear_solver_fails: usize, -} - -impl SundialsStatistics { - fn new() -> Self { - Self { - number_of_linear_solver_setups: 0, - number_of_steps: 0, - number_of_error_test_failures: 0, - number_of_nonlinear_solver_iterations: 0, - number_of_nonlinear_solver_fails: 0, - } - } - fn new_from_ida(ida_mem: *mut c_void) -> Result { - let mut nsteps: c_long = 0; - let mut nrevals: c_long = 0; - let mut nlinsetups: c_long = 0; - let mut netfails: c_long = 0; - let mut klast: c_int = 0; - let mut kcur: c_int = 0; - let mut hinused: realtype = 0.; - let mut hlast: realtype = 0.; - let mut hcur: realtype = 0.; - let mut tcur: realtype = 0.; - - sundials_check(unsafe { - IDAGetIntegratorStats( - ida_mem, - &mut nsteps, - &mut nrevals, - &mut nlinsetups, - &mut netfails, - &mut klast, - &mut kcur, - &mut hinused, - &mut hlast, - &mut hcur, - &mut tcur, - ) - })?; - - let mut nniters: c_long = 0; - let mut nncfails: c_long = 0; - sundials_check(unsafe { IDAGetNonlinSolvStats(ida_mem, &mut nniters, &mut nncfails) })?; - - Ok(Self { - number_of_linear_solver_setups: nlinsetups.try_into().unwrap(), - number_of_steps: nsteps.try_into().unwrap(), - number_of_error_test_failures: netfails.try_into().unwrap(), - number_of_nonlinear_solver_iterations: nniters.try_into().unwrap(), - number_of_nonlinear_solver_fails: nncfails.try_into().unwrap(), - }) - } -} - -struct SundialsData -where - Eqn: OdeEquationsImplicit, -{ - eqn: Rc, - rhs_jac: SundialsMatrix, - mass: SundialsMatrix, -} - -impl SundialsData -where - Eqn: OdeEquationsImplicit, -{ - fn new(eqn: Rc) -> Self { - let n = eqn.rhs().nstates(); - let rhs_jac_sparsity = eqn.rhs().jacobian_sparsity(); - let rhs_jac = SundialsMatrix::new_from_sparsity(n, n, rhs_jac_sparsity); - let mass = if let Some(mass) = eqn.mass() { - let mass_sparsity = mass.sparsity(); - SundialsMatrix::new_from_sparsity(n, n, mass_sparsity) - } else { - let ones = SundialsVector::from_element(n, 1.0); - SundialsMatrix::from_diagonal(&ones) - }; - Self { rhs_jac, mass, eqn } - } -} - -pub struct SundialsIda -where - Eqn: OdeEquationsImplicit, -{ - ida_mem: *mut c_void, - linear_solver: SUNLinearSolver, - data: Option>, - problem: Option>, - yp: SundialsVector, - jacobian: SundialsMatrix, - statistics: SundialsStatistics, - state: Option>, - is_state_modified: bool, -} - -impl SundialsIda -where - Eqn: OdeEquationsImplicit, -{ - extern "C" fn residual( - t: realtype, - y: N_Vector, - yp: N_Vector, - rr: N_Vector, - user_data: *mut c_void, - ) -> i32 { - let data = unsafe { &*(user_data as *const SundialsData) }; - let y = SundialsVector::new_not_owned(y); - let yp = SundialsVector::new_not_owned(yp); - let mut rr = SundialsVector::new_not_owned(rr); - // F(t, y, y') = M y' - f(t, y) - // rr = f(t, y) - data.eqn.rhs().call_inplace(&y, t, &mut rr); - // rr = M y' - rr - if let Some(mass) = data.eqn.mass() { - mass.gemv_inplace(&yp, t, -1.0, &mut rr); - } else { - rr.axpy(1.0, &yp, -1.0); - } - 0 - } - - extern "C" fn jacobian( - t: realtype, - c_j: realtype, - y: N_Vector, - _yp: N_Vector, - _r: N_Vector, - jac: SUNMatrix, - user_data: *mut c_void, - _tmp1: N_Vector, - _tmp2: N_Vector, - _tmp3: N_Vector, - ) -> i32 { - let data = unsafe { &mut *(user_data as *mut SundialsData) }; - let eqn = &data.eqn; - - // jac = c_j * M - rhs_jac - let y = SundialsVector::new_not_owned(y); - let mut jac = SundialsMatrix::new_not_owned(jac); - if let Some(mass) = eqn.mass() { - mass.matrix_inplace(t, &mut data.mass); - } - eqn.rhs().jacobian_inplace(&y, t, &mut data.rhs_jac); - data.rhs_jac *= scale(-1.0); - jac.scale_add_and_assign(&data.rhs_jac, c_j, &data.mass); - 0 - } - - fn check(retval: c_int) -> Result<(), DiffsolError> { - sundials_check(retval) - } - - pub fn new() -> Self { - #[cfg(not(sundials_version_major = "5"))] - let ida_mem = unsafe { IDACreate(*get_suncontext()) }; - - #[cfg(sundials_version_major = "5")] - let ida_mem = unsafe { IDACreate() }; - - let yp = SundialsVector::new_serial(0); - let jacobian = SundialsMatrix::new_dense(0, 0); - - Self { - ida_mem, - data: None, - problem: None, - yp, - linear_solver: std::ptr::null_mut(), - statistics: SundialsStatistics::new(), - jacobian, - state: None, - is_state_modified: false, - } - } - - pub fn get_statistics(&self) -> &SundialsStatistics { - &self.statistics - } - - pub fn calc_ic(&mut self, t: realtype) -> Result<(), DiffsolError> { - if self.problem.is_none() { - return Err(ode_solver_error!(ProblemNotSet)); - } - if self.problem.as_ref().unwrap().eqn.mass().is_none() { - return Ok(()); - } - let diag = self - .problem - .as_ref() - .unwrap() - .eqn - .mass() - .unwrap() - .matrix(t) - .diagonal(); - let id = diag.filter_indices(|x| x == Eqn::T::zero()); - let number_of_states = self.problem.as_ref().unwrap().eqn.rhs().nstates(); - // need to convert to realtype sundials vector - let mut id_realtype = SundialsVector::new_serial(number_of_states); - for i in 0..number_of_states { - match id[i] { - 1 => id_realtype[i] = 1.0, - _ => id_realtype[i] = 0.0, - } - } - Self::check(unsafe { IDASetId(self.ida_mem, id_realtype.sundials_vector()) })?; - Self::check(unsafe { IDACalcIC(self.ida_mem, IDA_YA_YDP_INIT, t) })?; - Ok(()) - } -} - -impl Default for SundialsIda -where - Eqn: OdeEquationsImplicit, -{ - fn default() -> Self { - Self::new() - } -} - -impl Drop for SundialsIda -where - Eqn: OdeEquationsImplicit, -{ - fn drop(&mut self) { - if !self.linear_solver.is_null() { - unsafe { SUNLinSolFree(self.linear_solver) }; - } - unsafe { IDAFree(&mut self.ida_mem) }; - } -} - -impl OdeSolverMethod for SundialsIda -where - Eqn: OdeEquationsImplicit, -{ - type State = SdirkState; - - fn checkpoint(&mut self) -> Result { - self.state - .as_ref() - .cloned() - .ok_or(ode_solver_error!(StateNotSet)) - } - - fn problem(&self) -> Option<&OdeSolverProblem> { - self.problem.as_ref() - } - - fn state(&self) -> Option> { - self.state.as_ref().map(|s| s.as_ref()) - } - - fn order(&self) -> usize { - 1 - } - - fn state_mut(&mut self) -> Option> { - self.is_state_modified = true; - self.state.as_mut().map(|s| s.as_mut()) - } - - fn take_state(&mut self) -> Option { - Option::take(&mut self.state) - } - - fn set_problem( - &mut self, - state: Self::State, - problem: &OdeSolverProblem, - ) -> Result<(), DiffsolError> { - self.state = Some(state); - let state = self.state.as_ref().unwrap(); - self.problem = Some(problem.clone()); - let eqn = problem.eqn.as_ref(); - let number_of_states = eqn.rhs().nstates(); - let ida_mem = self.ida_mem; - - // set user data - self.data = Some(SundialsData::new(problem.eqn.clone())); - Self::check(unsafe { IDASetUserData(self.ida_mem, &self.data as *const _ as *mut c_void) }) - .unwrap(); - - // initialize - self.yp = ::zeros(number_of_states); - Self::check(unsafe { - IDAInit( - ida_mem, - Some(Self::residual), - state.t, - state.y.sundials_vector(), - self.yp.sundials_vector(), - ) - }) - .unwrap(); - - // tolerances - let rtol = problem.rtol; - let atol = problem.atol.as_ref(); - Self::check(unsafe { IDASVtolerances(ida_mem, rtol, atol.sundials_vector()) }).unwrap(); - - // linear solver - self.jacobian = SundialsMatrix::new_dense(number_of_states, number_of_states); - - self.linear_solver = unsafe { - #[cfg(not(sundials_version_major = "5"))] - { - SUNLinSol_Dense( - state.y.sundials_vector(), - self.jacobian.sundials_matrix(), - *get_suncontext(), - ) - } - #[cfg(sundials_version_major = "5")] - { - SUNLinSol_Dense(state.y.sundials_vector(), self.jacobian.sundials_matrix()) - } - }; - - Self::check(unsafe { SUNLinSolInitialize(self.linear_solver) }).unwrap(); - Self::check(unsafe { - IDASetLinearSolver(ida_mem, self.linear_solver, self.jacobian.sundials_matrix()) - }) - .unwrap(); - - // set jacobian function - Self::check(unsafe { IDASetJacFn(ida_mem, Some(Self::jacobian)) }).unwrap(); - - Ok(()) - } - - fn set_stop_time(&mut self, tstop: Eqn::T) -> Result<(), DiffsolError> { - Self::check(unsafe { IDASetStopTime(self.ida_mem, tstop) }) - } - - fn step(&mut self) -> Result, DiffsolError> { - let state = self.state.as_mut().ok_or(ode_solver_error!(StateNotSet))?; - if self.problem.is_none() { - return Err(ode_solver_error!(ProblemNotSet)); - } - if self.is_state_modified { - // reinit as state has been modified - Self::check(unsafe { - IDAReInit( - self.ida_mem, - state.t, - state.y.sundials_vector(), - self.yp.sundials_vector(), - ) - })? - } - let itask = IDA_ONE_STEP; - let retval = unsafe { - IDASolve( - self.ida_mem, - state.t + 1.0, - &mut state.t as *mut realtype, - state.y.sundials_vector(), - self.yp.sundials_vector(), - itask, - ) - }; - - // update stats - self.statistics = SundialsStatistics::new_from_ida(self.ida_mem).unwrap(); - - // check return value - match retval { - IDA_SUCCESS => Ok(OdeSolverStopReason::InternalTimestep), - IDA_TSTOP_RETURN => Ok(OdeSolverStopReason::TstopReached), - IDA_ROOT_RETURN => Ok(OdeSolverStopReason::RootFound(state.t)), - IDA_MEM_NULL => Err(ode_solver_error!(SundialsError, "The ida_mem argument was NULL.")), - IDA_ILL_INPUT => Err(ode_solver_error!(SundialsError, "One of the inputs to IDASolve() was illegal, or some other input to the solver was either illegal or missing.")), - IDA_TOO_MUCH_WORK => Err(ode_solver_error!(SundialsError, "The solver took mxstep internal steps but could not reach tout.")), - IDA_TOO_MUCH_ACC => Err(ode_solver_error!(SundialsError, "The solver could not satisfy the accuracy demanded by the user for some internal step.")), - IDA_ERR_FAIL => Err(ode_solver_error!(SundialsError, "Error test failures occurred too many times (MXNEF = 10) during one internal time step or occurred with.")), - IDA_CONV_FAIL => Err(ode_solver_error!(SundialsError, "Convergence test failures occurred too many times (MXNCF = 10) during one internal time step or occurred with.")), - IDA_LINIT_FAIL => Err(ode_solver_error!(SundialsError, "The linear solver’s initialization function failed.")), - IDA_LSETUP_FAIL => Err(ode_solver_error!(SundialsError, "The linear solver’s setup function failed in an unrecoverable manner.")), - IDA_LSOLVE_FAIL => Err(ode_solver_error!(SundialsError, "The linear solver’s solve function failed in an unrecoverable manner.")), - IDA_CONSTR_FAIL => Err(ode_solver_error!(SundialsError, "The inequality constraints were violated and the solver was unable to recover.")), - IDA_REP_RES_ERR => Err(ode_solver_error!(SundialsError, "The user’s residual function repeatedly returned a recoverable error flag, but the solver was unable to recover.")), - IDA_RES_FAIL => Err(ode_solver_error!(SundialsError, "The user’s residual function returned a nonrecoverable error flag.")), - IDA_RTFUNC_FAIL => Err(ode_solver_error!(SundialsError, "The rootfinding function failed.")), - _ => Err(ode_solver_error!(SundialsError, "Unknown error")), - } - } - - fn interpolate(&self, t: ::T) -> Result { - if self.data.is_none() { - return Err(ode_solver_error!(ProblemNotSet)); - } - let state = self.state.as_ref().ok_or(ode_solver_error!(StateNotSet))?; - if t > state.t { - return Err(ode_solver_error!(InterpolationTimeGreaterThanCurrentTime)); - } - let ret = SundialsVector::new_serial(self.data.as_ref().unwrap().eqn.rhs().nstates()); - Self::check(unsafe { IDAGetDky(self.ida_mem, t, 0, ret.sundials_vector()) }).unwrap(); - Ok(ret) - } - - fn interpolate_out(&self, _t: Eqn::T) -> Result { - unimplemented!() - } - - fn interpolate_sens(&self, _t: Eqn::T) -> Result, DiffsolError> { - unimplemented!() - } -} - -#[cfg(test)] -mod test { - - use crate::{ - ode_solver::{ - test_models::{ - exponential_decay::exponential_decay_problem, foodweb::foodweb_problem, - heat2d::head2d_problem, robertson::robertson, - }, - tests::{ - test_interpolate, test_no_set_problem, test_ode_solver_no_sens, test_state_mut, - }, - }, - OdeEquations, Op, SundialsIda, SundialsMatrix, - }; - - type M = SundialsMatrix; - #[test] - fn sundials_no_set_problem() { - test_no_set_problem::(SundialsIda::default()) - } - #[test] - fn sundials_state_mut() { - test_state_mut::(SundialsIda::default()) - } - #[test] - fn sundials_interpolate() { - test_interpolate::(SundialsIda::default()) - } - - #[test] - fn test_sundials_exponential_decay() { - let mut s = crate::SundialsIda::default(); - let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 18 - number_of_steps: 43 - number_of_error_test_failures: 3 - number_of_nonlinear_solver_iterations: 63 - number_of_nonlinear_solver_fails: 0 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 65 - number_of_jac_muls: 36 - number_of_matrix_evals: 18 - number_of_jac_adj_muls: 0 - "###); - } - - #[test] - fn test_sundials_robertson() { - let mut s = crate::SundialsIda::default(); - let (problem, soln) = robertson::(false); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 59 - number_of_steps: 355 - number_of_error_test_failures: 15 - number_of_nonlinear_solver_iterations: 506 - number_of_nonlinear_solver_fails: 1 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 509 - number_of_jac_muls: 180 - number_of_matrix_evals: 60 - number_of_jac_adj_muls: 0 - "###); - } - - #[test] - fn test_sundials_foodweb() { - let mut s = crate::SundialsIda::default(); - let (problem, soln) = foodweb_problem::(); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 42 - number_of_steps: 256 - number_of_error_test_failures: 9 - number_of_nonlinear_solver_iterations: 458 - number_of_nonlinear_solver_fails: 1 - "###); - } - #[test] - fn test_sundials_heat2d() { - let mut s = crate::SundialsIda::default(); - let (problem, soln) = head2d_problem::(); - test_ode_solver_no_sens(&mut s, &problem, soln, None, false); - insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- - number_of_linear_solver_setups: 42 - number_of_steps: 165 - number_of_error_test_failures: 11 - number_of_nonlinear_solver_iterations: 214 - number_of_nonlinear_solver_fails: 0 - "###); - insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- - number_of_calls: 217 - number_of_jac_muls: 4300 - number_of_matrix_evals: 43 - number_of_jac_adj_muls: 0 - "###); - } -} diff --git a/src/ode_solver/test_models/dydt_y2.rs b/src/ode_solver/test_models/dydt_y2.rs index 6149ce29..8b5c5732 100644 --- a/src/ode_solver/test_models/dydt_y2.rs +++ b/src/ode_solver/test_models/dydt_y2.rs @@ -29,12 +29,12 @@ pub fn dydt_y2_problem( let size2 = size; let y0 = -200.; let tlast = 20.0; - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .use_coloring(use_coloring) .rtol(1e-4) - .build_ode(rhs::, rhs_jac::, move |_p, _t| { - M::V::from_vec([y0.into()].repeat(size2)) - }) + .rhs_implicit(rhs::, rhs_jac::) + .init(move |_p, _t| M::V::from_vec([y0.into()].repeat(size2))) + .build() .unwrap(); let mut soln = OdeSolverSolution::default(); let y0 = M::V::from_vec([y0.into()].repeat(size)); diff --git a/src/ode_solver/test_models/exponential_decay.rs b/src/ode_solver/test_models/exponential_decay.rs index aaccf7f2..f5b3a707 100644 --- a/src/ode_solver/test_models/exponential_decay.rs +++ b/src/ode_solver/test_models/exponential_decay.rs @@ -1,12 +1,11 @@ use crate::{ - matrix::Matrix, ode_solver::problem::OdeSolverSolution, - op::closure_with_adjoint::ClosureWithAdjoint, scalar::scale, ConstantClosureWithAdjoint, - ConstantOp, OdeBuilder, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, - OdeEquationsSens, OdeSolverEquations, OdeSolverProblem, UnitCallable, Vector, + matrix::Matrix, ode_solver::problem::OdeSolverSolution, scalar::scale, ConstantOp, OdeBuilder, + OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsSens, OdeSolverProblem, + Vector, }; use nalgebra::ComplexField; use num_traits::{One, Zero}; -use std::{ops::MulAssign, rc::Rc}; +use std::ops::MulAssign; // exponential decay problem // dy/dt = -ay (p = [a, y0]) @@ -141,15 +140,13 @@ pub fn negative_exponential_decay_problem( let h = -1.0; let k = 0.1; let y0 = (-10.0 * k).exp(); - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .h0(h) .p([k, y0]) .use_coloring(use_coloring) - .build_ode( - exponential_decay::, - exponential_decay_jacobian::, - exponential_decay_init::, - ) + .rhs_implicit(exponential_decay::, exponential_decay_jacobian::) + .init(exponential_decay_init::) + .build() .unwrap(); let p = [M::T::from(k), M::T::from(y0)]; let mut soln = OdeSolverSolution { @@ -175,15 +172,13 @@ pub fn exponential_decay_problem( let h = 1.0; let k = 0.1; let y0 = 1.0; - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .h0(h) .p([k, y0]) .use_coloring(use_coloring) - .build_ode( - exponential_decay::, - exponential_decay_jacobian::, - exponential_decay_init::, - ) + .rhs_implicit(exponential_decay::, exponential_decay_jacobian::) + .init(exponential_decay_init::) + .build() .unwrap(); let p = [M::T::from(k), M::T::from(y0)]; let mut soln = OdeSolverSolution::default(); @@ -205,16 +200,13 @@ pub fn exponential_decay_problem_with_root( ) { let k = 0.1; let y0 = 1.0; - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .p([k, y0]) .use_coloring(use_coloring) - .build_ode_with_root( - exponential_decay::, - exponential_decay_jacobian::, - exponential_decay_init::, - exponential_decay_root::, - 1, - ) + .rhs_implicit(exponential_decay::, exponential_decay_jacobian::) + .init(exponential_decay_init::) + .root(exponential_decay_root::, 1) + .build() .unwrap(); let p = [M::T::from(k), M::T::from(y0)]; let mut soln = OdeSolverSolution::default(); @@ -232,85 +224,38 @@ pub fn exponential_decay_problem_adjoint() -> ( OdeSolverProblem>, OdeSolverSolution, ) { - let k = M::T::from(0.1); - let y0 = M::T::from(1.0); - let t0 = M::T::from(0.0); - let h0 = M::T::from(1.0); - let p = Rc::new(M::V::from_vec(vec![k, y0])); - let init = exponential_decay_init::; - let y0 = init(&p, t0); - let nstates = y0.len(); - let rhs = exponential_decay::; - let rhs_jac = exponential_decay_jacobian::; - let rhs_adj_jac = exponential_decay_jacobian_adjoint::; - let rhs_sens_adj = exponential_decay_sens_transpose::; - let mut rhs = ClosureWithAdjoint::new( - rhs, - rhs_jac, - rhs_adj_jac, - rhs_sens_adj, - nstates, - nstates, - p.clone(), - ); - let nout = 2; - let out = exponential_decay_out::; - let out_jac = exponential_decay_out_jac_mul::; - let out_jac_adj = exponential_decay_out_adj_mul::; - let out_sens_adj = exponential_decay_out_sens_adj::; - let out = ClosureWithAdjoint::new( - out, - out_jac, - out_jac_adj, - out_sens_adj, - nstates, - nout, - p.clone(), - ); - let init = ConstantClosureWithAdjoint::new( - exponential_decay_init::, - exponential_decay_init_sens_adjoint::, - p.clone(), - ); - if M::is_sparse() { - rhs.calculate_jacobian_sparsity(&y0, t0); - rhs.calculate_adjoint_sparsity(&y0, t0); - } - let out = Some(out); - let mass: Option> = None; - let root: Option> = None; - let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); - let rtol = M::T::from(1e-6); - let atol = Rc::new(M::V::from_element(nstates, M::T::from(1e-6))); - let out_rtol = Some(M::T::from(1e-6)); - let out_atol = Some(Rc::new(M::V::from_element(nout, M::T::from(1e-6)))); - let param_rtol = Some(M::T::from(1e-6)); - let param_atol = Some(Rc::new(M::V::from_element(p.len(), M::T::from(1e-6)))); - let sens_rtol = Some(M::T::from(1e-6)); - let sens_atol = Some(Rc::new(M::V::from_element(nstates, M::T::from(1e-6)))); - let integrate_out = true; - let problem = OdeSolverProblem::new( - Rc::new(eqn), - rtol, - atol, - sens_rtol, - sens_atol, - out_rtol, - out_atol, - param_rtol, - param_atol, - t0, - h0, - integrate_out, - ) - .unwrap(); + let k = 0.1; + let y0 = 1.0; + let problem = OdeBuilder::::new() + .p([k, y0]) + .integrate_out(true) + .rhs_adjoint_implicit( + exponential_decay::, + exponential_decay_jacobian::, + exponential_decay_jacobian_adjoint::, + exponential_decay_sens_transpose::, + ) + .init_adjoint( + exponential_decay_init::, + exponential_decay_init_sens_adjoint::, + ) + .out_adjoint_implicit( + exponential_decay_out::, + exponential_decay_out_jac_mul::, + exponential_decay_out_adj_mul::, + exponential_decay_out_sens_adj::, + 2, + ) + .build() + .unwrap(); let mut soln = OdeSolverSolution { - atol: problem.atol.as_ref().clone(), + atol: problem.atol.clone(), rtol: problem.rtol, ..Default::default() }; let t0 = M::T::from(0.0); let t1 = M::T::from(9.0); + let p = [M::T::from(k), M::T::from(y0)]; for i in 0..10 { let t = M::T::from(i as f64); let y0: M::V = problem.eqn.init().call(M::T::zero()); @@ -347,18 +292,21 @@ pub fn exponential_decay_problem_sens( ) { let k = 0.1; let y0 = 1.0; - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .p([k, y0]) - .sens_rtol(Some(1e-6)) - .sens_atol(Some([1e-6, 1e-6])) + .sens_rtol(1e-6) + .sens_atol([1e-6, 1e-6]) .use_coloring(use_coloring) - .build_ode_with_sens( + .rhs_sens_implicit( exponential_decay::, exponential_decay_jacobian::, exponential_decay_sens::, + ) + .init_sens( exponential_decay_init::, exponential_decay_init_sens::, ) + .build() .unwrap(); let p = [M::T::from(k), M::T::from(y0)]; let mut soln = OdeSolverSolution::default(); diff --git a/src/ode_solver/test_models/exponential_decay_with_algebraic.rs b/src/ode_solver/test_models/exponential_decay_with_algebraic.rs index ae5053d1..286d614b 100644 --- a/src/ode_solver/test_models/exponential_decay_with_algebraic.rs +++ b/src/ode_solver/test_models/exponential_decay_with_algebraic.rs @@ -1,18 +1,10 @@ use crate::{ - matrix::Matrix, - ode_solver::problem::OdeSolverSolution, - op::{ - closure_with_sens::ClosureWithSens, constant_closure_with_sens::ConstantClosureWithSens, - linear_closure_with_adjoint::LinearClosureWithAdjoint, - }, - scalar::scale, - ClosureWithAdjoint, ConstantClosureWithAdjoint, ConstantOp, LinearClosure, OdeBuilder, - OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsSens, OdeSolverEquations, - OdeSolverProblem, UnitCallable, Vector, + matrix::Matrix, ode_solver::problem::OdeSolverSolution, scalar::scale, OdeBuilder, + OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsSens, OdeSolverProblem, Vector, }; use nalgebra::ComplexField; use num_traits::{One, Zero}; -use std::{ops::MulAssign, rc::Rc}; +use std::ops::MulAssign; // exponential decay problem with algebraic constraint // dy/dt = -ay @@ -208,15 +200,16 @@ pub fn exponential_decay_with_algebraic_problem( OdeSolverSolution, ) { let p = M::V::from_vec(vec![0.1.into()]); - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .p([0.1]) .use_coloring(use_coloring) - .build_ode_with_mass( + .rhs_implicit( exponential_decay_with_algebraic::, exponential_decay_with_algebraic_jacobian::, - exponential_decay_with_algebraic_mass::, - exponential_decay_with_algebraic_init::, ) + .mass(exponential_decay_with_algebraic_mass::) + .init(exponential_decay_with_algebraic_init::) + .build() .unwrap(); let mut soln = OdeSolverSolution::default(); @@ -234,87 +227,36 @@ pub fn exponential_decay_with_algebraic_adjoint_problem() - OdeSolverProblem>, OdeSolverSolution, ) { - let a = M::T::from(0.1); - let t0 = M::T::from(0.0); - let h0 = M::T::from(1.0); - let p = Rc::new(M::V::from_vec(vec![a])); - let init = exponential_decay_with_algebraic_init::; - let y0 = init(&p, t0); - let nstates = y0.len(); - let rhs = exponential_decay_with_algebraic::; - let rhs_jac = exponential_decay_with_algebraic_jacobian::; - let rhs_adj_jac = exponential_decay_with_algebraic_adjoint::; - let rhs_sens_adj = exponential_decay_with_algebraic_sens_adjoint::; - let mut rhs = ClosureWithAdjoint::new( - rhs, - rhs_jac, - rhs_adj_jac, - rhs_sens_adj, - nstates, - nstates, - p.clone(), - ); + let a = 0.1; let nout = 1; - let out = exponential_decay_with_algebraic_out::; - let out_jac = exponential_decay_with_algebraic_out_jac_mul::; - let out_jac_adj = exponential_decay_with_algebraic_out_jac_adj_mul::; - let out_sens_adj = exponential_decay_with_algebraic_out_sens_adj::; - let out = ClosureWithAdjoint::new( - out, - out_jac, - out_jac_adj, - out_sens_adj, - nstates, - nout, - p.clone(), - ); - let init = ConstantClosureWithAdjoint::new( - exponential_decay_with_algebraic_init::, - exponential_decay_with_algebraic_init_sens_adjoint::, - p.clone(), - ); - let mut mass = LinearClosureWithAdjoint::new( - exponential_decay_with_algebraic_mass::, - exponential_decay_with_algebraic_mass_transpose::, - nstates, - nstates, - p.clone(), - ); - if M::is_sparse() { - rhs.calculate_jacobian_sparsity(&y0, t0); - rhs.calculate_adjoint_sparsity(&y0, t0); - mass.calculate_sparsity(t0); - mass.calculate_adjoint_sparsity(t0); - } - let out = Some(out); + let problem = OdeBuilder::::new() + .p([a]) + .integrate_out(true) + .rhs_adjoint_implicit( + exponential_decay_with_algebraic::, + exponential_decay_with_algebraic_jacobian::, + exponential_decay_with_algebraic_adjoint::, + exponential_decay_with_algebraic_sens_adjoint::, + ) + .init_adjoint( + exponential_decay_with_algebraic_init::, + exponential_decay_with_algebraic_init_sens_adjoint::, + ) + .mass_adjoint( + exponential_decay_with_algebraic_mass::, + exponential_decay_with_algebraic_mass_transpose::, + ) + .out_adjoint_implicit( + exponential_decay_with_algebraic_out::, + exponential_decay_with_algebraic_out_jac_mul::, + exponential_decay_with_algebraic_out_jac_adj_mul::, + exponential_decay_with_algebraic_out_sens_adj::, + nout, + ) + .build() + .unwrap(); - let root: Option> = None; - let mass = Some(mass); - let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); - let rtol = M::T::from(1e-6); - let atol = Rc::new(M::V::from_element(nstates, M::T::from(1e-6))); - let out_rtol = Some(M::T::from(1e-6)); - let out_atol = Some(Rc::new(M::V::from_element(nout, M::T::from(1e-6)))); - let param_rtol = Some(M::T::from(1e-6)); - let param_atol = Some(Rc::new(M::V::from_element(1, M::T::from(1e-6)))); - let sens_atol = Some(Rc::new(M::V::from_element(nstates, M::T::from(1e-6)))); - let sens_rtol = Some(M::T::from(1e-6)); - let integrate_out = true; - let problem = OdeSolverProblem::new( - Rc::new(eqn), - rtol, - atol, - sens_rtol, - sens_atol, - out_rtol, - out_atol, - param_rtol, - param_atol, - t0, - h0, - integrate_out, - ) - .unwrap(); + let p = M::V::from_vec(vec![a.into()]); let atol_out = M::V::from_element(nout, M::T::from(1e-6)); let mut soln = OdeSolverSolution { atol: atol_out, @@ -340,53 +282,23 @@ pub fn exponential_decay_with_algebraic_problem_sens() -> ( OdeSolverProblem>, OdeSolverSolution, ) { - let p = Rc::new(M::V::from_vec(vec![0.1.into()])); - let mut rhs = ClosureWithSens::new( - exponential_decay_with_algebraic::, - exponential_decay_with_algebraic_jacobian::, - exponential_decay_with_algebraic_sens::, - 3, - 3, - p.clone(), - ); - let mut mass = LinearClosure::new(exponential_decay_with_algebraic_mass::, 3, 3, p.clone()); - let init = ConstantClosureWithSens::new( - exponential_decay_with_algebraic_init::, - exponential_decay_with_algebraic_init_sens::, - 3, - 3, - p.clone(), - ); - let t0 = M::T::zero(); - - if M::is_sparse() { - let y0 = init.call(t0); - rhs.calculate_jacobian_sparsity(&y0, t0); - rhs.calculate_sens_sparsity(&y0, t0); - mass.calculate_sparsity(t0); - } - - let out: Option> = None; - let root: Option> = None; - let eqn = OdeSolverEquations::new(rhs, Some(mass), root, init, out, p.clone()); - let sens_rtol = Some(M::T::from(1e-6)); - let sens_atol = Some(Rc::new(M::V::from_element(3, M::T::from(1e-6)))); - let problem = OdeSolverProblem::new( - Rc::new(eqn), - M::T::from(1e-6), - Rc::new(M::V::from_element(3, M::T::from(1e-6))), - sens_rtol, - sens_atol, - None, - None, - None, - None, - t0, - M::T::from(1.0), - false, - ) - .unwrap(); + let k = 0.1; + let problem = OdeBuilder::::new() + .p([k]) + .rhs_sens_implicit( + exponential_decay_with_algebraic::, + exponential_decay_with_algebraic_jacobian::, + exponential_decay_with_algebraic_sens::, + ) + .init_sens( + exponential_decay_with_algebraic_init::, + exponential_decay_with_algebraic_init_sens::, + ) + .mass(exponential_decay_with_algebraic_mass::) + .build() + .unwrap(); + let p = M::V::from_vec(vec![k.into()]); let mut soln = OdeSolverSolution::default(); for i in 0..10 { let t = M::T::from(i as f64 / 10.0); diff --git a/src/ode_solver/test_models/foodweb.rs b/src/ode_solver/test_models/foodweb.rs index 9977dfd3..702e7056 100644 --- a/src/ode_solver/test_models/foodweb.rs +++ b/src/ode_solver/test_models/foodweb.rs @@ -1,10 +1,8 @@ -use std::rc::Rc; - use crate::{ find_jacobian_non_zeros, find_matrix_non_zeros, ode_solver::problem::OdeSolverSolution, ConstantOp, JacobianColoring, LinearOp, Matrix, MatrixSparsity, NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit, OdeEquationsRef, OdeSolverProblem, Op, - UnitCallable, Vector, + ParameterisedOp, UnitCallable, Vector, }; use num_traits::Zero; @@ -139,7 +137,7 @@ where ); let eqn: DiffSl = DiffSl::from_context(DiffSlContext::new(code.as_str()).unwrap()); - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .rtol(1e-5) .atol([1e-5]) .build_from_eqn(eqn) @@ -790,7 +788,7 @@ where type Init = FoodWebInit<'a, M, NX>; type Rhs = FoodWebRhs<'a, M, NX>; type Mass = FoodWebMass<'a, M, NX>; - type Root = &'a UnitCallable; + type Root = ParameterisedOp<'a, UnitCallable>; type Out = FoodWebOut<'a, M, NX>; } @@ -810,6 +808,12 @@ where fn out(&self) -> Option> { Some(FoodWebOut::new(self)) } + fn root(&self) -> Option<>::Root> { + None + } + fn set_params(&mut self, _p: &Self::V) { + unimplemented!() + } } #[cfg(feature = "diffsl")] @@ -1032,18 +1036,7 @@ where let context = FoodWebContext::::new(); let eqn = FoodWeb::new(context, t0); let problem = OdeSolverProblem::new( - Rc::new(eqn), - rtol, - Rc::new(atol), - None, - None, - None, - None, - None, - None, - t0, - h0, - false, + eqn, rtol, atol, None, None, None, None, None, None, t0, h0, false, ) .unwrap(); let soln = soln::(); diff --git a/src/ode_solver/test_models/gaussian_decay.rs b/src/ode_solver/test_models/gaussian_decay.rs index 42664ab3..58409ffb 100644 --- a/src/ode_solver/test_models/gaussian_decay.rs +++ b/src/ode_solver/test_models/gaussian_decay.rs @@ -31,14 +31,12 @@ pub fn gaussian_decay_problem( OdeSolverSolution, ) { let size2 = size; - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .p([0.1].repeat(size)) .use_coloring(use_coloring) - .build_ode( - gaussian_decay::, - gaussian_decay_jacobian::, - move |_p, _t| M::V::from_vec([1.0.into()].repeat(size2)), - ) + .rhs_implicit(gaussian_decay::, gaussian_decay_jacobian::) + .init(move |_p, _t| M::V::from_vec([1.0.into()].repeat(size2))) + .build() .unwrap(); let p = [M::T::from(0.1)].repeat(size); let mut soln = OdeSolverSolution::default(); diff --git a/src/ode_solver/test_models/heat2d.rs b/src/ode_solver/test_models/heat2d.rs index 17db6758..f4ab039c 100644 --- a/src/ode_solver/test_models/heat2d.rs +++ b/src/ode_solver/test_models/heat2d.rs @@ -93,7 +93,7 @@ pub fn heat2d_diffsl_problem< let context: DiffSlContext = DiffSlContext::new(code.as_str()).unwrap(); let eqn = DiffSl::from_context(context); - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .rtol(1e-7) .atol([1e-7]) .build_from_eqn(eqn) @@ -250,18 +250,14 @@ pub fn head2d_problem() -> ( OdeSolverProblem>, OdeSolverSolution, ) { - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .rtol(1e-7) .atol([1e-7]) - .build_ode_with_mass_and_out( - heat2d_rhs::, - heat2d_jac_mul::, - heat2d_mass::, - heat2d_init::, - heat2d_out::, - heat2d_out_jac_mul::, - 1, - ) + .rhs_implicit(heat2d_rhs::, heat2d_jac_mul::) + .mass(heat2d_mass::) + .init(heat2d_init::) + .out_implicit(heat2d_out::, heat2d_out_jac_mul::, 1) + .build() .unwrap(); (problem, soln::()) diff --git a/src/ode_solver/test_models/robertson.rs b/src/ode_solver/test_models/robertson.rs index 51799ae2..cc0219d1 100644 --- a/src/ode_solver/test_models/robertson.rs +++ b/src/ode_solver/test_models/robertson.rs @@ -1,11 +1,6 @@ -use std::rc::Rc; - use crate::{ - matrix::Matrix, - ode_solver::problem::OdeSolverSolution, - op::{closure_with_sens::ClosureWithSens, constant_closure_with_sens::ConstantClosureWithSens}, - ConstantOp, LinearClosure, OdeBuilder, OdeEquationsImplicit, OdeEquationsSens, - OdeSolverEquations, OdeSolverProblem, UnitCallable, Vector, + matrix::Matrix, ode_solver::problem::OdeSolverSolution, OdeBuilder, OdeEquationsImplicit, + OdeEquationsSens, OdeSolverProblem, Vector, }; use num_traits::Zero; @@ -53,7 +48,7 @@ pub fn robertson_diffsl_problem< let context = DiffSlContext::::new(code).unwrap(); let eqn = DiffSl::from_context(context); - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) @@ -61,7 +56,7 @@ pub fn robertson_diffsl_problem< .unwrap(); let mut soln = soln::(); soln.rtol = problem.rtol; - soln.atol = problem.atol.as_ref().clone(); + soln.atol = problem.atol.clone(); (problem, soln) } @@ -103,23 +98,21 @@ fn robertson_init_sens(_p: &M::V, _t: M::T, _v: &M::V, y: &mut M::V) } #[allow(clippy::type_complexity)] -pub fn robertson( +pub fn robertson( use_coloring: bool, ) -> ( OdeSolverProblem>, OdeSolverSolution, ) { - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) .use_coloring(use_coloring) - .build_ode_with_mass( - robertson_rhs::, - robertson_jac_mul::, - robertson_mass::, - robertson_init::, - ) + .rhs_implicit(robertson_rhs::, robertson_jac_mul::) + .mass(robertson_mass::) + .init(robertson_init::) + .build() .unwrap(); (problem, soln()) @@ -157,56 +150,20 @@ pub fn robertson_sens() -> ( OdeSolverProblem>, OdeSolverSolution, ) { - let p = Rc::new(M::V::from_vec(vec![ - M::T::from(0.04), - M::T::from(1.0e4), - M::T::from(3.0e7), - ])); - let mut rhs = ClosureWithSens::new( - robertson_rhs::, - robertson_jac_mul::, - robertson_sens_mul::, - 3, - 3, - p.clone(), - ); - let mut mass = LinearClosure::new(robertson_mass::, 3, 3, p.clone()); - let init = ConstantClosureWithSens::new( - robertson_init::, - robertson_init_sens::, - 3, - 3, - p.clone(), - ); - let t0 = M::T::zero(); - - if M::is_sparse() { - let y0 = init.call(t0); - rhs.calculate_jacobian_sparsity(&y0, t0); - rhs.calculate_sens_sparsity(&y0, t0); - mass.calculate_sparsity(t0); - } - - let out: Option> = None; - let root: Option> = None; - let eqn = OdeSolverEquations::new(rhs, Some(mass), root, init, out, p.clone()); - let rtol = M::T::from(1e-4); - let atol = M::V::from_vec(vec![M::T::from(1e-8), M::T::from(1e-6), M::T::from(1e-6)]); - let problem = OdeSolverProblem::new( - Rc::new(eqn), - rtol, - Rc::new(atol), - None, - None, - None, - None, - None, - None, - t0, - M::T::from(1.0), - false, - ) - .unwrap(); + let problem = OdeBuilder::::new() + .atol([1e-8, 1e-6, 1e-6]) + .rtol(1e-4) + .p([0.04, 1.0e4, 3.0e7]) + .turn_off_sensitivities_error_control() + .rhs_sens_implicit( + robertson_rhs::, + robertson_jac_mul::, + robertson_sens_mul::, + ) + .init_sens(robertson_init::, robertson_init_sens::) + .mass(robertson_mass::) + .build() + .unwrap(); let mut soln = OdeSolverSolution::default(); let data = vec![ diff --git a/src/ode_solver/test_models/robertson_ode.rs b/src/ode_solver/test_models/robertson_ode.rs index b653d215..9e02f817 100644 --- a/src/ode_solver/test_models/robertson_ode.rs +++ b/src/ode_solver/test_models/robertson_ode.rs @@ -13,7 +13,7 @@ pub fn robertson_ode( OdeSolverSolution, ) { const N: usize = 3; - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol( @@ -25,7 +25,7 @@ pub fn robertson_ode( .collect::>(), ) .use_coloring(use_coloring) - .build_ode( + .rhs_implicit( // dy1/dt = -.04*y1 + 1.e4*y2*y3 //* dy2/dt = .04*y1 - 1.e4*y2*y3 - 3.e7*(y2)^2 //* dy3/dt = 3.e7*(y2)^2 @@ -49,11 +49,12 @@ pub fn robertson_ode( y[i + 2] = M::T::from(2.0) * p[2] * x[i + 1] * v[i + 1]; } }, - move |_p: &M::V, _t: M::T| { - let init = [M::T::one(), M::T::zero(), M::T::zero()]; - M::V::from_vec(init.iter().cycle().take(ngroups * N).cloned().collect()) - }, ) + .init(move |_p: &M::V, _t: M::T| { + let init = [M::T::one(), M::T::zero(), M::T::zero()]; + M::V::from_vec(init.iter().cycle().take(ngroups * N).cloned().collect()) + }) + .build() .unwrap(); let mut soln = OdeSolverSolution::default(); diff --git a/src/ode_solver/test_models/robertson_ode_with_sens.rs b/src/ode_solver/test_models/robertson_ode_with_sens.rs index 6b0b58e3..8c0c973f 100644 --- a/src/ode_solver/test_models/robertson_ode_with_sens.rs +++ b/src/ode_solver/test_models/robertson_ode_with_sens.rs @@ -11,12 +11,12 @@ pub fn robertson_ode_with_sens( OdeSolverProblem>, OdeSolverSolution, ) { - let problem = OdeBuilder::new() + let problem = OdeBuilder::::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) .use_coloring(use_coloring) - .build_ode_with_sens( + .rhs_sens_implicit( // dy1/dt = -.04*y1 + 1.e4*y2*y3 //* dy2/dt = .04*y1 - 1.e4*y2*y3 - 3.e7*(y2)^2 //* dy3/dt = 3.e7*(y2)^2 @@ -38,9 +38,12 @@ pub fn robertson_ode_with_sens( y[1] = v[0] * x[0] - v[1] * x[1] * x[2] - v[2] * x[1] * x[1]; y[2] = v[2] * x[1] * x[1]; }, + ) + .init_sens( |_p: &M::V, _t: M::T| M::V::from_vec(vec![1.0.into(), 0.0.into(), 0.0.into()]), |_p: &M::V, _t: M::T, _v: &M::V, y: &mut M::V| y.fill(M::T::zero()), ) + .build() .unwrap(); let mut soln = OdeSolverSolution::default(); diff --git a/src/op/bdf.rs b/src/op/bdf.rs index 17cf143a..adad7780 100644 --- a/src/op/bdf.rs +++ b/src/op/bdf.rs @@ -1,19 +1,17 @@ use crate::{ matrix::DenseMatrix, ode_solver::equations::OdeEquationsImplicit, scale, LinearOp, Matrix, - MatrixRef, MatrixSparsity, NonLinearOp, NonLinearOpJacobian, OdeSolverProblem, Op, Vector, - VectorRef, + MatrixSparsity, NonLinearOp, NonLinearOpJacobian, Op, Vector, }; use num_traits::{One, Zero}; use std::ops::MulAssign; use std::{ cell::{Ref, RefCell}, ops::{AddAssign, Deref, SubAssign}, - rc::Rc, }; // callable to solve for F(y) = M (y' + psi) - c * f(y) = 0 pub struct BdfCallable { - eqn: Rc, + pub(crate) eqn: Eqn, psi_neg_y0: RefCell, c: RefCell, tmp: RefCell, @@ -25,6 +23,19 @@ pub struct BdfCallable { } impl BdfCallable { + pub fn clone_state(&self, eqn: Eqn) -> Self { + Self { + eqn, + psi_neg_y0: RefCell::new(self.psi_neg_y0.borrow().clone()), + c: RefCell::new(*self.c.borrow()), + tmp: RefCell::new(self.tmp.borrow().clone()), + rhs_jac: RefCell::new(self.rhs_jac.borrow().clone()), + mass_jac: RefCell::new(self.mass_jac.borrow().clone()), + jacobian_is_stale: RefCell::new(*self.jacobian_is_stale.borrow()), + number_of_jac_evals: RefCell::new(*self.number_of_jac_evals.borrow()), + sparsity: self.sparsity.clone(), + } + } // F(y) = M (y - y0 + psi) - c * f(y) = 0 // M = I // dg = f(y) @@ -43,8 +54,7 @@ impl BdfCallable { let c = self.c.borrow(); d.axpy(*c, dg, -Eqn::T::one()); } - pub fn from_sensitivity_eqn(eqn: &Rc) -> Self { - let eqn = eqn.clone(); + pub fn new_no_jacobian(eqn: Eqn) -> Self { let n = eqn.rhs().nstates(); let c = RefCell::new(Eqn::T::zero()); let psi_neg_y0 = RefCell::new(::zeros(n)); @@ -66,15 +76,14 @@ impl BdfCallable { sparsity, } } - pub fn eqn_mut(&mut self) -> &mut Rc { - &mut self.eqn - } - pub fn eqn(&self) -> &Rc { + pub fn eqn(&self) -> &Eqn { &self.eqn } - pub fn new(ode_problem: &OdeSolverProblem) -> Self { - let eqn = ode_problem.eqn.clone(); - let n = ode_problem.eqn.rhs().nstates(); + pub fn eqn_mut(&mut self) -> &mut Eqn { + &mut self.eqn + } + pub fn new(eqn: Eqn) -> Self { + let n = eqn.rhs().nstates(); let c = RefCell::new(Eqn::T::zero()); let psi_neg_y0 = RefCell::new(::zeros(n)); let jacobian_is_stale = RefCell::new(true); @@ -146,10 +155,7 @@ impl BdfCallable { pub fn number_of_jac_evals(&self) -> usize { *self.number_of_jac_evals.borrow() } - pub fn set_c(&self, h: Eqn::T, alpha: Eqn::T) - where - for<'b> &'b Eqn::M: MatrixRef, - { + pub fn set_c(&self, h: Eqn::T, alpha: Eqn::T) { self.c.replace(h * alpha); } fn set_psi>( @@ -204,11 +210,7 @@ impl Op for BdfCallable { // dF(y)/dp = dM/dp (y - y0 + psi) + Ms - c * df(y)/dp - c df(y)/dy s = 0 // jac is M - c * df(y)/dy, same // callable to solve for F(y) = M (y' + psi) - f(y) = 0 -impl NonLinearOp for BdfCallable -where - for<'b> &'b Eqn::V: VectorRef, - for<'b> &'b Eqn::M: MatrixRef, -{ +impl NonLinearOp for BdfCallable { // F(y) = M (y - y0 + psi) - c * f(y) = 0 fn call_inplace(&self, x: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) { let psi_neg_y0_ref = self.psi_neg_y0.borrow(); @@ -229,11 +231,7 @@ where } } -impl NonLinearOpJacobian for BdfCallable -where - for<'b> &'b Eqn::V: VectorRef, - for<'b> &'b Eqn::M: MatrixRef, -{ +impl NonLinearOpJacobian for BdfCallable { // (M - c * f'(y)) v fn jac_mul_inplace(&self, x: &Eqn::V, t: Eqn::T, v: &Eqn::V, y: &mut Eqn::V) { self.eqn.rhs().jac_mul_inplace(x, t, v, y); @@ -290,7 +288,7 @@ mod tests { #[test] fn test_bdf_callable() { let (problem, _soln) = exponential_decay_problem::(false); - let mut bdf_callable = BdfCallable::new(&problem); + let mut bdf_callable = BdfCallable::new(&problem.eqn); let c = 0.1; let phi_neg_y0 = Vcpu::from_vec(vec![1.1, 1.2]); bdf_callable.set_c_direct(c); diff --git a/src/op/closure.rs b/src/op/closure.rs index 05545852..5ffdbda8 100644 --- a/src/op/closure.rs +++ b/src/op/closure.rs @@ -1,11 +1,11 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use crate::{ find_jacobian_non_zeros, jacobian::JacobianColoring, Matrix, MatrixSparsity, NonLinearOp, - NonLinearOpJacobian, Op, Vector, + NonLinearOpJacobian, Op, }; -use super::OpStatistics; +use super::{BuilderOp, OpStatistics, ParameterisedOp}; pub struct Closure where @@ -18,7 +18,6 @@ where nstates: usize, nout: usize, nparams: usize, - p: Rc, coloring: Option>, sparsity: Option, statistics: RefCell, @@ -30,23 +29,21 @@ where F: Fn(&M::V, &M::V, M::T, &mut M::V), G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { - pub fn new(func: F, jacobian_action: G, nstates: usize, nout: usize, p: Rc) -> Self { - let nparams = p.len(); + pub fn new(func: F, jacobian_action: G, nstates: usize, nout: usize, nparams: usize) -> Self { Self { func, jacobian_action, nstates, - nout, nparams, - p, + nout, statistics: RefCell::new(OpStatistics::default()), coloring: None, sparsity: None, } } - - pub fn calculate_sparsity(&mut self, y0: &M::V, t0: M::T) { - let non_zeros = find_jacobian_non_zeros(self, y0, t0); + pub fn calculate_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) { + let param_op = ParameterisedOp { op: self, p }; + let non_zeros = find_jacobian_non_zeros(¶m_op, y0, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -58,6 +55,27 @@ where } } +impl BuilderOp for Closure +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + fn calculate_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) { + self.calculate_sparsity(y0, t0, p); + } + + fn set_nstates(&mut self, nstates: usize) { + self.nstates = nstates; + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; + } +} + impl Op for Closure where M: Matrix, @@ -76,47 +94,42 @@ where fn nparams(&self) -> usize { self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; - } - fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } } -impl NonLinearOp for Closure +impl NonLinearOp for ParameterisedOp<'_, Closure> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) { - self.statistics.borrow_mut().increment_call(); - (self.func)(x, self.p.as_ref(), t, y) + self.op.statistics.borrow_mut().increment_call(); + (self.op.func)(x, self.p, t, y) } } -impl NonLinearOpJacobian for Closure +impl NonLinearOpJacobian for ParameterisedOp<'_, Closure> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) { - self.statistics.borrow_mut().increment_jac_mul(); - (self.jacobian_action)(x, self.p.as_ref(), t, v, y) + self.op.statistics.borrow_mut().increment_jac_mul(); + (self.op.jacobian_action)(x, self.p, t, v, y) } fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - self.statistics.borrow_mut().increment_matrix(); - if let Some(coloring) = self.coloring.as_ref() { + self.op.statistics.borrow_mut().increment_matrix(); + if let Some(coloring) = self.op.coloring.as_ref() { coloring.jacobian_inplace(self, x, t, y); } else { self._default_jacobian_inplace(x, t, y); } } fn jacobian_sparsity(&self) -> Option<::Sparsity> { - self.sparsity.clone() + self.op.sparsity.clone() } } diff --git a/src/op/closure_no_jac.rs b/src/op/closure_no_jac.rs index 08cb94b6..aafdbc82 100644 --- a/src/op/closure_no_jac.rs +++ b/src/op/closure_no_jac.rs @@ -1,8 +1,8 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; -use crate::{Matrix, NonLinearOp, Op, Vector}; +use crate::{Matrix, NonLinearOp, Op}; -use super::OpStatistics; +use super::{BuilderOp, OpStatistics, ParameterisedOp}; pub struct ClosureNoJac where @@ -13,8 +13,8 @@ where nstates: usize, nout: usize, nparams: usize, - p: Rc, statistics: RefCell, + _phantom: std::marker::PhantomData, } impl ClosureNoJac @@ -22,19 +22,37 @@ where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), { - pub fn new(func: F, nstates: usize, nout: usize, p: Rc) -> Self { - let nparams = p.len(); + pub fn new(func: F, nstates: usize, nout: usize, nparams: usize) -> Self { Self { func, nstates, - nout, nparams, - p, + nout, statistics: RefCell::new(OpStatistics::default()), + _phantom: std::marker::PhantomData, } } } +impl BuilderOp for ClosureNoJac +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), +{ + fn calculate_sparsity(&mut self, _y0: &Self::V, _t0: Self::T, _p: &Self::V) { + // Do nothing + } + fn set_nstates(&mut self, nstates: usize) { + self.nstates = nstates; + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; + } +} + impl Op for ClosureNoJac where M: Matrix, @@ -52,22 +70,18 @@ where fn nparams(&self) -> usize { self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; - } fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } } -impl NonLinearOp for ClosureNoJac +impl NonLinearOp for ParameterisedOp<'_, ClosureNoJac> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), { fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) { - self.statistics.borrow_mut().increment_call(); - (self.func)(x, self.p.as_ref(), t, y) + self.op.statistics.borrow_mut().increment_call(); + (self.op.func)(x, self.p, t, y) } } diff --git a/src/op/closure_with_adjoint.rs b/src/op/closure_with_adjoint.rs index cb5c9bd2..2771e8a3 100644 --- a/src/op/closure_with_adjoint.rs +++ b/src/op/closure_with_adjoint.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use crate::{ jacobian::{ @@ -9,8 +9,9 @@ use crate::{ NonLinearOpSensAdjoint, Op, Vector, }; -use super::OpStatistics; +use super::{BuilderOp, OpStatistics, ParameterisedOp}; +#[derive(Clone)] pub struct ClosureWithAdjoint where M: Matrix, @@ -26,7 +27,6 @@ where nstates: usize, nout: usize, nparams: usize, - p: Rc, coloring: Option>, sparsity: Option, sparsity_adjoint: Option, @@ -51,9 +51,8 @@ where sens_adjoint_action: I, nstates: usize, nout: usize, - p: Rc, + nparams: usize, ) -> Self { - let nparams = p.len(); Self { func, jacobian_action, @@ -62,7 +61,6 @@ where nstates, nout, nparams, - p, statistics: RefCell::new(OpStatistics::default()), coloring: None, sparsity: None, @@ -73,8 +71,9 @@ where } } - pub fn calculate_jacobian_sparsity(&mut self, y0: &M::V, t0: M::T) { - let non_zeros = find_jacobian_non_zeros(self, y0, t0); + pub fn calculate_jacobian_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) { + let op = ParameterisedOp { op: self, p }; + let non_zeros = find_jacobian_non_zeros(&op, y0, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -85,8 +84,9 @@ where )); } - pub fn calculate_adjoint_sparsity(&mut self, y0: &M::V, t0: M::T) { - let non_zeros = find_adjoint_non_zeros(self, y0, t0); + pub fn calculate_adjoint_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) { + let op = ParameterisedOp { op: self, p }; + let non_zeros = find_adjoint_non_zeros(&op, y0, t0); self.sparsity_adjoint = Some( MatrixSparsity::try_from_indices(self.nstates, self.nout, non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -97,10 +97,12 @@ where )); } - pub fn calculate_sens_adjoint_sparsity(&mut self, y0: &M::V, t0: M::T) { - let non_zeros = find_sens_adjoint_non_zeros(self, y0, t0); + pub fn calculate_sens_adjoint_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) { + let op = ParameterisedOp { op: self, p }; + let non_zeros = find_sens_adjoint_non_zeros(&op, y0, t0); + let nparams = p.len(); self.sens_sparsity = Some( - MatrixSparsity::try_from_indices(self.nstates, self.nparams, non_zeros.clone()) + MatrixSparsity::try_from_indices(self.nstates, nparams, non_zeros.clone()) .expect("invalid sparsity pattern"), ); self.coloring_sens_adjoint = Some(JacobianColoring::new( @@ -130,17 +132,36 @@ where fn nparams(&self) -> usize { self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; - } - fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } } -impl NonLinearOp for ClosureWithAdjoint +impl BuilderOp for ClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + fn calculate_sparsity(&mut self, y0: &Self::V, t0: Self::T, p: &Self::V) { + self.calculate_jacobian_sparsity(y0, t0, p); + self.calculate_adjoint_sparsity(y0, t0, p); + self.calculate_sens_adjoint_sparsity(y0, t0, p); + } + fn set_nstates(&mut self, nstates: usize) { + self.nstates = nstates; + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; + } +} + +impl NonLinearOp for ParameterisedOp<'_, ClosureWithAdjoint> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), @@ -149,12 +170,12 @@ where I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) { - self.statistics.borrow_mut().increment_call(); - (self.func)(x, self.p.as_ref(), t, y) + self.op.statistics.borrow_mut().increment_call(); + (self.op.func)(x, self.p, t, y) } } -impl NonLinearOpJacobian for ClosureWithAdjoint +impl NonLinearOpJacobian for ParameterisedOp<'_, ClosureWithAdjoint> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), @@ -163,23 +184,23 @@ where I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) { - self.statistics.borrow_mut().increment_jac_mul(); - (self.jacobian_action)(x, self.p.as_ref(), t, v, y) + self.op.statistics.borrow_mut().increment_jac_mul(); + (self.op.jacobian_action)(x, self.p, t, v, y) } fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - self.statistics.borrow_mut().increment_matrix(); - if let Some(coloring) = self.coloring.as_ref() { + self.op.statistics.borrow_mut().increment_matrix(); + if let Some(coloring) = self.op.coloring.as_ref() { coloring.jacobian_inplace(self, x, t, y); } else { self._default_jacobian_inplace(x, t, y); } } fn jacobian_sparsity(&self) -> Option<::Sparsity> { - self.sparsity.clone() + self.op.sparsity.clone() } } -impl NonLinearOpAdjoint for ClosureWithAdjoint +impl NonLinearOpAdjoint for ParameterisedOp<'_, ClosureWithAdjoint> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), @@ -188,23 +209,24 @@ where I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { - self.statistics.borrow_mut().increment_jac_adj_mul(); - (self.jacobian_adjoint_action)(x, self.p.as_ref(), t, v, y); + self.op.statistics.borrow_mut().increment_jac_adj_mul(); + (self.op.jacobian_adjoint_action)(x, self.p, t, v, y); } fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - if let Some(coloring) = self.coloring_adjoint.as_ref() { + if let Some(coloring) = self.op.coloring_adjoint.as_ref() { coloring.adjoint_inplace(self, x, t, y); } else { self._default_adjoint_inplace(x, t, y); } } fn adjoint_sparsity(&self) -> Option<::Sparsity> { - self.sparsity_adjoint.clone() + self.op.sparsity_adjoint.clone() } } -impl NonLinearOpSensAdjoint for ClosureWithAdjoint +impl NonLinearOpSensAdjoint + for ParameterisedOp<'_, ClosureWithAdjoint> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), @@ -213,16 +235,16 @@ where I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn sens_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { - (self.sens_adjoint_action)(_x, self.p.as_ref(), _t, _v, y); + (self.op.sens_adjoint_action)(_x, self.p, _t, _v, y); } fn sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - if let Some(coloring) = self.coloring_sens_adjoint.as_ref() { + if let Some(coloring) = self.op.coloring_sens_adjoint.as_ref() { coloring.sens_adjoint_inplace(self, x, t, y); } else { self._default_sens_adjoint_inplace(x, t, y); } } fn sens_adjoint_sparsity(&self) -> Option<::Sparsity> { - self.sens_sparsity.clone() + self.op.sens_sparsity.clone() } } diff --git a/src/op/closure_with_sens.rs b/src/op/closure_with_sens.rs index bbf2020d..75c96857 100644 --- a/src/op/closure_with_sens.rs +++ b/src/op/closure_with_sens.rs @@ -1,26 +1,22 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use crate::{ jacobian::{find_jacobian_non_zeros, find_sens_non_zeros, JacobianColoring}, Matrix, MatrixSparsity, NonLinearOp, NonLinearOpJacobian, NonLinearOpSens, Op, Vector, }; -use super::OpStatistics; +use super::{BuilderOp, OpStatistics, ParameterisedOp}; pub struct ClosureWithSens where M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V), - G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { func: F, jacobian_action: G, sens_action: H, nstates: usize, - nout: usize, nparams: usize, - p: Rc, + nout: usize, coloring: Option>, sens_coloring: Option>, sparsity: Option, @@ -40,10 +36,9 @@ where jacobian_action: G, sens_action: H, nstates: usize, + nparams: usize, nout: usize, - p: Rc, ) -> Self { - let nparams = p.len(); Self { func, jacobian_action, @@ -51,7 +46,6 @@ where nstates, nout, nparams, - p, statistics: RefCell::new(OpStatistics::default()), coloring: None, sparsity: None, @@ -60,8 +54,9 @@ where } } - pub fn calculate_jacobian_sparsity(&mut self, y0: &M::V, t0: M::T) { - let non_zeros = find_jacobian_non_zeros(self, y0, t0); + pub fn calculate_jacobian_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) { + let op = ParameterisedOp { op: self, p }; + let non_zeros = find_jacobian_non_zeros(&op, y0, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -71,10 +66,12 @@ where &non_zeros, )); } - pub fn calculate_sens_sparsity(&mut self, y0: &M::V, t0: M::T) { - let non_zeros = find_sens_non_zeros(self, y0, t0); + pub fn calculate_sens_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) { + let op = ParameterisedOp { op: self, p }; + let non_zeros = find_sens_non_zeros(&op, y0, t0); + let nparams = p.len(); self.sens_sparsity = Some( - MatrixSparsity::try_from_indices(self.nout(), self.nparams, non_zeros.clone()) + MatrixSparsity::try_from_indices(self.nout(), nparams, non_zeros.clone()) .expect("invalid sparsity pattern"), ); self.sens_coloring = Some(JacobianColoring::new( @@ -84,12 +81,32 @@ where } } -impl Op for ClosureWithSens +impl BuilderOp for ClosureWithSens where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + fn set_nstates(&mut self, nstates: usize) { + self.nstates = nstates; + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; + } + + fn calculate_sparsity(&mut self, y0: &Self::V, t0: Self::T, p: &Self::V) { + self.calculate_jacobian_sparsity(y0, t0, p); + self.calculate_sens_sparsity(y0, t0, p); + } +} + +impl Op for ClosureWithSens +where + M: Matrix, { type V = M::V; type T = M::T; @@ -103,17 +120,12 @@ where fn nparams(&self) -> usize { self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; - } - fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } } -impl NonLinearOp for ClosureWithSens +impl NonLinearOp for ParameterisedOp<'_, ClosureWithSens> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), @@ -121,12 +133,12 @@ where H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) { - self.statistics.borrow_mut().increment_call(); - (self.func)(x, self.p.as_ref(), t, y) + self.op.statistics.borrow_mut().increment_call(); + (self.op.func)(x, self.p, t, y) } } -impl NonLinearOpJacobian for ClosureWithSens +impl NonLinearOpJacobian for ParameterisedOp<'_, ClosureWithSens> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), @@ -134,23 +146,23 @@ where H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) { - self.statistics.borrow_mut().increment_jac_mul(); - (self.jacobian_action)(x, self.p.as_ref(), t, v, y) + self.op.statistics.borrow_mut().increment_jac_mul(); + (self.op.jacobian_action)(x, self.p, t, v, y) } fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - self.statistics.borrow_mut().increment_matrix(); - if let Some(coloring) = self.coloring.as_ref() { + self.op.statistics.borrow_mut().increment_matrix(); + if let Some(coloring) = self.op.coloring.as_ref() { coloring.jacobian_inplace(self, x, t, y); } else { self._default_jacobian_inplace(x, t, y); } } fn jacobian_sparsity(&self) -> Option<::Sparsity> { - self.sparsity.clone() + self.op.sparsity.clone() } } -impl NonLinearOpSens for ClosureWithSens +impl NonLinearOpSens for ParameterisedOp<'_, ClosureWithSens> where M: Matrix, F: Fn(&M::V, &M::V, M::T, &mut M::V), @@ -158,17 +170,17 @@ where H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), { fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { - (self.sens_action)(x, self.p.as_ref(), t, v, y); + (self.op.sens_action)(x, self.p, t, v, y); } fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - if let Some(coloring) = self.sens_coloring.as_ref() { + if let Some(coloring) = self.op.sens_coloring.as_ref() { coloring.jacobian_inplace(self, x, t, y); } else { self._default_sens_inplace(x, t, y); } } fn sens_sparsity(&self) -> Option<::Sparsity> { - self.sens_sparsity.clone() + self.op.sens_sparsity.clone() } } diff --git a/src/op/constant_closure.rs b/src/op/constant_closure.rs index 2f0dd43a..1a9f0dee 100644 --- a/src/op/constant_closure.rs +++ b/src/op/constant_closure.rs @@ -1,6 +1,4 @@ -use num_traits::Zero; -use std::rc::Rc; - +use super::{BuilderOp, ParameterisedOp}; use crate::{ConstantOp, Matrix, Op, Vector}; pub struct ConstantClosure @@ -9,10 +7,9 @@ where I: Fn(&M::V, M::T) -> M::V, { func: I, - nstates: usize, nout: usize, nparams: usize, - p: Rc, + _phantom: std::marker::PhantomData, } impl ConstantClosure @@ -20,17 +17,12 @@ where M: Matrix, I: Fn(&M::V, M::T) -> M::V, { - pub fn new(func: I, p: Rc) -> Self { - let nparams = p.len(); - let y0 = (func)(p.as_ref(), M::T::zero()); - let nstates = y0.len(); - let nout = nstates; + pub fn new(func: I, nout: usize, nparams: usize) -> Self { Self { func, - nstates, nout, nparams, - p, + _phantom: std::marker::PhantomData, } } } @@ -44,7 +36,7 @@ where type T = M::T; type M = M; fn nstates(&self) -> usize { - self.nstates + 0 } fn nout(&self) -> usize { self.nout @@ -52,21 +44,36 @@ where fn nparams(&self) -> usize { self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; +} + +impl BuilderOp for ConstantClosure +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, +{ + fn calculate_sparsity(&mut self, _y0: &Self::V, _t0: Self::T, _p: &Self::V) { + // do nothing + } + fn set_nstates(&mut self, _nstates: usize) { + // do nothing + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; } } -impl ConstantOp for ConstantClosure +impl ConstantOp for ParameterisedOp<'_, ConstantClosure> where M: Matrix, I: Fn(&M::V, M::T) -> M::V, { fn call_inplace(&self, t: Self::T, y: &mut Self::V) { - y.copy_from(&(self.func)(self.p.as_ref(), t)); + y.copy_from(&(self.op.func)(self.p, t)); } fn call(&self, t: Self::T) -> Self::V { - (self.func)(self.p.as_ref(), t) + (self.op.func)(self.p, t) } } diff --git a/src/op/constant_closure_with_adjoint.rs b/src/op/constant_closure_with_adjoint.rs index 0d1f1f1c..af6a59c4 100644 --- a/src/op/constant_closure_with_adjoint.rs +++ b/src/op/constant_closure_with_adjoint.rs @@ -1,8 +1,7 @@ -use num_traits::Zero; -use std::rc::Rc; - use crate::{ConstantOp, ConstantOpSensAdjoint, Matrix, Op, Vector}; +use super::{BuilderOp, ParameterisedOp}; + pub struct ConstantClosureWithAdjoint where M: Matrix, @@ -11,10 +10,9 @@ where { func: I, func_sens_adjoint: J, - nstates: usize, nout: usize, nparams: usize, - p: Rc, + _phantom: std::marker::PhantomData, } impl ConstantClosureWithAdjoint @@ -23,22 +21,37 @@ where I: Fn(&M::V, M::T) -> M::V, J: Fn(&M::V, M::T, &M::V, &mut M::V), { - pub fn new(func: I, func_sens_adjoint: J, p: Rc) -> Self { - let nparams = p.len(); - let y0 = (func)(p.as_ref(), M::T::zero()); - let nstates = y0.len(); - let nout = nstates; + pub fn new(func: I, func_sens_adjoint: J, nout: usize, nparams: usize) -> Self { Self { func, func_sens_adjoint, - nstates, nout, nparams, - p, + _phantom: std::marker::PhantomData, } } } +impl BuilderOp for ConstantClosureWithAdjoint +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, + J: Fn(&M::V, M::T, &M::V, &mut M::V), +{ + fn calculate_sparsity(&mut self, _y0: &Self::V, _t0: Self::T, _p: &Self::V) { + // Do nothing + } + fn set_nstates(&mut self, _nstates: usize) { + // Do nothing + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; + } +} + impl Op for ConstantClosureWithAdjoint where M: Matrix, @@ -49,7 +62,7 @@ where type T = M::T; type M = M; fn nstates(&self) -> usize { - self.nstates + 0 } fn nout(&self) -> usize { self.nout @@ -57,33 +70,29 @@ where fn nparams(&self) -> usize { self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; - } } -impl ConstantOp for ConstantClosureWithAdjoint +impl ConstantOp for ParameterisedOp<'_, ConstantClosureWithAdjoint> where M: Matrix, I: Fn(&M::V, M::T) -> M::V, J: Fn(&M::V, M::T, &M::V, &mut M::V), { fn call_inplace(&self, t: Self::T, y: &mut Self::V) { - y.copy_from(&(self.func)(self.p.as_ref(), t)); + y.copy_from(&(self.op.func)(self.p, t)); } fn call(&self, t: Self::T) -> Self::V { - (self.func)(self.p.as_ref(), t) + (self.op.func)(self.p, t) } } -impl ConstantOpSensAdjoint for ConstantClosureWithAdjoint +impl ConstantOpSensAdjoint for ParameterisedOp<'_, ConstantClosureWithAdjoint> where M: Matrix, I: Fn(&M::V, M::T) -> M::V, J: Fn(&M::V, M::T, &M::V, &mut M::V), { fn sens_transpose_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) { - (self.func_sens_adjoint)(self.p.as_ref(), t, v, y); + (self.op.func_sens_adjoint)(self.p, t, v, y); } } diff --git a/src/op/constant_closure_with_sens.rs b/src/op/constant_closure_with_sens.rs index 45fd3c76..7d2938db 100644 --- a/src/op/constant_closure_with_sens.rs +++ b/src/op/constant_closure_with_sens.rs @@ -1,7 +1,7 @@ -use std::rc::Rc; - use crate::{ConstantOp, ConstantOpSens, Matrix, Op, Vector}; +use super::{BuilderOp, ParameterisedOp}; + pub struct ConstantClosureWithSens where M: Matrix, @@ -10,10 +10,9 @@ where { func: I, func_sens: J, - nstates: usize, nout: usize, nparams: usize, - p: Rc, + _phantom: std::marker::PhantomData, } impl ConstantClosureWithSens @@ -22,15 +21,13 @@ where I: Fn(&M::V, M::T) -> M::V, J: Fn(&M::V, M::T, &M::V, &mut M::V), { - pub fn new(func: I, func_sens: J, nstates: usize, nout: usize, p: Rc) -> Self { - let nparams = p.len(); + pub fn new(func: I, func_sens: J, nout: usize, nparams: usize) -> Self { Self { func, func_sens, - nstates, nout, nparams, - p, + _phantom: std::marker::PhantomData, } } } @@ -45,7 +42,7 @@ where type T = M::T; type M = M; fn nstates(&self) -> usize { - self.nstates + 0 } fn nout(&self) -> usize { self.nout @@ -53,33 +50,49 @@ where fn nparams(&self) -> usize { self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; +} + +impl BuilderOp for ConstantClosureWithSens +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, + J: Fn(&M::V, M::T, &M::V, &mut M::V), +{ + fn calculate_sparsity(&mut self, _y0: &Self::V, _t0: Self::T, _p: &Self::V) { + // do nothing + } + fn set_nstates(&mut self, _nstates: usize) { + // do nothing + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; } } -impl ConstantOp for ConstantClosureWithSens +impl ConstantOp for ParameterisedOp<'_, ConstantClosureWithSens> where M: Matrix, I: Fn(&M::V, M::T) -> M::V, J: Fn(&M::V, M::T, &M::V, &mut M::V), { fn call_inplace(&self, t: Self::T, y: &mut Self::V) { - y.copy_from(&(self.func)(self.p.as_ref(), t)); + y.copy_from(&(self.op.func)(self.p, t)); } fn call(&self, t: Self::T) -> Self::V { - (self.func)(self.p.as_ref(), t) + (self.op.func)(self.p, t) } } -impl ConstantOpSens for ConstantClosureWithSens +impl ConstantOpSens for ParameterisedOp<'_, ConstantClosureWithSens> where M: Matrix, I: Fn(&M::V, M::T) -> M::V, J: Fn(&M::V, M::T, &M::V, &mut M::V), { fn sens_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) { - (self.func_sens)(self.p.as_ref(), t, v, y); + (self.op.func_sens)(self.p, t, v, y); } } diff --git a/src/op/init.rs b/src/op/init.rs index 87672f4c..0f88ad38 100644 --- a/src/op/init.rs +++ b/src/op/init.rs @@ -3,7 +3,7 @@ use crate::{ VectorIndex, }; use num_traits::{One, Zero}; -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use super::{NonLinearOp, Op}; @@ -11,17 +11,16 @@ use super::{NonLinearOp, Op}; /// /// We calculate consistent initial conditions following the approach of /// Brown, P. N., Hindmarsh, A. C., & Petzold, L. R. (1998). Consistent initial condition calculation for differential-algebraic systems. SIAM Journal on Scientific Computing, 19(5), 1495-1512. -pub struct InitOp { - eqn: Rc, +pub struct InitOp<'a, Eqn: OdeEquationsImplicit> { + eqn: &'a Eqn, jac: Eqn::M, pub y0: RefCell, pub algebraic_indices: ::Index, neg_mass: Eqn::M, } -impl InitOp { - pub fn new(eqn: &Rc, t0: Eqn::T, y0: &Eqn::V) -> Self { - let eqn = eqn.clone(); +impl<'a, Eqn: OdeEquationsImplicit> InitOp<'a, Eqn> { + pub fn new(eqn: &'a Eqn, t0: Eqn::T, y0: &Eqn::V) -> Self { let n = eqn.rhs().nstates(); let mass_diagonal = eqn.mass().unwrap().matrix(t0).diagonal(); let algebraic_indices = mass_diagonal.filter_indices(|x| x == Eqn::T::zero()); @@ -71,7 +70,7 @@ impl InitOp { } } -impl Op for InitOp { +impl Op for InitOp<'_, Eqn> { type V = Eqn::V; type T = Eqn::T; type M = Eqn::M; @@ -86,7 +85,7 @@ impl Op for InitOp { } } -impl NonLinearOp for InitOp { +impl NonLinearOp for InitOp<'_, Eqn> { // -M_u du + f(u, v) // g(t, u, v) fn call_inplace(&self, x: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) { @@ -103,7 +102,7 @@ impl NonLinearOp for InitOp { } } -impl NonLinearOpJacobian for InitOp { +impl NonLinearOpJacobian for InitOp<'_, Eqn> { // J v fn jac_mul_inplace(&self, _x: &Eqn::V, _t: Eqn::T, v: &Eqn::V, y: &mut Eqn::V) { self.jac.gemv(Eqn::T::one(), v, Eqn::T::one(), y); diff --git a/src/op/linear_closure.rs b/src/op/linear_closure.rs index 82416b3c..f67bcd5f 100644 --- a/src/op/linear_closure.rs +++ b/src/op/linear_closure.rs @@ -1,11 +1,11 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use crate::{ find_matrix_non_zeros, jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity, LinearOp, - Matrix, Op, Vector, + Matrix, Op, }; -use super::OpStatistics; +use super::{BuilderOp, OpStatistics, ParameterisedOp}; pub struct LinearClosure where @@ -16,7 +16,6 @@ where nstates: usize, nout: usize, nparams: usize, - p: Rc, coloring: Option>, sparsity: Option, statistics: RefCell, @@ -27,22 +26,21 @@ where M: Matrix, F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { - pub fn new(func: F, nstates: usize, nout: usize, p: Rc) -> Self { - let nparams = p.len(); + pub fn new(func: F, nstates: usize, nout: usize, nparams: usize) -> Self { Self { func, nstates, statistics: RefCell::new(OpStatistics::default()), nout, nparams, - p, coloring: None, sparsity: None, } } - pub fn calculate_sparsity(&mut self, t0: M::T) { - let non_zeros = find_matrix_non_zeros(self, t0); + pub fn calculate_sparsity(&mut self, t0: M::T, p: &M::V) { + let op = ParameterisedOp { op: self, p }; + let non_zeros = find_matrix_non_zeros(&op, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -72,35 +70,49 @@ where self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; - } - fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } } -impl LinearOp for LinearClosure +impl BuilderOp for LinearClosure +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), +{ + fn calculate_sparsity(&mut self, _y0: &Self::V, t0: Self::T, p: &Self::V) { + self.calculate_sparsity(t0, p); + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; + } + fn set_nstates(&mut self, nstates: usize) { + self.nstates = nstates; + } +} + +impl LinearOp for ParameterisedOp<'_, LinearClosure> where M: Matrix, F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { fn gemv_inplace(&self, x: &M::V, t: M::T, beta: M::T, y: &mut M::V) { - self.statistics.borrow_mut().increment_call(); - (self.func)(x, self.p.as_ref(), t, beta, y) + self.op.statistics.borrow_mut().increment_call(); + (self.op.func)(x, self.p, t, beta, y) } fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { - self.statistics.borrow_mut().increment_matrix(); - if let Some(coloring) = &self.coloring { + self.op.statistics.borrow_mut().increment_matrix(); + if let Some(coloring) = &self.op.coloring { coloring.matrix_inplace(self, t, y); } else { self._default_matrix_inplace(t, y); } } fn sparsity(&self) -> Option<::Sparsity> { - self.sparsity.clone() + self.op.sparsity.clone() } } diff --git a/src/op/linear_closure_with_adjoint.rs b/src/op/linear_closure_with_adjoint.rs index ae328d28..fff69881 100644 --- a/src/op/linear_closure_with_adjoint.rs +++ b/src/op/linear_closure_with_adjoint.rs @@ -1,11 +1,11 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use crate::{ find_matrix_non_zeros, find_transpose_non_zeros, jacobian::JacobianColoring, - matrix::sparsity::MatrixSparsity, LinearOp, LinearOpTranspose, Matrix, Op, Vector, + matrix::sparsity::MatrixSparsity, LinearOp, LinearOpTranspose, Matrix, Op, }; -use super::OpStatistics; +use super::{BuilderOp, OpStatistics, ParameterisedOp}; pub struct LinearClosureWithAdjoint where @@ -18,7 +18,6 @@ where nstates: usize, nout: usize, nparams: usize, - p: Rc, coloring: Option>, sparsity: Option, coloring_adjoint: Option>, @@ -32,8 +31,7 @@ where F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { - pub fn new(func: F, func_adjoint: G, nstates: usize, nout: usize, p: Rc) -> Self { - let nparams = p.len(); + pub fn new(func: F, func_adjoint: G, nstates: usize, nout: usize, nparams: usize) -> Self { Self { func, func_adjoint, @@ -41,7 +39,6 @@ where statistics: RefCell::new(OpStatistics::default()), nout, nparams, - p, coloring: None, sparsity: None, coloring_adjoint: None, @@ -49,8 +46,9 @@ where } } - pub fn calculate_sparsity(&mut self, t0: M::T) { - let non_zeros = find_matrix_non_zeros(self, t0); + pub fn calculate_sparsity(&mut self, t0: M::T, p: &M::V) { + let op = ParameterisedOp { op: self, p }; + let non_zeros = find_matrix_non_zeros(&op, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -60,8 +58,9 @@ where &non_zeros, )); } - pub fn calculate_adjoint_sparsity(&mut self, t0: M::T) { - let non_zeros = find_transpose_non_zeros(self, t0); + pub fn calculate_adjoint_sparsity(&mut self, t0: M::T, p: &M::V) { + let op = ParameterisedOp { op: self, p }; + let non_zeros = find_transpose_non_zeros(&op, t0); self.sparsity_adjoint = Some( MatrixSparsity::try_from_indices(self.nstates, self.nout, non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -92,51 +91,67 @@ where self.nparams } - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams); - self.p = p; - } - fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } } -impl LinearOp for LinearClosureWithAdjoint +impl BuilderOp for LinearClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), +{ + fn calculate_sparsity(&mut self, _y0: &Self::V, t0: Self::T, p: &Self::V) { + self.calculate_sparsity(t0, p); + self.calculate_adjoint_sparsity(t0, p); + } + fn set_nout(&mut self, nout: usize) { + self.nout = nout; + } + fn set_nparams(&mut self, nparams: usize) { + self.nparams = nparams; + } + fn set_nstates(&mut self, nstates: usize) { + self.nstates = nstates; + } +} + +impl LinearOp for ParameterisedOp<'_, LinearClosureWithAdjoint> where M: Matrix, F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { fn gemv_inplace(&self, x: &M::V, t: M::T, beta: M::T, y: &mut M::V) { - self.statistics.borrow_mut().increment_call(); - (self.func)(x, self.p.as_ref(), t, beta, y) + self.op.statistics.borrow_mut().increment_call(); + (self.op.func)(x, self.p, t, beta, y) } fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { - self.statistics.borrow_mut().increment_matrix(); - if let Some(coloring) = &self.coloring { + self.op.statistics.borrow_mut().increment_matrix(); + if let Some(coloring) = &self.op.coloring { coloring.matrix_inplace(self, t, y); } else { self._default_matrix_inplace(t, y); } } fn sparsity(&self) -> Option<::Sparsity> { - self.sparsity.clone() + self.op.sparsity.clone() } } -impl LinearOpTranspose for LinearClosureWithAdjoint +impl LinearOpTranspose for ParameterisedOp<'_, LinearClosureWithAdjoint> where M: Matrix, F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) { - (self.func_adjoint)(x, self.p.as_ref(), t, beta, y) + (self.op.func_adjoint)(x, self.p, t, beta, y) } fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) { - if let Some(coloring) = &self.coloring_adjoint { + if let Some(coloring) = &self.op.coloring_adjoint { coloring.matrix_inplace(self, t, y); } else { self._default_transpose_inplace(t, y); @@ -144,6 +159,6 @@ where } fn transpose_sparsity(&self) -> Option<::Sparsity> { - self.sparsity_adjoint.clone() + self.op.sparsity_adjoint.clone() } } diff --git a/src/op/mod.rs b/src/op/mod.rs index 5a82c46f..09b9e72a 100644 --- a/src/op/mod.rs +++ b/src/op/mod.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{ ConstantOp, ConstantOpSens, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, Matrix, NonLinearOp, NonLinearOpAdjoint, NonLinearOpSens, NonLinearOpSensAdjoint, Scalar, Vector, @@ -44,18 +42,48 @@ pub trait Op { fn nout(&self) -> usize; /// Return the number of parameters of the operator. - fn nparams(&self) -> usize { - 0 + fn nparams(&self) -> usize; + + /// Return statistics about the operator (e.g. how many times it was called, how many times the jacobian was computed, etc.) + fn statistics(&self) -> OpStatistics { + OpStatistics::default() } +} - /// Set the parameters of the operator to the given value. - fn set_params(&mut self, p: Rc) { - assert_eq!(p.len(), self.nparams()); +/// A wrapper for an operator that parameterises it with a parameter vector. +pub struct ParameterisedOp<'a, C: Op> { + pub op: &'a C, + pub p: &'a C::V, +} + +impl<'a, C: Op> ParameterisedOp<'a, C> { + pub fn new(op: &'a C, p: &'a C::V) -> Self { + Self { op, p } } +} - /// Return statistics about the operator (e.g. how many times it was called, how many times the jacobian was computed, etc.) +pub trait BuilderOp: Op { + fn set_nstates(&mut self, nstates: usize); + fn set_nparams(&mut self, nparams: usize); + fn set_nout(&mut self, nout: usize); + fn calculate_sparsity(&mut self, y0: &Self::V, t0: Self::T, p: &Self::V); +} + +impl Op for ParameterisedOp<'_, C> { + type V = C::V; + type T = C::T; + type M = C::M; + fn nstates(&self) -> usize { + self.op.nstates() + } + fn nout(&self) -> usize { + self.op.nout() + } + fn nparams(&self) -> usize { + self.op.nparams() + } fn statistics(&self) -> OpStatistics { - OpStatistics::default() + self.op.statistics() } } @@ -112,6 +140,24 @@ impl Op for &C { } } +impl Op for &mut C { + type T = C::T; + type V = C::V; + type M = C::M; + fn nstates(&self) -> usize { + C::nstates(*self) + } + fn nout(&self) -> usize { + C::nout(*self) + } + fn nparams(&self) -> usize { + C::nparams(*self) + } + fn statistics(&self) -> OpStatistics { + C::statistics(*self) + } +} + impl NonLinearOp for &C { fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { C::call_inplace(*self, x, t, y) diff --git a/src/op/sdirk.rs b/src/op/sdirk.rs index 8826d218..f662e53d 100644 --- a/src/op/sdirk.rs +++ b/src/op/sdirk.rs @@ -1,22 +1,21 @@ use crate::{ matrix::{MatrixRef, MatrixView}, ode_solver::equations::OdeEquations, - scale, LinearOp, Matrix, MatrixSparsity, NonLinearOpJacobian, OdeEquationsImplicit, - OdeSolverProblem, Vector, VectorRef, + scale, LinearOp, Matrix, MatrixSparsity, NonLinearOpJacobian, OdeEquationsImplicit, Vector, + VectorRef, }; use num_traits::{One, Zero}; use std::{ cell::{Ref, RefCell}, ops::Deref, ops::MulAssign, - rc::Rc, }; use super::{NonLinearOp, Op}; // callable to solve for F(y) = M (y) - h f(phi + a * y) = 0 pub struct SdirkCallable { - eqn: Rc, + pub(crate) eqn: Eqn, c: Eqn::T, h: RefCell, phi: RefCell, @@ -29,12 +28,26 @@ pub struct SdirkCallable { } impl SdirkCallable { + pub fn clone_state(&self, eqn: Eqn) -> Self { + Self { + eqn, + c: self.c, + h: RefCell::new(*self.h.borrow()), + phi: RefCell::new(self.phi.borrow().clone()), + tmp: RefCell::new(self.tmp.borrow().clone()), + rhs_jac: RefCell::new(self.rhs_jac.borrow().clone()), + mass_jac: RefCell::new(self.mass_jac.borrow().clone()), + jacobian_is_stale: RefCell::new(*self.jacobian_is_stale.borrow()), + number_of_jac_evals: RefCell::new(*self.number_of_jac_evals.borrow()), + sparsity: self.sparsity.clone(), + } + } // y = h g(phi + c * y_s) pub fn integrate_out(&self, ys: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) { self.eqn.out().unwrap().call_inplace(ys, t, y); y.mul_assign(scale(*(self.h.borrow()))); } - pub fn from_eqn(eqn: Rc, c: Eqn::T) -> Self { + pub fn new_no_jacobian(eqn: Eqn, c: Eqn::T) -> Self { let n = eqn.rhs().nstates(); let h = RefCell::new(Eqn::T::zero()); let phi = RefCell::new(::zeros(n)); @@ -58,13 +71,12 @@ impl SdirkCallable { } } - pub fn eqn_mut(&mut self) -> &mut Rc { + pub fn eqn_mut(&mut self) -> &mut Eqn { &mut self.eqn } - pub fn new(ode_problem: &OdeSolverProblem, c: Eqn::T) -> Self { - let eqn = ode_problem.eqn.clone(); - let n = ode_problem.eqn.rhs().nstates(); + pub fn new(eqn: Eqn, c: Eqn::T) -> Self { + let n = eqn.rhs().nstates(); let h = RefCell::new(Eqn::T::zero()); let phi = RefCell::new(::zeros(n)); let jacobian_is_stale = RefCell::new(true); @@ -130,7 +142,7 @@ impl SdirkCallable { pub fn get_last_f_eval(&self) -> Ref { self.tmp.borrow() } - pub fn eqn(&self) -> &Rc { + pub fn eqn(&self) -> &Eqn { &self.eqn } #[allow(dead_code)] @@ -277,7 +289,7 @@ mod tests { let c = 0.1; let h = 1.3; let phi = Vcpu::from_vec(vec![1.1, 1.2, 1.3]); - let sdirk_callable = SdirkCallable::new(&problem, c); + let sdirk_callable = SdirkCallable::new(&problem.eqn, c); sdirk_callable.set_h(h); sdirk_callable.set_phi_direct(phi); let t = 0.9; @@ -297,7 +309,7 @@ mod tests { let (problem, _soln) = exponential_decay_problem::(false); let c = 0.1; let h = 1.0; - let sdirk_callable = SdirkCallable::new(&problem, c); + let sdirk_callable = SdirkCallable::new(&problem.eqn, c); sdirk_callable.set_h(h); let phi = Vcpu::from_vec(vec![1.1, 1.2]); diff --git a/src/op/unit.rs b/src/op/unit.rs index 2056cc40..a686d100 100644 --- a/src/op/unit.rs +++ b/src/op/unit.rs @@ -6,6 +6,8 @@ use crate::{ }; use num_traits::{One, Zero}; +use super::{BuilderOp, ParameterisedOp}; + /// A dummy operator that returns the input vector. Can be used either as a [NonLinearOp] or [LinearOp]. pub struct UnitCallable { n: usize, @@ -42,49 +44,64 @@ impl Op for UnitCallable { } } -impl LinearOp for UnitCallable { +impl BuilderOp for UnitCallable { + fn calculate_sparsity(&mut self, _y0: &Self::V, _t0: Self::T, _p: &Self::V) { + // Do nothing + } + fn set_nout(&mut self, nout: usize) { + self.n = nout; + } + fn set_nparams(&mut self, _nparams: usize) { + // Do nothing + } + fn set_nstates(&mut self, nstates: usize) { + self.n = nstates; + } +} + +impl LinearOp for ParameterisedOp<'_, UnitCallable> { fn gemv_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) { y.axpy(Self::T::one(), x, beta); } } -impl NonLinearOp for UnitCallable { +impl NonLinearOp for ParameterisedOp<'_, UnitCallable> { fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) { y.copy_from(x); } } -impl NonLinearOpJacobian for UnitCallable { +impl NonLinearOpJacobian for ParameterisedOp<'_, UnitCallable> { fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) { y.copy_from(v); } } -impl NonLinearOpAdjoint for UnitCallable { +impl NonLinearOpAdjoint for ParameterisedOp<'_, UnitCallable> { fn jac_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) { y.copy_from(v); } } -impl NonLinearOpSens for UnitCallable { +impl NonLinearOpSens for ParameterisedOp<'_, UnitCallable> { fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { y.fill(Self::T::zero()); } } -impl NonLinearOpSensAdjoint for UnitCallable { +impl NonLinearOpSensAdjoint for ParameterisedOp<'_, UnitCallable> { fn sens_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { y.fill(Self::T::zero()); } } -impl LinearOpSens for UnitCallable { +impl LinearOpSens for ParameterisedOp<'_, UnitCallable> { fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { y.fill(Self::T::zero()); } } -impl LinearOpTranspose for UnitCallable { +impl LinearOpTranspose for ParameterisedOp<'_, UnitCallable> { fn gemv_transpose_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) { y.axpy(Self::T::one(), x, beta); } diff --git a/src/vector/sundials.rs b/src/vector/sundials.rs index cbe9a723..c0cfdee7 100644 --- a/src/vector/sundials.rs +++ b/src/vector/sundials.rs @@ -470,8 +470,14 @@ impl VectorIndex for SundialsIndexVector { } impl Vector for SundialsVector { - type View<'a> = SundialsVectorView<'a> where Self: 'a; - type ViewMut<'a> = SundialsVectorViewMut<'a> where Self: 'a; + type View<'a> + = SundialsVectorView<'a> + where + Self: 'a; + type ViewMut<'a> + = SundialsVectorViewMut<'a> + where + Self: 'a; type Index = SundialsIndexVector; fn len(&self) -> IndexType { unsafe { N_VGetLength_Serial(self.sundials_vector()) as IndexType }