From 29a29aec1bffd3685b905442f1915a4e9d366bd0 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Sun, 20 Oct 2024 07:57:45 +0100 Subject: [PATCH] feat: adjoint equations (#99) This is a big set of features, most of which was required to support solving the adjoint equations: * Generalises the sensitivitity and adjoint equations into `AugmentedOdeEquations`, a set of n equations with the same jacobian as the main equations * Splits the `OdeEquations` into multiple traits `OdeEquations` (for explicit solvers), `OdeEquationsImplicit` (for implicit solvers, currently all of them), `OdeEquationsSens` (if you want to calculate forward sensitivitites), `OdeEquationsAdjoint` (if you want to calculate adjoint sensitivities) * Solvers can now work in reverse time (many equations will be unstable, but this is used to solve the adjoint equations) * (internal) nonlinear and linear solvers no longer have an `Op` as a generic parameter, so they can be reused for different operators (as long as the number of rows/cols are the same) * Solvers can be checkpointed so they can be re-started from a checkpoint * Hermite interpolation struct added for saving an interpolating along a solution trajectory * Solvers can integrate an output function along the solution trajectory * output function integration, adjoint and forward sensitivity integration can be added or removed from error control for all the solvers. Tolerances can be set via the builder or `OdeSolverEquation` structs. --- .github/workflows/rust.yml | 2 - benches/ode_solvers.rs | 34 +- benches/plot.py | 13 +- src/error.rs | 4 +- src/jacobian/mod.rs | 168 +++- src/lib.rs | 85 +- src/linear_solver/faer/lu.rs | 50 +- src/linear_solver/faer/sparse_lu.rs | 47 +- src/linear_solver/mod.rs | 67 +- src/linear_solver/nalgebra/lu.rs | 52 +- src/linear_solver/suitesparse/klu.rs | 56 +- src/linear_solver/sundials.rs | 62 +- src/matrix/default_solver.rs | 6 +- src/matrix/dense_faer_serial.rs | 5 +- src/matrix/dense_nalgebra_serial.rs | 8 +- src/matrix/mod.rs | 4 +- src/matrix/sparse_faer.rs | 4 +- src/matrix/sundials.rs | 5 +- src/nonlinear_solver/convergence.rs | 7 +- src/nonlinear_solver/mod.rs | 83 +- src/nonlinear_solver/newton.rs | 85 +- src/ode_solver/adjoint_equations.rs | 724 ++++++++++++++ src/ode_solver/bdf.rs | 911 ++++++++++++------ src/ode_solver/bdf_state.rs | 246 +++-- src/ode_solver/builder.rs | 386 +++++--- src/ode_solver/checkpointing.rs | 263 +++++ src/ode_solver/diffsl.rs | 31 +- src/ode_solver/equations.rs | 174 +++- src/ode_solver/method.rs | 346 +++++-- src/ode_solver/mod.rs | 307 +++++- src/ode_solver/problem.rs | 69 +- src/ode_solver/sdirk.rs | 714 ++++++++++---- src/ode_solver/sdirk_state.rs | 114 ++- src/ode_solver/sens_equations.rs | 226 +++-- src/ode_solver/state.rs | 421 +++++--- src/ode_solver/sundials.rs | 58 +- src/ode_solver/test.rs | 58 ++ src/ode_solver/test_models/dydt_y2.rs | 7 +- .../test_models/exponential_decay.rs | 238 ++++- .../exponential_decay_with_algebraic.rs | 317 +++++- src/ode_solver/test_models/foodweb.rs | 54 +- src/ode_solver/test_models/gaussian_decay.rs | 7 +- src/ode_solver/test_models/heat2d.rs | 14 +- src/ode_solver/test_models/mod.rs | 1 - src/ode_solver/test_models/robertson.rs | 169 +++- src/ode_solver/test_models/robertson_ode.rs | 7 +- .../test_models/robertson_ode_with_sens.rs | 5 +- src/ode_solver/test_models/robertson_sens.rs | 174 ---- src/op/bdf.rs | 63 +- src/op/closure.rs | 16 +- src/op/closure_no_jac.rs | 7 +- src/op/closure_with_adjoint.rs | 218 +++++ src/op/closure_with_sens.rs | 60 +- src/op/constant_closure.rs | 4 +- src/op/constant_closure_with_adjoint.rs | 89 ++ src/op/constant_closure_with_sens.rs | 15 +- src/op/constant_op.rs | 54 ++ src/op/init.rs | 16 +- src/op/linear_closure.rs | 10 +- ...sens.rs => linear_closure_with_adjoint.rs} | 67 +- src/op/linear_op.rs | 124 +++ src/op/linearise.rs | 12 +- src/op/matrix.rs | 4 +- src/op/mod.rs | 291 +----- src/op/nonlinear_op.rs | 175 ++++ src/op/sdirk.rs | 22 +- src/op/unit.rs | 42 +- src/solver/mod.rs | 59 +- src/vector/faer_serial.rs | 4 + src/vector/mod.rs | 5 +- src/vector/nalgebra_serial.rs | 3 + src/vector/sundials.rs | 21 +- 72 files changed, 6057 insertions(+), 2212 deletions(-) create mode 100644 src/ode_solver/adjoint_equations.rs create mode 100644 src/ode_solver/checkpointing.rs create mode 100644 src/ode_solver/test.rs delete mode 100644 src/ode_solver/test_models/robertson_sens.rs create mode 100644 src/op/closure_with_adjoint.rs create mode 100644 src/op/constant_closure_with_adjoint.rs create mode 100644 src/op/constant_op.rs rename src/op/{linear_closure_with_sens.rs => linear_closure_with_adjoint.rs} (51%) create mode 100644 src/op/linear_op.rs create mode 100644 src/op/nonlinear_op.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 1622f9e7..3f16d7a2 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -37,8 +37,6 @@ jobs: include: - toolchain: beta os: ubuntu-latest - - toolchain: nightly - os: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/benches/ode_solvers.rs b/benches/ode_solvers.rs index e7d69725..224433d4 100644 --- a/benches/ode_solvers.rs +++ b/benches/ode_solvers.rs @@ -546,23 +546,18 @@ criterion_group!(benches, criterion_benchmark); criterion_main!(benches); mod benchmarks { - use diffsol::linear_solver::LinearSolver; use diffsol::matrix::MatrixRef; - use diffsol::op::bdf::BdfCallable; - use diffsol::op::sdirk::SdirkCallable; use diffsol::vector::VectorRef; + use diffsol::LinearSolver; use diffsol::{ - Bdf, DefaultDenseMatrix, DefaultSolver, Matrix, NewtonNonlinearSolver, OdeEquations, - OdeSolverMethod, OdeSolverProblem, Sdirk, Tableau, + Bdf, DefaultDenseMatrix, DefaultSolver, Matrix, NewtonNonlinearSolver, + OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Sdirk, Tableau, }; // bdf - pub fn bdf( - problem: &OdeSolverProblem, - t: Eqn::T, - ls: impl LinearSolver>, - ) where - Eqn: OdeEquations, + pub fn bdf(problem: &OdeSolverProblem, t: Eqn::T, ls: impl LinearSolver) + where + Eqn: OdeEquationsImplicit, Eqn::M: Matrix + DefaultSolver, Eqn::V: DefaultDenseMatrix, for<'a> &'a Eqn::V: VectorRef, @@ -570,15 +565,16 @@ mod benchmarks { { let nls = NewtonNonlinearSolver::new(ls); let mut s = Bdf::<::M, _, _>::new(nls); - let _y = s.solve(problem, t); + let state = OdeSolverState::new(problem, &s).unwrap(); + let _y = s.solve(problem, state, t); } pub fn esdirk34( problem: &OdeSolverProblem, t: Eqn::T, - linear_solver: impl LinearSolver>, + linear_solver: impl LinearSolver, ) where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, Eqn::M: Matrix + DefaultSolver, Eqn::V: DefaultDenseMatrix, for<'a> &'a Eqn::V: VectorRef, @@ -586,15 +582,16 @@ mod benchmarks { { let tableau = Tableau::<::M>::esdirk34(); let mut s = Sdirk::new(tableau, linear_solver); - let _y = s.solve(problem, t); + let state = OdeSolverState::new(problem, &s).unwrap(); + let _y = s.solve(problem, state, t); } pub fn tr_bdf2( problem: &OdeSolverProblem, t: Eqn::T, - linear_solver: impl LinearSolver>, + linear_solver: impl LinearSolver, ) where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, Eqn::M: Matrix + DefaultSolver, Eqn::V: DefaultDenseMatrix, for<'a> &'a Eqn::V: VectorRef, @@ -602,6 +599,7 @@ mod benchmarks { { let tableau = Tableau::<::M>::tr_bdf2(); let mut s = Sdirk::new(tableau, linear_solver); - let _y = s.solve(problem, t); + let state = OdeSolverState::new(problem, &s).unwrap(); + let _y = s.solve(problem, state, t); } } diff --git a/benches/plot.py b/benches/plot.py index 916d6020..2b751f97 100644 --- a/benches/plot.py +++ b/benches/plot.py @@ -17,19 +17,22 @@ "name": "robertson_ode", "reference_name": "robertson_ode_klu", "arg": [25, 100, 400, 900], - "solvers": ["faer_sparse_bdf_klu"], + #"solvers": ["faer_sparse_bdf_klu"], + "solvers": ["faer_sparse_bdf"], }, { "name": "heat2d", "reference_name": "heat2d_klu", "arg": [5, 10, 20, 30], - "solvers": ["faer_sparse_esdirk_klu", "faer_sparse_tr_bdf2_klu", "faer_sparse_bdf_klu", "faer_sparse_bdf_klu_diffsl"] + "solvers": ["faer_sparse_esdirk", "faer_sparse_tr_bdf2", "faer_sparse_bdf", "faer_sparse_bdf", "faer_sparse_bdf_diffsl"] + #"solvers": ["faer_sparse_esdirk_klu", "faer_sparse_tr_bdf2_klu", "faer_sparse_bdf_klu", "faer_sparse_bdf_klu_diffsl"] }, { "name": "foodweb", "reference_name": "foodweb_bnd", "arg": [5, 10, 20, 30], - "solvers": ["faer_sparse_esdirk_klu", "faer_sparse_tr_bdf2_klu", "faer_sparse_bdf_klu", "faer_sparse_bdf_klu_diffsl"] + "solvers": ["faer_sparse_esdirk", "faer_sparse_tr_bdf2", "faer_sparse_bdf", "faer_sparse_bdf_diffsl"] + #"solvers": ["faer_sparse_esdirk_klu", "faer_sparse_tr_bdf2_klu", "faer_sparse_bdf_klu", "faer_sparse_bdf_klu_diffsl"] }, ] estimates = {} @@ -116,5 +119,9 @@ fig1.savefig(f"{basedir}/bench_tr_bdf2_esdirk.svg") fig2.savefig(f"{basedir}/bench_bdf.svg") fig3.savefig(f"{basedir}/bench_bdf_diffsl.svg") +basedir = "." +fig1.savefig(f"{basedir}/bench_tr_bdf2_esdirk.png") +fig2.savefig(f"{basedir}/bench_bdf.png") +fig3.savefig(f"{basedir}/bench_bdf_diffsl.png") \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index f2208d54..ac80ea30 100644 --- a/src/error.rs +++ b/src/error.rs @@ -71,8 +71,8 @@ pub enum OdeSolverError { SensitivityNotSupported, #[error("Failed to get mutable reference to equations, is there a solver created with this problem?")] FailedToGetMutableReference, - #[error("atol must have length 1 or equal to the number of states")] - AtolLengthMismatch, + #[error("Builder error: {0}")] + BuilderError(String), #[error("t_eval must be increasing and all values must be greater than or equal to the current time")] StateProblemMismatch, #[error("State is not consistent with the problem equations")] diff --git a/src/jacobian/mod.rs b/src/jacobian/mod.rs index 0142fb6a..ae79ac19 100644 --- a/src/jacobian/mod.rs +++ b/src/jacobian/mod.rs @@ -1,9 +1,9 @@ use std::collections::HashSet; -use crate::op::{LinearOp, Op}; -use crate::vector::Vector; -use crate::Scalar; -use crate::{op::NonLinearOp, Matrix, MatrixSparsityRef, VectorIndex}; +use crate::{ + LinearOp, LinearOpTranspose, Matrix, MatrixSparsityRef, NonLinearOp, NonLinearOpAdjoint, + NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Op, Scalar, Vector, VectorIndex, +}; use num_traits::{One, Zero}; use self::{coloring::nonzeros2graph, greedy_coloring::color_graph_greedy}; @@ -12,48 +12,80 @@ pub mod coloring; pub mod graph; pub mod greedy_coloring; -/// Find the non-zero entries of the Jacobian matrix of a non-linear operator. -pub fn find_non_zeros_nonlinear( - op: &F, - x: &F::V, - t: F::T, -) -> Vec<(usize, usize)> { - let mut v = F::V::zeros(op.nstates()); - let mut col = F::V::zeros(op.nout()); - let mut triplets = Vec::with_capacity(op.nstates()); - for j in 0..op.nstates() { - v[j] = F::T::NAN; - op.jac_mul_inplace(x, t, &v, &mut col); - for i in 0..op.nout() { - if col[i].is_nan() { - triplets.push((i, j)); +macro_rules! gen_find_non_zeros_nonlinear { + ($name:ident, $op_fn:ident, $op_trait:ident) => { + /// Find the non-zero entries of the $name matrix of a non-linear operator. + pub fn $name( + op: &F, + x: &F::V, + t: F::T, + ) -> Vec<(usize, usize)> { + let mut v = F::V::zeros(op.nstates()); + let mut col = F::V::zeros(op.nout()); + let mut triplets = Vec::with_capacity(op.nstates()); + for j in 0..op.nstates() { + v[j] = F::T::NAN; + op.$op_fn(x, t, &v, &mut col); + for i in 0..op.nout() { + if col[i].is_nan() { + triplets.push((i, j)); + } + col[i] = F::T::zero(); + } + v[j] = F::T::zero(); } - col[i] = F::T::zero(); + triplets } - v[j] = F::T::zero(); - } - triplets + }; } -/// Find the non-zero entries of the matrix of a linear operator. -pub fn find_non_zeros_linear(op: &F, t: F::T) -> Vec<(usize, usize)> { - let mut v = F::V::zeros(op.nstates()); - let mut col = F::V::zeros(op.nout()); - let mut triplets = Vec::with_capacity(op.nstates()); - for j in 0..op.nstates() { - v[j] = F::T::NAN; - op.call_inplace(&v, t, &mut col); - for i in 0..op.nout() { - if col[i].is_nan() { - triplets.push((i, j)); +gen_find_non_zeros_nonlinear!( + find_jacobian_non_zeros, + jac_mul_inplace, + NonLinearOpJacobian +); +gen_find_non_zeros_nonlinear!( + find_adjoint_non_zeros, + jac_transpose_mul_inplace, + NonLinearOpAdjoint +); +gen_find_non_zeros_nonlinear!(find_sens_non_zeros, sens_mul_inplace, NonLinearOpSens); +gen_find_non_zeros_nonlinear!( + find_sens_adjoint_non_zeros, + sens_transpose_mul_inplace, + NonLinearOpSensAdjoint +); + +macro_rules! gen_find_non_zeros_linear { + ($name:ident, $op_fn:ident $(, $op_trait:tt )?) => { + /// Find the non-zero entries of the $name matrix of a non-linear operator. + pub fn $name(op: &F, t: F::T) -> Vec<(usize, usize)> { + let mut v = F::V::zeros(op.nstates()); + let mut col = F::V::zeros(op.nout()); + let mut triplets = Vec::with_capacity(op.nstates()); + for j in 0..op.nstates() { + v[j] = F::T::NAN; + op.$op_fn(&v, t, &mut col); + for i in 0..op.nout() { + if col[i].is_nan() { + triplets.push((i, j)); + } + col[i] = F::T::zero(); + } + v[j] = F::T::zero(); } - col[i] = F::T::zero(); + triplets } - v[j] = F::T::zero(); - } - triplets + }; } +gen_find_non_zeros_linear!(find_matrix_non_zeros, call_inplace); +gen_find_non_zeros_linear!( + find_transpose_non_zeros, + call_transpose_inplace, + LinearOpTranspose +); + pub struct JacobianColoring { dst_indices_per_color: Vec<::Index>, src_indices_per_color: Vec<::Index>, @@ -107,7 +139,7 @@ impl JacobianColoring { // Self::new_from_non_zeros(op, non_zeros) //} - pub fn jacobian_inplace>( + pub fn jacobian_inplace>( &self, op: &F, x: &F::V, @@ -127,6 +159,46 @@ impl JacobianColoring { } } + pub fn adjoint_inplace>( + &self, + op: &F, + x: &F::V, + t: F::T, + y: &mut F::M, + ) { + let mut v = F::V::zeros(op.nstates()); + let mut col = F::V::zeros(op.nout()); + for c in 0..self.dst_indices_per_color.len() { + let input = &self.input_indices_per_color[c]; + let dst_indices = &self.dst_indices_per_color[c]; + let src_indices = &self.src_indices_per_color[c]; + v.assign_at_indices(input, F::T::one()); + op.jac_transpose_mul_inplace(x, t, &v, &mut col); + y.set_data_with_indices(dst_indices, src_indices, &col); + v.assign_at_indices(input, F::T::zero()); + } + } + + pub fn sens_adjoint_inplace>( + &self, + op: &F, + x: &F::V, + t: F::T, + y: &mut F::M, + ) { + let mut v = F::V::zeros(op.nstates()); + let mut col = F::V::zeros(op.nout()); + for c in 0..self.dst_indices_per_color.len() { + let input = &self.input_indices_per_color[c]; + let dst_indices = &self.dst_indices_per_color[c]; + let src_indices = &self.src_indices_per_color[c]; + v.assign_at_indices(input, F::T::one()); + op.sens_transpose_mul_inplace(x, t, &v, &mut col); + y.set_data_with_indices(dst_indices, src_indices, &col); + v.assign_at_indices(input, F::T::zero()); + } + } + pub fn matrix_inplace>( &self, op: &F, @@ -151,26 +223,28 @@ impl JacobianColoring { mod tests { use std::rc::Rc; - use crate::jacobian::{find_non_zeros_linear, find_non_zeros_nonlinear, JacobianColoring}; + use crate::jacobian::{find_jacobian_non_zeros, JacobianColoring}; use crate::matrix::sparsity::MatrixSparsityRef; use crate::matrix::Matrix; use crate::op::linear_closure::LinearClosure; - use crate::op::{LinearOp, Op}; use crate::vector::Vector; use crate::{ jacobian::{coloring::nonzeros2graph, greedy_coloring::color_graph_greedy}, op::closure::Closure, + LinearOp, Op, }; - use crate::{scale, NonLinearOp, SparseColMat}; + use crate::{scale, NonLinearOpJacobian, SparseColMat}; use nalgebra::DMatrix; use num_traits::{One, Zero}; use std::ops::MulAssign; + use super::find_matrix_non_zeros; + fn helper_triplets2op_nonlinear<'a, M: Matrix + 'a>( triplets: &'a [(usize, usize, M::T)], nrows: usize, ncols: usize, - ) -> impl NonLinearOp + 'a { + ) -> impl NonLinearOpJacobian + 'a { let nstates = ncols; let nout = nrows; let f = move |x: &M::V, y: &mut M::V| { @@ -241,7 +315,7 @@ mod tests { ]; for triplets in test_triplets { let op = helper_triplets2op_nonlinear::(triplets.as_slice(), 2, 2); - let non_zeros = find_non_zeros_nonlinear(&op, &M::V::zeros(2), M::T::zero()); + let non_zeros = find_jacobian_non_zeros(&op, &M::V::zeros(2), M::T::zero()); let expect = triplets .iter() .map(|(i, j, _v)| (*i, *j)) @@ -279,7 +353,7 @@ mod tests { let expect = vec![vec![1, 1], vec![1, 2], vec![1, 1], vec![1, 2]]; for (triplets, expect) in test_triplets.iter().zip(expect) { let op = helper_triplets2op_nonlinear::(triplets.as_slice(), 2, 2); - let non_zeros = find_non_zeros_nonlinear(&op, &M::V::zeros(2), M::T::zero()); + let non_zeros = find_jacobian_non_zeros(&op, &M::V::zeros(2), M::T::zero()); let ncols = op.nstates(); let graph = nonzeros2graph(non_zeros.as_slice(), ncols); let coloring = color_graph_greedy(&graph); @@ -320,7 +394,7 @@ mod tests { let op = helper_triplets2op_nonlinear::(triplets.as_slice(), n, n); let y0 = M::V::zeros(n); let t0 = M::T::zero(); - let non_zeros = find_non_zeros_nonlinear(&op, &y0, t0); + let non_zeros = find_jacobian_non_zeros(&op, &y0, t0); let coloring = JacobianColoring::new_from_non_zeros(&op, non_zeros); let mut jac = M::new_from_sparsity(3, 3, op.sparsity().map(|s| s.to_owned())); coloring.jacobian_inplace(&op, &y0, t0, &mut jac); @@ -336,7 +410,7 @@ mod tests { for triplets in test_triplets { let op = helper_triplets2op_linear::(triplets.as_slice(), n, n); let t0 = M::T::zero(); - let non_zeros = find_non_zeros_linear(&op, t0); + let non_zeros = find_matrix_non_zeros(&op, t0); let coloring = JacobianColoring::new_from_non_zeros(&op, non_zeros); let mut jac = M::new_from_sparsity(3, 3, op.sparsity().map(|s| s.to_owned())); coloring.matrix_inplace(&op, t0, &mut jac); diff --git a/src/lib.rs b/src/lib.rs index d0b1b9c4..cf844e90 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,8 @@ //! The solver state is held in [OdeSolverState], and contains a state vector, the gradient of the state vector, the time, and the step size. You can intitialise a new state using [OdeSolverState::new], //! or create an uninitialised state using [OdeSolverState::new_without_initialise] and intitialise it manually or using the [OdeSolverState::set_consistent] and [OdeSolverState::set_step_size] methods. //! +//! To view the state within a solver, you can use the [OdeSolverMethod::state] or [OdeSolverMethod::state_mut] methods. These will return references to the state using either the [StateRef] or [StateRefMut] structs +//! //! ## The solver //! //! To solve the problem given the initial state, you need to choose a solver. DiffSol provides the following solvers: @@ -46,12 +48,13 @@ //! //! ## Sparsity pattern for Jacobians and Mass matrices //! -//! Via an implementation of [OdeEquations], the user provides the action of the jacobian on a vector `J(x) v`. By default DiffSol uses this to generate a jacobian matrix for the ODE solver. +//! Via an implementation of [OdeEquationsImplicit], the user provides the action of the jacobian on a vector `J(x) v`. By default DiffSol uses this to generate a jacobian matrix for the ODE solver. //! For sparse jacobians, DiffSol will attempt to detect the sparsity pattern of the jacobian using this function and use a sparse matrix representation internally. //! It attempts to determine the sparsity pattern of the jacobian (i.e. its non-zero values) by passing in `NaNs` for the input vector `x` and checking which elements //! of the output vector `J(x) v` are also `NaN`, using the fact that `NaN`s propagate through most operations. However, this method is not foolproof and will fail if, //! for example, your jacobian function uses any control flow that depends on the input vector. If this is the case, you can provide the jacobian matrix directly by -//! implementing the optional [NonLinearOp::jacobian_inplace] and the [LinearOp::matrix_inplace] (if applicable) functions, or by providing a sparsity pattern using the [Op::sparsity] function. +//! implementing the optional [NonLinearOpJacobian::jacobian_inplace] and the [LinearOp::matrix_inplace] (if applicable) functions, +//! or by providing a sparsity pattern using the [Op::sparsity] function. //! //! ## Events / Root finding //! @@ -60,15 +63,47 @@ //! //! ## Forward Sensitivity Analysis //! -//! DiffSol provides a way to compute the forward sensitivity of the solution with respect to the parameters. You can use this by using the [OdeBuilder::build_ode_with_sens] or [OdeBuilder::build_ode_with_mass_and_sens] builder functions. -//! Note that by default the sensitivity equations are not included in the error control for the solvers, you can change this by using the [OdeBuilder::sensitivities_error_control] method. +//! DiffSol provides a way to compute the forward sensitivity of the solution with respect to the parameters. To use this your equations struct must implement the [OdeEquationsSens] trait. +//! Note that by default the sensitivity equations are included in the error control for the solvers, you can change this by setting tolerances using the [OdeBuilder::sens_atol] and [[OdeBuilder::sens_rtol]] methods. +//! You will also need to use [SensitivitiesOdeSolverMethod::set_problem_with_sensitivities] to set the problem with sensitivities. //! //! To obtain the sensitivity solution via interpolation, you can use the [OdeSolverMethod::interpolate_sens] method. Otherwise the sensitivity vectors are stored in the [OdeSolverState] struct. //! +//! ## Checkpointing +//! +//! You can checkpoint the solver at a set of times using the [OdeSolverMethod::checkpoint] method. This will store the state of the solver at the given times, and subsequently use the [OdeSolverMethod::set_problem] +//! method to restore the solver to the state at the given time. +//! +//! ## Interpolation +//! +//! The [HermiteInterpolator] struct provides a way to interpolate a solution between a sequence of steps. If the number of steps in your solution is too large to fit in memory, +//! you can instead use checkpointing to store the solution at a reduced set of times and dynamically interpolate between these checkpoints using the [Checkpointing] struct +//! (at the cost of recomputing the solution between the checkpoints). +//! +//! ## Quadrature and Output functions +//! +//! The [OdeSolverEquations::Out] associated type can be used to define an output function. DiffSol will optionally integrate this function over the solution trajectory by +//! using the [OdeBuilder::integrate_out] method. By default, the output integration is added to the error control of the solver, and the tolerances can be +//! adjusted using the [OdeBuilder::out_atol] and [OdeBuilder::out_rtol] methods. It can be removed from the error control by setting the tolerances to `None`. +//! +//! ## Adjoint Sensitivity Analysis +//! +//! If you require the partial gradient of the output function with respect to the parameters and your parameter vector is sufficiently large, then it is more efficient +//! to use the adjoint sensitivity method. This method uses a lagrange multiplier to derive a set of adjoint ode equations that are solved backwards in time, +//! and then used to compute the sensitivities of the output function. Checkpointing is typically used to store the forward solution at a set of times as theses are required +//! to solve the adjoint equations. +//! +//! To use the adjoint sensitivity method, your equations struct must implement the [OdeEquationsAdjoint] trait. When you compute the forward solution, use checkpointing +//! to store the solution at a set of times. From this you should obtain a `Vec` (that can be the start and end of the solution), and +//! a [HermiteInterpolator] that can be used to interpolate the solution between the last two checkpoints. You can then use the [AdjointOdeSolverMethod::into_adjoint_solver] +//! method to create an adjoint solver from the forward solver, and then use this solver to step the adjoint equations backwards in time. Once the adjoint equations have been solved, +//! the sensitivities of the output function will be stored in the [StateRef::sg] field of the adjoint solver state. If your parameters are used to calculate the initial conditions +//! of the forward problem, then you will need to use the [AdjointEquations::correct_sg_for_init] method to correct the sensitivities for the initial conditions. +//! //! ## Nonlinear and linear solvers //! //! DiffSol provides generic nonlinear and linear solvers that are used internally by the ODE solver. You can use the solvers provided by DiffSol, or implement your own following the provided traits. -//! The linear solver trait is [LinearSolver], and the nonlinear solver trait is [NonLinearSolver]. The [SolverProblem] struct is used to define the problem to solve. +//! The linear solver trait is [LinearSolver], and the nonlinear solver trait is [NonLinearSolver]. //! //! The provided linear solvers are: //! - [NalgebraLU]: a direct solver that uses the LU decomposition implemented in the [nalgebra](https://nalgebra.org) library. @@ -117,7 +152,7 @@ pub mod vector; #[cfg(feature = "sundials")] pub mod sundials_sys; -use linear_solver::LinearSolver; +pub use linear_solver::LinearSolver; pub use linear_solver::{faer::sparse_lu::FaerSparseLU, FaerLU, NalgebraLU}; pub use matrix::sparse_faer::SparseColMat; @@ -140,36 +175,48 @@ pub use linear_solver::suitesparse::klu::KLU; #[cfg(feature = "diffsl")] pub use ode_solver::diffsl::DiffSlContext; -pub use jacobian::{find_non_zeros_linear, find_non_zeros_nonlinear, JacobianColoring}; +pub use jacobian::{ + find_adjoint_non_zeros, find_jacobian_non_zeros, find_matrix_non_zeros, + find_sens_adjoint_non_zeros, find_sens_non_zeros, find_transpose_non_zeros, JacobianColoring, +}; pub use matrix::{default_solver::DefaultSolver, Matrix}; use matrix::{ sparsity::Dense, sparsity::DenseRef, sparsity::MatrixSparsity, sparsity::MatrixSparsityRef, DenseMatrix, MatrixCommon, MatrixRef, MatrixView, MatrixViewMut, }; -pub use nonlinear_solver::newton::NewtonNonlinearSolver; use nonlinear_solver::{ - convergence::Convergence, convergence::ConvergenceStatus, newton::newton_iteration, - root::RootFinder, NonLinearSolver, + convergence::Convergence, convergence::ConvergenceStatus, root::RootFinder, }; +pub use nonlinear_solver::{newton::NewtonNonlinearSolver, NonLinearSolver}; use ode_solver::jacobian_update::JacobianUpdate; +pub use ode_solver::state::{StateRef, StateRefMut}; pub use ode_solver::{ - bdf::Bdf, bdf_state::BdfState, builder::OdeBuilder, equations::OdeEquations, - equations::OdeSolverEquations, method::OdeSolverMethod, method::OdeSolverStopReason, - problem::OdeSolverProblem, sdirk::Sdirk, sdirk_state::SdirkState, - sens_equations::SensEquations, sens_equations::SensInit, sens_equations::SensRhs, - state::OdeSolverState, tableau::Tableau, + adjoint_equations::AdjointContext, adjoint_equations::AdjointEquations, + adjoint_equations::AdjointInit, adjoint_equations::AdjointRhs, bdf::Bdf, bdf::BdfAdj, + bdf_state::BdfState, builder::OdeBuilder, checkpointing::Checkpointing, + checkpointing::HermiteInterpolator, equations::AugmentedOdeEquations, + equations::AugmentedOdeEquationsImplicit, equations::NoAug, equations::OdeEquations, + equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit, equations::OdeEquationsSens, + equations::OdeSolverEquations, method::AdjointOdeSolverMethod, method::OdeSolverMethod, + method::OdeSolverStopReason, method::SensitivitiesOdeSolverMethod, problem::OdeSolverProblem, + sdirk::Sdirk, sdirk::SdirkAdj, sdirk_state::SdirkState, sens_equations::SensEquations, + sens_equations::SensInit, sens_equations::SensRhs, state::OdeSolverState, tableau::Tableau, +}; +use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint}; +use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose}; +pub use op::nonlinear_op::{ + NonLinearOp, NonLinearOpAdjoint, NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, }; pub use op::{ - closure::Closure, constant_closure::ConstantClosure, linear_closure::LinearClosure, - unit::UnitCallable, ConstantOp, LinearOp, NonLinearOp, Op, + closure::Closure, closure_with_adjoint::ClosureWithAdjoint, constant_closure::ConstantClosure, + constant_closure_with_adjoint::ConstantClosureWithAdjoint, linear_closure::LinearClosure, + unit::UnitCallable, Op, }; use op::{ closure_no_jac::ClosureNoJac, closure_with_sens::ClosureWithSens, constant_closure_with_sens::ConstantClosureWithSens, init::InitOp, - linear_closure_with_sens::LinearClosureWithSens, }; use scalar::{IndexType, Scalar, Scale}; -use solver::SolverProblem; pub use vector::DefaultDenseMatrix; use vector::{Vector, VectorCommon, VectorIndex, VectorRef, VectorView, VectorViewMut}; diff --git a/src/linear_solver/faer/lu.rs b/src/linear_solver/faer/lu.rs index fe26a412..500eb328 100644 --- a/src/linear_solver/faer/lu.rs +++ b/src/linear_solver/faer/lu.rs @@ -1,48 +1,47 @@ -use crate::{error::LinearSolverError, linear_solver_error}; use std::rc::Rc; +use crate::{error::LinearSolverError, linear_solver_error}; + use crate::{ - error::DiffsolError, linear_solver::LinearSolver, op::linearise::LinearisedOp, - solver::SolverProblem, LinearOp, Matrix, MatrixSparsityRef, NonLinearOp, Op, Scalar, + error::DiffsolError, linear_solver::LinearSolver, Matrix, MatrixSparsityRef, + NonLinearOpJacobian, Scalar, }; use faer::{linalg::solvers::FullPivLu, solvers::SpSolver, Col, Mat}; /// A [LinearSolver] that uses the LU decomposition in the [`faer`](https://github.com/sarah-ek/faer-rs) library to solve the linear system. -pub struct LU +pub struct LU where T: Scalar, - C: NonLinearOp, V = Col, T = T>, { lu: Option>, - problem: Option>>, matrix: Option>, } -impl Default for LU +impl Default for LU where T: Scalar, - C: NonLinearOp, V = Col, T = T>, { fn default() -> Self { Self { lu: None, - problem: None, matrix: None, } } } -impl, V = Col, T = T>> LinearSolver for LU { - fn set_linearisation(&mut self, x: &C::V, t: C::T) { - Rc::>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f) - .unwrap() - .set_x(x); +impl LinearSolver> for LU { + fn set_linearisation, M = Mat>>( + &mut self, + op: &C, + x: &Col, + t: T, + ) { let matrix = self.matrix.as_mut().expect("Matrix not set"); - self.problem.as_ref().unwrap().f.matrix_inplace(t, matrix); + op.jacobian_inplace(x, t, matrix); self.lu = Some(matrix.full_piv_lu()); } - fn solve_in_place(&self, x: &mut C::V) -> Result<(), DiffsolError> { + fn solve_in_place(&self, x: &mut Col) -> Result<(), DiffsolError> { if self.lu.is_none() { return Err(linear_solver_error!(LuNotInitialized))?; } @@ -51,16 +50,15 @@ impl, V = Col, T = T>> LinearSolver f Ok(()) } - fn set_problem(&mut self, problem: &SolverProblem) { - let linearised_problem = problem.linearise(); - let ncols = linearised_problem.f.nstates(); - let nrows = linearised_problem.f.nout(); - let matrix = C::M::new_from_sparsity( - nrows, - ncols, - linearised_problem.f.sparsity().map(|s| s.to_owned()), - ); - self.problem = Some(linearised_problem); + fn set_problem, M = Mat>>( + &mut self, + op: &C, + _rtol: T, + _atol: Rc>, + ) { + let ncols = op.nstates(); + let nrows = op.nout(); + let matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned())); self.matrix = Some(matrix); } } diff --git a/src/linear_solver/faer/sparse_lu.rs b/src/linear_solver/faer/sparse_lu.rs index dcb922ad..719d216d 100644 --- a/src/linear_solver/faer/sparse_lu.rs +++ b/src/linear_solver/faer/sparse_lu.rs @@ -5,10 +5,8 @@ use crate::{ linear_solver::LinearSolver, linear_solver_error, matrix::sparsity::MatrixSparsityRef, - op::linearise::LinearisedOp, scalar::IndexType, - solver::SolverProblem, - LinearOp, Matrix, NonLinearOp, Op, Scalar, SparseColMat, + Matrix, NonLinearOpJacobian, Scalar, SparseColMat, }; use faer::{ @@ -18,41 +16,37 @@ use faer::{ }; /// A [LinearSolver] that uses the LU decomposition in the [`faer`](https://github.com/sarah-ek/faer-rs) library to solve the linear system. -pub struct FaerSparseLU +pub struct FaerSparseLU where T: Scalar, - C: NonLinearOp, V = Col, T = T>, { lu: Option>, lu_symbolic: Option>, - problem: Option>>, matrix: Option>, } -impl Default for FaerSparseLU +impl Default for FaerSparseLU where T: Scalar, - C: NonLinearOp, V = Col, T = T>, { fn default() -> Self { Self { lu: None, - problem: None, matrix: None, lu_symbolic: None, } } } -impl, V = Col, T = T>> LinearSolver - for FaerSparseLU -{ - fn set_linearisation(&mut self, x: &C::V, t: C::T) { - Rc::>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f) - .unwrap() - .set_x(x); +impl LinearSolver> for FaerSparseLU { + fn set_linearisation, M = SparseColMat>>( + &mut self, + op: &C, + x: &Col, + t: T, + ) { let matrix = self.matrix.as_mut().expect("Matrix not set"); - self.problem.as_ref().unwrap().f.matrix_inplace(t, matrix); + op.jacobian_inplace(x, t, matrix); self.lu = Some( Lu::try_new_with_symbolic( self.lu_symbolic.as_ref().unwrap().clone(), @@ -62,7 +56,7 @@ impl, V = Col, T = T>> LinearSo ) } - fn solve_in_place(&self, x: &mut C::V) -> Result<(), DiffsolError> { + fn solve_in_place(&self, x: &mut Col) -> Result<(), DiffsolError> { if self.lu.is_none() { return Err(linear_solver_error!(LuNotInitialized))?; } @@ -71,19 +65,20 @@ impl, V = Col, T = T>> LinearSo Ok(()) } - fn set_problem(&mut self, problem: &SolverProblem) { - let linearised_problem = problem.linearise(); - let ncols = linearised_problem.f.nstates(); - let nrows = linearised_problem.f.nout(); + fn set_problem, M = SparseColMat>>( + &mut self, + op: &C, + _rtol: T, + _atol: Rc>, + ) { + let ncols = op.nstates(); + let nrows = op.nout(); let matrix = C::M::new_from_sparsity( nrows, ncols, - linearised_problem - .f - .sparsity() + op.sparsity() .map(|s| MatrixSparsityRef::>::to_owned(&s)), ); - self.problem = Some(linearised_problem); self.matrix = Some(matrix); self.lu_symbolic = Some( SymbolicLu::try_new(self.matrix.as_ref().unwrap().faer().symbolic()) diff --git a/src/linear_solver/mod.rs b/src/linear_solver/mod.rs index 01245084..b583a0db 100644 --- a/src/linear_solver/mod.rs +++ b/src/linear_solver/mod.rs @@ -1,4 +1,6 @@ -use crate::{error::DiffsolError, op::Op, solver::SolverProblem}; +use std::rc::Rc; + +use crate::{error::DiffsolError, Matrix, NonLinearOpJacobian}; #[cfg(feature = "nalgebra")] pub mod nalgebra; @@ -16,23 +18,35 @@ pub use faer::lu::LU as FaerLU; pub use nalgebra::lu::LU as NalgebraLU; /// A solver for the linear problem `Ax = b`, where `A` is a linear operator that is obtained by taking the linearisation of a nonlinear operator `C` -pub trait LinearSolver { +pub trait LinearSolver: Default { + // sets the point at which the linearisation of the operator is evaluated + // the operator is assumed to have the same sparsity as that given to [Self::set_problem] + fn set_linearisation>( + &mut self, + op: &C, + x: &M::V, + t: M::T, + ); + /// Set the problem to be solved, any previous problem is discarded. /// Any internal state of the solver is reset. - fn set_problem(&mut self, problem: &SolverProblem); - - // sets the point at which the linearisation of the operator is evaluated - fn set_linearisation(&mut self, x: &C::V, t: C::T); + /// This function will normally set the sparsity pattern of the matrix to be solved. + fn set_problem>( + &mut self, + op: &C, + rtol: M::T, + atol: Rc, + ); /// Solve the problem `Ax = b` and return the solution `x`. /// panics if [Self::set_linearisation] has not been called previously - fn solve(&self, b: &C::V) -> Result { + fn solve(&self, b: &M::V) -> Result { let mut b = b.clone(); self.solve_in_place(&mut b)?; Ok(b) } - fn solve_in_place(&self, b: &mut C::V) -> Result<(), DiffsolError>; + fn solve_in_place(&self, b: &mut M::V) -> Result<(), DiffsolError>; } pub struct LinearSolveSolution { @@ -52,17 +66,20 @@ pub mod tests { use crate::{ linear_solver::{FaerLU, NalgebraLU}, - op::{closure::Closure, NonLinearOp}, + op::closure::Closure, scalar::scale, vector::VectorRef, - LinearSolver, Matrix, SolverProblem, Vector, + LinearSolver, Matrix, NonLinearOpJacobian, Vector, }; use num_traits::{One, Zero}; use super::LinearSolveSolution; + #[allow(clippy::type_complexity)] pub fn linear_problem() -> ( - SolverProblem>, + impl NonLinearOpJacobian, + M::T, + Rc, Vec>, ) { let diagonal = M::V::from_vec(vec![2.0.into(), 2.0.into()]); @@ -78,32 +95,32 @@ pub mod tests { p, ); op.calculate_sparsity(&M::V::from_element(2, M::T::one()), M::T::zero()); - let op = Rc::new(op); let rtol = M::T::from(1e-6); let atol = Rc::new(M::V::from_vec(vec![1e-6.into(), 1e-6.into()])); - let problem = SolverProblem::new(op, atol, rtol); let solns = vec![LinearSolveSolution::new( M::V::from_vec(vec![2.0.into(), 4.0.into()]), M::V::from_vec(vec![1.0.into(), 2.0.into()]), )]; - (problem, solns) + (op, rtol, atol, solns) } pub fn test_linear_solver( - mut solver: impl LinearSolver, - problem: SolverProblem, + mut solver: impl LinearSolver, + op: C, + rtol: C::T, + atol: Rc, solns: Vec>, ) where - C: NonLinearOp, + C: NonLinearOpJacobian, for<'a> &'a C::V: VectorRef, { - solver.set_problem(&problem); - let x = C::V::zeros(problem.f.nout()); + solver.set_problem(&op, rtol, atol.clone()); + let x = C::V::zeros(op.nout()); let t = C::T::zero(); - solver.set_linearisation(&x, t); + solver.set_linearisation(&op, &x, t); for soln in solns { let x = solver.solve(&soln.b).unwrap(); - let tol = { &soln.x * scale(problem.rtol) + problem.atol.as_ref() }; + let tol = { &soln.x * scale(rtol) + atol.as_ref() }; x.assert_eq(&soln.x, &tol); } } @@ -113,14 +130,14 @@ pub mod tests { #[test] fn test_lu_nalgebra() { - let (p, solns) = linear_problem::(); + let (op, rtol, atol, solns) = linear_problem::(); let s = NalgebraLU::default(); - test_linear_solver(s, p, solns); + test_linear_solver(s, op, rtol, atol, solns); } #[test] fn test_lu_faer() { - let (p, solns) = linear_problem::(); + let (op, rtol, atol, solns) = linear_problem::(); let s = FaerLU::default(); - test_linear_solver(s, p, solns); + test_linear_solver(s, op, rtol, atol, solns); } } diff --git a/src/linear_solver/nalgebra/lu.rs b/src/linear_solver/nalgebra/lu.rs index ce9c74f3..48c0b67f 100644 --- a/src/linear_solver/nalgebra/lu.rs +++ b/src/linear_solver/nalgebra/lu.rs @@ -1,43 +1,38 @@ -use nalgebra::{DMatrix, DVector, Dyn}; use std::rc::Rc; +use nalgebra::{DMatrix, DVector, Dyn}; + use crate::{ error::{DiffsolError, LinearSolverError}, linear_solver_error, matrix::sparsity::MatrixSparsityRef, - op::{linearise::LinearisedOp, NonLinearOp}, - LinearOp, LinearSolver, Matrix, Op, Scalar, SolverProblem, + LinearSolver, Matrix, NonLinearOpJacobian, Scalar, }; /// A [LinearSolver] that uses the LU decomposition in the [`nalgebra` library](https://nalgebra.org/) to solve the linear system. -pub struct LU +#[derive(Clone)] +pub struct LU where T: Scalar, - C: NonLinearOp, V = DVector, T = T>, { matrix: Option>, lu: Option>, - problem: Option>>, } -impl Default for LU +impl Default for LU where T: Scalar, - C: NonLinearOp, V = DVector, T = T>, { fn default() -> Self { Self { lu: None, - problem: None, matrix: None, } } } -impl, V = DVector, T = T>> LinearSolver - for LU -{ - fn solve_in_place(&self, state: &mut C::V) -> Result<(), DiffsolError> { +impl LinearSolver> for LU { + fn solve_in_place(&self, state: &mut DVector) -> Result<(), DiffsolError> { if self.lu.is_none() { return Err(linear_solver_error!(LuNotInitialized))?; } @@ -48,25 +43,26 @@ impl, V = DVector, T = T>> LinearSol } } - fn set_linearisation(&mut self, x: &::V, t: ::T) { - Rc::>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f) - .unwrap() - .set_x(x); + fn set_linearisation, M = DMatrix>>( + &mut self, + op: &C, + x: &DVector, + t: T, + ) { let matrix = self.matrix.as_mut().expect("Matrix not set"); - self.problem.as_ref().unwrap().f.matrix_inplace(t, matrix); + op.jacobian_inplace(x, t, matrix); self.lu = Some(matrix.clone().lu()); } - fn set_problem(&mut self, problem: &SolverProblem) { - let linearised_problem = problem.linearise(); - let ncols = linearised_problem.f.nstates(); - let nrows = linearised_problem.f.nout(); - let matrix = C::M::new_from_sparsity( - nrows, - ncols, - linearised_problem.f.sparsity().map(|s| s.to_owned()), - ); - self.problem = Some(linearised_problem); + fn set_problem, M = DMatrix>>( + &mut self, + op: &C, + _rtol: T, + _atol: Rc>, + ) { + let ncols = op.nstates(); + let nrows = op.nout(); + let matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned())); self.matrix = Some(matrix); } } diff --git a/src/linear_solver/suitesparse/klu.rs b/src/linear_solver/suitesparse/klu.rs index 0b5dc260..05939424 100644 --- a/src/linear_solver/suitesparse/klu.rs +++ b/src/linear_solver/suitesparse/klu.rs @@ -1,4 +1,5 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; +use std::rc::Rc; use faer::Col; @@ -27,9 +28,8 @@ use crate::{ linear_solver::LinearSolver, linear_solver_error, matrix::MatrixCommon, - op::linearise::LinearisedOp, vector::Vector, - LinearOp, Matrix, MatrixSparsityRef, NonLinearOp, Op, SolverProblem, SparseColMat, + Matrix, MatrixSparsityRef, NonLinearOpJacobian, SparseColMat, }; trait MatrixKLU: Matrix { @@ -109,6 +109,7 @@ impl KluNumeric { symbolic: &mut KluSymbolic, mat: &mut impl MatrixKLU, ) -> Result { + // TODO: there is also klu_refactor which is faster and reuses inner let inner = unsafe { klu_factor( mat.column_pointers_mut_ptr(), @@ -136,6 +137,7 @@ impl Drop for KluNumeric { } } +#[derive(Clone)] struct KluCommon { inner: klu_common, } @@ -154,22 +156,19 @@ impl KluCommon { } } -pub struct KLU +pub struct KLU where M: Matrix, - C: NonLinearOp, { klu_common: RefCell, klu_symbolic: Option, klu_numeric: Option, - problem: Option>>, matrix: Option, } -impl Default for KLU +impl Default for KLU where M: Matrix, - C: NonLinearOp, { fn default() -> Self { let klu_common = KluCommon::default(); @@ -178,24 +177,24 @@ where klu_common, klu_numeric: None, klu_symbolic: None, - problem: None, matrix: None, } } } -impl LinearSolver for KLU +impl LinearSolver for KLU where M: MatrixKLU, M::V: VectorKLU, - C: NonLinearOp, { - fn set_linearisation(&mut self, x: &C::V, t: C::T) { - Rc::>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f) - .unwrap() - .set_x(x); + fn set_linearisation>( + &mut self, + op: &C, + x: &M::V, + t: M::T, + ) { let matrix = self.matrix.as_mut().expect("Matrix not set"); - self.problem.as_ref().unwrap().f.matrix_inplace(t, matrix); + op.jacobian_inplace(x, t, matrix); self.klu_numeric = KluNumeric::try_from_symbolic( self.klu_symbolic.as_mut().expect("Symbolic not set"), matrix, @@ -203,7 +202,7 @@ where .ok(); } - fn solve_in_place(&self, x: &mut C::V) -> Result<(), DiffsolError> { + fn solve_in_place(&self, x: &mut M::V) -> Result<(), DiffsolError> { if self.klu_numeric.is_none() { return Err(linear_solver_error!(LuNotInitialized)); } @@ -224,16 +223,15 @@ where Ok(()) } - fn set_problem(&mut self, problem: &SolverProblem) { - let linearised_problem = problem.linearise(); - let ncols = linearised_problem.f.nstates(); - let nrows = linearised_problem.f.nout(); - let mut matrix = C::M::new_from_sparsity( - nrows, - ncols, - linearised_problem.f.sparsity().map(|s| s.to_owned()), - ); - self.problem = Some(linearised_problem); + fn set_problem>( + &mut self, + op: &C, + _rtol: M::T, + _atol: Rc, + ) { + let ncols = op.nstates(); + let nrows = op.nout(); + let mut matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned())); let mut klu_common = self.klu_common.borrow_mut(); self.klu_symbolic = KluSymbolic::try_from_matrix(&mut matrix, klu_common.as_mut()).ok(); self.matrix = Some(matrix); @@ -251,8 +249,8 @@ mod tests { #[test] fn test_klu() { - let (p, solns) = linear_problem::>(); + let (op, rtol, atol, solns) = linear_problem::>(); let s = KLU::default(); - test_linear_solver(s, p, solns); + test_linear_solver(s, op, rtol, atol, solns); } } diff --git a/src/linear_solver/sundials.rs b/src/linear_solver/sundials.rs index 3d1b1276..0e1179aa 100644 --- a/src/linear_solver/sundials.rs +++ b/src/linear_solver/sundials.rs @@ -6,8 +6,7 @@ use crate::sundials_sys::{ use crate::{ error::*, linear_solver_error, ode_solver::sundials::sundials_check, - op::linearise::LinearisedOp, vector::sundials::SundialsVector, LinearOp, Matrix, NonLinearOp, - Op, SolverProblem, SundialsMatrix, + vector::sundials::SundialsVector, Matrix, NonLinearOpJacobian, SundialsMatrix, }; #[cfg(not(sundials_version_major = "5"))] @@ -15,43 +14,29 @@ use crate::vector::sundials::get_suncontext; use super::LinearSolver; -pub struct SundialsLinearSolver -where - Op: NonLinearOp, -{ +pub struct SundialsLinearSolver { linear_solver: Option, - problem: Option>>, is_setup: bool, matrix: Option, } -impl Default for SundialsLinearSolver -where - Op: NonLinearOp, -{ +impl Default for SundialsLinearSolver { fn default() -> Self { Self::new_dense() } } -impl SundialsLinearSolver -where - Op: NonLinearOp, -{ +impl SundialsLinearSolver { pub fn new_dense() -> Self { Self { linear_solver: None, - problem: None, is_setup: false, matrix: None, } } } -impl Drop for SundialsLinearSolver -where - Op: NonLinearOp, -{ +impl Drop for SundialsLinearSolver { fn drop(&mut self) { if let Some(linear_solver) = self.linear_solver { unsafe { SUNLinSolFree(linear_solver) }; @@ -59,17 +44,15 @@ where } } -impl LinearSolver for SundialsLinearSolver -where - Op: NonLinearOp, -{ - fn set_problem(&mut self, problem: &SolverProblem) { - let linearised_problem = problem.linearise(); - let matrix = SundialsMatrix::zeros( - linearised_problem.f.nstates(), - linearised_problem.f.nstates(), - ); - let y0 = SundialsVector::new_serial(linearised_problem.f.nstates()); +impl LinearSolver for SundialsLinearSolver { + fn set_problem>( + &mut self, + op: &C, + _rtol: realtype, + _atol: Rc, + ) { + let matrix = SundialsMatrix::zeros(op.nstates(), op.nstates()); + let y0 = SundialsVector::new_serial(op.nstates()); #[cfg(not(sundials_version_major = "5"))] let linear_solver = { @@ -82,22 +65,25 @@ where unsafe { SUNLinSol_Dense(y0.sundials_vector(), matrix.sundials_matrix()) }; self.matrix = Some(matrix); - self.problem = Some(linearised_problem); self.linear_solver = Some(linear_solver); } - fn set_linearisation(&mut self, x: &Op::V, t: Op::T) { - Rc::>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f) - .unwrap() - .set_x(x); + fn set_linearisation< + C: NonLinearOpJacobian, + >( + &mut self, + op: &C, + x: &SundialsVector, + t: realtype, + ) { let matrix = self.matrix.as_mut().expect("Matrix not set"); let linear_solver = self.linear_solver.expect("Linear solver not set"); - self.problem.as_ref().unwrap().f.matrix_inplace(t, matrix); + op.jacobian_inplace(x, t, matrix); sundials_check(unsafe { SUNLinSolSetup(linear_solver, matrix.sundials_matrix()) }).unwrap(); self.is_setup = true; } - fn solve_in_place(&self, b: &mut Op::V) -> Result<(), DiffsolError> { + fn solve_in_place(&self, b: &mut SundialsVector) -> Result<(), DiffsolError> { if !self.is_setup { return Err(linear_solver_error!(LinearSolverNotSetup)); } diff --git a/src/matrix/default_solver.rs b/src/matrix/default_solver.rs index 48fea375..cfccceda 100644 --- a/src/matrix/default_solver.rs +++ b/src/matrix/default_solver.rs @@ -1,10 +1,10 @@ -use crate::{LinearSolver, NonLinearOp}; +use crate::LinearSolver; use super::Matrix; pub trait DefaultSolver: Matrix { - type LS>: LinearSolver + Default; - fn default_solver>() -> Self::LS { + type LS: LinearSolver + Default; + fn default_solver() -> Self::LS { Self::LS::default() } } diff --git a/src/matrix/dense_faer_serial.rs b/src/matrix/dense_faer_serial.rs index 2dcfed36..822bf0d8 100644 --- a/src/matrix/dense_faer_serial.rs +++ b/src/matrix/dense_faer_serial.rs @@ -3,7 +3,6 @@ use std::ops::{AddAssign, Mul, MulAssign}; use super::default_solver::DefaultSolver; use super::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut}; use crate::error::DiffsolError; -use crate::op::NonLinearOp; use crate::scalar::{IndexType, Scalar, Scale}; use crate::FaerLU; use crate::{Dense, DenseRef, Vector}; @@ -15,7 +14,7 @@ use faer::{ use faer::{unzipped, zipped}; impl DefaultSolver for Mat { - type LS, V = Col, T = T>> = FaerLU; + type LS = FaerLU; } macro_rules! impl_matrix_common { @@ -54,7 +53,7 @@ impl_mul_scale!(MatRef<'a, T>); impl_mul_scale!(Mat); impl_mul_scale!(&Mat); -impl<'a, T: Scalar> MulAssign> for MatMut<'a, T> { +impl MulAssign> for MatMut<'_, T> { fn mul_assign(&mut self, rhs: Scale) { let scale: faer::Scale = rhs.into(); *self *= scale; diff --git a/src/matrix/dense_nalgebra_serial.rs b/src/matrix/dense_nalgebra_serial.rs index d2940904..d834fa94 100644 --- a/src/matrix/dense_nalgebra_serial.rs +++ b/src/matrix/dense_nalgebra_serial.rs @@ -5,9 +5,7 @@ use nalgebra::{ RawStorageMut, }; -use crate::op::NonLinearOp; -use crate::vector::Vector; -use crate::{scalar::Scale, IndexType, Scalar}; +use crate::{scalar::Scale, IndexType, Scalar, Vector}; use super::default_solver::DefaultSolver; use super::sparsity::{Dense, DenseRef}; @@ -15,7 +13,7 @@ use crate::error::DiffsolError; use crate::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut, NalgebraLU}; impl DefaultSolver for DMatrix { - type LS, V = DVector, T = T>> = NalgebraLU; + type LS = NalgebraLU; } macro_rules! impl_matrix_common { @@ -60,7 +58,7 @@ macro_rules! impl_mul_scale { impl_mul_scale!(DMatrixView<'a, T>); impl_mul_scale!(DMatrix); -impl<'a, T: Scalar> MulAssign> for DMatrixViewMut<'a, T> { +impl MulAssign> for DMatrixViewMut<'_, T> { fn mul_assign(&mut self, rhs: Scale) { *self *= rhs.value(); } diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs index e481771e..854dfd06 100644 --- a/src/matrix/mod.rs +++ b/src/matrix/mod.rs @@ -32,7 +32,7 @@ pub trait MatrixCommon: Sized + Debug { fn ncols(&self) -> IndexType; } -impl<'a, M> MatrixCommon for &'a M +impl MatrixCommon for &M where M: MatrixCommon, { @@ -47,7 +47,7 @@ where } } -impl<'a, M> MatrixCommon for &'a mut M +impl MatrixCommon for &mut M where M: MatrixCommon, { diff --git a/src/matrix/sparse_faer.rs b/src/matrix/sparse_faer.rs index 62178efe..0fddb0b3 100644 --- a/src/matrix/sparse_faer.rs +++ b/src/matrix/sparse_faer.rs @@ -5,7 +5,7 @@ use super::sparsity::MatrixSparsityRef; use super::{Matrix, MatrixCommon, MatrixSparsity}; use crate::error::{DiffsolError, MatrixError}; use crate::vector::Vector; -use crate::{DefaultSolver, FaerSparseLU, IndexType, NonLinearOp, Scalar, Scale}; +use crate::{DefaultSolver, FaerSparseLU, IndexType, Scalar, Scale}; use faer::sparse::ops::{ternary_op_assign_into, union_symbolic}; use faer::sparse::{SymbolicSparseColMat, SymbolicSparseColMatRef}; @@ -37,7 +37,7 @@ impl SparseColMat { } impl DefaultSolver for SparseColMat { - type LS, V = Col, T = T>> = FaerSparseLU; + type LS = FaerSparseLU; } impl MatrixCommon for SparseColMat { diff --git a/src/matrix/sundials.rs b/src/matrix/sundials.rs index 528e6a7a..6224fe32 100644 --- a/src/matrix/sundials.rs +++ b/src/matrix/sundials.rs @@ -11,7 +11,7 @@ use crate::sundials_sys::{ }; use crate::{ - error::*, matrix_error, ode_solver::sundials::sundials_check, op::NonLinearOp, scalar::scale, + error::*, matrix_error, ode_solver::sundials::sundials_check, scalar::scale, vector::sundials::SundialsVector, IndexType, Scale, SundialsLinearSolver, Vector, }; @@ -94,8 +94,7 @@ impl Display for SundialsMatrix { } impl DefaultSolver for SundialsMatrix { - type LS> = - SundialsLinearSolver; + type LS = SundialsLinearSolver; } impl MatrixCommon for SundialsMatrix { diff --git a/src/nonlinear_solver/convergence.rs b/src/nonlinear_solver/convergence.rs index 29732163..73d846cd 100644 --- a/src/nonlinear_solver/convergence.rs +++ b/src/nonlinear_solver/convergence.rs @@ -2,7 +2,7 @@ use nalgebra::ComplexField; use num_traits::{One, Pow}; use std::rc::Rc; -use crate::{scalar::IndexType, solver::SolverProblem, NonLinearOp, Scalar, Vector}; +use crate::{scalar::IndexType, Scalar, Vector}; #[derive(Clone)] pub struct Convergence { @@ -31,11 +31,6 @@ impl Convergence { pub fn niter(&self) -> IndexType { self.niter } - pub fn new_from_problem>(problem: &SolverProblem) -> Self { - let rtol = problem.rtol; - let atol = problem.atol.clone(); - Self::new(rtol, atol) - } pub fn new(rtol: V::T, atol: Rc) -> Self { let minimum_tol = V::T::from(10.0) * V::T::EPSILON / rtol; let maximum_tol = V::T::from(0.03); diff --git a/src/nonlinear_solver/mod.rs b/src/nonlinear_solver/mod.rs index 3d07719f..feb5e104 100644 --- a/src/nonlinear_solver/mod.rs +++ b/src/nonlinear_solver/mod.rs @@ -1,4 +1,6 @@ -use crate::{error::DiffsolError, op::Op, solver::SolverProblem}; +use std::rc::Rc; + +use crate::{error::DiffsolError, Matrix, NonLinearOp, NonLinearOpJacobian}; use convergence::Convergence; pub struct NonLinearSolveSolution { @@ -13,34 +15,52 @@ impl NonLinearSolveSolution { } /// A solver for the nonlinear problem `F(x) = 0`. -pub trait NonLinearSolver { - /// Get the problem to be solved. - fn problem(&self) -> &SolverProblem; - - fn convergence(&self) -> &Convergence; +pub trait NonLinearSolver: Default { + fn convergence(&self) -> &Convergence; - fn convergence_mut(&mut self) -> &mut Convergence; + fn convergence_mut(&mut self) -> &mut Convergence; /// Set the problem to be solved, any previous problem is discarded. - fn set_problem(&mut self, problem: &SolverProblem); + fn set_problem>( + &mut self, + op: &C, + rtol: M::T, + atol: Rc, + ); /// Reset the approximation of the Jacobian matrix. - fn reset_jacobian(&mut self, x: &C::V, t: C::T); + fn reset_jacobian>( + &mut self, + op: &C, + x: &M::V, + t: M::T, + ); // Solve the problem `F(x, t) = 0` for fixed t, and return the solution `x`. - fn solve(&mut self, x: &C::V, t: C::T, error_y: &C::V) -> Result { + fn solve>( + &mut self, + op: &C, + x: &M::V, + t: M::T, + error_y: &M::V, + ) -> Result { let mut x = x.clone(); - self.solve_in_place(&mut x, t, error_y)?; + self.solve_in_place(op, &mut x, t, error_y)?; Ok(x) } /// Solve the problem `F(x) = 0` in place. - fn solve_in_place(&mut self, x: &mut C::V, t: C::T, error_y: &C::V) - -> Result<(), DiffsolError>; + fn solve_in_place>( + &mut self, + op: &C, + x: &mut C::V, + t: C::T, + error_y: &C::V, + ) -> Result<(), DiffsolError>; /// Solve the linearised problem `J * x = b`, where `J` was calculated using [Self::reset_jacobian]. /// The input `b` is provided in `x`, and the solution is returned in `x`. - fn solve_linearised_in_place(&self, x: &mut C::V) -> Result<(), DiffsolError>; + fn solve_linearised_in_place(&self, x: &mut M::V) -> Result<(), DiffsolError>; } pub mod convergence; @@ -54,17 +74,18 @@ pub mod tests { use self::newton::NewtonNonlinearSolver; use crate::{ - linear_solver::nalgebra::lu::LU, - matrix::MatrixCommon, - op::{closure::Closure, NonLinearOp}, - scale, DenseMatrix, Vector, + linear_solver::nalgebra::lu::LU, matrix::MatrixCommon, op::closure::Closure, scale, + DenseMatrix, Vector, }; use super::*; use num_traits::{One, Zero}; + #[allow(clippy::type_complexity)] pub fn get_square_problem() -> ( - SolverProblem>, + impl NonLinearOpJacobian, + M::T, + Rc, Vec>, ) where @@ -90,27 +111,29 @@ pub mod tests { p, ); let rtol = M::T::from(1e-6); - let atol = M::V::from_vec(vec![1e-6.into(), 1e-6.into()]); - let problem = SolverProblem::new(Rc::new(op), Rc::new(atol), rtol); + let atol = Rc::new(M::V::from_vec(vec![1e-6.into(), 1e-6.into()])); let solns = vec![NonLinearSolveSolution::new( M::V::from_vec(vec![2.1.into(), 2.1.into()]), M::V::from_vec(vec![2.0.into(), 2.0.into()]), )]; - (problem, solns) + (op, rtol, atol, solns) } pub fn test_nonlinear_solver( - mut solver: impl NonLinearSolver, - problem: SolverProblem, + mut solver: impl NonLinearSolver, + op: C, + rtol: C::T, + atol: Rc, solns: Vec>, ) where - C: NonLinearOp, + C: NonLinearOpJacobian, { - solver.set_problem(&problem); + solver.set_problem(&op, rtol, atol.clone()); let t = C::T::zero(); + solver.reset_jacobian(&op, &solns[0].x0, t); for soln in solns { - let x = solver.solve(&soln.x0, t, &soln.x0).unwrap(); - let tol = x.clone() * scale(problem.rtol) + problem.atol.as_ref(); + let x = solver.solve(&op, &soln.x0, t, &soln.x0).unwrap(); + let tol = x.clone() * scale(rtol) + atol.as_ref(); x.assert_eq(&soln.x, &tol); } } @@ -120,8 +143,8 @@ pub mod tests { #[test] fn test_newton_cpu_square() { let lu = LU::default(); - let (prob, soln) = get_square_problem::(); + let (op, rtol, atol, soln) = get_square_problem::(); let s = NewtonNonlinearSolver::new(lu); - test_nonlinear_solver(s, prob, soln); + test_nonlinear_solver(s, op, rtol, atol, soln); } } diff --git a/src/nonlinear_solver/newton.rs b/src/nonlinear_solver/newton.rs index d66134bb..d56168d9 100644 --- a/src/nonlinear_solver/newton.rs +++ b/src/nonlinear_solver/newton.rs @@ -1,8 +1,9 @@ +use std::rc::Rc; + use crate::{ error::{DiffsolError, NonLinearSolverError}, - non_linear_solver_error, - op::NonLinearOp, - Convergence, ConvergenceStatus, LinearSolver, NonLinearSolver, SolverProblem, Vector, + non_linear_solver_error, Convergence, ConvergenceStatus, LinearSolver, Matrix, NonLinearOp, + NonLinearOpJacobian, NonLinearSolver, Vector, }; pub fn newton_iteration( @@ -35,80 +36,90 @@ pub fn newton_iteration( Err(non_linear_solver_error!(NewtonDidNotConverge)) } -pub struct NewtonNonlinearSolver> { - convergence: Option>, +pub struct NewtonNonlinearSolver> { + convergence: Option>, linear_solver: Ls, - problem: Option>, is_jacobian_set: bool, - tmp: C::V, + tmp: M::V, } -impl> NewtonNonlinearSolver { +impl> NewtonNonlinearSolver { pub fn new(linear_solver: Ls) -> Self { Self { - problem: None, convergence: None, linear_solver, is_jacobian_set: false, - tmp: C::V::zeros(0), + tmp: M::V::zeros(0), } } + pub fn linear_solver(&self) -> &Ls { + &self.linear_solver + } +} + +impl> Default for NewtonNonlinearSolver { + fn default() -> Self { + Self::new(Ls::default()) + } } -impl> NonLinearSolver for NewtonNonlinearSolver { - fn convergence(&self) -> &Convergence { +impl> NonLinearSolver for NewtonNonlinearSolver { + fn convergence(&self) -> &Convergence { self.convergence .as_ref() .expect("NewtonNonlinearSolver::convergence() called before set_problem") } - fn convergence_mut(&mut self) -> &mut Convergence { + fn convergence_mut(&mut self) -> &mut Convergence { self.convergence .as_mut() .expect("NewtonNonlinearSolver::convergence_mut() called before set_problem") } - fn problem(&self) -> &SolverProblem { - self.problem - .as_ref() - .expect("NewtonNonlinearSolver::problem() called before set_problem") - } - fn set_problem(&mut self, problem: &SolverProblem) { - self.problem = Some(problem.clone()); - self.linear_solver.set_problem(problem); - let problem = self.problem.as_ref().unwrap(); - self.convergence = Some(Convergence::new_from_problem(problem)); + fn set_problem>( + &mut self, + op: &C, + rtol: M::T, + atol: Rc, + ) { + self.linear_solver.set_problem(op, rtol, atol.clone()); + self.convergence = Some(Convergence::new(rtol, atol)); self.is_jacobian_set = false; - self.tmp = C::V::zeros(problem.f.nstates()); + self.tmp = C::V::zeros(op.nstates()); } - fn reset_jacobian(&mut self, x: &C::V, t: C::T) { - self.linear_solver.set_linearisation(x, t); + fn reset_jacobian>( + &mut self, + op: &C, + x: &C::V, + t: C::T, + ) { + self.linear_solver.set_linearisation(op, x, t); self.is_jacobian_set = true; } - fn solve_linearised_in_place(&self, x: &mut C::V) -> Result<(), DiffsolError> { + fn solve_linearised_in_place(&self, x: &mut M::V) -> Result<(), DiffsolError> { self.linear_solver.solve_in_place(x) } - fn solve_in_place( + fn solve_in_place>( &mut self, - xn: &mut C::V, - t: C::T, - error_y: &C::V, + op: &C, + xn: &mut M::V, + t: M::T, + error_y: &M::V, ) -> Result<(), DiffsolError> { - if self.convergence.is_none() || self.problem.is_none() { + if self.convergence.is_none() { panic!("NewtonNonlinearSolver::solve() called before set_problem"); } if !self.is_jacobian_set { - self.reset_jacobian(xn, t); + panic!("NewtonNonlinearSolver::solve_in_place() called before reset_jacobian"); } - if xn.len() != self.problem.as_ref().unwrap().f.nstates() { - panic!("NewtonNonlinearSolver::solve() called with state of wrong size, expected {}, got {}", self.problem.as_ref().unwrap().f.nstates(), xn.len()); + if xn.len() != op.nstates() { + panic!("NewtonNonlinearSolver::solve() called with state of wrong size, expected {}, got {}", op.nstates(), xn.len()); } let linear_solver = |x: &mut C::V| self.linear_solver.solve_in_place(x); - let problem = self.problem.as_ref().unwrap(); - let fun = |x: &C::V, y: &mut C::V| problem.f.call_inplace(x, t, y); + let fun = |x: &C::V, y: &mut C::V| op.call_inplace(x, t, y); let convergence = self.convergence.as_mut().unwrap(); newton_iteration(xn, &mut self.tmp, error_y, fun, linear_solver, convergence) } diff --git a/src/ode_solver/adjoint_equations.rs b/src/ode_solver/adjoint_equations.rs new file mode 100644 index 00000000..6dc6b74a --- /dev/null +++ b/src/ode_solver/adjoint_equations.rs @@ -0,0 +1,724 @@ +use num_traits::{One, Zero}; +use std::{cell::RefCell, ops::AddAssign, ops::SubAssign, rc::Rc}; + +use crate::{ + op::nonlinear_op::NonLinearOpJacobian, AugmentedOdeEquations, Checkpointing, ConstantOp, + ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, Matrix, NonLinearOp, NonLinearOpAdjoint, + NonLinearOpSensAdjoint, OdeEquations, OdeEquationsAdjoint, OdeSolverMethod, OdeSolverProblem, + Op, Vector, +}; + +pub struct AdjointContext +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + checkpointer: Checkpointing, + x: Eqn::V, + index: usize, + last_t: Option, + col: Eqn::V, +} + +impl AdjointContext +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + pub fn new(checkpointer: Checkpointing) -> Self { + let x = ::zeros(checkpointer.problem.eqn.rhs().nstates()); + let mut col = ::zeros(checkpointer.problem.eqn.out().unwrap().nout()); + let index = 0; + col[0] = Eqn::T::one(); + Self { + checkpointer, + x, + index, + col, + last_t: None, + } + } + + pub fn set_state(&mut self, t: Eqn::T) { + if let Some(last_t) = self.last_t { + if last_t == t { + return; + } + } + self.last_t = Some(t); + self.checkpointer.interpolate(t, &mut self.x).unwrap(); + } + + pub fn state(&self) -> &Eqn::V { + &self.x + } + + pub fn col(&self) -> &Eqn::V { + &self.col + } + + pub fn set_index(&mut self, index: usize) { + self.col[self.index] = Eqn::T::zero(); + self.index = index; + self.col[self.index] = Eqn::T::one(); + } +} + +pub struct AdjointMass +where + Eqn: OdeEquationsAdjoint, +{ + eqn: Rc, +} + +impl AdjointMass +where + Eqn: OdeEquationsAdjoint, +{ + pub fn new(eqn: &Rc) -> Self { + Self { eqn: eqn.clone() } + } +} + +impl Op for AdjointMass +where + Eqn: OdeEquationsAdjoint, +{ + type T = Eqn::T; + type V = Eqn::V; + type M = Eqn::M; + + fn nstates(&self) -> usize { + self.eqn.rhs().nstates() + } + fn nout(&self) -> usize { + self.eqn.rhs().nstates() + } + fn nparams(&self) -> usize { + self.eqn.rhs().nparams() + } +} + +impl LinearOp for AdjointMass +where + Eqn: OdeEquationsAdjoint, +{ + fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) { + self.eqn + .mass() + .unwrap() + .gemv_transpose_inplace(x, t, beta, y); + } + + fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { + self.eqn.mass().unwrap().transpose_inplace(t, y); + } +} + +pub struct AdjointInit +where + Eqn: OdeEquationsAdjoint, +{ + eqn: Rc, +} + +impl AdjointInit +where + Eqn: OdeEquationsAdjoint, +{ + pub fn new(eqn: &Rc) -> Self { + Self { eqn: eqn.clone() } + } +} + +impl Op for AdjointInit +where + Eqn: OdeEquationsAdjoint, +{ + type T = Eqn::T; + type V = Eqn::V; + type M = Eqn::M; + + fn nstates(&self) -> usize { + self.eqn.rhs().nstates() + } + fn nout(&self) -> usize { + self.eqn.rhs().nstates() + } + fn nparams(&self) -> usize { + self.eqn.rhs().nparams() + } +} + +impl ConstantOp for AdjointInit +where + Eqn: OdeEquationsAdjoint, +{ + fn call_inplace(&self, _t: Self::T, y: &mut Self::V) { + y.fill(Eqn::T::zero()); + } +} + +/// Right-hand side of the adjoint equations is: +/// +/// F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t) +/// +/// f_x is the partial derivative of the right-hand side with respect to the state vector. +/// g_x is the partial derivative of the functional g with respect to the state vector. +/// +/// We need the current state x(t), which is obtained from the checkpointed forward solve at the current time step. +pub struct AdjointRhs +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + eqn: Rc, + context: Rc>>, + tmp: RefCell, + with_out: bool, +} + +impl AdjointRhs +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + pub fn new( + eqn: &Rc, + context: Rc>>, + with_out: bool, + ) -> Self { + let tmp_n = if with_out { eqn.rhs().nstates() } else { 0 }; + let tmp = RefCell::new(::zeros(tmp_n)); + Self { + eqn: eqn.clone(), + context, + tmp, + with_out, + } + } +} + +impl Op for AdjointRhs +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + type T = Eqn::T; + type V = Eqn::V; + type M = Eqn::M; + + fn nstates(&self) -> usize { + self.eqn.rhs().nstates() + } + fn nout(&self) -> usize { + self.eqn.rhs().nstates() + } + fn nparams(&self) -> usize { + self.eqn.rhs().nparams() + } + fn sparsity(&self) -> Option<::SparsityRef<'_>> { + self.eqn.rhs().sparsity_adjoint() + } +} + +impl NonLinearOp for AdjointRhs +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + /// F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t) + fn call_inplace(&self, lambda: &Self::V, t: Self::T, y: &mut Self::V) { + self.context.borrow_mut().set_state(t); + let context = self.context.borrow(); + let x = context.state(); + + // y = -f^T_x(x, t) λ + self.eqn.rhs().jac_transpose_mul_inplace(x, t, lambda, y); + + // y = -f^T_x(x, t) λ - g^T_x(x,t) + if self.with_out { + let col = context.col(); + let mut tmp = self.tmp.borrow_mut(); + self.eqn + .out() + .unwrap() + .jac_transpose_mul_inplace(x, t, col, &mut tmp); + y.add_assign(&*tmp); + } + } +} + +impl NonLinearOpJacobian for AdjointRhs +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + // J = -f^T_x(x, t) + fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { + self.context.borrow_mut().set_state(t); + let context = self.context.borrow(); + let x = context.state(); + self.eqn.rhs().jac_transpose_mul_inplace(x, t, v, y); + } + fn jacobian_inplace(&self, _x: &Self::V, t: Self::T, y: &mut Self::M) { + self.context.borrow_mut().set_state(t); + let context = self.context.borrow(); + let x = context.state(); + self.eqn.rhs().adjoint_inplace(x, t, y); + } +} + +/// Output of the adjoint equations is: +/// +/// F(λ, x, t) = -g_p^T(x, t) - f_p^T(x, t) λ +/// +/// f_p is the partial derivative of the right-hand side with respect to the parameter vector +/// g_p is the partial derivative of the functional g with respect to the parameter vector +/// +/// We need the current state x(t), which is obtained from the checkpointed forward solve at the current time step. +pub struct AdjointOut +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + eqn: Rc, + context: Rc>>, + tmp: RefCell, + with_out: bool, +} + +impl AdjointOut +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + pub fn new( + eqn: &Rc, + context: Rc>>, + with_out: bool, + ) -> Self { + let tmp_n = if with_out { eqn.rhs().nparams() } else { 0 }; + let tmp = RefCell::new(::zeros(tmp_n)); + Self { + eqn: eqn.clone(), + context, + tmp, + with_out, + } + } +} + +impl Op for AdjointOut +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + type T = Eqn::T; + type V = Eqn::V; + type M = Eqn::M; + + fn nstates(&self) -> usize { + self.eqn.rhs().nstates() + } + fn nout(&self) -> usize { + self.eqn.rhs().nparams() + } + fn nparams(&self) -> usize { + self.eqn.rhs().nparams() + } + fn sparsity(&self) -> Option<::SparsityRef<'_>> { + self.eqn.rhs().sparsity_sens_adjoint() + } +} + +impl NonLinearOp for AdjointOut +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + /// F(λ, x, t) = -g_p(x, t) - λ^T f_p(x, t) + fn call_inplace(&self, lambda: &Self::V, t: Self::T, y: &mut Self::V) { + self.context.borrow_mut().set_state(t); + let context = self.context.borrow(); + let x = context.state(); + self.eqn.rhs().sens_transpose_mul_inplace(x, t, lambda, y); + + if self.with_out { + let col = context.col(); + let mut tmp = self.tmp.borrow_mut(); + self.eqn + .out() + .unwrap() + .sens_transpose_mul_inplace(x, t, col, &mut tmp); + y.add_assign(&*tmp); + } + } +} + +impl NonLinearOpJacobian for AdjointOut +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + // J = -f_p(x, t) + fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { + self.context.borrow_mut().set_state(t); + let context = self.context.borrow(); + let x = context.state(); + self.eqn.rhs().sens_transpose_mul_inplace(x, t, v, y); + } + fn jacobian_inplace(&self, _x: &Self::V, t: Self::T, y: &mut Self::M) { + self.context.borrow_mut().set_state(t); + let context = self.context.borrow(); + let x = context.state(); + self.eqn.rhs().sens_adjoint_inplace(x, t, y); + } +} + +/// Adjoint equations for ODEs +/// +/// M * dλ/dt = -f^T_x(x, t) λ - g^T_x(x,t) +/// λ(T) = 0 +/// g(λ, x, t) = -g_p(x, t) - λ^T f_p(x, t) +/// +pub struct AdjointEquations +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + eqn: Rc, + rhs: Rc>, + out: Option>>, + mass: Option>>, + context: Rc>>, + tmp: RefCell, + tmp2: RefCell, + init: Rc>, + atol: Option>, + rtol: Option, + out_rtol: Option, + out_atol: Option>, +} + +impl AdjointEquations +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + pub(crate) fn new( + problem: &OdeSolverProblem, + context: Rc>>, + with_out: bool, + ) -> Self { + let eqn = problem.eqn.clone(); + let rhs = Rc::new(AdjointRhs::new(&eqn, context.clone(), with_out)); + let init = Rc::new(AdjointInit::new(&eqn)); + let out = if with_out { + Some(Rc::new(AdjointOut::new(&eqn, context.clone(), with_out))) + } else { + None + }; + let tmp = if with_out { + RefCell::new(::zeros(0)) + } else { + RefCell::new(::zeros(eqn.rhs().nparams())) + }; + let tmp2 = if with_out { + RefCell::new(::zeros(0)) + } else { + RefCell::new(::zeros(eqn.rhs().nstates())) + }; + let atol = if with_out { + problem.sens_atol.clone() + } else { + None + }; + let rtol = if with_out { problem.sens_rtol } else { None }; + let out_atol = if with_out { + problem.out_atol.clone() + } else { + None + }; + let out_rtol = if with_out { problem.out_rtol } else { None }; + let mass = eqn.mass().map(|_m| Rc::new(AdjointMass::new(&eqn))); + Self { + rhs, + init, + mass, + context, + out, + tmp, + tmp2, + eqn, + atol, + rtol, + out_rtol, + out_atol, + } + } + + pub fn correct_sg_for_init(&self, t: Eqn::T, s: &[Eqn::V], sg: &mut [Eqn::V]) { + let mut tmp = self.tmp.borrow_mut(); + for (s_i, sg_i) in s.iter().zip(sg.iter_mut()) { + if let Some(mass) = self.eqn.mass() { + let mut tmp2 = self.tmp2.borrow_mut(); + mass.call_transpose_inplace(s_i, t, &mut tmp2); + self.eqn + .init() + .sens_mul_transpose_inplace(t, &tmp2, &mut tmp); + sg_i.sub_assign(&*tmp); + } else { + self.eqn.init().sens_mul_transpose_inplace(t, s_i, &mut tmp); + sg_i.sub_assign(&*tmp); + } + } + } +} + +impl std::fmt::Debug for AdjointEquations +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AdjointEquations").finish() + } +} + +impl Op for AdjointEquations +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + type T = Eqn::T; + type V = Eqn::V; + type M = Eqn::M; + + fn nstates(&self) -> usize { + self.eqn.rhs().nstates() + } + fn nout(&self) -> usize { + self.eqn.rhs().nout() + } + fn nparams(&self) -> usize { + self.eqn.rhs().nparams() + } +} + +impl OdeEquations for AdjointEquations +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + type T = Eqn::T; + type V = Eqn::V; + type M = Eqn::M; + type Rhs = AdjointRhs; + type Mass = AdjointMass; + type Root = Eqn::Root; + type Init = AdjointInit; + type Out = AdjointOut; + + fn rhs(&self) -> &Rc { + &self.rhs + } + fn mass(&self) -> Option<&Rc> { + self.mass.as_ref() + } + fn root(&self) -> Option<&Rc> { + None + } + fn init(&self) -> &Rc { + &self.init + } + fn set_params(&mut self, _p: Self::V) { + panic!("Not implemented for SensEquations"); + } + fn out(&self) -> Option<&Rc> { + self.out.as_ref() + } +} + +impl AugmentedOdeEquations> + for AdjointEquations +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + fn include_in_error_control(&self) -> bool { + self.atol.is_some() && self.rtol.is_some() + } + fn include_out_in_error_control(&self) -> bool { + self.out().is_some() && self.out_atol.is_some() && self.out_rtol.is_some() + } + + fn atol(&self) -> Option<&Rc> { + self.atol.as_ref() + } + fn out_atol(&self) -> Option<&Rc> { + self.out_atol.as_ref() + } + fn out_rtol(&self) -> Option { + self.out_rtol + } + fn rtol(&self) -> Option { + self.rtol + } + + fn max_index(&self) -> usize { + self.eqn.out().map(|o| o.nout()).unwrap_or(0) + } + + fn set_index(&mut self, index: usize) { + self.context.borrow_mut().set_index(index); + } + + fn update_rhs_out_state(&mut self, _y: &Eqn::V, _dy: &Eqn::V, _t: Eqn::T) {} + + fn update_init_state(&mut self, _t: ::T) {} +} + +#[cfg(test)] +mod tests { + use std::{cell::RefCell, rc::Rc}; + + use crate::{ + ode_solver::{ + adjoint_equations::AdjointEquations, + test_models::exponential_decay::exponential_decay_problem_adjoint, + }, + AdjointContext, AugmentedOdeEquations, Checkpointing, FaerSparseLU, Matrix, MatrixCommon, + NalgebraLU, NonLinearOp, NonLinearOpJacobian, OdeSolverMethod, Sdirk, SdirkState, + SparseColMat, Tableau, Vector, + }; + type Mcpu = nalgebra::DMatrix; + type Vcpu = nalgebra::DVector; + + #[test] + fn test_rhs_exponential() { + // dy/dt = -ay (p = [a]) + // a = 0.1 + let (problem, _soln) = exponential_decay_problem_adjoint::(); + let mut solver = Sdirk::::new(Tableau::esdirk34(), NalgebraLU::default()); + let state = SdirkState { + t: 0.0, + y: Vcpu::from_vec(vec![1.0, 1.0]), + dy: Vcpu::from_vec(vec![1.0, 1.0]), + g: Vcpu::zeros(0), + dg: Vcpu::zeros(0), + sg: Vec::new(), + dsg: Vec::new(), + s: Vec::new(), + ds: Vec::new(), + h: 0.0, + }; + solver.set_problem(state.clone(), &problem).unwrap(); + let checkpointer = Checkpointing::new(solver, 0, vec![state.clone(), state.clone()], None); + let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer))); + let adj_eqn = AdjointEquations::new(&problem, context.clone(), false); + // F(λ, x, t) = -f^T_x(x, t) λ + // f_x = |-a 0| + // |0 -a| + // F(s, t)_0 = |a 0| |1| = |a| = |0.1| + // |0 a| |2| |2a| = |0.2| + let v = Vcpu::from_vec(vec![1.0, 2.0]); + let f = adj_eqn.rhs.call(&v, state.t); + let f_expect = Vcpu::from_vec(vec![0.1, 0.2]); + f.assert_eq_st(&f_expect, 1e-10); + + let mut adj_eqn = AdjointEquations::new(&problem, context.clone(), true); + + // f_x^T = |-a 0| + // |0 -a| + // J = -f_x^T + let adjoint = adj_eqn.rhs.jacobian(&state.y, state.t); + assert_eq!(adjoint.nrows(), 2); + assert_eq!(adjoint.ncols(), 2); + assert_eq!(adjoint[(0, 0)], 0.1); + assert_eq!(adjoint[(1, 1)], 0.1); + + // g_x = |1 2| + // |3 4| + // S = -g^T_x(x,t) + // so S = |-1 -3| + // |-2 -4| + + // f_p^T = |-x_1 -x_2 | + // |0 0 | + // g_p = |0 0| + // |0 0| + // g(λ, x, t) = -g_p(x, t) - λ^T f_p(x, t) + // = |1 1| |1| + |0| = |3| + // |0 0| |2| |0| = |0| + adj_eqn.set_index(0); + let out = adj_eqn.out.unwrap().call(&v, state.t); + let out_expect = Vcpu::from_vec(vec![3.0, 0.0]); + out.assert_eq_st(&out_expect, 1e-10); + + // F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t) + // f_x = |-a 0| + // |0 -a| + // F(s, t)_0 = |a 0| |1| - |1.0| = | a - 1| = |-0.9| + // |0 a| |2| |2.0| |2a - 2| = |-1.8| + let f = adj_eqn.rhs.call(&v, state.t); + let f_expect = Vcpu::from_vec(vec![-0.9, -1.8]); + f.assert_eq_st(&f_expect, 1e-10); + } + + #[test] + fn test_rhs_exponential_sparse() { + // dy/dt = -ay (p = [a]) + // a = 0.1 + let (problem, _soln) = exponential_decay_problem_adjoint::>(); + let mut solver = + Sdirk::, _, _>::new(Tableau::esdirk34(), FaerSparseLU::default()); + let state = SdirkState { + t: 0.0, + y: faer::Col::from_vec(vec![1.0, 1.0]), + dy: faer::Col::from_vec(vec![1.0, 1.0]), + g: faer::Col::zeros(0), + dg: faer::Col::zeros(0), + sg: Vec::new(), + dsg: Vec::new(), + s: Vec::new(), + ds: Vec::new(), + h: 0.0, + }; + solver.set_problem(state.clone(), &problem).unwrap(); + let checkpointer = Checkpointing::new(solver, 0, vec![state.clone(), state.clone()], None); + let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer))); + let mut adj_eqn = AdjointEquations::new(&problem, context, true); + + // f_x^T = |-a 0| + // |0 -a| + // J = -f_x^T + let adjoint = adj_eqn.rhs.jacobian(&state.y, state.t); + assert_eq!(adjoint.nrows(), 2); + assert_eq!(adjoint.ncols(), 2); + for (i, j, v) in adjoint.triplet_iter() { + if i == j { + assert_eq!(*v, 0.1); + } else { + assert_eq!(*v, 0.0); + } + } + + // g_x = |1 2| + // |3 4| + // S = -g^T_x(x,t) + // so S = |-1 -3| + // |-2 -4| + + // F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t) + // f_x = |-a 0| + // |0 -a| + // F(s, t)_0 = |a 0| |1| - |1.0| = |a - 1| = |-0.9| + // |0 a| |2| |2.0| |2a - 2| = |-1.8| + adj_eqn.set_index(0); + let v = faer::Col::from_vec(vec![1.0, 2.0]); + let f = adj_eqn.rhs.call(&v, state.t); + let f_expect = faer::Col::from_vec(vec![-0.9, -1.8]); + f.assert_eq_st(&f_expect, 1e-10); + } +} diff --git a/src/ode_solver/bdf.rs b/src/ode_solver/bdf.rs index ce1c1798..994e8ca4 100644 --- a/src/ode_solver/bdf.rs +++ b/src/ode_solver/bdf.rs @@ -2,26 +2,33 @@ use nalgebra::ComplexField; use std::ops::AddAssign; use std::rc::Rc; -use crate::error::{DiffsolError, OdeSolverError}; +use crate::{ + error::{DiffsolError, OdeSolverError}, + AdjointEquations, NoAug, OdeEquationsAdjoint, OdeEquationsSens, SensEquations, StateRef, + StateRefMut, +}; use num_traits::{abs, One, Pow, Zero}; use serde::Serialize; +use crate::ode_solver_error; use crate::{ matrix::{default_solver::DefaultSolver, MatrixRef}, - newton_iteration, nonlinear_solver::root::RootFinder, op::bdf::BdfCallable, scalar::scale, vector::DefaultDenseMatrix, - BdfState, DenseMatrix, IndexType, JacobianUpdate, MatrixViewMut, NewtonNonlinearSolver, - NonLinearSolver, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Op, - Scalar, SolverProblem, Vector, VectorRef, VectorView, VectorViewMut, + AugmentedOdeEquations, BdfState, DenseMatrix, IndexType, JacobianUpdate, MatrixViewMut, + NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquationsImplicit, OdeSolverMethod, + OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Op, Scalar, Vector, VectorRef, + VectorView, VectorViewMut, }; -use crate::{ode_solver_error, NonLinearOp, SensEquations}; -use super::equations::OdeEquations; use super::jacobian_update::SolverState; +use super::{ + equations::OdeEquations, + method::{AdjointOdeSolverMethod, AugmentedOdeSolverMethod, SensitivitiesOdeSolverMethod}, +}; #[derive(Clone, Debug, Serialize, Default)] pub struct BdfStatistics { @@ -32,6 +39,29 @@ pub struct BdfStatistics { pub number_of_nonlinear_solver_fails: usize, } +pub type BdfSens = Bdf>; +pub type BdfAdj = + Bdf>, Nls, AdjointEquations>>; +impl SensitivitiesOdeSolverMethod for BdfSens +where + Eqn: OdeEquationsSens, + M: DenseMatrix, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, + Nls: NonLinearSolver, +{ +} + +// notes quadrature. +// ndf formula rearranged to [2]: +// (1 - kappa) * gamma_k * (y_{n+1} - y^0_{n+1}) + (\sum_{m=1}^k gamma_m * y^m_n) - h * F(t_{n+1}, y_{n+1}) = 0 (1) +// where d = y_{n+1} - y^0_{n+1} +// and y^0_{n+1} = \sum_{m=0}^k y^m_n +// +// 1. use (1) to calculate d explicitly +// 2. use d to update the differences matrix +// 3. use d to calculate the predicted solution y_{n+1} + /// Implements a Backward Difference formula (BDF) implicit multistep integrator. /// /// The basic algorithm is derived in \[1\]. This @@ -50,19 +80,25 @@ pub struct BdfStatistics { /// \[3\] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, T., Cournapeau, D., ... & Van Mulbregt, P. (2020). SciPy 1.0: fundamental algorithms for scientific computing in Python. Nature methods, 17(3), 261-272. pub struct Bdf< M: DenseMatrix, - Eqn: OdeEquations, - Nls: NonLinearSolver>, + Eqn: OdeEquationsImplicit, + Nls: NonLinearSolver, + AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit = NoAug, > { nonlinear_solver: Nls, ode_problem: Option>, + op: Option>, n_equal_steps: usize, y_delta: Eqn::V, + g_delta: Eqn::V, y_predict: Eqn::V, t_predict: Eqn::T, s_predict: Eqn::V, - s_op: Option>>, + s_op: Option>, s_deltas: Vec, + sg_deltas: Vec, diff_tmp: M, + gdiff_tmp: M, + sgdiff_tmp: M, u: M, alpha: Vec, gamma: Vec, @@ -79,10 +115,11 @@ impl Default for Bdf< ::M, Eqn, - NewtonNonlinearSolver, ::LS>>, + NewtonNonlinearSolver::LS>, + NoAug, > where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, Eqn::M: DefaultSolver, Eqn::V: DefaultDenseMatrix, for<'b> &'b Eqn::V: VectorRef, @@ -95,11 +132,35 @@ where } } -impl, Eqn: OdeEquations, Nls> Bdf +impl + Bdf< + ::M, + Eqn, + NewtonNonlinearSolver::LS>, + SensEquations, + > +where + Eqn: OdeEquationsSens, + Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + pub fn with_sensitivities() -> Self { + let linear_solver = Eqn::M::default_solver(); + let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); + Self::new(nonlinear_solver) + } +} + +impl Bdf where + AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit, + Eqn: OdeEquationsImplicit, + M: DenseMatrix, for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, - Nls: NonLinearSolver>, + Nls: NonLinearSolver, { const NEWTON_MAXITER: IndexType = 4; const MIN_FACTOR: f64 = 0.5; @@ -138,15 +199,20 @@ where Self { s_op: None, + op: None, ode_problem: None, nonlinear_solver, n_equal_steps: 0, diff_tmp: M::zeros(n, max_order + 3), + gdiff_tmp: M::zeros(n, max_order + 3), + sgdiff_tmp: M::zeros(n, max_order + 3), y_delta: Eqn::V::zeros(n), y_predict: Eqn::V::zeros(n), t_predict: Eqn::T::zero(), s_predict: Eqn::V::zeros(n), s_deltas: Vec::new(), + sg_deltas: Vec::new(), + g_delta: Eqn::V::zeros(n), gamma, alpha, error_const2, @@ -164,10 +230,6 @@ where &self.statistics } - fn nonlinear_problem_op(&self) -> &Rc> { - &self.nonlinear_solver.problem().f - } - fn _compute_r(order: usize, factor: Eqn::T) -> M { //computes the R matrix with entries //given by the first equation on page 8 of [1] @@ -200,12 +262,14 @@ where //let y = &self.y_predict; //let t = self.t_predict; if self.jacobian_update.check_rhs_jacobian_update(c, &state) { - self.nonlinear_solver.problem().f.set_jacobian_is_stale(); - self.nonlinear_solver.reset_jacobian(y, t); + self.op.as_mut().unwrap().set_jacobian_is_stale(); + self.nonlinear_solver + .reset_jacobian(self.op.as_ref().unwrap(), y, t); self.jacobian_update.update_rhs_jacobian(); self.jacobian_update.update_jacobian(c); } else if self.jacobian_update.check_jacobian_update(c, &state) { - self.nonlinear_solver.reset_jacobian(y, t); + self.nonlinear_solver + .reset_jacobian(self.op.as_ref().unwrap(), y, t); self.jacobian_update.update_jacobian(c); } } @@ -227,17 +291,18 @@ where { let state = self.state.as_mut().unwrap(); Self::_update_diff_for_step_size(&ru, &mut state.diff, &mut self.diff_tmp, order); - for i in 0..state.sdiff.len() { - Self::_update_diff_for_step_size( - &ru, - &mut state.sdiff[i], - &mut self.diff_tmp, - order, - ); + for diff in state.sdiff.iter_mut() { + Self::_update_diff_for_step_size(&ru, diff, &mut self.diff_tmp, order); + } + if self.ode_problem.as_ref().unwrap().integrate_out { + Self::_update_diff_for_step_size(&ru, &mut state.gdiff, &mut self.gdiff_tmp, order); + } + for diff in state.sgdiff.iter_mut() { + Self::_update_diff_for_step_size(&ru, diff, &mut self.sgdiff_tmp, order); } } - self.nonlinear_problem_op().set_c(new_h, self.alpha[order]); + self.op.as_mut().unwrap().set_c(new_h, self.alpha[order]); self.state.as_mut().unwrap().h = new_h; @@ -262,29 +327,69 @@ where std::mem::swap(diff, diff_tmp); } - fn _update_sens_step_size(&mut self, factor: Eqn::T) { - //If step size h is changed then also need to update the terms in - //the first equation of page 9 of [1]: - // - //- constant c = h / (1-kappa) gamma_k term - //- lu factorisation of (M - c * J) used in newton iteration (same equation) + fn calculate_output_delta(&mut self) { + // integrate output function + let state = self.state.as_mut().unwrap(); + let out = self.ode_problem.as_ref().unwrap().eqn.out().unwrap(); + out.call_inplace(&self.y_predict, self.t_predict, &mut state.dg); + self.op.as_ref().unwrap().integrate_out( + &state.dg, + &state.gdiff, + self.gamma.as_slice(), + self.alpha.as_slice(), + state.order, + &mut self.g_delta, + ); + } - // update D using equations in section 3.2 of [1] - let order = self.state.as_ref().unwrap().order; - let r = Self::_compute_r(order, factor); - let ru = r.mat_mul(&self.u); + fn calculate_sens_output_delta(&mut self, i: usize) { let state = self.state.as_mut().unwrap(); - for sdiff in state.sdiff.iter_mut() { - Self::_update_diff_for_step_size(&ru, sdiff, &mut self.diff_tmp, order); - } + let op = self.s_op.as_ref().unwrap(); + + // integrate sensitivity output equations + let out = op.eqn().out().unwrap(); + out.call_inplace(&state.s[i], self.t_predict, &mut state.dsg[i]); + self.op.as_ref().unwrap().integrate_out( + &state.dsg[i], + &state.sgdiff[i], + self.gamma.as_slice(), + self.alpha.as_slice(), + state.order, + &mut self.sg_deltas[i], + ); } - fn update_differences(&mut self) { + fn update_differences_and_integrate_out(&mut self) { let order = self.state.as_ref().unwrap().order; let state = self.state.as_mut().unwrap(); + + // update differences Self::_update_diff(order, &self.y_delta, &mut state.diff); - for i in 0..state.sdiff.len() { - Self::_update_diff(order, &self.s_deltas[i], &mut state.sdiff[i]); + + // integrate output function + if self.ode_problem.as_ref().unwrap().integrate_out { + Self::_predict_using_diff(&mut state.g, &state.gdiff, order); + state.g.axpy(Eqn::T::one(), &self.g_delta, Eqn::T::one()); + + // update output difference + Self::_update_diff(order, &self.g_delta, &mut state.gdiff); + } + + // do the same for sensitivities + if self.s_op.is_some() { + for i in 0..self.s_op.as_ref().unwrap().eqn().max_index() { + // update sensitivity differences + Self::_update_diff(order, &self.s_deltas[i], &mut state.sdiff[i]); + + // integrate sensitivity output equations + if self.s_op.as_ref().unwrap().eqn().out().is_some() { + Self::_predict_using_diff(&mut state.sg[i], &state.sgdiff[i], order); + state.sg[i].axpy(Eqn::T::one(), &self.sg_deltas[i], Eqn::T::one()); + + // update sensitivity output difference + Self::_update_diff(order, &self.sg_deltas[i], &mut state.sgdiff[i]); + } + } } } @@ -321,7 +426,7 @@ where Self::_predict_using_diff(&mut self.y_predict, &state.diff, state.order); // update psi and c (h, D, y0 has changed) - self.nonlinear_problem_op().set_psi_and_y0( + self.op.as_mut().unwrap().set_psi_and_y0( &state.diff, self.gamma.as_slice(), self.alpha.as_slice(), @@ -344,7 +449,9 @@ where if abs(state.t - tstop) <= troundoff { self.tstop = None; return Ok(Some(OdeSolverStopReason::TstopReached)); - } else if tstop < state.t - troundoff { + } else if (state.h > M::T::zero() && tstop < state.t - troundoff) + || (state.h < M::T::zero() && tstop > state.t + troundoff) + { let error = OdeSolverError::StopTimeBeforeCurrentTime { stop_time: self.tstop.unwrap().into(), state_time: state.t.into(), @@ -355,7 +462,9 @@ where } // check if the next step will be beyond tstop, if so adjust the step size - if state.t + state.h > tstop + troundoff { + if (state.h > M::T::zero() && state.t + state.h > tstop + troundoff) + || (state.h < M::T::zero() && state.t + state.h < tstop - troundoff) + { let factor = (tstop - state.t) / state.h; // update step size ignoring the possible "step size too small" error _ = self._update_step_size(factor); @@ -368,7 +477,27 @@ where self.state .as_mut() .unwrap() - .initialise_diff_to_first_order(self.ode_problem.as_ref().unwrap().eqn_sens.is_some()); + .initialise_diff_to_first_order(); + + if self.ode_problem.as_ref().unwrap().integrate_out { + self.state + .as_mut() + .unwrap() + .initialise_gdiff_to_first_order(); + } + if self.s_op.is_some() { + self.state + .as_mut() + .unwrap() + .initialise_sdiff_to_first_order(); + if self.s_op.as_ref().unwrap().eqn().out().is_some() { + self.state + .as_mut() + .unwrap() + .initialise_sgdiff_to_first_order(); + } + } + self.u = Self::_compute_r(1, Eqn::T::one()); self.is_state_modified = false; } @@ -386,44 +515,148 @@ where order_summation } - fn sensitivity_solve( - &mut self, - t_new: Eqn::T, - mut error_norm: Eqn::T, - ) -> Result { - let h = self.state.as_ref().unwrap().h; - let order = self.state.as_ref().unwrap().order; + fn error_control(&self) -> Eqn::T { + let state = self.state.as_ref().unwrap(); + let order = state.order; + let output_in_error_control = self.ode_problem.as_ref().unwrap().output_in_error_control(); + let integrate_sens = self.s_op.is_some(); + let sens_in_error_control = + integrate_sens && self.s_op.as_ref().unwrap().eqn().include_in_error_control(); + let integrate_sens_out = + integrate_sens && self.s_op.as_ref().unwrap().eqn().out().is_some(); + let sens_output_in_error_control = integrate_sens_out + && self + .s_op + .as_ref() + .unwrap() + .eqn() + .include_out_in_error_control(); + + let atol = self.ode_problem.as_ref().unwrap().atol.as_ref(); + let rtol = self.ode_problem.as_ref().unwrap().rtol; + let mut error_norm = + self.y_delta.squared_norm(&state.y, atol, rtol) * self.error_const2[order - 1]; + let mut ncontrib = 1; + if output_in_error_control { + let rtol = self.ode_problem.as_ref().unwrap().out_rtol.unwrap(); + let atol = self + .ode_problem + .as_ref() + .unwrap() + .out_atol + .as_ref() + .unwrap(); + error_norm += + self.g_delta.squared_norm(&state.g, atol, rtol) * self.error_const2[order]; + ncontrib += 1; + } + if sens_in_error_control { + let sens_atol = self.s_op.as_ref().unwrap().eqn().atol().unwrap(); + let sens_rtol = self.s_op.as_ref().unwrap().eqn().rtol().unwrap(); + for i in 0..state.sdiff.len() { + error_norm += self.s_deltas[i].squared_norm(&state.s[i], sens_atol, sens_rtol) + * self.error_const2[order]; + } + ncontrib += state.sdiff.len(); + } + if sens_output_in_error_control { + let rtol = self.s_op.as_ref().unwrap().eqn().out_rtol().unwrap(); + let atol = self.s_op.as_ref().unwrap().eqn().out_atol().unwrap(); + for i in 0..state.sgdiff.len() { + error_norm += self.sg_deltas[i].squared_norm(&state.sg[i], atol, rtol) + * self.error_const2[order]; + } + ncontrib += state.sgdiff.len(); + } + error_norm / Eqn::T::from(ncontrib as f64) + } - // update for new state - { - let dy_new = self.nonlinear_problem_op().as_ref().tmp(); - let y_new = &self.y_predict; - self.ode_problem + fn predict_error_control(&self, order: usize) -> Eqn::T { + let state = self.state.as_ref().unwrap(); + let output_in_error_control = self.ode_problem.as_ref().unwrap().output_in_error_control(); + let integrate_sens = self.s_op.is_some(); + let sens_in_error_control = + integrate_sens && self.s_op.as_ref().unwrap().eqn().include_in_error_control(); + let integrate_sens_out = + integrate_sens && self.s_op.as_ref().unwrap().eqn().out().is_some(); + let sens_output_in_error_control = integrate_sens_out + && self + .s_op .as_ref() .unwrap() - .eqn_sens + .eqn() + .include_out_in_error_control(); + + let atol = self.ode_problem.as_ref().unwrap().atol.as_ref(); + let rtol = self.ode_problem.as_ref().unwrap().rtol; + let mut error_norm = state + .diff + .column(order + 1) + .squared_norm(&state.y, atol, rtol) + * self.error_const2[order]; + let mut ncontrib = 1; + if output_in_error_control { + let rtol = self.ode_problem.as_ref().unwrap().out_rtol.unwrap(); + let atol = self + .ode_problem .as_ref() .unwrap() - .rhs() - .update_state(y_new, &dy_new, t_new); + .out_atol + .as_ref() + .unwrap(); + error_norm += state + .gdiff + .column(order + 1) + .squared_norm(&state.g, atol, rtol) + * self.error_const2[order]; + ncontrib += 1; } + if sens_in_error_control { + let sens_atol = self.s_op.as_ref().unwrap().eqn().atol().unwrap(); + let sens_rtol = self.s_op.as_ref().unwrap().eqn().rtol().unwrap(); + for i in 0..state.sdiff.len() { + error_norm += state.sdiff[i].column(order + 1).squared_norm( + &state.s[i], + sens_atol, + sens_rtol, + ) * self.error_const2[order]; + } + } + if sens_output_in_error_control { + let rtol = self.s_op.as_ref().unwrap().eqn().out_rtol().unwrap(); + let atol = self.s_op.as_ref().unwrap().eqn().out_atol().unwrap(); + for i in 0..state.sgdiff.len() { + error_norm += + state.sgdiff[i] + .column(order + 1) + .squared_norm(&state.sg[i], atol, rtol) + * self.error_const2[order]; + } + } + error_norm / Eqn::T::from(ncontrib as f64) + } - // reuse linear solver from nonlinear solver - let ls = |x: &mut Eqn::V| -> Result<(), DiffsolError> { - self.nonlinear_solver.solve_linearised_in_place(x) - }; + fn sensitivity_solve(&mut self, t_new: Eqn::T) -> Result<(), DiffsolError> { + let h = self.state.as_ref().unwrap().h; + let order = self.state.as_ref().unwrap().order; + let op = self.s_op.as_mut().unwrap(); - // construct bdf discretisation of sensitivity equations - let op = self.s_op.as_ref().unwrap(); - op.set_c(h, self.alpha[order]); + // update for new state + { + let dy_new = self.op.as_ref().unwrap().tmp(); + let y_new = &self.y_predict; + Rc::get_mut(op.eqn_mut()) + .unwrap() + .update_rhs_out_state(y_new, &dy_new, t_new); + + // construct bdf discretisation of sensitivity equations + op.set_c(h, self.alpha[order]); + } // solve for sensitivities equations discretised using BDF - let fun = |x: &Eqn::V, y: &mut Eqn::V| op.call_inplace(x, t_new, y); - let rtol = self.problem().as_ref().unwrap().rtol; - let atol = self.problem().as_ref().unwrap().atol.clone(); - let mut convergence = self.nonlinear_solver.convergence().clone(); - let nparams = self.problem().as_ref().unwrap().eqn.rhs().nparams(); - for i in 0..nparams { + let naug = op.eqn().max_index(); + for i in 0..naug { + let op = self.s_op.as_mut().unwrap(); // setup { let state = self.state.as_ref().unwrap(); @@ -438,44 +671,36 @@ where order, &self.s_predict, ); - op.eqn().as_ref().rhs().set_param_index(i); + Rc::get_mut(op.eqn_mut()).unwrap().set_index(i); } // solve { let s_new = &mut self.state.as_mut().unwrap().s[i]; s_new.copy_from(&self.s_predict); - newton_iteration( - s_new, - &mut self.s_deltas[i], - &self.s_predict, - fun, - ls, - &mut convergence, - )?; - self.statistics.number_of_nonlinear_solver_iterations += convergence.niter(); + self.nonlinear_solver + .solve_in_place(&*op, s_new, t_new, &self.s_predict)?; + self.statistics.number_of_nonlinear_solver_iterations += + self.nonlinear_solver.convergence().niter(); let s_new = &*s_new; self.s_deltas[i].copy_from(s_new); self.s_deltas[i] -= &self.s_predict; } - let s_new = &self.state.as_ref().unwrap().s[i]; - - if self.problem().as_ref().unwrap().sens_error_control { - error_norm += self.s_deltas[i].squared_norm(s_new, atol.as_ref(), rtol); + if op.eqn().out().is_some() && op.eqn().include_out_in_error_control() { + self.calculate_sens_output_delta(i); } } - if self.problem().as_ref().unwrap().sens_error_control { - error_norm /= Eqn::T::from(nparams as f64 + 1.0); - } - Ok(error_norm) + Ok(()) } } -impl, Eqn: OdeEquations, Nls> OdeSolverMethod - for Bdf +impl OdeSolverMethod for Bdf where - Nls: NonLinearSolver>, + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit, + M: DenseMatrix, + Nls: NonLinearSolver, for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, { @@ -509,6 +734,30 @@ where )) } + fn interpolate_out(&self, t: Eqn::T) -> Result { + // state must be set + let state = self.state.as_ref().ok_or(ode_solver_error!(StateNotSet))?; + if self.is_state_modified { + if t == state.t { + return Ok(state.g.clone()); + } else { + return Err(ode_solver_error!(InterpolationTimeOutsideCurrentStep)); + } + } + // check that t is before/after the current time depending on the direction + let is_forward = state.h > Eqn::T::zero(); + if (is_forward && t > state.t) || (!is_forward && t < state.t) { + return Err(ode_solver_error!(InterpolationTimeAfterCurrentTime)); + } + Ok(Self::interpolate_from_diff( + t, + &state.gdiff, + state.t, + state.h, + state.order, + )) + } + fn interpolate_sens(&self, t: ::T) -> Result, DiffsolError> { // state must be set let state = self.state.as_ref().ok_or(ode_solver_error!(StateNotSet))?; @@ -542,16 +791,16 @@ where self.ode_problem.as_ref() } - fn state(&self) -> Option<&BdfState> { - self.state.as_ref() + fn state(&self) -> Option> { + self.state.as_ref().map(|state| state.as_ref()) } fn take_state(&mut self) -> Option> { Option::take(&mut self.state) } - fn state_mut(&mut self) -> Option<&mut BdfState> { + fn state_mut(&mut self) -> Option> { self.is_state_modified = true; - self.state.as_mut() + self.state.as_mut().map(|state| state.as_mut()) } fn checkpoint(&mut self) -> Result { @@ -576,14 +825,17 @@ where state.check_consistent_with_problem(problem)?; // setup linear solver for first step - let bdf_callable = Rc::new(BdfCallable::new(problem)); + let bdf_callable = BdfCallable::new(problem); bdf_callable.set_c(state.h, self.alpha[state.order]); - let nonlinear_problem = SolverProblem::new_from_ode_problem(bdf_callable, problem); - self.nonlinear_solver.set_problem(&nonlinear_problem); + self.nonlinear_solver + .set_problem(&bdf_callable, problem.rtol, problem.atol.clone()); self.nonlinear_solver .convergence_mut() .set_max_iter(Self::NEWTON_MAXITER); + self.nonlinear_solver + .reset_jacobian(&bdf_callable, &state.y, state.t); + self.op = Some(bdf_callable); // setup root solver if let Some(root_fn) = problem.eqn.root() { @@ -602,25 +854,16 @@ where self.y_predict = ::zeros(nstates); } - // allocate internal state for sensitivities - if self.ode_problem.as_ref().unwrap().eqn_sens.is_some() { - state.check_sens_consistent_with_problem(problem)?; - let nparams = self.ode_problem.as_ref().unwrap().eqn.rhs().nparams(); - self.s_op = Some(BdfCallable::from_eqn( - self.ode_problem - .as_ref() - .unwrap() - .eqn_sens - .as_ref() - .unwrap(), - )); - - if self.s_deltas.len() != nparams || self.s_deltas[0].len() != nstates { - self.s_deltas = vec![::zeros(nstates); nparams]; - } - if self.s_predict.len() != nstates { - self.s_predict = ::zeros(nstates); - } + let nout = if let Some(out) = problem.eqn.out() { + out.nout() + } else { + 0 + }; + if self.g_delta.len() != nout { + self.g_delta = ::zeros(nout); + } + if self.gdiff_tmp.nrows() != nout { + self.gdiff_tmp = M::zeros(nout, BdfState::::MAX_ORDER + 3); } // init U matrix @@ -639,6 +882,10 @@ where if self.state.is_none() { return Err(ode_solver_error!(StateNotSet)); } + let problem = self.ode_problem.as_ref().unwrap(); + let integrate_out = problem.integrate_out; + let output_in_error_control = problem.output_in_error_control(); + let integrate_sens = self.s_op.is_some(); let mut convergence_fail = false; @@ -653,11 +900,9 @@ where let order = self.state.as_ref().unwrap().order; self.y_delta.copy_from(&self.y_predict); - // initialise error_norm to quieten the compiler - error_norm = Eqn::T::from(2.0); - // solve BDF equation using y0 as starting point let mut solve_result = self.nonlinear_solver.solve_in_place( + self.op.as_ref().unwrap(), &mut self.y_delta, self.t_predict, &self.y_predict, @@ -674,27 +919,14 @@ where // and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} self.y_delta -= &self.y_predict; - // calculate error norm - { - let rtol = self.problem().as_ref().unwrap().rtol; - let atol = self.ode_problem.as_ref().unwrap().atol.as_ref(); - error_norm = - self.y_delta - .squared_norm(&self.state.as_mut().unwrap().y, atol, rtol) - * self.error_const2[order]; + // deal with output equations + if integrate_out && output_in_error_control { + self.calculate_output_delta(); } - // only bother doing sensitivity calculations if we might keep the step - if self.ode_problem.as_ref().unwrap().eqn_sens.is_some() - && error_norm <= Eqn::T::from(1.0) - { - error_norm = match self.sensitivity_solve(self.t_predict, error_norm) { - Ok(en) => en, - Err(_) => { - solve_result = Err(ode_solver_error!(SensitivitySolveFailed)); - Eqn::T::from(2.0) - } - } + // sensitivities + if integrate_sens && self.sensitivity_solve(self.t_predict).is_err() { + solve_result = Err(ode_solver_error!(SensitivitySolveFailed)); } } @@ -726,6 +958,8 @@ where continue; } + error_norm = self.error_control(); + // need to caulate safety even if step is accepted let maxiter = self.nonlinear_solver.convergence().max_iter() as f64; let niter = self.nonlinear_solver.convergence().niter() as f64; @@ -753,8 +987,9 @@ where self.statistics.number_of_error_test_failures += 1; } } + // take the accepted step - self.update_differences(); + self.update_differences_and_integrate_out(); { let state = self.state.as_mut().unwrap(); @@ -766,7 +1001,7 @@ where // update statistics self.statistics.number_of_linear_solver_setups = - self.nonlinear_problem_op().number_of_jac_evals(); + self.op.as_ref().unwrap().number_of_jac_evals(); self.statistics.number_of_steps += 1; self.jacobian_update.step(); @@ -777,41 +1012,17 @@ where if self.n_equal_steps > self.state.as_ref().unwrap().order { let factors = { let state = self.state.as_mut().unwrap(); - let atol = self.ode_problem.as_ref().unwrap().atol.as_ref(); - let rtol = self.ode_problem.as_ref().unwrap().rtol; let order = state.order; // similar to the optimal step size factor we calculated above for the current // order k, we need to calculate the optimal step size factors for orders // k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n let error_m_norm = if order > 1 { - let mut error_m_norm = - state.diff.column(order).squared_norm(&state.y, atol, rtol) - * self.error_const2[order - 1]; - for i in 0..state.sdiff.len() { - error_m_norm += - state.sdiff[i] - .column(order) - .squared_norm(&state.s[i], atol, rtol) - * self.error_const2[order - 1]; - } - error_m_norm / Eqn::T::from((state.sdiff.len() + 1) as f64) + self.predict_error_control(order - 1) } else { Eqn::T::INFINITY }; let error_p_norm = if order < BdfState::::MAX_ORDER { - let mut error_p_norm = state - .diff - .column(order + 2) - .squared_norm(&state.y, atol, rtol) - * self.error_const2[order + 1]; - for i in 0..state.sdiff.len() { - error_p_norm = - state.sdiff[i] - .column(order + 2) - .squared_norm(&state.s[i], atol, rtol) - * self.error_const2[order + 1]; - } - error_p_norm / Eqn::T::from((state.sdiff.len() + 1) as f64) + self.predict_error_control(order + 1) } else { Eqn::T::INFINITY }; @@ -905,6 +1116,72 @@ where } } +impl AugmentedOdeSolverMethod + for Bdf +where + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit, + M: DenseMatrix, + Nls: NonLinearSolver, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + fn set_augmented_problem( + &mut self, + state: BdfState, + problem: &OdeSolverProblem, + augmented_eqn: AugmentedEqn, + ) -> Result<(), DiffsolError> { + state.check_sens_consistent_with_problem(problem, &augmented_eqn)?; + + self.set_problem(state, problem)?; + + self.state + .as_mut() + .unwrap() + .set_augmented_problem(problem, &augmented_eqn)?; + + // allocate internal state for sensitivities + let naug = augmented_eqn.max_index(); + let nstates = problem.eqn.rhs().nstates(); + let augmented_eqn = Rc::new(augmented_eqn); + self.s_op = Some(BdfCallable::from_sensitivity_eqn(&augmented_eqn)); + + if self.s_deltas.len() != naug || self.s_deltas[0].len() != nstates { + self.s_deltas = vec![::zeros(nstates); naug]; + } + if self.s_predict.len() != nstates { + self.s_predict = ::zeros(nstates); + } + if let Some(out) = self.s_op.as_ref().unwrap().eqn().out() { + if self.sg_deltas.len() != naug || self.sg_deltas[0].len() != out.nout() { + self.sg_deltas = vec![::zeros(out.nout()); naug]; + } + if self.sgdiff_tmp.nrows() != out.nout() { + self.sgdiff_tmp = M::zeros(out.nout(), BdfState::::MAX_ORDER + 3); + } + } + Ok(()) + } +} + +impl AdjointOdeSolverMethod for Bdf +where + Eqn: OdeEquationsAdjoint, + AugmentedEqn: AugmentedOdeEquations + OdeEquationsAdjoint, + M: DenseMatrix, + Nls: NonLinearSolver, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + type AdjointSolver = Bdf, Nls, AdjointEquations>; + + fn new_adjoint_solver(&self) -> Self::AdjointSolver { + let adjoint_nls = Nls::default(); + Self::AdjointSolver::new(adjoint_nls) + } +} + #[cfg(test)] mod test { use crate::{ @@ -912,24 +1189,26 @@ mod test { test_models::{ dydt_y2::dydt_y2_problem, exponential_decay::{ - exponential_decay_problem, exponential_decay_problem_sens, - exponential_decay_problem_with_root, negative_exponential_decay_problem, + exponential_decay_problem, exponential_decay_problem_adjoint, + exponential_decay_problem_sens, exponential_decay_problem_with_root, + negative_exponential_decay_problem, }, exponential_decay_with_algebraic::{ + exponential_decay_with_algebraic_adjoint_problem, exponential_decay_with_algebraic_problem, exponential_decay_with_algebraic_problem_sens, }, foodweb::{foodweb_problem, FoodWebContext}, gaussian_decay::gaussian_decay_problem, heat2d::head2d_problem, - robertson::robertson, + robertson::{robertson, robertson_sens}, robertson_ode::robertson_ode, robertson_ode_with_sens::robertson_ode_with_sens, - robertson_sens::robertson_sens, }, tests::{ test_checkpointing, test_interpolate, test_no_set_problem, test_ode_solver, - test_state_mut, test_state_mut_on_problem, + test_ode_solver_adjoint, test_ode_solver_no_sens, test_state_mut, + test_state_mut_on_problem, }, }, Bdf, FaerSparseLU, NewtonNonlinearSolver, OdeEquations, Op, SparseColMat, @@ -963,18 +1242,18 @@ mod test { fn bdf_test_nalgebra_negative_exponential_decay() { let mut s = Bdf::default(); let (problem, soln) = negative_exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] fn bdf_test_nalgebra_exponential_decay() { let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 11 - number_of_steps: 41 + number_of_steps: 47 number_of_error_test_failures: 0 number_of_nonlinear_solver_iterations: 82 number_of_nonlinear_solver_fails: 0 @@ -984,6 +1263,7 @@ mod test { number_of_calls: 84 number_of_jac_muls: 2 number_of_matrix_evals: 1 + number_of_jac_adj_muls: 0 "###); } @@ -993,7 +1273,7 @@ mod test { let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = exponential_decay_problem::>(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] @@ -1007,11 +1287,11 @@ mod test { type M = faer::Mat; let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 11 - number_of_steps: 41 + number_of_steps: 47 number_of_error_test_failures: 0 number_of_nonlinear_solver_iterations: 82 number_of_nonlinear_solver_fails: 0 @@ -1021,27 +1301,73 @@ mod test { number_of_calls: 84 number_of_jac_muls: 2 number_of_matrix_evals: 1 + number_of_jac_adj_muls: 0 "###); } #[test] fn bdf_test_nalgebra_exponential_decay_sens() { - let mut s = Bdf::default(); + let mut s = Bdf::with_sensitivities(); let (problem, soln) = exponential_decay_problem_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 12 - number_of_steps: 55 + number_of_linear_solver_setups: 11 + number_of_steps: 44 number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 273 + number_of_nonlinear_solver_iterations: 217 number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 110 - number_of_jac_muls: 171 - number_of_matrix_evals: 2 + number_of_calls: 87 + number_of_jac_muls: 136 + number_of_matrix_evals: 1 + number_of_jac_adj_muls: 0 + "###); + } + + #[test] + fn bdf_test_nalgebra_exponential_decay_adjoint() { + let s = Bdf::default(); + let (problem, soln) = exponential_decay_problem_adjoint::(); + let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + --- + number_of_calls: 84 + number_of_jac_muls: 6 + number_of_matrix_evals: 3 + number_of_jac_adj_muls: 492 + "###); + insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 24 + number_of_steps: 86 + number_of_error_test_failures: 12 + number_of_nonlinear_solver_iterations: 486 + number_of_nonlinear_solver_fails: 0 + "###); + } + + #[test] + fn bdf_test_nalgebra_exponential_decay_algebraic_adjoint() { + let s = Bdf::default(); + let (problem, soln) = exponential_decay_with_algebraic_adjoint_problem::(); + let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + --- + number_of_calls: 190 + number_of_jac_muls: 24 + number_of_matrix_evals: 8 + number_of_jac_adj_muls: 278 + "###); + insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 32 + number_of_steps: 74 + number_of_error_test_failures: 15 + number_of_nonlinear_solver_iterations: 266 + number_of_nonlinear_solver_fails: 0 "###); } @@ -1049,20 +1375,21 @@ mod test { fn test_bdf_nalgebra_exponential_decay_algebraic() { let mut s = Bdf::default(); let (problem, soln) = exponential_decay_with_algebraic_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 16 - number_of_steps: 36 - number_of_error_test_failures: 2 - number_of_nonlinear_solver_iterations: 71 + number_of_linear_solver_setups: 20 + number_of_steps: 41 + number_of_error_test_failures: 4 + number_of_nonlinear_solver_iterations: 79 number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 75 + number_of_calls: 83 number_of_jac_muls: 6 number_of_matrix_evals: 2 + number_of_jac_adj_muls: 0 "###); } @@ -1072,27 +1399,28 @@ mod test { let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = exponential_decay_with_algebraic_problem::>(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] fn test_bdf_nalgebra_exponential_decay_algebraic_sens() { - let mut s = Bdf::default(); - let (problem, soln) = exponential_decay_with_algebraic_problem_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + let mut s = Bdf::with_sensitivities(); + let (problem, soln) = exponential_decay_with_algebraic_problem_sens::(); + test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 21 - number_of_steps: 49 - number_of_error_test_failures: 5 - number_of_nonlinear_solver_iterations: 163 + number_of_linear_solver_setups: 18 + number_of_steps: 43 + number_of_error_test_failures: 3 + number_of_nonlinear_solver_iterations: 155 number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- number_of_calls: 71 - number_of_jac_muls: 108 + number_of_jac_muls: 100 number_of_matrix_evals: 3 + number_of_jac_adj_muls: 0 "###); } @@ -1100,20 +1428,21 @@ mod test { fn test_bdf_nalgebra_robertson() { let mut s = Bdf::default(); let (problem, soln) = robertson::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 79 - number_of_steps: 330 - number_of_error_test_failures: 1 - number_of_nonlinear_solver_iterations: 748 + number_of_linear_solver_setups: 77 + number_of_steps: 316 + number_of_error_test_failures: 3 + number_of_nonlinear_solver_iterations: 722 number_of_nonlinear_solver_fails: 19 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 751 + number_of_calls: 725 number_of_jac_muls: 60 number_of_matrix_evals: 20 + number_of_jac_adj_muls: 0 "###); } @@ -1123,7 +1452,7 @@ mod test { let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = robertson::>(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[cfg(feature = "suitesparse")] @@ -1133,7 +1462,7 @@ mod test { let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = robertson::>(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[cfg(feature = "diffsl-llvm")] @@ -1146,27 +1475,28 @@ mod test { let mut s = Bdf::default(); robertson::robertson_diffsl_compile(&mut context); let (problem, soln) = robertson::robertson_diffsl_problem::(&context, false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] fn test_bdf_nalgebra_robertson_sens() { - let mut s = Bdf::default(); - let (problem, soln) = robertson_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + let mut s = Bdf::with_sensitivities(); + let (problem, soln) = robertson_sens::(); + test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 193 - number_of_steps: 526 - number_of_error_test_failures: 42 - number_of_nonlinear_solver_iterations: 4021 - number_of_nonlinear_solver_fails: 66 + number_of_linear_solver_setups: 160 + number_of_steps: 410 + number_of_error_test_failures: 4 + number_of_nonlinear_solver_iterations: 3107 + number_of_nonlinear_solver_fails: 81 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 1259 - number_of_jac_muls: 3083 - number_of_matrix_evals: 60 + number_of_calls: 996 + number_of_jac_muls: 2495 + number_of_matrix_evals: 71 + number_of_jac_adj_muls: 0 "###); } @@ -1174,20 +1504,21 @@ mod test { fn test_bdf_nalgebra_robertson_colored() { let mut s = Bdf::default(); let (problem, soln) = robertson::(true); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 79 - number_of_steps: 330 - number_of_error_test_failures: 1 - number_of_nonlinear_solver_iterations: 748 + number_of_linear_solver_setups: 77 + number_of_steps: 316 + number_of_error_test_failures: 3 + number_of_nonlinear_solver_iterations: 722 number_of_nonlinear_solver_fails: 19 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 751 + number_of_calls: 725 number_of_jac_muls: 63 number_of_matrix_evals: 20 + number_of_jac_adj_muls: 0 "###); } @@ -1195,41 +1526,43 @@ mod test { fn test_bdf_nalgebra_robertson_ode() { let mut s = Bdf::default(); let (problem, soln) = robertson_ode::(false, 3); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 90 - number_of_steps: 412 - number_of_error_test_failures: 2 - number_of_nonlinear_solver_iterations: 908 + number_of_linear_solver_setups: 86 + number_of_steps: 416 + number_of_error_test_failures: 1 + number_of_nonlinear_solver_iterations: 911 number_of_nonlinear_solver_fails: 15 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 910 + number_of_calls: 913 number_of_jac_muls: 162 number_of_matrix_evals: 18 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_bdf_nalgebra_robertson_ode_sens() { - let mut s = Bdf::default(); + let mut s = Bdf::with_sensitivities(); let (problem, soln) = robertson_ode_with_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 176 - number_of_steps: 483 - number_of_error_test_failures: 8 - number_of_nonlinear_solver_iterations: 3638 - number_of_nonlinear_solver_fails: 92 + number_of_linear_solver_setups: 112 + number_of_steps: 467 + number_of_error_test_failures: 2 + number_of_nonlinear_solver_iterations: 3472 + number_of_nonlinear_solver_fails: 49 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 1170 - number_of_jac_muls: 2894 - number_of_matrix_evals: 79 + number_of_calls: 1041 + number_of_jac_muls: 2672 + number_of_matrix_evals: 45 + number_of_jac_adj_muls: 0 "###); } @@ -1237,20 +1570,21 @@ mod test { fn test_bdf_nalgebra_dydt_y2() { let mut s = Bdf::default(); let (problem, soln) = dydt_y2_problem::(false, 10); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 30 - number_of_steps: 160 - number_of_error_test_failures: 2 - number_of_nonlinear_solver_iterations: 356 - number_of_nonlinear_solver_fails: 2 + number_of_linear_solver_setups: 27 + number_of_steps: 161 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 355 + number_of_nonlinear_solver_fails: 3 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 358 - number_of_jac_muls: 40 - number_of_matrix_evals: 4 + number_of_calls: 357 + number_of_jac_muls: 50 + number_of_matrix_evals: 5 + number_of_jac_adj_muls: 0 "###); } @@ -1258,20 +1592,21 @@ mod test { fn test_bdf_nalgebra_dydt_y2_colored() { let mut s = Bdf::default(); let (problem, soln) = dydt_y2_problem::(true, 10); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 30 - number_of_steps: 160 - number_of_error_test_failures: 2 - number_of_nonlinear_solver_iterations: 356 - number_of_nonlinear_solver_fails: 2 + number_of_linear_solver_setups: 27 + number_of_steps: 161 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 355 + number_of_nonlinear_solver_fails: 3 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 358 - number_of_jac_muls: 14 - number_of_matrix_evals: 4 + number_of_calls: 357 + number_of_jac_muls: 15 + number_of_matrix_evals: 5 + number_of_jac_adj_muls: 0 "###); } @@ -1279,20 +1614,21 @@ mod test { fn test_bdf_nalgebra_gaussian_decay() { let mut s = Bdf::default(); let (problem, soln) = gaussian_decay_problem::(false, 10); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 14 - number_of_steps: 60 + number_of_steps: 66 number_of_error_test_failures: 1 - number_of_nonlinear_solver_iterations: 124 + number_of_nonlinear_solver_iterations: 130 number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 126 + number_of_calls: 132 number_of_jac_muls: 20 number_of_matrix_evals: 2 + number_of_jac_adj_muls: 0 "###); } @@ -1302,20 +1638,21 @@ mod test { let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = head2d_problem::, 10>(); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 21 - number_of_steps: 173 + number_of_steps: 167 number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 343 + number_of_nonlinear_solver_iterations: 330 number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 346 - number_of_jac_muls: 135 - number_of_matrix_evals: 5 + number_of_calls: 333 + number_of_jac_muls: 128 + number_of_matrix_evals: 4 + number_of_jac_adj_muls: 0 "###); } @@ -1331,7 +1668,7 @@ mod test { let mut s = Bdf::, _, _>::new(nonlinear_solver); heat2d_diffsl_compile::, LlvmModule, 10>(&mut context); let (problem, soln) = heat2d::heat2d_diffsl_problem(&context); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] @@ -1341,13 +1678,13 @@ mod test { let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); let mut s = Bdf::, _, _>::new(nonlinear_solver); let (problem, soln) = foodweb_problem::, 10>(&foodweb_context); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 40 - number_of_steps: 148 - number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 330 + number_of_linear_solver_setups: 45 + number_of_steps: 161 + number_of_error_test_failures: 2 + number_of_nonlinear_solver_iterations: 355 number_of_nonlinear_solver_fails: 14 "###); } @@ -1364,21 +1701,21 @@ mod test { let mut s = Bdf::, _, _>::new(nonlinear_solver); foodweb::foodweb_diffsl_compile::, LlvmModule, 10>(&mut context); let (problem, soln) = foodweb::foodweb_diffsl_problem(&context); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] fn test_tstop_bdf() { let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, true); + test_ode_solver_no_sens(&mut s, &problem, soln, None, true); } #[test] fn test_root_finder_bdf() { let mut s = Bdf::default(); let (problem, soln) = exponential_decay_problem_with_root::(false); - let y = test_ode_solver(&mut s, &problem, soln, None, false); + let y = test_ode_solver_no_sens(&mut s, &problem, soln, None, false); assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); } } diff --git a/src/ode_solver/bdf_state.rs b/src/ode_solver/bdf_state.rs index 279ee95d..454bf76f 100644 --- a/src/ode_solver/bdf_state.rs +++ b/src/ode_solver/bdf_state.rs @@ -1,20 +1,35 @@ use crate::{ - error::DiffsolError, scalar::IndexType, scale, DenseMatrix, OdeEquations, OdeSolverProblem, - OdeSolverState, Op, Vector, VectorViewMut, + error::{DiffsolError, OdeSolverError}, + ode_solver_error, + scalar::IndexType, + scale, AugmentedOdeEquations, DenseMatrix, OdeEquations, OdeSolverProblem, OdeSolverState, Op, + StateRef, StateRefMut, Vector, VectorViewMut, }; use std::ops::MulAssign; +use super::state::StateCommon; + #[derive(Clone)] pub struct BdfState> { pub(crate) order: usize, pub(crate) diff: M, pub(crate) sdiff: Vec, + pub(crate) gdiff: M, + pub(crate) sgdiff: Vec, pub(crate) y: V, pub(crate) dy: V, + pub(crate) g: V, + pub(crate) dg: V, pub(crate) s: Vec, pub(crate) ds: Vec, + pub(crate) sg: Vec, + pub(crate) dsg: Vec, pub(crate) t: V::T, pub(crate) h: V::T, + pub(crate) diff_initialised: bool, + pub(crate) sdiff_initialised: bool, + pub(crate) gdiff_initialised: bool, + pub(crate) sgdiff_initialised: bool, } impl BdfState @@ -24,23 +39,45 @@ where { pub(crate) const MAX_ORDER: IndexType = 5; - pub fn initialise_diff_to_first_order(&mut self, has_sens: bool) { + pub fn initialise_diff_to_first_order(&mut self) { self.order = 1usize; - self.diff.column_mut(0).copy_from(&self.y); self.diff.column_mut(1).copy_from(&self.dy); self.diff.column_mut(1).mul_assign(scale(self.h)); - let nparams = self.s.len(); - if has_sens { - for i in 0..nparams { - let sdiff = &mut self.sdiff[i]; - let s = &self.s[i]; - let ds = &self.ds[i]; - sdiff.column_mut(0).copy_from(s); - sdiff.column_mut(1).copy_from(ds); - sdiff.column_mut(1).mul_assign(scale(self.h)); - } + self.diff_initialised = true; + } + + pub fn initialise_sdiff_to_first_order(&mut self) { + let naug = self.sdiff.len(); + for i in 0..naug { + let sdiff = &mut self.sdiff[i]; + let s = &self.s[i]; + let ds = &self.ds[i]; + sdiff.column_mut(0).copy_from(s); + sdiff.column_mut(1).copy_from(ds); + sdiff.column_mut(1).mul_assign(scale(self.h)); } + self.sdiff_initialised = true; + } + + pub fn initialise_gdiff_to_first_order(&mut self) { + self.gdiff.column_mut(0).copy_from(&self.g); + self.gdiff.column_mut(1).copy_from(&self.dg); + self.gdiff.column_mut(1).mul_assign(scale(self.h)); + self.gdiff_initialised = true; + } + + pub fn initialise_sgdiff_to_first_order(&mut self) { + let naug = self.sgdiff.len(); + for i in 0..naug { + let sgdiff = &mut self.sgdiff[i]; + let sg = &self.sg[i]; + let dsg = &self.dsg[i]; + sgdiff.column_mut(0).copy_from(sg); + sgdiff.column_mut(1).copy_from(dsg); + sgdiff.column_mut(1).mul_assign(scale(self.h)); + } + self.sgdiff_initialised = true; } } @@ -53,82 +90,151 @@ where &mut self, ode_problem: &OdeSolverProblem, ) -> Result<(), DiffsolError> { - let not_initialised = self.diff.ncols() == 0; let nstates = ode_problem.eqn.rhs().nstates(); - let nparams = ode_problem.eqn.rhs().nparams(); - let has_sens = ode_problem.eqn_sens.is_some(); - if not_initialised { - self.diff = M::zeros(nstates, Self::MAX_ORDER + 3); - if has_sens { - self.sdiff = vec![M::zeros(nstates, Self::MAX_ORDER + 3); nparams]; + if self.diff.nrows() != nstates { + return Err(ode_solver_error!(StateProblemMismatch)); + } + let expected_gdiff_len = if let Some(out) = ode_problem.eqn.out() { + if ode_problem.integrate_out { + out.nout() + } else { + 0 + } + } else { + 0 + }; + if self.gdiff.nrows() != expected_gdiff_len { + return Err(ode_solver_error!(StateProblemMismatch)); + } + if !self.diff_initialised { + self.initialise_diff_to_first_order(); + } + if !self.gdiff_initialised { + self.initialise_gdiff_to_first_order(); + } + Ok(()) + } + + fn set_augmented_problem>( + &mut self, + ode_problem: &OdeSolverProblem, + augmented_eqn: &AugmentedEqn, + ) -> Result<(), DiffsolError> { + let naug = augmented_eqn.max_index(); + let nstates = ode_problem.eqn.rhs().nstates(); + if self.sdiff.len() != naug || self.sdiff[0].nrows() != nstates { + return Err(ode_solver_error!(StateProblemMismatch)); + } + let (sgdiff_len, sgdiff_size) = if let Some(_out) = augmented_eqn.out() { + if let Some(out) = augmented_eqn.out() { + (naug, out.nout()) + } else { + (0, 0) } - self.initialise_diff_to_first_order(has_sens); + } else { + (0, 0) + }; + if self.sgdiff.len() != sgdiff_len + || (sgdiff_len > 0 && self.sgdiff[0].nrows() != sgdiff_size) + { + return Err(ode_solver_error!(StateProblemMismatch)); + } + if !self.sdiff_initialised { + self.initialise_sdiff_to_first_order(); + } + if !self.sgdiff_initialised { + self.initialise_sgdiff_to_first_order(); } Ok(()) } - fn new_internal_state(y: V, dy: V, s: Vec, ds: Vec, t: ::T, h: ::T) -> Self { + fn new_from_common(state: super::state::StateCommon) -> Self { + let StateCommon { + y, + dy, + g, + dg, + s, + ds, + sg, + dsg, + t, + h, + } = state; + let nstates = y.len(); + let diff = M::zeros(nstates, Self::MAX_ORDER + 3); + let sdiff = vec![M::zeros(nstates, Self::MAX_ORDER + 3); s.len()]; + let gdiff = M::zeros(g.len(), Self::MAX_ORDER + 3); + let sgdiff = if !sg.is_empty() { + vec![M::zeros(sg[0].len(), Self::MAX_ORDER + 3); sg.len()] + } else { + Vec::new() + }; Self { order: 1, - diff: M::zeros(0, 0), - sdiff: Vec::new(), + diff, + sdiff, + gdiff, + sgdiff, y, dy, + g, + dg, s, ds, + sg, + dsg, t, h, + diff_initialised: false, + sdiff_initialised: false, + gdiff_initialised: false, + sgdiff_initialised: false, } } - fn s(&self) -> &[V] { - self.s.as_slice() - } - fn s_mut(&mut self) -> &mut [V] { - &mut self.s - } - fn ds_mut(&mut self) -> &mut [V] { - &mut self.ds - } - fn ds(&self) -> &[V] { - self.ds.as_slice() - } - fn s_ds_mut(&mut self) -> (&mut [V], &mut [V]) { - (&mut self.s, &mut self.ds) - } - fn y(&self) -> &V { - &self.y - } - - fn y_mut(&mut self) -> &mut V { - &mut self.y - } - - fn dy(&self) -> &V { - &self.dy - } - - fn dy_mut(&mut self) -> &mut V { - &mut self.dy - } - - fn y_dy_mut(&mut self) -> (&mut V, &mut V) { - (&mut self.y, &mut self.dy) - } - - fn t(&self) -> V::T { - self.t - } - - fn t_mut(&mut self) -> &mut V::T { - &mut self.t + fn into_common(self) -> StateCommon { + StateCommon { + y: self.y, + dy: self.dy, + g: self.g, + dg: self.dg, + s: self.s, + ds: self.ds, + sg: self.sg, + dsg: self.dsg, + t: self.t, + h: self.h, + } } - fn h(&self) -> V::T { - self.h + fn as_ref(&self) -> StateRef { + StateRef { + y: &self.y, + dy: &self.dy, + g: &self.g, + dg: &self.dg, + s: &self.s, + ds: &self.ds, + sg: &self.sg, + dsg: &self.dsg, + t: self.t, + h: self.h, + } } - fn h_mut(&mut self) -> &mut V::T { - &mut self.h + fn as_mut(&mut self) -> StateRefMut { + StateRefMut { + y: &mut self.y, + dy: &mut self.dy, + g: &mut self.g, + dg: &mut self.dg, + s: &mut self.s, + ds: &mut self.ds, + sg: &mut self.sg, + dsg: &mut self.dsg, + t: &mut self.t, + h: &mut self.h, + } } } diff --git a/src/ode_solver/builder.rs b/src/ode_solver/builder.rs index 6f6c881e..6236b6b7 100644 --- a/src/ode_solver/builder.rs +++ b/src/ode_solver/builder.rs @@ -1,13 +1,11 @@ -use std::rc::Rc; - use crate::{ error::{DiffsolError, OdeSolverError}, ode_solver_error, vector::DefaultDenseMatrix, Closure, ClosureNoJac, ClosureWithSens, ConstantClosure, ConstantClosureWithSens, - LinearClosure, LinearClosureWithSens, Matrix, OdeEquations, OdeSolverProblem, Op, UnitCallable, - Vector, + LinearClosure, Matrix, OdeEquations, OdeSolverProblem, Op, UnitCallable, Vector, }; +use std::rc::Rc; use super::equations::OdeSolverEquations; @@ -17,10 +15,15 @@ pub struct OdeBuilder { h0: f64, rtol: f64, atol: Vec, + sens_atol: Option>, + sens_rtol: Option, + out_rtol: Option, + out_atol: Option>, + param_rtol: Option, + param_atol: Option>, p: Vec, use_coloring: bool, - sensitivities: bool, - sensitivities_error_control: bool, + integrate_out: bool, } impl Default for OdeBuilder { @@ -59,13 +62,12 @@ impl Default for OdeBuilder { /// let t = 0.4; /// let mut state = OdeSolverState::new(&problem, &solver).unwrap(); /// solver.set_problem(state, &problem); -/// while solver.state().unwrap().t() <= t { +/// while solver.state().unwrap().t <= t { /// solver.step().unwrap(); /// } /// let y = solver.interpolate(t); /// ``` /// - impl OdeBuilder { /// Create a new builder with default parameters: /// - t0 = 0.0 @@ -76,15 +78,22 @@ impl OdeBuilder { /// - use_coloring = false /// - constant_mass = false pub fn new() -> Self { + let default_atol = vec![1e-6]; + let default_rtol = 1e-6; Self { t0: 0.0, h0: 1.0, - rtol: 1e-6, - atol: vec![1e-6], + rtol: default_rtol, + atol: default_atol.clone(), p: vec![], use_coloring: false, - sensitivities: false, - sensitivities_error_control: false, + integrate_out: false, + out_rtol: Some(default_rtol), + out_atol: Some(default_atol.clone()), + param_rtol: Some(default_rtol), + param_atol: Some(default_atol.clone()), + sens_atol: Some(default_atol), + sens_rtol: Some(default_rtol), } } @@ -94,13 +103,52 @@ impl OdeBuilder { self } - pub fn sensitivities(mut self, sensitivities: bool) -> Self { - self.sensitivities = sensitivities; + pub fn sens_rtol(mut self, sens_rtol: Option) -> Self { + self.sens_rtol = sens_rtol; + self + } + + pub fn sens_atol(mut self, sens_atol: Option) -> Self + where + V: IntoIterator, + f64: From, + { + self.sens_atol = sens_atol.map(|atol| atol.into_iter().map(|x| f64::from(x)).collect()); + self + } + + pub fn out_rtol(mut self, out_rtol: Option) -> Self { + self.out_rtol = out_rtol; + self + } + + pub fn out_atol(mut self, out_atol: Option) -> Self + where + V: IntoIterator, + f64: From, + { + self.out_atol = out_atol.map(|atol| atol.into_iter().map(|x| f64::from(x)).collect()); + self + } + + pub fn param_rtol(mut self, param_rtol: Option) -> Self { + self.param_rtol = param_rtol; self } - pub fn sensitivities_error_control(mut self, sensitivities_error_control: bool) -> Self { - self.sensitivities_error_control = sensitivities_error_control; + pub fn param_atol(mut self, param_atol: Option) -> Self + where + V: IntoIterator, + f64: From, + { + self.param_atol = param_atol.map(|atol| atol.into_iter().map(|x| f64::from(x)).collect()); + self + } + + /// Set whether to integrate the output. + /// If true, the output will be integrated using the same method as the ODE. + pub fn integrate_out(mut self, integrate_out: bool) -> Self { + self.integrate_out = integrate_out; self } @@ -146,11 +194,19 @@ impl OdeBuilder { self } - fn build_atol(atol: Vec, nstates: usize) -> Result { + fn build_atol(atol: Vec, nstates: usize, ty: &str) -> Result { if atol.len() == 1 { Ok(V::from_element(nstates, V::T::from(atol[0]))) } else if atol.len() != nstates { - Err(ode_solver_error!(AtolLengthMismatch)) + Err(ode_solver_error!( + BuilderError, + format!( + "Invalid number of {} absolute tolerances. Expected 1 or {}, got {}.", + ty, + nstates, + atol.len() + ) + )) } else { let mut v = V::zeros(nstates); for (i, &a) in atol.iter().enumerate() { @@ -160,6 +216,32 @@ impl OdeBuilder { } } + #[allow(clippy::type_complexity)] + fn build_atols( + atol: Vec, + sens_atol: Option>, + out_atol: Option>, + param_atol: Option>, + nstates: usize, + nout: Option, + nparam: usize, + ) -> Result<(V, Option, Option, Option), DiffsolError> { + let atol = Self::build_atol(atol, nstates, "states")?; + let out_atol = match out_atol { + Some(out_atol) => Some(Self::build_atol(out_atol, nout.unwrap_or(0), "output")?), + None => None, + }; + let param_atol = match param_atol { + Some(param_atol) => Some(Self::build_atol(param_atol, nparam, "parameters")?), + None => None, + }; + let sens_atol = match sens_atol { + Some(sens_atol) => Some(Self::build_atol(sens_atol, nstates, "sensitivity")?), + None => None, + }; + Ok((atol, sens_atol, out_atol, param_atol)) + } + fn build_p(p: Vec) -> V { let mut v = V::zeros(p.len()); for (i, &p) in p.iter().enumerate() { @@ -243,16 +325,30 @@ impl OdeBuilder { let mass = Some(Rc::new(mass)); let rhs = Rc::new(rhs); let init = Rc::new(init); + let nparams = p.len(); + let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( + self.atol, + self.sens_atol, + self.out_atol, + self.param_atol, + nstates, + None, + nparams, + )?; let eqn = OdeSolverEquations::new(rhs, mass, None, init, None, p); - let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?; OdeSolverProblem::new( eqn, M::T::from(self.rtol), atol, + self.sens_rtol.map(M::T::from), + sens_atol, + self.out_rtol.map(M::T::from), + out_atol, + self.param_rtol.map(M::T::from), + param_atol, M::T::from(self.t0), M::T::from(self.h0), - false, - self.sensitivities_error_control, + self.integrate_out, ) } @@ -293,6 +389,7 @@ impl OdeBuilder { let t0 = M::T::from(self.t0); let y0 = init(&p, t0); let nstates = y0.len(); + let nparams = p.len(); let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone()); let out = Closure::new(out, out_jac, nstates, nout, p.clone()); let mut mass = LinearClosure::new(mass, nstates, nstates, p.clone()); @@ -306,123 +403,28 @@ impl OdeBuilder { let init = Rc::new(init); let out = Some(Rc::new(out)); let eqn = OdeSolverEquations::new(rhs, mass, None, init, out, p); - let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?; + let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( + self.atol, + self.sens_atol, + self.out_atol, + self.param_atol, + nstates, + Some(nout), + nparams, + )?; OdeSolverProblem::new( eqn, M::T::from(self.rtol), atol, + self.sens_rtol.map(M::T::from), + sens_atol, + self.out_rtol.map(M::T::from), + out_atol, + self.param_rtol.map(M::T::from), + param_atol, M::T::from(self.t0), M::T::from(self.h0), - false, - self.sensitivities_error_control, - ) - } - - /// Build an ODE problem with a mass matrix and sensitivities. - /// - /// # Arguments - /// - /// - `rhs`: Function of type Fn(x: &V, p: &V, t: S, y: &mut V) that computes the right-hand side of the ODE. - /// - `rhs_jac`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the Jacobian of the right-hand side with the vector v. - /// - `rhs_sens`: Function of type Fn(x: &V, p: &V, t: S, v: &V, y: &mut V) that computes the multiplication of the partial derivative of the rhs wrt the parameters, with the vector v. - /// - `mass`: Function of type Fn(v: &V, p: &V, t: S, beta: S, y: &mut V) that computes a gemv multiplication of the mass matrix with the vector v (i.e. y = M * v + beta * y). - /// - `mass_sens`: Function of type Fn(v: &V, p: &V, t: S, y: &mut V) that computes the multiplication of the partial derivative of the mass matrix wrt the parameters, with the vector v. - /// - `init`: Function of type Fn(p: &V, t: S) -> V that computes the initial state. - /// - `init_sens`: Function of type Fn(p: &V, t: S, y: &mut V) that computes the multiplication of the partial derivative of the initial state wrt the parameters, with the vector v. - /// - /// # Example - /// - /// ``` - /// use diffsol::OdeBuilder; - /// use nalgebra::DVector; - /// type M = nalgebra::DMatrix; - /// - /// // dy/dt = a y - /// // 0 = z - y - /// // y(0) = 0.1 - /// // z(0) = 0.1 - /// let problem = OdeBuilder::new() - /// .build_ode_with_mass_and_sens::( - /// |x, p, _t, y| { - /// y[0] = p[0] * x[0]; - /// y[1] = x[1] - x[0]; - /// }, - /// |x, p, _t, v, y| { - /// y[0] = p[0] * v[0]; - /// y[1] = v[1] - v[0]; - /// }, - /// |x, _p, _t, v, y| { - /// y[0] = v[0] * x[0]; - /// y[1] = 0.0; - /// }, - /// |x, _p, _t, beta, y| { - /// y[0] = x[0] + beta * y[0]; - /// y[1] = beta * y[1]; - /// }, - /// |x, p, t, v, y| { - /// y.fill(0.0); - /// }, - /// |p, _t| DVector::from_element(2, 0.1), - /// |p, t, v, y| { - /// y.fill(0.0); - /// } - /// ); - /// ``` - #[allow(clippy::type_complexity, clippy::too_many_arguments)] - pub fn build_ode_with_mass_and_sens( - self, - rhs: F, - rhs_jac: G, - rhs_sens: J, - mass: H, - mass_sens: L, - init: I, - init_sens: K, - ) -> Result< - OdeSolverProblem< - OdeSolverEquations< - M, - ClosureWithSens, - ConstantClosureWithSens, - LinearClosureWithSens, - >, - >, - DiffsolError, - > - where - M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V), - G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - H: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), - I: Fn(&M::V, M::T) -> M::V, - J: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - K: Fn(&M::V, M::T, &M::V, &mut M::V), - L: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - { - let p = Rc::new(Self::build_p(self.p)); - let t0 = M::T::from(self.t0); - let y0 = init(&p, t0); - let nstates = y0.len(); - let mut rhs = ClosureWithSens::new(rhs, rhs_jac, rhs_sens, nstates, nstates, p.clone()); - let mut mass = LinearClosureWithSens::new(mass, mass_sens, nstates, nstates, p.clone()); - let init = ConstantClosureWithSens::new(init, init_sens, nstates, nstates, p.clone()); - if self.use_coloring || M::is_sparse() { - rhs.calculate_sparsity(&y0, t0); - mass.calculate_sparsity(t0); - } - let mass = Some(Rc::new(mass)); - let rhs = Rc::new(rhs); - let init = Rc::new(init); - let eqn = OdeSolverEquations::new(rhs, mass, None, init, None, p); - let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?; - OdeSolverProblem::new( - eqn, - M::T::from(self.rtol), - atol, - M::T::from(self.t0), - M::T::from(self.h0), - true, - self.sensitivities_error_control, + self.integrate_out, ) } @@ -484,16 +486,30 @@ impl OdeBuilder { } let rhs = Rc::new(rhs); let init = Rc::new(init); + let nparams = p.len(); let eqn = OdeSolverEquations::new(rhs, None, None, init, None, p); - let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?; + let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( + self.atol, + self.sens_atol, + self.out_atol, + self.param_atol, + nstates, + None, + nparams, + )?; OdeSolverProblem::new( eqn, M::T::from(self.rtol), atol, + self.sens_rtol.map(M::T::from), + sens_atol, + self.out_rtol.map(M::T::from), + out_atol, + self.param_rtol.map(M::T::from), + param_atol, M::T::from(self.t0), M::T::from(self.h0), - false, - self.sensitivities_error_control, + self.integrate_out, ) } @@ -526,7 +542,6 @@ impl OdeBuilder { /// |p, t, v, y| y.fill(0.0), /// ); /// ``` - #[allow(clippy::type_complexity)] pub fn build_ode_with_sens( self, @@ -556,20 +571,35 @@ impl OdeBuilder { let init = ConstantClosureWithSens::new(init, init_sens, nstates, nstates, p.clone()); let mut rhs = ClosureWithSens::new(rhs, rhs_jac, rhs_sens, nstates, nstates, p.clone()); if self.use_coloring || M::is_sparse() { - rhs.calculate_sparsity(&y0, t0); + rhs.calculate_jacobian_sparsity(&y0, t0); + rhs.calculate_sens_sparsity(&y0, t0); } let rhs = Rc::new(rhs); let init = Rc::new(init); + let nparams = p.len(); let eqn = OdeSolverEquations::new(rhs, None, None, init, None, p); - let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?; + let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( + self.atol, + self.sens_atol, + self.out_atol, + self.param_atol, + nstates, + None, + nparams, + )?; OdeSolverProblem::new( eqn, M::T::from(self.rtol), atol, + self.sens_rtol.map(M::T::from), + sens_atol, + self.out_rtol.map(M::T::from), + out_atol, + self.param_rtol.map(M::T::from), + param_atol, M::T::from(self.t0), M::T::from(self.h0), - true, - self.sensitivities_error_control, + self.integrate_out, ) } @@ -609,7 +639,6 @@ impl OdeBuilder { /// 1, /// ); /// ``` - #[allow(clippy::type_complexity)] pub fn build_ode_with_root( self, @@ -649,16 +678,30 @@ impl OdeBuilder { } let rhs = Rc::new(rhs); let init = Rc::new(init); + let nparams = p.len(); let eqn = OdeSolverEquations::new(rhs, None, Some(root), init, None, p); - let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?; + let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( + self.atol, + self.sens_atol, + self.out_atol, + self.param_atol, + nstates, + None, + nparams, + )?; OdeSolverProblem::new( eqn, M::T::from(self.rtol), atol, + self.sens_rtol.map(M::T::from), + sens_atol, + self.out_rtol.map(M::T::from), + out_atol, + self.param_rtol.map(M::T::from), + param_atol, M::T::from(self.t0), M::T::from(self.h0), - false, - self.sensitivities_error_control, + self.integrate_out, ) } @@ -682,6 +725,39 @@ impl OdeBuilder { self.build_ode(rhs, rhs_jac, init) } + /// Build an ODE problem from a set of equations + pub fn build_from_eqn(self, eqn: Eqn) -> Result, DiffsolError> + where + Eqn: OdeEquations, + { + let nparams = eqn.rhs().nparams(); + let nstates = eqn.rhs().nstates(); + let nout = eqn.out().map(|out| out.nout()); + let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( + self.atol, + self.sens_atol, + self.out_atol, + self.param_atol, + nstates, + nout, + nparams, + )?; + OdeSolverProblem::new( + eqn, + Eqn::T::from(self.rtol), + atol, + self.sens_rtol.map(Eqn::T::from), + sens_atol, + self.out_rtol.map(Eqn::T::from), + out_atol, + self.param_rtol.map(Eqn::T::from), + param_atol, + Eqn::T::from(self.t0), + Eqn::T::from(self.h0), + self.integrate_out, + ) + } + /// Build an ODE problem using the DiffSL language (requires either the `diffsl-cranelift` or `diffls-llvm` features). /// The source code is provided as a string, please see the [DiffSL documentation](https://martinjrobins.github.io/diffsl/) for more information. #[cfg(feature = "diffsl")] @@ -695,17 +771,33 @@ impl OdeBuilder { { use crate::ode_solver::diffsl; let p = Self::build_p::(self.p); + let nparams = p.len(); let mut eqn = diffsl::DiffSl::new(context, self.use_coloring || M::is_sparse()); + let nstates = eqn.rhs().nstates(); + let nout = eqn.out().map(|out| out.nout()); eqn.set_params(p); - let atol = Self::build_atol::(self.atol, eqn.rhs().nstates())?; + let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( + self.atol, + self.sens_atol, + self.out_atol, + self.param_atol, + nstates, + nout, + nparams, + )?; OdeSolverProblem::new( eqn, self.rtol, atol, + self.sens_rtol.map(M::T::from), + sens_atol, + self.out_rtol.map(M::T::from), + out_atol, + self.param_rtol.map(M::T::from), + param_atol, self.t0, self.h0, - self.sensitivities, - self.sensitivities_error_control, + self.integrate_out, ) } } diff --git a/src/ode_solver/checkpointing.rs b/src/ode_solver/checkpointing.rs new file mode 100644 index 00000000..1c3cc467 --- /dev/null +++ b/src/ode_solver/checkpointing.rs @@ -0,0 +1,263 @@ +use std::cell::RefCell; + +use crate::{ + error::DiffsolError, other_error, OdeEquations, OdeSolverMethod, OdeSolverProblem, + OdeSolverState, Vector, +}; +use num_traits::One; + +pub struct HermiteInterpolator +where + V: Vector, +{ + ys: Vec, + ydots: Vec, + ts: Vec, +} + +impl Default for HermiteInterpolator +where + V: Vector, +{ + fn default() -> Self { + HermiteInterpolator { + ys: Vec::new(), + ydots: Vec::new(), + ts: Vec::new(), + } + } +} + +impl HermiteInterpolator +where + V: Vector, +{ + pub fn new(ys: Vec, ydots: Vec, ts: Vec) -> Self { + HermiteInterpolator { ys, ydots, ts } + } + pub fn reset( + &mut self, + problem: &OdeSolverProblem, + solver: &mut Method, + state0: &State, + state1: &State, + ) -> Result<(), DiffsolError> + where + Eqn: OdeEquations, + Method: OdeSolverMethod, + State: OdeSolverState, + { + let state0_ref = state0.as_ref(); + let state1_ref = state1.as_ref(); + self.ys.clear(); + self.ydots.clear(); + self.ts.clear(); + self.ys.push(state0_ref.y.clone()); + self.ydots.push(state0_ref.dy.clone()); + self.ts.push(state0_ref.t); + + solver.set_problem(state0.clone(), problem)?; + while solver.state().unwrap().t < state1_ref.t { + solver.step()?; + self.ys.push(solver.state().unwrap().y.clone()); + self.ydots.push(solver.state().unwrap().dy.clone()); + self.ts.push(solver.state().unwrap().t); + } + Ok(()) + } + + pub fn interpolate(&self, t: V::T, y: &mut V) -> Option<()> { + if t < self.ts[0] || t > self.ts[self.ts.len() - 1] { + return None; + } + if t == self.ts[0] { + y.copy_from(&self.ys[0]); + return Some(()); + } + let idx = self + .ts + .iter() + .position(|&t0| t0 > t) + .unwrap_or(self.ts.len() - 1); + let t0 = self.ts[idx - 1]; + let t1 = self.ts[idx]; + let h = t1 - t0; + let theta = (t - t0) / h; + let u0 = &self.ys[idx - 1]; + let u1 = &self.ys[idx]; + let f0 = &self.ydots[idx - 1]; + let f1 = &self.ydots[idx]; + + y.copy_from(u0); + y.axpy(V::T::one(), u1, -V::T::one()); + y.axpy( + h * (theta - V::T::from(1.0)), + f0, + V::T::one() - V::T::from(2.0) * theta, + ); + y.axpy(h * theta, f1, V::T::one()); + y.axpy( + V::T::from(1.0) - theta, + u0, + theta * (theta - V::T::from(1.0)), + ); + y.axpy(theta, u1, V::T::one()); + Some(()) + } +} + +pub struct Checkpointing +where + Method: OdeSolverMethod, + Eqn: OdeEquations, +{ + checkpoints: Vec, + segment: RefCell>, + previous_segment: RefCell>>, + solver: RefCell, + pub(crate) problem: OdeSolverProblem, +} + +impl Checkpointing +where + Method: OdeSolverMethod, + Eqn: OdeEquations, +{ + pub fn new( + mut solver: Method, + start_idx: usize, + checkpoints: Vec, + segment: Option>, + ) -> Self { + if solver.problem().is_none() { + panic!("Solver must have a problem set"); + } + if checkpoints.len() < 2 { + panic!("Checkpoints must have at least 2 elements"); + } + if start_idx >= checkpoints.len() - 1 { + panic!("start_idx must be less than checkpoints.len() - 1"); + } + let problem = solver.problem().unwrap().clone(); + let segment = segment.unwrap_or_else(|| { + let mut segment = HermiteInterpolator::default(); + segment + .reset( + &problem, + &mut solver, + &checkpoints[start_idx], + &checkpoints[start_idx + 1], + ) + .unwrap(); + segment + }); + let segment = RefCell::new(segment); + let previous_segment = RefCell::new(None); + let solver = RefCell::new(solver); + Checkpointing { + checkpoints, + segment, + previous_segment, + solver, + problem, + } + } + + pub fn interpolate(&self, t: Eqn::T, y: &mut Eqn::V) -> Result<(), DiffsolError> { + { + let segment = self.segment.borrow(); + if segment.interpolate(t, y).is_some() { + return Ok(()); + } + } + + { + let previous_segment = self.previous_segment.borrow(); + if let Some(previous_segment) = previous_segment.as_ref() { + if previous_segment.interpolate(t, y).is_some() { + return Ok(()); + } + } + } + + // if t is before first segment or after last segment, return error + if t < self.checkpoints[0].as_ref().t + || t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t + { + return Err(other_error!("t is outside of the checkpoints")); + } + + // else find idx of segment + let idx = self + .checkpoints + .iter() + .skip(1) + .position(|state| state.as_ref().t > t) + .expect("t is not in checkpoints"); + if self.previous_segment.borrow().is_none() { + self.previous_segment + .replace(Some(HermiteInterpolator::default())); + } + let mut solver = self.solver.borrow_mut(); + let mut previous_segment = self.previous_segment.borrow_mut(); + let mut segment = self.segment.borrow_mut(); + previous_segment.as_mut().unwrap().reset( + &self.problem, + &mut *solver, + &self.checkpoints[idx], + &self.checkpoints[idx + 1], + )?; + std::mem::swap(&mut *segment, previous_segment.as_mut().unwrap()); + segment.interpolate(t, y).unwrap(); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use nalgebra::{DMatrix, DVector}; + + use crate::{ + ode_solver::test_models::robertson::robertson, Bdf, BdfState, OdeEquations, + OdeSolverMethod, OdeSolverState, Op, Vector, + }; + + use super::{Checkpointing, HermiteInterpolator}; + + #[test] + fn test_checkpointing() { + let mut solver = Bdf::default(); + let (problem, soln) = robertson::>(false); + let t_final = soln.solution_points.last().unwrap().t; + let n_steps = 30; + let state0: BdfState<_, _> = OdeSolverState::new(&problem, &solver).unwrap(); + solver.set_problem(state0.clone(), &problem).unwrap(); + let mut checkpoints = vec![state0]; + let mut i = 0; + let mut ys = Vec::new(); + let mut ts = Vec::new(); + let mut ydots = Vec::new(); + while solver.state().unwrap().t < t_final { + ts.push(solver.state().unwrap().t); + ys.push(solver.state().unwrap().y.clone()); + ydots.push(solver.state().unwrap().dy.clone()); + solver.step().unwrap(); + i += 1; + if i % n_steps == 0 && solver.state().unwrap().t < t_final { + checkpoints.push(solver.checkpoint().unwrap()); + ts.clear(); + ys.clear(); + ydots.clear(); + } + } + checkpoints.push(solver.checkpoint().unwrap()); + let segment = HermiteInterpolator::new(ys, ydots, ts); + let checkpointer = + Checkpointing::new(solver, checkpoints.len() - 2, checkpoints, Some(segment)); + let mut y = DVector::zeros(problem.eqn.rhs().nstates()); + for point in soln.solution_points.iter().rev() { + checkpointer.interpolate(point.t, &mut y).unwrap(); + y.assert_eq_norm(&point.state, &problem.atol, problem.rtol, 10.0); + } + } +} diff --git a/src/ode_solver/diffsl.rs b/src/ode_solver/diffsl.rs index 9889fb3a..f9e5c96d 100644 --- a/src/ode_solver/diffsl.rs +++ b/src/ode_solver/diffsl.rs @@ -3,11 +3,10 @@ use std::{cell::RefCell, rc::Rc}; use diffsl::{execution::module::CodegenModule, Compiler}; use crate::{ - error::DiffsolError, - jacobian::{find_non_zeros_linear, find_non_zeros_nonlinear, JacobianColoring}, - matrix::sparsity::MatrixSparsity, - op::{LinearOp, NonLinearOp, Op}, - ConstantOp, Matrix, OdeEquations, Vector, + error::DiffsolError, find_jacobian_non_zeros, find_matrix_non_zeros, + jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity, + op::nonlinear_op::NonLinearOpJacobian, ConstantOp, LinearOp, Matrix, NonLinearOp, OdeEquations, + Op, Vector, }; pub type T = f64; @@ -41,7 +40,7 @@ pub type T = f64; /// let t = 0.4; /// let state = OdeSolverState::new(&problem, &solver).unwrap(); /// solver.set_problem(state, &problem); -/// while solver.state().unwrap().t() <= t { +/// while solver.state().unwrap().t <= t { /// solver.step().unwrap(); /// } /// let y = solver.interpolate(t); @@ -188,7 +187,7 @@ impl<'a, M: Matrix, CG: CodegenModule> DiffSlRhs<'a, M, CG> { if use_coloring { let x0 = M::V::zeros(context.nstates); let t0 = 0.0; - let non_zeros = find_non_zeros_nonlinear(&ret, &x0, t0); + let non_zeros = find_jacobian_non_zeros(&ret, &x0, t0); ret.sparsity = Some( MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -212,7 +211,7 @@ impl<'a, M: Matrix, CG: CodegenModule> DiffSlMass<'a, M, CG> { if use_coloring { let t0 = 0.0; - let non_zeros = find_non_zeros_linear(&ret, t0); + let non_zeros = find_matrix_non_zeros(&ret, t0); ret.sparsity = Some( MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -318,7 +317,9 @@ impl, CG: CodegenModule> NonLinearOp for DiffSlRoot<'_, M, CG> y.as_mut_slice(), ); } +} +impl, CG: CodegenModule> NonLinearOpJacobian for DiffSlRoot<'_, M, CG> { fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { y.fill(0.0); } @@ -337,6 +338,9 @@ impl, CG: CodegenModule> NonLinearOp for DiffSlOut<'_, M, CG> { .get_out(self.context.data.borrow().as_slice()); y.copy_from_slice(out); } +} + +impl, CG: CodegenModule> NonLinearOpJacobian for DiffSlOut<'_, M, CG> { fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { self.context.compiler.calc_out_grad( t, @@ -362,7 +366,9 @@ impl, CG: CodegenModule> NonLinearOp for DiffSlRhs<'_, M, CG> { y.as_mut_slice(), ); } +} +impl, CG: CodegenModule> NonLinearOpJacobian for DiffSlRhs<'_, M, CG> { fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { let mut dummy_rhs = Self::V::zeros(self.nstates()); self.context.compiler.rhs_grad( @@ -459,7 +465,8 @@ mod tests { use nalgebra::DVector; use crate::{ - Bdf, ConstantOp, LinearOp, NonLinearOp, OdeBuilder, OdeEquations, OdeSolverMethod, Vector, + Bdf, ConstantOp, LinearOp, NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations, + OdeSolverMethod, OdeSolverState, Vector, }; use super::{DiffSl, DiffSlContext}; @@ -531,7 +538,8 @@ mod tests { let problem = OdeBuilder::new().p([r, k]).build_diffsl(&context).unwrap(); let mut solver = Bdf::default(); let t = 1.0; - let (ys, ts) = solver.solve(&problem, t).unwrap(); + let state = OdeSolverState::new(&problem, &solver).unwrap(); + let (ys, ts) = solver.solve(&problem, state, t).unwrap(); for (i, t) in ts.iter().enumerate() { let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0); let z_expect = 2.0 * y_expect; @@ -541,7 +549,8 @@ mod tests { // do it again with some explicit t_evals let t_evals = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0]; - let ys = solver.solve_dense(&problem, &t_evals).unwrap(); + let state = OdeSolverState::new(&problem, &solver).unwrap(); + let ys = solver.solve_dense(&problem, state, &t_evals).unwrap(); for (i, t) in t_evals.iter().enumerate() { let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0); let z_expect = 2.0 * y_expect; diff --git a/src/ode_solver/equations.rs b/src/ode_solver/equations.rs index 138d16fc..0f16b9f1 100644 --- a/src/ode_solver/equations.rs +++ b/src/ode_solver/equations.rs @@ -1,9 +1,9 @@ use std::rc::Rc; use crate::{ - op::{unit::UnitCallable, ConstantOp}, - scalar::Scalar, - LinearOp, Matrix, NonLinearOp, Vector, + op::{constant_op::ConstantOpSensAdjoint, linear_op::LinearOpTranspose}, + ConstantOp, ConstantOpSens, LinearOp, Matrix, NonLinearOp, NonLinearOpAdjoint, + NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Scalar, UnitCallable, Vector, }; use serde::Serialize; @@ -34,6 +34,105 @@ impl OdeEquationsStatistics { } } +pub trait AugmentedOdeEquations: + OdeEquations +{ + fn update_rhs_out_state(&mut self, y: &Eqn::V, dy: &Eqn::V, t: Eqn::T); + fn update_init_state(&mut self, t: Eqn::T); + fn set_index(&mut self, index: usize); + fn max_index(&self) -> usize; + fn include_in_error_control(&self) -> bool; + fn include_out_in_error_control(&self) -> bool; + fn rtol(&self) -> Option; + fn atol(&self) -> Option<&Rc>; + fn out_rtol(&self) -> Option; + fn out_atol(&self) -> Option<&Rc>; +} + +pub trait AugmentedOdeEquationsImplicit: + AugmentedOdeEquations + OdeEquationsImplicit +{ +} + +impl AugmentedOdeEquationsImplicit for Aug +where + Aug: AugmentedOdeEquations + OdeEquationsImplicit, + Eqn: OdeEquations, +{ +} + +pub struct NoAug { + _phantom: std::marker::PhantomData, +} + +impl OdeEquations for NoAug { + type T = Eqn::T; + type V = Eqn::V; + type M = Eqn::M; + type Mass = Eqn::Mass; + type Rhs = Eqn::Rhs; + type Root = Eqn::Root; + type Init = Eqn::Init; + type Out = Eqn::Out; + + fn set_params(&mut self, _p: Self::V) { + panic!("This should never be called") + } + + fn rhs(&self) -> &Rc { + panic!("This should never be called") + } + + fn mass(&self) -> Option<&Rc> { + panic!("This should never be called") + } + + fn root(&self) -> Option<&Rc> { + panic!("This should never be called") + } + + fn out(&self) -> Option<&Rc> { + panic!("This should never be called") + } + + fn init(&self) -> &Rc { + panic!("This should never be called") + } +} + +impl AugmentedOdeEquations for NoAug { + fn update_rhs_out_state(&mut self, _y: &Eqn::V, _dy: &Eqn::V, _t: Eqn::T) { + panic!("This should never be called") + } + fn update_init_state(&mut self, _t: Eqn::T) { + panic!("This should never be called") + } + fn set_index(&mut self, _index: usize) { + panic!("This should never be called") + } + fn atol(&self) -> Option<&Rc<::V>> { + panic!("This should never be called") + } + fn include_out_in_error_control(&self) -> bool { + panic!("This should never be called") + } + fn out_atol(&self) -> Option<&Rc<::V>> { + panic!("This should never be called") + } + fn out_rtol(&self) -> Option<::T> { + panic!("This should never be called") + } + fn rtol(&self) -> Option<::T> { + panic!("This should never be called") + } + fn max_index(&self) -> usize { + panic!("This should never be called") + } + fn include_in_error_control(&self) -> bool { + panic!("This should never be called") + } +} + /// this is the trait that defines the ODE equations of the form /// /// $$ @@ -83,6 +182,56 @@ pub trait OdeEquations { fn init(&self) -> &Rc; } +pub trait OdeEquationsImplicit: + OdeEquations> +{ +} + +impl OdeEquationsImplicit for T where + T: OdeEquations> +{ +} + +pub trait OdeEquationsSens: + OdeEquationsImplicit< + Rhs: NonLinearOpSens, + Init: ConstantOpSens, +> +{ +} + +impl OdeEquationsSens for T where + T: OdeEquationsImplicit< + Rhs: NonLinearOpSens, + Init: ConstantOpSens, + > +{ +} + +pub trait OdeEquationsAdjoint: + OdeEquationsImplicit< + Rhs: NonLinearOpAdjoint + + NonLinearOpSensAdjoint, + Init: ConstantOpSensAdjoint, + Out: NonLinearOpAdjoint + + NonLinearOpSensAdjoint, + Mass: LinearOpTranspose, +> +{ +} + +impl OdeEquationsAdjoint for T where + T: OdeEquationsImplicit< + Rhs: NonLinearOpAdjoint + + NonLinearOpSensAdjoint, + Init: ConstantOpSensAdjoint, + Out: NonLinearOpAdjoint + + NonLinearOpSensAdjoint, + Mass: LinearOpTranspose, + > +{ +} + /// This struct implements the ODE equation trait [OdeEquations] for a given right-hand side op, mass op, optional root op, and initial condition function. /// /// While the [crate::OdeBuilder] struct is the easiest way to define an ODE problem, @@ -92,13 +241,13 @@ pub trait OdeEquations { /// The main traits that you need to implement are the [crate::Op] and [NonLinearOp] trait, /// which define a nonlinear operator or function `F` that maps an input vector `x` to an output vector `y`, (i.e. `y = F(x)`). /// Once you have implemented this trait, you can then pass an instance of your struct to the `rhs` argument of the [Self::new] method. -/// Once you have created an instance of [OdeSolverEquations], you can then use [crate::OdeSolverProblem::new] to create a problem. +/// Once you have created an instance of [OdeSolverEquations], you can then use [crate::OdeBuilder::build_from_eqn] to create a problem. /// /// For example: /// /// ```rust /// use std::rc::Rc; -/// use diffsol::{Bdf, OdeSolverState, OdeSolverMethod, NonLinearOp, OdeSolverEquations, OdeSolverProblem, Op, UnitCallable, ConstantClosure}; +/// use diffsol::{Bdf, OdeSolverState, OdeSolverMethod, NonLinearOp, NonLinearOpJacobian, OdeSolverEquations, OdeSolverProblem, Op, UnitCallable, ConstantClosure, OdeBuilder}; /// type M = nalgebra::DMatrix; /// type V = nalgebra::DVector; /// @@ -120,6 +269,8 @@ pub trait OdeEquations { /// fn call_inplace(&self, x: &V, _t: f64, y: &mut V) { /// y[0] = -0.1 * x[0]; /// } +/// } +/// impl NonLinearOpJacobian for MyProblem { /// fn jac_mul_inplace(&self, x: &V, _t: f64, v: &V, y: &mut V) { /// y[0] = -0.1 * v[0]; /// } @@ -141,19 +292,13 @@ pub trait OdeEquations { /// let p = Rc::new(V::from_vec(vec![])); /// let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p); /// -/// let rtol = 1e-6; -/// let atol = V::from_vec(vec![1e-6]); -/// let t0 = 0.0; -/// let h0 = 0.1; -/// let with_sensitivity = false; -/// let sensitivity_error_control = false; -/// let problem = OdeSolverProblem::new(eqn, rtol, atol, t0, h0, with_sensitivity, sensitivity_error_control).unwrap(); +/// let problem = OdeBuilder::new().build_from_eqn(eqn).unwrap(); /// /// let mut solver = Bdf::default(); /// let t = 0.4; /// let state = OdeSolverState::new(&problem, &solver).unwrap(); /// solver.set_problem(state, &problem); -/// while solver.state().unwrap().t() <= t { +/// while solver.state().unwrap().t <= t { /// solver.step().unwrap(); /// } /// let y = solver.interpolate(t); @@ -272,8 +417,7 @@ mod tests { use crate::ode_solver::test_models::exponential_decay::exponential_decay_problem; use crate::ode_solver::test_models::exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem; use crate::vector::Vector; - use crate::LinearOp; - use crate::NonLinearOp; + use crate::{LinearOp, NonLinearOp, NonLinearOpJacobian}; type Mcpu = nalgebra::DMatrix; type Vcpu = nalgebra::DVector; diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index 5b0e3c00..a542e3cc 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -1,11 +1,19 @@ -use nalgebra::ComplexField; +use std::cell::RefCell; +use std::rc::Rc; use crate::{ - error::DiffsolError, error::OdeSolverError, matrix::default_solver::DefaultSolver, - ode_solver_error, scalar::Scalar, DefaultDenseMatrix, DenseMatrix, Matrix, MatrixCommon, - NonLinearOp, OdeEquations, OdeSolverProblem, OdeSolverState, Op, VectorViewMut, + error::{DiffsolError, OdeSolverError}, + matrix::default_solver::DefaultSolver, + ode_solver_error, + scalar::Scalar, + AdjointContext, AdjointEquations, AugmentedOdeEquations, Checkpointing, DefaultDenseMatrix, + DenseMatrix, Matrix, NewtonNonlinearSolver, NonLinearOp, OdeEquations, OdeEquationsAdjoint, + OdeEquationsSens, OdeSolverProblem, OdeSolverState, Op, SensEquations, StateRef, StateRefMut, + Vector, VectorViewMut, }; +use super::checkpointing::HermiteInterpolator; + #[derive(Debug, PartialEq)] pub enum OdeSolverStopReason { InternalTimestep, @@ -22,22 +30,25 @@ pub enum OdeSolverStopReason { /// # Example /// /// ``` -/// use diffsol::{ OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeEquations, DefaultSolver }; +/// use diffsol::{ OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeEquationsImplicit, DefaultSolver }; /// /// fn solve_ode(solver: &mut impl OdeSolverMethod, problem: &OdeSolverProblem, t: Eqn::T) -> Eqn::V /// where -/// Eqn: OdeEquations, +/// Eqn: OdeEquationsImplicit, /// Eqn::M: DefaultSolver, /// { /// let state = OdeSolverState::new(problem, solver).unwrap(); /// solver.set_problem(state, problem); -/// while solver.state().unwrap().t() <= t { +/// while solver.state().unwrap().t <= t { /// solver.step().unwrap(); /// } /// solver.interpolate(t).unwrap() /// } /// ``` -pub trait OdeSolverMethod { +pub trait OdeSolverMethod +where + Self: Sized, +{ type State: OdeSolverState; /// Get the current problem if it has been set @@ -63,12 +74,12 @@ pub trait OdeSolverMethod { fn take_state(&mut self) -> Option; /// Get the current state of the solver, if it exists - fn state(&self) -> Option<&Self::State>; + fn state(&self) -> Option>; /// Get a mutable reference to the current state of the solver, if it exists /// Note that calling this will cause the next call to `step` to perform some reinitialisation to take into /// account the mutated state, this could be expensive for multi-step methods. - fn state_mut(&mut self) -> Option<&mut Self::State>; + fn state_mut(&mut self) -> Option>; /// Step the solution forward by one step, altering the internal state of the solver. /// The return value is a `Result` containing the reason for stopping the solver, possible reasons are: @@ -84,19 +95,23 @@ pub trait OdeSolverMethod { /// Interpolate the solution at a given time. This time should be between the current time and the last solver time step fn interpolate(&self, t: Eqn::T) -> Result; + /// Interpolate the integral of the output function at a given time. This time should be between the current time and the last solver time step + fn interpolate_out(&self, t: Eqn::T) -> Result; + /// Interpolate the sensitivity vectors at a given time. This time should be between the current time and the last solver time step fn interpolate_sens(&self, t: Eqn::T) -> Result, DiffsolError>; /// Get the current order of accuracy of the solver (e.g. explict euler method is first-order) fn order(&self) -> usize; - /// Reinitialise the solver state and solve the problem up to time `final_time` + /// Using the provided state, solve the problem up to time `final_time` /// Returns a Vec of solution values at timepoints chosen by the solver. /// After the solver has finished, the internal state of the solver is at time `final_time`. #[allow(clippy::type_complexity)] fn solve( &mut self, problem: &OdeSolverProblem, + state: Self::State, final_time: Eqn::T, ) -> Result<(::M, Vec), DiffsolError> where @@ -104,74 +119,60 @@ pub trait OdeSolverMethod { Eqn::V: DefaultDenseMatrix, Self: Sized, { - let state = OdeSolverState::new(problem, self)?; self.set_problem(state, problem)?; - let mut ret_t = vec![self.state().unwrap().t()]; - let nstates = problem.eqn.rhs().nstates(); - let ntimes_guess = std::cmp::max( - 10, - ((final_time - self.state().unwrap().t()).abs() / self.state().unwrap().h()) - .into() - .ceil() as usize, - ); - let mut ret_y = <::M as Matrix>::zeros(nstates, ntimes_guess); - { - let mut y_i = ret_y.column_mut(0); + let mut ret_t = Vec::new(); + let mut ret_y = Vec::new(); + let mut write_out = |t: Eqn::T, y: &Eqn::V, g: &Eqn::V| { + ret_t.push(t); match problem.eqn.out() { Some(out) => { - y_i.copy_from(&out.call(self.state().unwrap().y(), self.state().unwrap().t())) + if problem.integrate_out { + ret_y.push(g.clone()); + } else { + ret_y.push(out.call(y, t)); + } } - None => y_i.copy_from(self.state().unwrap().y()), + None => ret_y.push(y.clone()), } - } + }; + + // do the main loop + write_out( + self.state().unwrap().t, + self.state().unwrap().y, + self.state().unwrap().g, + ); self.set_stop_time(final_time)?; while self.step()? != OdeSolverStopReason::TstopReached { - ret_t.push(self.state().unwrap().t()); - let mut y_i = { - let max_i = ret_y.ncols(); - let curr_i = ret_t.len() - 1; - if curr_i >= max_i { - ret_y = - <::M as Matrix>::zeros(nstates, max_i * 2); - } - ret_y.column_mut(curr_i) - }; - match problem.eqn.out() { - Some(out) => { - y_i.copy_from(&out.call(self.state().unwrap().y(), self.state().unwrap().t())) - } - None => y_i.copy_from(self.state().unwrap().y()), - } + write_out( + self.state().unwrap().t, + self.state().unwrap().y, + self.state().unwrap().g, + ); } // store the final step - ret_t.push(self.state().unwrap().t()); - { - let mut y_i = { - let max_i = ret_y.ncols(); - let curr_i = ret_t.len() - 1; - if curr_i >= max_i { - ret_y = - <::M as Matrix>::zeros(nstates, max_i + 1); - } - ret_y.column_mut(curr_i) - }; - match problem.eqn.out() { - Some(out) => { - y_i.copy_from(&out.call(self.state().unwrap().y(), self.state().unwrap().t())) - } - None => y_i.copy_from(self.state().unwrap().y()), - } + write_out( + self.state().unwrap().t, + self.state().unwrap().y, + self.state().unwrap().g, + ); + let ntimes = ret_t.len(); + let nrows = ret_y[0].len(); + let mut ret_y_matrix = <::M as Matrix>::zeros(nrows, ntimes); + for (i, y) in ret_y.iter().enumerate() { + ret_y_matrix.column_mut(i).copy_from(y); } - Ok((ret_y, ret_t)) + Ok((ret_y_matrix, ret_t)) } - /// Reinitialise the solver state and solve the problem up to time `t_eval[t_eval.len()-1]` + /// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]` /// Returns a Vec of solution values at timepoints given by `t_eval`. /// After the solver has finished, the internal state of the solver is at time `t_eval[t_eval.len()-1]`. fn solve_dense( &mut self, problem: &OdeSolverProblem, + state: Self::State, t_eval: &[Eqn::T], ) -> Result<::M, DiffsolError> where @@ -179,29 +180,45 @@ pub trait OdeSolverMethod { Eqn::V: DefaultDenseMatrix, Self: Sized, { - let state = OdeSolverState::new(problem, self)?; self.set_problem(state, problem)?; - let nstates = problem.eqn.rhs().nstates(); - let mut ret = <::M as Matrix>::zeros(nstates, t_eval.len()); + let nrows = if problem.eqn.out().is_some() { + problem.eqn.out().unwrap().nout() + } else { + problem.eqn.rhs().nstates() + }; + let mut ret = <::M as Matrix>::zeros(nrows, t_eval.len()); // check t_eval is increasing and all values are greater than or equal to the current time - let t0 = self.state().unwrap().t(); + let t0 = self.state().unwrap().t; if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) { return Err(ode_solver_error!(InvalidTEval)); } + let mut write_out = |i: usize, y: &Eqn::V, g: Option<&Eqn::V>| { + let mut y_out = ret.column_mut(i); + if let Some(g) = g { + y_out.copy_from(g); + } else { + match problem.eqn.out() { + Some(out) => y_out.copy_from(&out.call(y, t_eval[i])), + None => y_out.copy_from(y), + } + } + }; + // do loop self.set_stop_time(t_eval[t_eval.len() - 1])?; let mut step_reason = OdeSolverStopReason::InternalTimestep; for (i, t) in t_eval.iter().take(t_eval.len() - 1).enumerate() { - while self.state().unwrap().t() < *t { + while self.state().unwrap().t < *t { step_reason = self.step()?; } let y = self.interpolate(*t)?; - let mut y_out = ret.column_mut(i); - match problem.eqn.out() { - Some(out) => y_out.copy_from(&out.call(&y, *t)), - None => y_out.copy_from(&y), + if problem.integrate_out { + let g = self.interpolate_out(*t)?; + write_out(i, &y, Some(&g)); + } else { + write_out(i, &y, None); } } @@ -209,15 +226,188 @@ pub trait OdeSolverMethod { while step_reason != OdeSolverStopReason::TstopReached { step_reason = self.step()?; } - { - let mut y_out = ret.column_mut(t_eval.len() - 1); - match problem.eqn.out() { - Some(out) => { - y_out.copy_from(&out.call(self.state().unwrap().y(), self.state().unwrap().t())) - } - None => y_out.copy_from(self.state().unwrap().y()), - } + if problem.integrate_out { + write_out( + t_eval.len() - 1, + self.state().unwrap().y, + Some(self.state().unwrap().g), + ); + } else { + write_out(t_eval.len() - 1, self.state().unwrap().y, None); } Ok(ret) } } + +pub trait AugmentedOdeSolverMethod: OdeSolverMethod +where + Eqn: OdeEquations, + AugmentedEqn: AugmentedOdeEquations, +{ + fn set_augmented_problem( + &mut self, + state: Self::State, + ode_problem: &OdeSolverProblem, + augmented_eqn: AugmentedEqn, + ) -> Result<(), DiffsolError>; +} + +pub trait SensitivitiesOdeSolverMethod: + AugmentedOdeSolverMethod> +where + Eqn: OdeEquationsSens, +{ + fn set_problem_with_sensitivities( + &mut self, + state: Self::State, + problem: &OdeSolverProblem, + ) -> Result<(), DiffsolError> { + let augmented_eqn = SensEquations::new(problem); + self.set_augmented_problem(state, problem, augmented_eqn) + } +} + +pub trait AdjointOdeSolverMethod: OdeSolverMethod +where + Eqn: OdeEquationsAdjoint, +{ + type AdjointSolver: AugmentedOdeSolverMethod< + AdjointEquations, + AdjointEquations, + State = Self::State, + >; + + fn new_adjoint_solver(&self) -> Self::AdjointSolver; + + fn into_adjoint_solver( + self, + checkpoints: Vec, + last_segment: HermiteInterpolator, + ) -> Result + where + Eqn::M: DefaultSolver, + { + // create the adjoint solver + let mut adjoint_solver = self.new_adjoint_solver(); + + let problem = self + .problem() + .ok_or(ode_solver_error!(ProblemNotSet))? + .clone(); + let t = self.state().unwrap().t; + let h = self.state().unwrap().h; + + // construct checkpointing + let checkpointer = + Checkpointing::new(self, checkpoints.len() - 2, checkpoints, Some(last_segment)); + + // construct adjoint equations and problem + let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer))); + let new_eqn = AdjointEquations::new(&problem, context.clone(), false); + let mut new_augmented_eqn = AdjointEquations::new(&problem, context, true); + let adj_problem = OdeSolverProblem { + eqn: Rc::new(new_eqn), + rtol: problem.rtol, + atol: problem.atol, + t0: t, + h0: -h, + integrate_out: false, + sens_rtol: None, + sens_atol: None, + out_rtol: None, + out_atol: None, + param_rtol: None, + param_atol: None, + }; + + // initialise adjoint state + let mut state = + Self::State::new_without_initialise_augmented(&adj_problem, &mut new_augmented_eqn)?; + let mut init_nls = + NewtonNonlinearSolver::::LS>::default(); + let new_augmented_eqn = + state.set_consistent_augmented(&adj_problem, new_augmented_eqn, &mut init_nls)?; + + // set the adjoint problem + adjoint_solver.set_augmented_problem(state, &adj_problem, new_augmented_eqn)?; + Ok(adjoint_solver) + } +} + +#[cfg(test)] +mod test { + use crate::{ + ode_solver::test_models::exponential_decay::exponential_decay_problem, + ode_solver::test_models::exponential_decay::exponential_decay_problem_adjoint, scale, Bdf, + OdeSolverMethod, OdeSolverState, Vector, + }; + + #[test] + fn test_solve() { + let mut s = Bdf::default(); + let (problem, _soln) = exponential_decay_problem::>(false); + + let k = 0.1; + let y0 = nalgebra::DVector::from_vec(vec![1.0, 1.0]); + let expect = |t: f64| &y0 * scale(f64::exp(-k * t)); + let state = OdeSolverState::new(&problem, &s).unwrap(); + let (y, t) = s.solve(&problem, state, 10.0).unwrap(); + assert!((t[0] - 0.0).abs() < 1e-10); + assert!((t[t.len() - 1] - 10.0).abs() < 1e-10); + for (i, t_i) in t.iter().enumerate() { + let y_i = y.column(i).into_owned(); + y_i.assert_eq_norm(&expect(*t_i), problem.atol.as_ref(), problem.rtol, 15.0); + } + } + + #[test] + fn test_solve_integrate_out() { + let mut s = Bdf::default(); + let (problem, _soln) = exponential_decay_problem_adjoint::>(); + + let k = 0.1; + let y0 = nalgebra::DVector::from_vec(vec![1.0, 1.0]); + let t0 = 0.0; + let expect = |t: f64| { + let g = &y0 * scale((f64::exp(-k * t0) - f64::exp(-k * t)) / k); + nalgebra::DVector::::from_vec(vec![ + 1.0 * g[0] + 2.0 * g[1], + 3.0 * g[0] + 4.0 * g[1], + ]) + }; + let state = OdeSolverState::new(&problem, &s).unwrap(); + let (y, t) = s.solve(&problem, state, 10.0).unwrap(); + for (i, t_i) in t.iter().enumerate() { + let y_i = y.column(i).into_owned(); + y_i.assert_eq_norm(&expect(*t_i), problem.atol.as_ref(), problem.rtol, 15.0); + } + } + + #[test] + fn test_dense_solve() { + let mut s = Bdf::default(); + let (problem, soln) = exponential_decay_problem::>(false); + + let state = OdeSolverState::new(&problem, &s).unwrap(); + let t_eval = soln.solution_points.iter().map(|p| p.t).collect::>(); + let y = s.solve_dense(&problem, state, t_eval.as_slice()).unwrap(); + for (i, soln_pt) in soln.solution_points.iter().enumerate() { + let y_i = y.column(i).into_owned(); + y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0); + } + } + + #[test] + fn test_dense_solve_integrate_out() { + let mut s = Bdf::default(); + let (problem, soln) = exponential_decay_problem_adjoint::>(); + + let state = OdeSolverState::new(&problem, &s).unwrap(); + let t_eval = soln.solution_points.iter().map(|p| p.t).collect::>(); + let y = s.solve_dense(&problem, state, t_eval.as_slice()).unwrap(); + for (i, soln_pt) in soln.solution_points.iter().enumerate() { + let y_i = y.column(i).into_owned(); + y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0); + } + } +} diff --git a/src/ode_solver/mod.rs b/src/ode_solver/mod.rs index 0ba8f0b7..7e8517c9 100644 --- a/src/ode_solver/mod.rs +++ b/src/ode_solver/mod.rs @@ -1,6 +1,8 @@ +pub mod adjoint_equations; pub mod bdf; pub mod bdf_state; pub mod builder; +pub mod checkpointing; pub mod equations; pub mod jacobian_update; pub mod method; @@ -23,33 +25,43 @@ mod tests { use std::rc::Rc; use self::problem::OdeSolverSolution; + use checkpointing::HermiteInterpolator; + use method::{AdjointOdeSolverMethod, SensitivitiesOdeSolverMethod}; use nalgebra::ComplexField; use super::*; use crate::matrix::Matrix; use crate::op::unit::UnitCallable; - use crate::op::{NonLinearOp, Op}; - use crate::{ConstantOp, DefaultDenseMatrix, DefaultSolver, Vector}; + use crate::{ConstantOp, DefaultDenseMatrix, DefaultSolver, NonLinearOp, Op, Vector}; use crate::{ - OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, + NonLinearOpJacobian, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, + OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, }; use num_traits::One; use num_traits::Zero; pub fn test_ode_solver( - method: &mut impl OdeSolverMethod, + method: &mut impl SensitivitiesOdeSolverMethod, problem: &OdeSolverProblem, solution: OdeSolverSolution, override_tol: Option, use_tstop: bool, + solve_for_sensitivities: bool, ) -> Eqn::V where M: Matrix, - Eqn: OdeEquations, + Eqn: OdeEquationsSens, Eqn::M: DefaultSolver, { - let state = OdeSolverState::new(problem, method).unwrap(); - method.set_problem(state, problem).unwrap(); + if solve_for_sensitivities { + let state = OdeSolverState::new_with_sensitivities(problem, method).unwrap(); + method + .set_problem_with_sensitivities(state, problem) + .unwrap(); + } else { + let state = OdeSolverState::new(problem, method).unwrap(); + method.set_problem(state, problem).unwrap(); + } let have_root = problem.eqn.as_ref().root().is_some(); for (i, point) in solution.solution_points.iter().enumerate() { let (soln, sens_soln) = if use_tstop { @@ -58,24 +70,24 @@ mod tests { match method.step() { Ok(OdeSolverStopReason::RootFound(_)) => { assert!(have_root); - return method.state().unwrap().y().clone(); + return method.state().unwrap().y.clone(); } Ok(OdeSolverStopReason::TstopReached) => { break ( - method.state().unwrap().y().clone(), - method.state().unwrap().s().to_vec(), + method.state().unwrap().y.clone(), + method.state().unwrap().s.to_vec(), ); } _ => (), } }, Err(_) => ( - method.state().unwrap().y().clone(), - method.state().unwrap().s().to_vec(), + method.state().unwrap().y.clone(), + method.state().unwrap().s.to_vec(), ), } } else { - while method.state().unwrap().t().abs() < point.t.abs() { + while method.state().unwrap().t.abs() < point.t.abs() { if let OdeSolverStopReason::RootFound(t) = method.step().unwrap() { assert!(have_root); return method.interpolate(t).unwrap(); @@ -114,23 +126,204 @@ mod tests { soln, point.state ); - if let Some(sens_soln_points) = &solution.sens_solution_points { - for (j, sens_points) in sens_soln_points.iter().enumerate() { - let sens_point = &sens_points[i]; - let sens_soln = &sens_soln[j]; - let error = sens_soln.clone() - &sens_point.state; - let error_norm = error.squared_norm(&sens_point.state, atol, rtol).sqrt(); - assert!( - error_norm < M::T::from(24.0), - "error_norm: {} at t = {}", - error_norm, - point.t - ); + if solve_for_sensitivities { + if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() { + for (j, sens_points) in sens_soln_points.iter().enumerate() { + let sens_point = &sens_points[i]; + let sens_soln = &sens_soln[j]; + let error = sens_soln.clone() - &sens_point.state; + let error_norm = + error.squared_norm(&sens_point.state, atol, rtol).sqrt(); + assert!( + error_norm < M::T::from(29.0), + "error_norm: {} at t = {}", + error_norm, + point.t + ); + } } } } } - method.state().unwrap().y().clone() + method.state().unwrap().y.clone() + } + + pub fn test_ode_solver_no_sens( + method: &mut impl OdeSolverMethod, + problem: &OdeSolverProblem, + solution: OdeSolverSolution, + override_tol: Option, + use_tstop: bool, + ) -> Eqn::V + where + M: Matrix, + Eqn: OdeEquationsImplicit, + Eqn::M: DefaultSolver, + { + let state = OdeSolverState::new(problem, method).unwrap(); + method.set_problem(state, problem).unwrap(); + let have_root = problem.eqn.as_ref().root().is_some(); + for point in solution.solution_points.iter() { + let soln = if use_tstop { + match method.set_stop_time(point.t) { + Ok(_) => loop { + match method.step() { + Ok(OdeSolverStopReason::RootFound(_)) => { + assert!(have_root); + return method.state().unwrap().y.clone(); + } + Ok(OdeSolverStopReason::TstopReached) => { + break method.state().unwrap().y.clone(); + } + _ => (), + } + }, + Err(_) => method.state().unwrap().y.clone(), + } + } else { + while method.state().unwrap().t.abs() < point.t.abs() { + if let OdeSolverStopReason::RootFound(t) = method.step().unwrap() { + assert!(have_root); + return method.interpolate(t).unwrap(); + } + } + method.interpolate(point.t).unwrap() + }; + let soln = if let Some(out) = problem.eqn.out() { + out.call(&soln, point.t) + } else { + soln + }; + assert_eq!( + soln.len(), + point.state.len(), + "soln.len() != point.state.len()" + ); + if let Some(override_tol) = override_tol { + soln.assert_eq_st(&point.state, override_tol); + } else { + let (rtol, atol) = if problem.eqn.out().is_some() { + // problem rtol and atol is on the state, so just use solution tolerance here + (solution.rtol, &solution.atol) + } else { + (problem.rtol, problem.atol.as_ref()) + }; + let error = soln.clone() - &point.state; + let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt(); + assert!( + error_norm < M::T::from(15.0), + "error_norm: {} at t = {}. soln: {:?}, expected: {:?}", + error_norm, + point.t, + soln, + point.state + ); + } + } + method.state().unwrap().y.clone() + } + + pub fn test_ode_solver_adjoint( + mut method: Method, + problem: &OdeSolverProblem, + solution: OdeSolverSolution, + ) -> Method::AdjointSolver + where + M: Matrix, + Method: AdjointOdeSolverMethod, + Eqn: OdeEquationsAdjoint, + Eqn::M: DefaultSolver, + { + let state = OdeSolverState::new(problem, &method).unwrap(); + method.set_problem(state, problem).unwrap(); + let t0 = solution.solution_points.first().unwrap().t; + let t1 = solution.solution_points.last().unwrap().t; + method.set_stop_time(t1).unwrap(); + let mut nsteps = 0; + let (rtol, atol) = (solution.rtol, &solution.atol); + let mut checkpoints = vec![method.checkpoint().unwrap()]; + let mut ts = Vec::new(); + let mut ys = Vec::new(); + let mut ydots = Vec::new(); + for point in solution.solution_points.iter() { + while method.state().unwrap().t.abs() < point.t.abs() { + ts.push(method.state().unwrap().t); + ys.push(method.state().unwrap().y.clone()); + ydots.push(method.state().unwrap().dy.clone()); + method.step().unwrap(); + nsteps += 1; + if nsteps > 50 && method.state().unwrap().t.abs() < t1.abs() { + checkpoints.push(method.checkpoint().unwrap()); + nsteps = 0; + ts.clear(); + ys.clear(); + ydots.clear(); + } + } + let soln = method.interpolate_out(point.t).unwrap(); + // problem rtol and atol is on the state, so just use solution tolerance here + let error = soln.clone() - &point.state; + let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt(); + assert!( + error_norm < M::T::from(15.0), + "error_norm: {} at t = {}. soln: {:?}, expected: {:?}", + error_norm, + point.t, + soln, + point.state + ); + } + ts.push(method.state().unwrap().t); + ys.push(method.state().unwrap().y.clone()); + ydots.push(method.state().unwrap().dy.clone()); + checkpoints.push(method.checkpoint().unwrap()); + let last_segment = HermiteInterpolator::new(ys, ydots, ts); + let mut adjoint_solver = method + .into_adjoint_solver(checkpoints, last_segment) + .unwrap(); + let y_expect = M::V::from_element(problem.eqn.rhs().nstates(), M::T::zero()); + adjoint_solver + .state() + .unwrap() + .y + .assert_eq_st(&y_expect, M::T::from(1e-9)); + let g_expect = M::V::from_element(problem.eqn.rhs().nparams(), M::T::zero()); + for i in 0..problem.eqn.out().unwrap().nout() { + adjoint_solver.state().unwrap().sg[i].assert_eq_st(&g_expect, M::T::from(1e-9)); + } + + adjoint_solver.set_stop_time(t0).unwrap(); + while adjoint_solver.state().unwrap().t.abs() > t0 { + adjoint_solver.step().unwrap(); + } + let mut state = adjoint_solver.take_state().unwrap(); + let state_mut = state.as_mut(); + adjoint_solver + .problem() + .unwrap() + .eqn + .correct_sg_for_init(t0, state_mut.s, state_mut.sg); + + let points = solution + .sens_solution_points + .as_ref() + .unwrap() + .iter() + .map(|x| &x[0]) + .collect::>(); + for (soln, point) in state_mut.sg.iter().zip(points.iter()) { + let error = soln.clone() - &point.state; + let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt(); + assert!( + error_norm < M::T::from(15.0), + "error_norm: {} at t = {}. soln: {:?}, expected: {:?}", + error_norm, + point.t, + soln, + point.state + ); + } + adjoint_solver } pub struct TestEqnInit { @@ -183,7 +376,9 @@ mod tests { fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) { y[0] = M::T::zero(); } + } + impl NonLinearOpJacobian for TestEqnRhs { fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { y[0] = M::T::zero(); } @@ -245,23 +440,28 @@ mod tests { TestEqn::new(), M::T::from(1e-6), M::V::from_element(1, M::T::from(1e-6)), + None, + None, + None, + None, + None, + None, M::T::zero(), M::T::one(), false, - false, ) .unwrap(); - let state = Method::State::new_without_initialise(&problem); + let state = Method::State::new_without_initialise(&problem).unwrap(); s.set_problem(state.clone(), &problem).unwrap(); let t0 = M::T::zero(); let t1 = M::T::one(); s.interpolate(t0) .unwrap() - .assert_eq_st(state.y(), M::T::from(1e-9)); + .assert_eq_st(state.as_ref().y, M::T::from(1e-9)); assert!(s.interpolate(t1).is_err()); s.step().unwrap(); - assert!(s.interpolate(s.state().unwrap().t()).is_ok()); - assert!(s.interpolate(s.state().unwrap().t() + t1).is_err()); + assert!(s.interpolate(s.state().unwrap().t).is_ok()); + assert!(s.interpolate(s.state().unwrap().t + t1).is_err()); } pub fn test_no_set_problem>>(mut s: Method) { @@ -277,19 +477,24 @@ mod tests { TestEqn::new(), M::T::from(1e-6), M::V::from_element(1, M::T::from(1e-6)), + None, + None, + None, + None, + None, + None, M::T::zero(), M::T::one(), false, - false, ) .unwrap(); - let state = Method::State::new_without_initialise(&problem); + let state = Method::State::new_without_initialise(&problem).unwrap(); s.set_problem(state.clone(), &problem).unwrap(); let state2 = s.state().unwrap(); - state2.y().assert_eq_st(state.y(), M::T::from(1e-9)); - s.state_mut().unwrap().y_mut()[0] = M::T::from(std::f64::consts::PI); + state2.y.assert_eq_st(state.as_ref().y, M::T::from(1e-9)); + s.state_mut().unwrap().y[0] = M::T::from(std::f64::consts::PI); assert_eq!( - s.state_mut().unwrap().y_mut()[0], + s.state_mut().unwrap().y[0], M::T::from(std::f64::consts::PI) ); } @@ -302,13 +507,13 @@ mod tests { ) where M: Matrix + DefaultSolver, Method: OdeSolverMethod, - Problem: OdeEquations, + Problem: OdeEquationsImplicit, { let state = OdeSolverState::new(&problem, &solver1).unwrap(); solver1.set_problem(state, &problem).unwrap(); let half_i = soln.solution_points.len() / 2; let half_t = soln.solution_points[half_i].t; - while solver1.state().unwrap().t() <= half_t { + while solver1.state().unwrap().t <= half_t { solver1.step().unwrap(); } let checkpoint = solver1.checkpoint().unwrap(); @@ -316,20 +521,19 @@ mod tests { // carry on solving with both solvers, they should produce about the same results (probably might diverge a bit, but should always match the solution) for point in soln.solution_points.iter().skip(half_i + 1) { - while solver2.state().unwrap().t() < point.t { + while solver2.state().unwrap().t < point.t { solver1.step().unwrap(); solver2.step().unwrap(); - let time_error = (solver1.state().unwrap().t() - solver2.state().unwrap().t()) - .abs() - / (solver1.state().unwrap().t().abs() * problem.rtol + problem.atol[0]); + let time_error = (solver1.state().unwrap().t - solver2.state().unwrap().t).abs() + / (solver1.state().unwrap().t.abs() * problem.rtol + problem.atol[0]); assert!( time_error < M::T::from(20.0), "time_error: {} at t = {}", time_error, - solver1.state().unwrap().t() + solver1.state().unwrap().t ); - solver1.state().unwrap().y().assert_eq_norm( - solver2.state().unwrap().y(), + solver1.state().unwrap().y.assert_eq_norm( + solver2.state().unwrap().y, &problem.atol, problem.rtol, M::T::from(20.0), @@ -347,22 +551,23 @@ mod tests { problem: OdeSolverProblem, soln: OdeSolverSolution, ) where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, Method: OdeSolverMethod, Eqn::M: DefaultSolver, Eqn::V: DefaultDenseMatrix, { // solve for a little bit - s.solve(&problem, Eqn::T::from(1.0)).unwrap(); + let state = OdeSolverState::new(&problem, &s).unwrap(); + s.solve(&problem, state, Eqn::T::from(1.0)).unwrap(); // reinit using state_mut - let state = Method::State::new_without_initialise(&problem); - s.state_mut().unwrap().y_mut().copy_from(state.y()); - *s.state_mut().unwrap().t_mut() = state.t(); + let state = Method::State::new_without_initialise(&problem).unwrap(); + s.state_mut().unwrap().y.copy_from(state.as_ref().y); + *s.state_mut().unwrap().t = state.as_ref().t; // solve and check against solution for point in soln.solution_points.iter() { - while s.state().unwrap().t() < point.t { + while s.state().unwrap().t < point.t { s.step().unwrap(); } let soln = s.interpolate(point.t).unwrap(); diff --git a/src/ode_solver/problem.rs b/src/ode_solver/problem.rs index fd3bdaac..40fa74dd 100644 --- a/src/ode_solver/problem.rs +++ b/src/ode_solver/problem.rs @@ -4,7 +4,7 @@ use crate::{ error::{DiffsolError, OdeSolverError}, ode_solver_error, vector::Vector, - ConstantOp, LinearOp, NonLinearOp, OdeEquations, SensEquations, + OdeEquations, }; pub struct OdeSolverProblem { @@ -13,8 +13,13 @@ pub struct OdeSolverProblem { pub atol: Rc, pub t0: Eqn::T, pub h0: Eqn::T, - pub eqn_sens: Option>>, - pub sens_error_control: bool, + pub integrate_out: bool, + pub sens_rtol: Option, + pub sens_atol: Option>, + pub out_rtol: Option, + pub out_atol: Option>, + pub param_rtol: Option, + pub param_atol: Option>, } // impl clone @@ -26,8 +31,13 @@ impl Clone for OdeSolverProblem { atol: self.atol.clone(), t0: self.t0, h0: self.h0, - eqn_sens: self.eqn_sens.clone(), - sens_error_control: self.sens_error_control, + integrate_out: self.integrate_out, + out_atol: self.out_atol.clone(), + out_rtol: self.out_rtol, + param_atol: self.param_atol.clone(), + param_rtol: self.param_rtol, + sens_atol: self.sens_atol.clone(), + sens_rtol: self.sens_rtol, } } } @@ -39,39 +49,45 @@ impl OdeSolverProblem { pub fn default_atol(nstates: usize) -> Eqn::V { Eqn::V::from_element(nstates, Eqn::T::from(1e-6)) } - pub fn new( + pub fn output_in_error_control(&self) -> bool { + self.integrate_out + && self.eqn.out().is_some() + && self.out_rtol.is_some() + && self.out_atol.is_some() + } + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( eqn: Eqn, rtol: Eqn::T, atol: Eqn::V, + sens_rtol: Option, + sens_atol: Option, + out_rtol: Option, + out_atol: Option, + param_rtol: Option, + param_atol: Option, t0: Eqn::T, h0: Eqn::T, - with_sensitivity: bool, - sens_error_control: bool, + integrate_out: bool, ) -> Result { let eqn = Rc::new(eqn); let atol = Rc::new(atol); - let mass_has_sens = if let Some(mass) = eqn.mass() { - mass.has_sens() - } else { - true - }; - let eqn_has_sens = eqn.rhs().has_sens() && eqn.init().has_sens() && mass_has_sens; - if with_sensitivity && !eqn_has_sens { - return Err(ode_solver_error!(SensitivityNotSupported)); - } - let eqn_sens = if with_sensitivity { - Some(Rc::new(SensEquations::new(&eqn))) - } else { - None - }; + let out_atol = out_atol.map(Rc::new); + let param_atol = param_atol.map(Rc::new); + let sens_atol = sens_atol.map(Rc::new); Ok(Self { eqn, rtol, atol, + out_atol, + out_rtol, + param_atol, + param_rtol, + sens_atol, + sens_rtol, t0, h0, - eqn_sens, - sens_error_control, + integrate_out, }) } @@ -107,11 +123,10 @@ impl OdeSolverSolution { } fn get_index(&self, t: V::T) -> usize { if self.negative_time { - return self - .solution_points + self.solution_points .iter() .position(|x| x.t < t) - .unwrap_or(self.solution_points.len()); + .unwrap_or(self.solution_points.len()) } else { self.solution_points .iter() diff --git a/src/ode_solver/sdirk.rs b/src/ode_solver/sdirk.rs index 31462742..c666fabe 100644 --- a/src/ode_solver/sdirk.rs +++ b/src/ode_solver/sdirk.rs @@ -8,24 +8,47 @@ use std::rc::Rc; use crate::error::DiffsolError; use crate::error::OdeSolverError; use crate::matrix::MatrixRef; -use crate::nonlinear_solver::newton::newton_iteration; use crate::ode_solver_error; use crate::vector::VectorRef; +use crate::AdjointEquations; +use crate::DefaultDenseMatrix; +use crate::DefaultSolver; use crate::LinearSolver; use crate::NewtonNonlinearSolver; +use crate::NoAug; use crate::OdeSolverStopReason; use crate::RootFinder; use crate::SdirkState; use crate::SensEquations; use crate::Tableau; use crate::{ - nonlinear_solver::NonLinearSolver, op::sdirk::SdirkCallable, scale, solver::SolverProblem, - DenseMatrix, JacobianUpdate, NonLinearOp, OdeEquations, OdeSolverMethod, OdeSolverProblem, - OdeSolverState, Op, Scalar, Vector, VectorViewMut, + nonlinear_solver::NonLinearSolver, op::sdirk::SdirkCallable, scale, AdjointOdeSolverMethod, + AugmentedOdeEquations, DenseMatrix, JacobianUpdate, NonLinearOp, OdeEquations, + OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, + OdeSolverState, Op, Scalar, StateRef, StateRefMut, Vector, VectorViewMut, }; use super::bdf::BdfStatistics; use super::jacobian_update::SolverState; +use super::method::AugmentedOdeSolverMethod; +use super::method::SensitivitiesOdeSolverMethod; + +// make a few convenience type aliases +pub type SdirkAdj = Sdirk< + M, + AdjointEquations>, + LS, + AdjointEquations>, +>; +impl SensitivitiesOdeSolverMethod for Sdirk> +where + M: DenseMatrix, + LS: LinearSolver, + Eqn: OdeEquationsSens, + for<'a> &'a Eqn::V: VectorRef, + for<'a> &'a Eqn::M: MatrixRef, +{ +} /// A singly diagonally implicit Runge-Kutta method. Can optionally have an explicit first stage for ESDIRK methods. /// @@ -36,23 +59,28 @@ use super::jacobian_update::SolverState; /// - The upper triangular part of the `a` matrix must be zero (i.e. not fully implicit). /// - The diagonal of the `a` matrix must be the same non-zero value for all rows (i.e. an SDIRK method), except for the first row which can be zero for ESDIRK methods. /// - The last row of the `a` matrix must be the same as the `b` vector, and the last element of the `c` vector must be 1 (i.e. a stiffly accurate method) -pub struct Sdirk +pub struct Sdirk> where M: DenseMatrix, - LS: LinearSolver>, - Eqn: OdeEquations, + LS: LinearSolver, + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquations, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { tableau: Tableau, problem: Option>, - nonlinear_solver: NewtonNonlinearSolver, LS>, + nonlinear_solver: NewtonNonlinearSolver, + op: Option>, state: Option>, diff: M, sdiff: Vec, + sgdiff: Vec, + gdiff: M, + old_g: Eqn::V, gamma: Eqn::T, is_sdirk: bool, - s_op: Option>>, + s_op: Option>, old_t: Eqn::T, old_y: Eqn::V, old_y_sens: Vec, @@ -66,11 +94,79 @@ where jacobian_update: JacobianUpdate, } -impl Sdirk +impl Sdirk<::M, Eqn, ::LS, NoAug> +where + Eqn: OdeEquationsImplicit, + Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, + for<'a> &'a Eqn::V: VectorRef, + for<'a> &'a Eqn::M: MatrixRef, +{ + pub fn tr_bdf2() -> Self { + let tableau = Tableau::<::M>::tr_bdf2(); + let linear_solver = Eqn::M::default_solver(); + Self::new(tableau, linear_solver) + } + pub fn esdirk34() -> Self { + let tableau = Tableau::<::M>::esdirk34(); + let linear_solver = Eqn::M::default_solver(); + Self::new(tableau, linear_solver) + } +} + +impl + Sdirk<::M, Eqn, ::LS, SensEquations> +where + Eqn: OdeEquationsSens, + Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, + for<'a> &'a Eqn::V: VectorRef, + for<'a> &'a Eqn::M: MatrixRef, +{ + pub fn tr_bdf2_with_sensitivities() -> Self { + let tableau = Tableau::<::M>::tr_bdf2(); + let linear_solver = Eqn::M::default_solver(); + Self::new_common(tableau, linear_solver) + } + pub fn esdirk34_with_sensitivities() -> Self { + let tableau = Tableau::<::M>::esdirk34(); + let linear_solver = Eqn::M::default_solver(); + Self::new_common(tableau, linear_solver) + } +} + +impl Sdirk> where - LS: LinearSolver>, + LS: LinearSolver, M: DenseMatrix, - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, + for<'a> &'a Eqn::V: VectorRef, + for<'a> &'a Eqn::M: MatrixRef, +{ + pub fn new(tableau: Tableau, linear_solver: LS) -> Self { + Self::new_common(tableau, linear_solver) + } +} + +impl Sdirk> +where + LS: LinearSolver, + M: DenseMatrix, + Eqn: OdeEquationsSens, + for<'a> &'a Eqn::V: VectorRef, + for<'a> &'a Eqn::M: MatrixRef, +{ + pub fn new_with_sensitivities(tableau: Tableau, linear_solver: LS) -> Self { + Self::new_common(tableau, linear_solver) + } +} + +impl Sdirk +where + LS: LinearSolver, + M: DenseMatrix, + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquations, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { @@ -79,7 +175,7 @@ where const MAX_FACTOR: f64 = 10.0; const MIN_TIMESTEP: f64 = 1e-13; - pub fn new(tableau: Tableau, linear_solver: LS) -> Self { + fn new_common(tableau: Tableau, linear_solver: LS) -> Self { let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); // check that the upper triangular part of a is zero @@ -149,22 +245,29 @@ where let n = 1; let old_t = Eqn::T::zero(); let old_y = ::zeros(n); + let old_g = ::zeros(n); let old_f = ::zeros(n); let statistics = BdfStatistics::default(); let old_f_sens = Vec::new(); let old_y_sens = Vec::new(); let diff = M::zeros(n, s); let sdiff = Vec::new(); + let sgdiff = Vec::new(); + let gdiff = M::zeros(n, s); Self { old_y_sens, old_f_sens, + old_g, diff, sdiff, + sgdiff, tableau, nonlinear_solver, + op: None, state: None, problem: None, s_op: None, + gdiff, gamma, is_sdirk, old_t, @@ -194,7 +297,9 @@ where if abs(state.t - tstop) <= troundoff { self.tstop = None; return Ok(Some(OdeSolverStopReason::TstopReached)); - } else if tstop < state.t - troundoff { + } else if (state.h > M::T::zero() && tstop < state.t - troundoff) + || (state.h < M::T::zero() && tstop > state.t + troundoff) + { return Err(DiffsolError::from( OdeSolverError::StopTimeBeforeCurrentTime { stop_time: tstop.into(), @@ -204,10 +309,12 @@ where } // check if the next step will be beyond tstop, if so adjust the step size - if state.t + state.h > tstop + troundoff { + if (state.h > M::T::zero() && state.t + state.h > tstop + troundoff) + || (state.h < M::T::zero() && state.t + state.h < tstop - troundoff) + { let factor = (tstop - state.t) / state.h; state.h *= factor; - self.nonlinear_solver.problem().f.set_h(state.h); + self.op.as_mut().unwrap().set_h(state.h); } Ok(None) } @@ -227,43 +334,40 @@ where } fn solve_for_sensitivities(&mut self, i: usize, t: Eqn::T) -> Result<(), DiffsolError> { + let h = self.state.as_ref().unwrap().h; // update for new state { - self.problem() - .as_ref() + let op = self.s_op.as_mut().unwrap(); + Rc::get_mut(op.eqn_mut()) .unwrap() - .eqn_sens - .as_ref() - .unwrap() - .rhs() - .update_state(&self.old_y, &self.old_f, t); - } - - // reuse linear solver from nonlinear solver - let ls = |x: &mut Eqn::V| -> Result<(), DiffsolError> { - self.nonlinear_solver.solve_linearised_in_place(x) - }; + .update_rhs_out_state(&self.old_y, &self.old_f, t); - // construct bdf discretisation of sensitivity equations - let op = self.s_op.as_ref().unwrap(); - op.set_h(self.state.as_ref().unwrap().h); + // construct bdf discretisation of sensitivity equations + op.set_h(h); + } // solve for sensitivities equations discretised using sdirk equation - let fun = |x: &Eqn::V, y: &mut Eqn::V| op.call_inplace(x, t, y); - let mut convergence = self.nonlinear_solver.convergence().clone(); - let nparams = self.problem().as_ref().unwrap().eqn.rhs().nparams(); - for j in 0..nparams { + for j in 0..self.sdiff.len() { let s0 = &self.state.as_ref().unwrap().s[j]; + let op = self.s_op.as_mut().unwrap(); op.set_phi(&self.sdiff[j].columns(0, i), s0, &self.a_rows[i]); - op.eqn().as_ref().rhs().set_param_index(j); + Rc::get_mut(op.eqn_mut()).unwrap().set_index(j); let ds = &mut self.old_f_sens[j]; Self::predict_stage(i, &self.sdiff[j], ds, &self.tableau); // solve - { - newton_iteration(ds, &mut self.old_y_sens[j], s0, fun, ls, &mut convergence)?; - self.old_y_sens[j].copy_from(&op.get_last_f_eval()); - self.statistics.number_of_nonlinear_solver_iterations += convergence.niter(); + let op = self.s_op.as_ref().unwrap(); + self.nonlinear_solver.solve_in_place(op, ds, t, s0)?; + + self.old_y_sens[j].copy_from(&op.get_last_f_eval()); + self.statistics.number_of_nonlinear_solver_iterations += + self.nonlinear_solver.convergence().niter(); + + // calculate sdg and store in sgdiff + if let Some(out) = self.s_op.as_ref().unwrap().eqn().out() { + let dsg = &mut self.state.as_mut().unwrap().dsg[j]; + out.call_inplace(&self.old_y_sens[j], t, dsg); + self.sgdiff[j].column_mut(i).axpy(h, dsg, Eqn::T::zero()); } } Ok(()) @@ -304,14 +408,20 @@ where fn _jacobian_updates(&mut self, h: Eqn::T, state: SolverState) { if self.jacobian_update.check_rhs_jacobian_update(h, &state) { - self.nonlinear_solver.problem().f.set_jacobian_is_stale(); - self.nonlinear_solver - .reset_jacobian(&self.old_f, self.state.as_ref().unwrap().t); + self.op.as_mut().unwrap().set_jacobian_is_stale(); + self.nonlinear_solver.reset_jacobian( + self.op.as_ref().unwrap(), + &self.old_f, + self.state.as_ref().unwrap().t, + ); self.jacobian_update.update_rhs_jacobian(); self.jacobian_update.update_jacobian(h); } else if self.jacobian_update.check_jacobian_update(h, &state) { - self.nonlinear_solver - .reset_jacobian(&self.old_f, self.state.as_ref().unwrap().t); + self.nonlinear_solver.reset_jacobian( + self.op.as_ref().unwrap(), + &self.old_f, + self.state.as_ref().unwrap().t, + ); self.jacobian_update.update_jacobian(h); } } @@ -327,7 +437,7 @@ where } // update h for new step size - self.nonlinear_solver.problem().f.set_h(new_h); + self.op.as_mut().unwrap().set_h(new_h); // update state self.state.as_mut().unwrap().h = new_h; @@ -336,11 +446,12 @@ where } } -impl OdeSolverMethod for Sdirk +impl OdeSolverMethod for Sdirk where - LS: LinearSolver>, + LS: LinearSolver, M: DenseMatrix, - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquations, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { @@ -362,7 +473,7 @@ where if self.state.is_none() { return Err(ode_solver_error!(StateNotSet)); } - self._jacobian_updates(self.state.as_ref().unwrap().h(), SolverState::Checkpoint); + self._jacobian_updates(self.state.as_ref().unwrap().h, SolverState::Checkpoint); Ok(self.state.as_ref().unwrap().clone()) } @@ -372,17 +483,20 @@ where problem: &OdeSolverProblem, ) -> Result<(), DiffsolError> { // setup linear solver for first step - let callable = Rc::new(SdirkCallable::new(problem, self.gamma)); + let callable = SdirkCallable::new(problem, self.gamma); callable.set_h(state.h); self.jacobian_update.update_jacobian(state.h); self.jacobian_update.update_rhs_jacobian(); - let nonlinear_problem = SolverProblem::new_from_ode_problem(callable, problem); - self.nonlinear_solver.set_problem(&nonlinear_problem); + self.nonlinear_solver + .set_problem(&callable, problem.rtol, problem.atol.clone()); // set max iterations for nonlinear solver self.nonlinear_solver .convergence_mut() .set_max_iter(Self::NEWTON_MAXITER); + self.nonlinear_solver + .reset_jacobian(&callable, &state.y, state.t); + self.op = Some(callable); // update statistics self.statistics = BdfStatistics::default(); @@ -394,26 +508,21 @@ where if self.diff.nrows() != nstates || self.diff.ncols() != order { self.diff = M::zeros(nstates, order); } - if problem.eqn_sens.is_some() { - state.check_sens_consistent_with_problem(problem)?; - let nparams = problem.eqn.rhs().nparams(); - if self.sdiff.len() != nparams - || self.sdiff[0].nrows() != nstates - || self.sdiff[0].ncols() != order - { - self.sdiff = vec![M::zeros(nstates, order); nparams]; - self.old_f_sens = vec![::zeros(nstates); nparams]; - self.old_y_sens = state.s.clone(); - self.s_op = Some(SdirkCallable::from_eqn( - problem.eqn_sens.as_ref().unwrap().clone(), - self.gamma, - )); - } + let gdiff_rows = if problem.integrate_out { + problem.eqn.out().unwrap().nout() + } else { + 0 + }; + if self.gdiff.nrows() != gdiff_rows || self.gdiff.ncols() != order { + self.gdiff = M::zeros(gdiff_rows, order); } self.old_f = state.dy.clone(); self.old_t = state.t; self.old_y = state.y.clone(); + if problem.integrate_out { + self.old_g = state.g.clone(); + } state.set_problem(problem)?; self.state = Some(state); @@ -441,6 +550,31 @@ where // dont' reset jacobian for the first attempt at the step let mut error = ::zeros(n); + let out_error_control = self.problem().as_ref().unwrap().output_in_error_control(); + let mut out_error = if out_error_control { + ::zeros(self.problem().as_ref().unwrap().eqn.out().unwrap().nout()) + } else { + ::zeros(0) + }; + let sens_error_control = + self.s_op.is_some() && self.s_op.as_ref().unwrap().eqn().include_in_error_control(); + let mut sens_error = if sens_error_control { + ::zeros(self.s_op.as_ref().unwrap().eqn().rhs().nstates()) + } else { + ::zeros(0) + }; + let sens_out_error_control = self.s_op.is_some() + && self + .s_op + .as_ref() + .unwrap() + .eqn() + .include_out_in_error_control(); + let mut sens_out_error = if sens_out_error_control { + ::zeros(self.s_op.as_ref().unwrap().eqn().out().unwrap().nout()) + } else { + ::zeros(0) + }; let mut factor: Eqn::T; @@ -459,7 +593,7 @@ where } // sensitivities too - if self.problem().as_ref().unwrap().eqn_sens.is_some() { + if self.s_op.is_some() { for (diff, dy) in self .sdiff .iter_mut() @@ -469,12 +603,29 @@ where hf.copy_from(dy); hf *= scale(h); } + for (diff, dg) in self + .sgdiff + .iter_mut() + .zip(self.state.as_ref().unwrap().dsg.iter()) + { + let mut hf = diff.column_mut(0); + hf.copy_from(dg); + hf *= scale(h); + } + } + + // output function + if self.problem.as_ref().unwrap().integrate_out { + let state = self.state.as_ref().unwrap(); + let mut hf = self.gdiff.column_mut(0); + hf.copy_from(&state.dg); + hf *= scale(h); } } for i in start..self.tableau.s() { let t = t0 + self.tableau.c()[i] * h; - self.nonlinear_solver.problem().f.set_phi( + self.op.as_mut().unwrap().set_phi( &self.diff.columns(0, i), &self.state.as_ref().unwrap().y, &self.a_rows[i], @@ -483,6 +634,7 @@ where Self::predict_stage(i, &self.diff, &mut self.old_f, &self.tableau); let mut solve_result = self.nonlinear_solver.solve_in_place( + self.op.as_ref().unwrap(), &mut self.old_f, t, &self.state.as_ref().unwrap().y, @@ -494,8 +646,8 @@ where if solve_result.is_ok() { // old_y now has the new y soln and old_f has the new dy soln self.old_y - .copy_from(&self.nonlinear_solver.problem().f.get_last_f_eval()); - if self.problem().as_ref().unwrap().eqn_sens.is_some() { + .copy_from(&self.op.as_ref().unwrap().get_last_f_eval()); + if self.s_op.is_some() { solve_result = self.solve_for_sensitivities(i, t); } } @@ -519,7 +671,18 @@ where // update diff with solved dy self.diff.column_mut(i).copy_from(&self.old_f); - if self.problem().as_ref().unwrap().eqn_sens.is_some() { + // calculate dg and store in gdiff + if self.problem.as_ref().unwrap().integrate_out { + let out = self.problem.as_ref().unwrap().eqn.out().unwrap(); + out.call_inplace(&self.old_y, t, &mut self.state.as_mut().unwrap().dg); + self.gdiff.column_mut(i).axpy( + h, + &self.state.as_mut().unwrap().dg, + Eqn::T::zero(), + ); + } + + if self.s_op.is_some() { for (diff, old_f_sens) in self.sdiff.iter_mut().zip(self.old_f_sens.iter()) { diff.column_mut(i).copy_from(old_f_sens); } @@ -533,19 +696,62 @@ where let atol = self.problem().as_ref().unwrap().atol.as_ref(); let rtol = self.problem().as_ref().unwrap().rtol; let mut error_norm = error.squared_norm(&self.old_y, atol, rtol); + let mut ncontributions = 1; + + // output errors + if out_error_control { + self.gdiff.gemv( + Eqn::T::one(), + self.tableau.d(), + Eqn::T::zero(), + &mut out_error, + ); + let atol = self.problem().as_ref().unwrap().out_atol.as_ref().unwrap(); + let rtol = self.problem().as_ref().unwrap().out_rtol.unwrap(); + let out_error_norm = out_error.squared_norm(&self.old_g, atol, rtol); + error_norm += out_error_norm; + ncontributions += 1; + } // sensitivity errors - if self.problem().as_ref().unwrap().eqn_sens.is_some() - && self.problem().as_ref().unwrap().sens_error_control - { + if sens_error_control { + let atol = self.s_op.as_ref().unwrap().eqn().atol().unwrap(); + let rtol = self.s_op.as_ref().unwrap().eqn().rtol().unwrap(); for i in 0..self.sdiff.len() { - self.sdiff[i].gemv(Eqn::T::one(), self.tableau.d(), Eqn::T::zero(), &mut error); - let sens_error_norm = error.squared_norm(&self.old_y_sens[i], atol, rtol); + self.sdiff[i].gemv( + Eqn::T::one(), + self.tableau.d(), + Eqn::T::zero(), + &mut sens_error, + ); + let sens_error_norm = sens_error.squared_norm(&self.old_y_sens[i], atol, rtol); error_norm += sens_error_norm; + ncontributions += 1; } - error_norm /= Eqn::T::from((self.sdiff.len() + 1) as f64); } + // sensitivity output errors + if sens_out_error_control { + let atol = self.s_op.as_ref().unwrap().eqn().out_atol().unwrap(); + let rtol = self.s_op.as_ref().unwrap().eqn().out_rtol().unwrap(); + for i in 0..self.sgdiff.len() { + self.sgdiff[i].gemv( + Eqn::T::one(), + self.tableau.d(), + Eqn::T::zero(), + &mut sens_out_error, + ); + let sens_error_norm = sens_out_error.squared_norm( + &self.state.as_ref().unwrap().sg[i], + atol, + rtol, + ); + error_norm += sens_error_norm; + ncontributions += 1; + } + } + error_norm /= Eqn::T::from(ncontributions as f64); + // adjust step size based on error let maxiter = self.nonlinear_solver.convergence().max_iter() as f64; let niter = self.nonlinear_solver.convergence().niter() as f64; @@ -588,6 +794,22 @@ where std::mem::swap(&mut self.old_f_sens[i], &mut state.ds[i]); std::mem::swap(&mut self.old_y_sens[i], &mut state.s[i]); } + + for i in 0..self.sgdiff.len() { + self.sgdiff[i].gemv( + Eqn::T::one(), + self.tableau.b(), + Eqn::T::one(), + &mut state.sg[i], + ); + } + + // integrate output function + if self.problem.as_ref().unwrap().integrate_out { + self.old_g.copy_from(&state.g); + self.gdiff + .gemv(Eqn::T::one(), self.tableau.b(), Eqn::T::one(), &mut state.g); + } } // update step size for next step @@ -598,7 +820,7 @@ where // update statistics self.statistics.number_of_linear_solver_setups = - self.nonlinear_solver.problem().f.number_of_jac_evals(); + self.op.as_ref().unwrap().number_of_jac_evals(); self.statistics.number_of_steps += 1; self.jacobian_update.step(); @@ -730,13 +952,113 @@ where } } - fn state(&self) -> Option<&SdirkState> { - self.state.as_ref() + fn interpolate_out(&self, t: ::T) -> Result<::V, DiffsolError> { + if self.state.is_none() { + return Err(ode_solver_error!(StateNotSet)); + } + let state = self.state.as_ref().unwrap(); + + if self.is_state_mutated { + if t == state.t { + return Ok(state.g.clone()); + } else { + return Err(ode_solver_error!(InterpolationTimeOutsideCurrentStep)); + } + } + + // check that t is within the current step depending on the direction + let is_forward = state.h > Eqn::T::zero(); + if (is_forward && (t > state.t || t < self.old_t)) + || (!is_forward && (t < state.t || t > self.old_t)) + { + return Err(ode_solver_error!(InterpolationTimeOutsideCurrentStep)); + } + + let dt = state.t - self.old_t; + let theta = if dt == Eqn::T::zero() { + Eqn::T::one() + } else { + (t - self.old_t) / dt + }; + + if let Some(beta) = self.tableau.beta() { + let beta_f = Self::interpolate_beta_function(theta, beta); + let ret = Self::interpolate_from_diff(&self.old_g, &beta_f, &self.gdiff); + Ok(ret) + } else { + let ret = Self::interpolate_hermite(theta, &self.old_g, &state.g, &self.gdiff); + Ok(ret) + } } - fn state_mut(&mut self) -> Option<&mut SdirkState> { + fn state(&self) -> Option> { + self.state.as_ref().map(|s| s.as_ref()) + } + + fn state_mut(&mut self) -> Option> { self.is_state_mutated = true; - self.state.as_mut() + self.state.as_mut().map(|s| s.as_mut()) + } +} + +impl AugmentedOdeSolverMethod + for Sdirk +where + LS: LinearSolver, + M: DenseMatrix, + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquations, + for<'a> &'a Eqn::V: VectorRef, + for<'a> &'a Eqn::M: MatrixRef, +{ + fn set_augmented_problem( + &mut self, + state: Self::State, + ode_problem: &OdeSolverProblem, + augmented_eqn: AugmentedEqn, + ) -> Result<(), DiffsolError> { + state.check_sens_consistent_with_problem(ode_problem, &augmented_eqn)?; + self.set_problem(state, ode_problem)?; + let naug = augmented_eqn.max_index(); + let nstates = augmented_eqn.rhs().nstates(); + let order = self.tableau.s(); + if self.sdiff.len() != naug + || self.sdiff[0].nrows() != nstates + || self.sdiff[0].ncols() != order + { + self.sdiff = vec![M::zeros(nstates, order); naug]; + self.old_f_sens = vec![::zeros(nstates); naug]; + self.old_y_sens = self.state.as_ref().unwrap().s.clone(); + } + if let Some(out) = augmented_eqn.out() { + if self.sgdiff.len() != naug + || self.sgdiff[0].nrows() != out.nout() + || self.sgdiff[0].ncols() != order + { + self.sgdiff = vec![M::zeros(out.nout(), order); naug]; + } + } + let augmented_eqn = Rc::new(augmented_eqn); + self.s_op = Some(SdirkCallable::from_eqn(augmented_eqn, self.gamma)); + Ok(()) + } +} + +impl AdjointOdeSolverMethod for Sdirk +where + Eqn: OdeEquationsAdjoint, + AugmentedEqn: AugmentedOdeEquations + OdeEquationsAdjoint, + M: DenseMatrix, + LS: LinearSolver, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + type AdjointSolver = Sdirk, LS, AdjointEquations>; + + fn new_adjoint_solver(&self) -> Self::AdjointSolver { + let tableau = self.tableau.clone(); + let linear_solver = LS::default(); + Self::AdjointSolver::new_common(tableau, linear_solver) } } @@ -746,47 +1068,44 @@ mod test { ode_solver::{ test_models::{ exponential_decay::{ - exponential_decay_problem, exponential_decay_problem_sens, - exponential_decay_problem_with_root, negative_exponential_decay_problem, + exponential_decay_problem, exponential_decay_problem_adjoint, + exponential_decay_problem_sens, exponential_decay_problem_with_root, + negative_exponential_decay_problem, }, + exponential_decay_with_algebraic::exponential_decay_with_algebraic_adjoint_problem, heat2d::head2d_problem, - robertson::robertson, + robertson::{robertson, robertson_sens}, robertson_ode::robertson_ode, - robertson_sens::robertson_sens, }, tests::{ test_checkpointing, test_interpolate, test_no_set_problem, test_ode_solver, - test_state_mut, test_state_mut_on_problem, + test_ode_solver_adjoint, test_ode_solver_no_sens, test_state_mut, + test_state_mut_on_problem, }, }, - FaerSparseLU, NalgebraLU, OdeEquations, Op, Sdirk, SparseColMat, Tableau, + OdeEquations, Op, Sdirk, SparseColMat, }; - use faer::Mat; use num_traits::abs; type M = nalgebra::DMatrix; #[test] fn sdirk_no_set_problem() { - let tableau = Tableau::::tr_bdf2(); - test_no_set_problem::(Sdirk::::new(tableau, NalgebraLU::default())); + test_no_set_problem::(Sdirk::tr_bdf2()); } #[test] fn sdirk_state_mut() { - let tableau = Tableau::::tr_bdf2(); - test_state_mut::(Sdirk::::new(tableau, NalgebraLU::default())); + test_state_mut::(Sdirk::tr_bdf2()); } #[test] fn sdirk_test_interpolate() { - let tableau = Tableau::::tr_bdf2(); - test_interpolate::(Sdirk::::new(tableau, NalgebraLU::default())); + test_interpolate::(Sdirk::tr_bdf2()); } #[test] fn sdirk_test_checkpointing() { - let tableau = Tableau::::tr_bdf2(); - let s1 = Sdirk::::new(tableau.clone(), NalgebraLU::default()); - let s2 = Sdirk::::new(tableau, NalgebraLU::default()); + let s1 = Sdirk::tr_bdf2(); + let s2 = Sdirk::tr_bdf2(); let (problem, soln) = exponential_decay_problem::(false); test_checkpointing(s1, s2, problem, soln); } @@ -794,25 +1113,22 @@ mod test { #[test] fn sdirk_test_state_mut_exponential_decay() { let (p, soln) = exponential_decay_problem::(false); - let tableau = Tableau::::tr_bdf2(); - let s = Sdirk::::new(tableau, NalgebraLU::default()); + let s = Sdirk::tr_bdf2(); test_state_mut_on_problem(s, p, soln); } #[test] fn sdirk_test_nalgebra_negative_exponential_decay() { - let tableau = Tableau::::esdirk34(); - let mut s = Sdirk::::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::esdirk34(); let (problem, soln) = negative_exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] fn test_tr_bdf2_nalgebra_exponential_decay() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::tr_bdf2(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 4 @@ -826,37 +1142,37 @@ mod test { number_of_calls: 118 number_of_jac_muls: 2 number_of_matrix_evals: 1 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_tr_bdf2_nalgebra_exponential_decay_sens() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::tr_bdf2_with_sensitivities(); let (problem, soln) = exponential_decay_problem_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 7 - number_of_steps: 55 + number_of_steps: 52 number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 550 + number_of_nonlinear_solver_iterations: 520 number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 222 - number_of_jac_muls: 336 + number_of_calls: 210 + number_of_jac_muls: 318 number_of_matrix_evals: 2 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_esdirk34_nalgebra_exponential_decay() { - let tableau = Tableau::::esdirk34(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::esdirk34(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 3 @@ -870,125 +1186,169 @@ mod test { number_of_calls: 86 number_of_jac_muls: 2 number_of_matrix_evals: 1 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_esdirk34_nalgebra_exponential_decay_sens() { - let tableau = Tableau::::esdirk34(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::esdirk34_with_sensitivities(); let (problem, soln) = exponential_decay_problem_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 5 - number_of_steps: 21 + number_of_steps: 20 number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 333 + number_of_nonlinear_solver_iterations: 317 number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 128 - number_of_jac_muls: 211 + number_of_calls: 122 + number_of_jac_muls: 201 number_of_matrix_evals: 1 + number_of_jac_adj_muls: 0 + "###); + } + + #[test] + fn sdirk_test_esdirk34_exponential_decay_adjoint() { + let s = Sdirk::esdirk34(); + let (problem, soln) = exponential_decay_problem_adjoint::(); + let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + --- + number_of_calls: 196 + number_of_jac_muls: 6 + number_of_matrix_evals: 3 + number_of_jac_adj_muls: 599 + "###); + insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 18 + number_of_steps: 29 + number_of_error_test_failures: 10 + number_of_nonlinear_solver_iterations: 595 + number_of_nonlinear_solver_fails: 0 + "###); + } + + #[test] + fn sdirk_test_esdirk34_exponential_decay_algebraic_adjoint() { + let s = Sdirk::esdirk34(); + let (problem, soln) = exponential_decay_with_algebraic_adjoint_problem::(); + let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); + insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" + --- + number_of_calls: 171 + number_of_jac_muls: 12 + number_of_matrix_evals: 4 + number_of_jac_adj_muls: 287 + "###); + insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" + --- + number_of_linear_solver_setups: 18 + number_of_steps: 20 + number_of_error_test_failures: 11 + number_of_nonlinear_solver_iterations: 278 + number_of_nonlinear_solver_fails: 0 "###); } #[test] fn test_tr_bdf2_nalgebra_robertson() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::tr_bdf2(); let (problem, soln) = robertson::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 97 - number_of_steps: 234 + number_of_steps: 232 number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 1965 - number_of_nonlinear_solver_fails: 13 + number_of_nonlinear_solver_iterations: 1921 + number_of_nonlinear_solver_fails: 18 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 1968 + number_of_calls: 1924 number_of_jac_muls: 36 number_of_matrix_evals: 12 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_tr_bdf2_nalgebra_robertson_sens() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); - let (problem, soln) = robertson_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + let mut s = Sdirk::tr_bdf2_with_sensitivities(); + let (problem, soln) = robertson_sens::(); + test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 115 - number_of_steps: 245 - number_of_error_test_failures: 3 - number_of_nonlinear_solver_iterations: 4990 + number_of_linear_solver_setups: 112 + number_of_steps: 216 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 4529 number_of_nonlinear_solver_fails: 37 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 1581 - number_of_jac_muls: 3569 - number_of_matrix_evals: 28 + number_of_calls: 1420 + number_of_jac_muls: 3277 + number_of_matrix_evals: 27 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_esdirk34_nalgebra_robertson() { - let tableau = Tableau::::esdirk34(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::esdirk34(); let (problem, soln) = robertson::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 87 - number_of_steps: 137 + number_of_linear_solver_setups: 100 + number_of_steps: 141 number_of_error_test_failures: 0 - number_of_nonlinear_solver_iterations: 1773 - number_of_nonlinear_solver_fails: 15 + number_of_nonlinear_solver_iterations: 1793 + number_of_nonlinear_solver_fails: 24 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 1776 - number_of_jac_muls: 45 - number_of_matrix_evals: 15 + number_of_calls: 1796 + number_of_jac_muls: 54 + number_of_matrix_evals: 18 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_esdirk34_nalgebra_robertson_sens() { - let tableau = Tableau::::esdirk34(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); - let (problem, soln) = robertson_sens::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + let mut s = Sdirk::esdirk34_with_sensitivities(); + let (problem, soln) = robertson_sens::(); + test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- - number_of_linear_solver_setups: 123 - number_of_steps: 156 - number_of_error_test_failures: 2 - number_of_nonlinear_solver_iterations: 5194 - number_of_nonlinear_solver_fails: 49 + number_of_linear_solver_setups: 114 + number_of_steps: 131 + number_of_error_test_failures: 0 + number_of_nonlinear_solver_iterations: 4442 + number_of_nonlinear_solver_fails: 44 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" --- - number_of_calls: 1695 - number_of_jac_muls: 3719 - number_of_matrix_evals: 38 + number_of_calls: 1492 + number_of_jac_muls: 3136 + number_of_matrix_evals: 33 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_tr_bdf2_nalgebra_robertson_ode() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::tr_bdf2(); let (problem, soln) = robertson_ode::(false, 1); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 113 @@ -1002,31 +1362,29 @@ mod test { number_of_calls: 2603 number_of_jac_muls: 39 number_of_matrix_evals: 13 + number_of_jac_adj_muls: 0 "###); } #[test] fn test_tr_bdf2_faer_sparse_heat2d() { - let tableau = Tableau::>::tr_bdf2(); - let mut s = Sdirk::new(tableau, FaerSparseLU::default()); + let mut s = Sdirk::tr_bdf2(); let (problem, soln) = head2d_problem::, 10>(); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] fn test_tstop_tr_bdf2() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::tr_bdf2(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, true); + test_ode_solver_no_sens(&mut s, &problem, soln, None, true); } #[test] fn test_root_finder_tr_bdf2() { - let tableau = Tableau::::tr_bdf2(); - let mut s = Sdirk::new(tableau, NalgebraLU::default()); + let mut s = Sdirk::tr_bdf2(); let (problem, soln) = exponential_decay_problem_with_root::(false); - let y = test_ode_solver(&mut s, &problem, soln, None, false); + let y = test_ode_solver_no_sens(&mut s, &problem, soln, None, false); assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]); } } diff --git a/src/ode_solver/sdirk_state.rs b/src/ode_solver/sdirk_state.rs index 320e3456..032644e1 100644 --- a/src/ode_solver/sdirk_state.rs +++ b/src/ode_solver/sdirk_state.rs @@ -1,11 +1,20 @@ -use crate::{error::DiffsolError, OdeEquations, OdeSolverProblem, OdeSolverState, Vector}; +use crate::{ + error::DiffsolError, OdeEquations, OdeSolverProblem, OdeSolverState, StateRef, StateRefMut, + Vector, +}; + +use super::state::StateCommon; #[derive(Clone)] pub struct SdirkState { pub(crate) y: V, pub(crate) dy: V, + pub(crate) g: V, + pub(crate) dg: V, pub(crate) s: Vec, pub(crate) ds: Vec, + pub(crate) sg: Vec, + pub(crate) dsg: Vec, pub(crate) t: V::T, pub(crate) h: V::T, } @@ -23,58 +32,71 @@ where Ok(()) } - fn new_internal_state(y: V, dy: V, s: Vec, ds: Vec, t: ::T, h: ::T) -> Self { - Self { y, dy, s, ds, t, h } - } - - fn s(&self) -> &[V] { - self.s.as_slice() - } - fn s_mut(&mut self) -> &mut [V] { - &mut self.s - } - fn ds_mut(&mut self) -> &mut [V] { - &mut self.ds - } - fn ds(&self) -> &[V] { - self.ds.as_slice() - } - fn s_ds_mut(&mut self) -> (&mut [V], &mut [V]) { - (&mut self.s, &mut self.ds) - } - fn y(&self) -> &V { - &self.y - } - - fn y_mut(&mut self) -> &mut V { - &mut self.y - } - - fn dy(&self) -> &V { - &self.dy - } - - fn dy_mut(&mut self) -> &mut V { - &mut self.dy - } - - fn y_dy_mut(&mut self) -> (&mut V, &mut V) { - (&mut self.y, &mut self.dy) + fn set_augmented_problem>( + &mut self, + _ode_problem: &OdeSolverProblem, + _augmented_eqn: &AugmentedEqn, + ) -> Result<(), DiffsolError> { + Ok(()) } - fn t(&self) -> V::T { - self.t + fn new_from_common(state: StateCommon) -> Self { + Self { + y: state.y, + dy: state.dy, + g: state.g, + dg: state.dg, + s: state.s, + ds: state.ds, + sg: state.sg, + dsg: state.dsg, + t: state.t, + h: state.h, + } } - fn t_mut(&mut self) -> &mut V::T { - &mut self.t + fn into_common(self) -> StateCommon { + StateCommon { + y: self.y, + dy: self.dy, + g: self.g, + dg: self.dg, + s: self.s, + ds: self.ds, + sg: self.sg, + dsg: self.dsg, + t: self.t, + h: self.h, + } } - fn h(&self) -> V::T { - self.h + fn as_mut(&mut self) -> StateRefMut { + StateRefMut { + y: &mut self.y, + dy: &mut self.dy, + g: &mut self.g, + dg: &mut self.dg, + s: &mut self.s, + ds: &mut self.ds, + sg: &mut self.sg, + dsg: &mut self.dsg, + t: &mut self.t, + h: &mut self.h, + } } - fn h_mut(&mut self) -> &mut V::T { - &mut self.h + fn as_ref(&self) -> StateRef { + StateRef { + y: &self.y, + dy: &self.dy, + g: &self.g, + dg: &self.dg, + s: &self.s, + ds: &self.ds, + sg: &self.sg, + dsg: &self.dsg, + t: self.t, + h: self.h, + } } } diff --git a/src/ode_solver/sens_equations.rs b/src/ode_solver/sens_equations.rs index d6705ea1..cffa77df 100644 --- a/src/ode_solver/sens_equations.rs +++ b/src/ode_solver/sens_equations.rs @@ -1,23 +1,24 @@ -use num_traits::{One, Zero}; +use num_traits::Zero; use std::{cell::RefCell, rc::Rc}; use crate::{ - matrix::sparsity::MatrixSparsityRef, ConstantOp, LinearOp, Matrix, MatrixSparsity, NonLinearOp, - OdeEquations, Op, Vector, + matrix::sparsity::MatrixSparsityRef, op::nonlinear_op::NonLinearOpJacobian, + AugmentedOdeEquations, ConstantOp, ConstantOpSens, Matrix, NonLinearOp, NonLinearOpSens, + OdeEquations, OdeEquationsSens, OdeSolverProblem, Op, Vector, }; pub struct SensInit where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { eqn: Rc, - init_sens: RefCell, - index: RefCell, + init_sens: Eqn::M, + index: usize, } impl SensInit where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { pub fn new(eqn: &Rc) -> Self { let nstates = eqn.rhs().nstates(); @@ -27,26 +28,24 @@ where nparams, eqn.init().sparsity_sens().map(|s| s.to_owned()), ); - let init_sens = RefCell::new(init_sens); - let index = RefCell::new(0); + let index = 0; Self { eqn: eqn.clone(), init_sens, index, } } - pub fn update_state(&self, t: Eqn::T) { - let mut init_sens = self.init_sens.borrow_mut(); - self.eqn.init().sens_inplace(t, &mut init_sens); + pub fn update_state(&mut self, t: Eqn::T) { + self.eqn.init().sens_inplace(t, &mut self.init_sens); } - pub fn set_param_index(&self, index: usize) { - self.index.replace(index); + pub fn set_param_index(&mut self, index: usize) { + self.index = index; } } impl Op for SensInit where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { type T = Eqn::T; type V = Eqn::V; @@ -65,19 +64,17 @@ where impl ConstantOp for SensInit where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { fn call_inplace(&self, _t: Self::T, y: &mut Self::V) { - let init_sens = self.init_sens.borrow(); - let index = *self.index.borrow(); y.fill(Eqn::T::zero()); - init_sens.add_column_to_vector(index, y); + self.init_sens.add_column_to_vector(self.index, y); } } -/// Right-hand side of the sensitivity equations is: +/// Right-hand side of the sensitivity equations is (we assume M_p = 0): /// -/// F(s, t) = J * s + f_p - M_p * dy/dt +/// F(s, t) = J * s + f_p /// /// f_p is the partial derivative of the right-hand side with respect to the parameters, /// this is constant and can be precomputed. It is a matrix of size nstates x nparams. @@ -85,7 +82,7 @@ where /// M_p * dy/dt is the partial derivative of the mass matrix wrt the parameters, /// multiplied by the derivative of the state wrt time. It is a matrix of size nstates x nparams. /// -/// Strategy is to pre-compute S = f_p - M_p * dy/dt from the state at given time step and store it in a matrix using [Self::update_state]. +/// Strategy is to pre-compute S = f_p from the state at given time step and store it in a matrix using [Self::update_state]. /// Then the ith column of function F(s, t) is evaluated as J * s_i + S_i, where s_i is the ith column of the sensitivity matrix /// and S_i is the ith column of the matrix S. The column to evaluate is set using [Self::set_param_index]. pub struct SensRhs @@ -94,17 +91,23 @@ where { eqn: Rc, sens: RefCell, - rhs_sens: Option>, - mass_sens: Option>, y: RefCell, index: RefCell, } impl SensRhs where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { - pub fn new(eqn: &Rc) -> Self { + pub fn new(eqn: &Rc, allocate: bool) -> Self { + if !allocate { + return Self { + eqn: eqn.clone(), + sens: RefCell::new(::zeros(0, 0)), + y: RefCell::new(::zeros(0)), + index: RefCell::new(0), + }; + } let nstates = eqn.rhs().nstates(); let nparams = eqn.rhs().nparams(); let rhs_sens = Eqn::M::new_from_sparsity( @@ -114,57 +117,18 @@ where ); let y = RefCell::new(::zeros(nstates)); let index = RefCell::new(0); - if let Some(mass) = eqn.mass() { - let mass_sens = Eqn::M::new_from_sparsity( - nstates, - nparams, - mass.sparsity_sens().map(|s| s.to_owned()), - ); - let sens = if rhs_sens.sparsity().is_some() && mass_sens.sparsity().is_some() { - // union of sparsity patterns - let sparsity = rhs_sens - .sparsity() - .unwrap() - .to_owned() - .union(mass_sens.sparsity().unwrap()) - .unwrap(); - Eqn::M::new_from_sparsity(nstates, nparams, Some(sparsity)) - } else { - Eqn::M::new_from_sparsity(nstates, nparams, None) - }; - Self { - eqn: eqn.clone(), - sens: RefCell::new(sens), - rhs_sens: Some(RefCell::new(rhs_sens)), - mass_sens: Some(RefCell::new(mass_sens)), - y, - index, - } - } else { - Self { - eqn: eqn.clone(), - sens: RefCell::new(rhs_sens), - rhs_sens: None, - mass_sens: None, - y, - index, - } + Self { + eqn: eqn.clone(), + sens: RefCell::new(rhs_sens), + y, + index, } } - /// pre-compute S = f_p - M_p * dy/dt from the state - pub fn update_state(&self, y: &Eqn::V, dy: &Eqn::V, t: Eqn::T) { - if self.rhs_sens.is_some() { - let mut rhs_sens = self.rhs_sens.as_ref().unwrap().borrow_mut(); - let mut mass_sens = self.mass_sens.as_ref().unwrap().borrow_mut(); - let mut sens = self.sens.borrow_mut(); - self.eqn.rhs().sens_inplace(y, t, &mut rhs_sens); - self.eqn.mass().unwrap().sens_inplace(dy, t, &mut mass_sens); - sens.scale_add_and_assign(&rhs_sens, -Eqn::T::one(), &mass_sens); - } else { - let mut sens = self.sens.borrow_mut(); - self.eqn.rhs().sens_inplace(y, t, &mut sens); - } + /// pre-compute S = f_p from the state + pub fn update_state(&mut self, y: &Eqn::V, _dy: &Eqn::V, t: Eqn::T) { + let mut sens = self.sens.borrow_mut(); + self.eqn.rhs().sens_inplace(y, t, &mut sens); let mut state_y = self.y.borrow_mut(); state_y.copy_from(y); } @@ -175,7 +139,7 @@ where impl Op for SensRhs where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { type T = Eqn::T; type V = Eqn::V; @@ -194,7 +158,7 @@ where impl NonLinearOp for SensRhs where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { /// the ith column of function F(s, t) is evaluated as J * s_i + S_i, where s_i is the ith column of the sensitivity matrix fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { @@ -204,6 +168,12 @@ where self.eqn.rhs().jac_mul_inplace(&state_y, t, x, y); sens.add_column_to_vector(index, y); } +} + +impl NonLinearOpJacobian for SensRhs +where + Eqn: OdeEquationsSens, +{ fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { let state_y = self.y.borrow(); self.eqn.rhs().jac_mul_inplace(&state_y, t, v, y); @@ -213,10 +183,11 @@ where self.eqn.rhs().jacobian_inplace(&state_y, t, y); } } -/// Sensitivity equations for ODEs + +/// Sensitivity & adjoint equations for ODEs (we assume M_p = 0): /// /// Sensitivity equations are linear: -/// M * ds/dt = J * s + f_p - M_p * dy/dt +/// M * ds/dt = J * s + f_p /// s(0) = dy(0)/dp /// where /// M is the mass matrix @@ -229,31 +200,47 @@ where /// pub struct SensEquations where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { eqn: Rc, rhs: Rc>, init: Rc>, + atol: Option>, + rtol: Option, +} + +impl std::fmt::Debug for SensEquations +where + Eqn: OdeEquationsSens, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SensEquations") + } } impl SensEquations where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { - pub fn new(eqn: &Rc) -> Self { - let rhs = Rc::new(SensRhs::new(eqn)); + pub(crate) fn new(problem: &OdeSolverProblem) -> Self { + let eqn = &problem.eqn; + let rtol = problem.sens_rtol; + let atol = problem.sens_atol.clone(); + let rhs = Rc::new(SensRhs::new(eqn, true)); let init = Rc::new(SensInit::new(eqn)); Self { rhs, init, eqn: eqn.clone(), + rtol, + atol, } } } impl Op for SensEquations where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { type T = Eqn::T; type V = Eqn::V; @@ -272,7 +259,7 @@ where impl OdeEquations for SensEquations where - Eqn: OdeEquations, + Eqn: OdeEquationsSens, { type T = Eqn::T; type V = Eqn::V; @@ -290,7 +277,7 @@ where self.eqn.mass() } fn root(&self) -> Option<&Rc> { - self.eqn.root() + None } fn init(&self) -> &Rc { &self.init @@ -299,7 +286,42 @@ where panic!("Not implemented for SensEquations"); } fn out(&self) -> Option<&Rc> { - self.eqn.out() + None + } +} + +impl AugmentedOdeEquations for SensEquations { + fn include_in_error_control(&self) -> bool { + self.rtol.is_some() && self.atol.is_some() + } + fn include_out_in_error_control(&self) -> bool { + false + } + fn rtol(&self) -> Option { + self.rtol + } + fn atol(&self) -> Option<&Rc> { + self.atol.as_ref() + } + fn out_atol(&self) -> Option<&Rc> { + None + } + fn out_rtol(&self) -> Option { + None + } + + fn max_index(&self) -> usize { + self.nparams() + } + fn update_rhs_out_state(&mut self, y: &Eqn::V, dy: &Eqn::V, t: Eqn::T) { + Rc::get_mut(&mut self.rhs).unwrap().update_state(y, dy, t); + } + fn update_init_state(&mut self, t: Eqn::T) { + Rc::get_mut(&mut self.init).unwrap().update_state(t); + } + fn set_index(&mut self, index: usize) { + Rc::get_mut(&mut self.rhs).unwrap().set_param_index(index); + Rc::get_mut(&mut self.init).unwrap().set_param_index(index); } } @@ -309,9 +331,9 @@ mod tests { ode_solver::test_models::{ exponential_decay::exponential_decay_problem_sens, exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem_sens, - robertson_sens::robertson_sens, + robertson::robertson_sens, }, - NonLinearOp, SdirkState, SensEquations, Vector, + AugmentedOdeEquations, NonLinearOp, SdirkState, SensEquations, Vector, }; type Mcpu = nalgebra::DMatrix; type Vcpu = nalgebra::DVector; @@ -320,11 +342,15 @@ mod tests { fn test_rhs_exponential() { // dy/dt = -ay (p = [a]) let (problem, _soln) = exponential_decay_problem_sens::(false); - let sens_eqn = SensEquations::new(&problem.eqn); + let mut sens_eqn = SensEquations::new(&problem); let state = SdirkState { t: 0.0, y: Vcpu::from_vec(vec![1.0, 1.0]), dy: Vcpu::from_vec(vec![1.0, 1.0]), + g: Vcpu::zeros(0), + dg: Vcpu::zeros(0), + sg: Vec::new(), + dsg: Vec::new(), s: Vec::new(), ds: Vec::new(), h: 0.0, @@ -334,7 +360,7 @@ mod tests { // M_p = 0 // so S = |-1.0| // |-1.0| - sens_eqn.rhs.update_state(&state.y, &state.dy, state.t); + sens_eqn.update_rhs_out_state(&state.y, &state.dy, state.t); let sens = sens_eqn.rhs.sens.borrow(); assert_eq!(sens.nrows(), 2); assert_eq!(sens.ncols(), 2); @@ -355,12 +381,16 @@ mod tests { #[test] fn test_rhs_exponential_algebraic() { - let (problem, _soln) = exponential_decay_with_algebraic_problem_sens::(false); - let sens_eqn = SensEquations::new(&problem.eqn); + let (problem, _soln) = exponential_decay_with_algebraic_problem_sens::(); + let mut sens_eqn = SensEquations::new(&problem); let state = SdirkState { t: 0.0, y: Vcpu::from_vec(vec![1.0, 1.0, 1.0]), dy: Vcpu::from_vec(vec![1.0, 1.0, 1.0]), + g: Vcpu::zeros(0), + dg: Vcpu::zeros(0), + sg: Vec::new(), + dsg: Vec::new(), s: Vec::new(), ds: Vec::new(), h: 0.0, @@ -374,7 +404,7 @@ mod tests { // so S = |-0.1| // |-0.1| // | 0 | - sens_eqn.rhs.update_state(&state.y, &state.dy, state.t); + sens_eqn.update_rhs_out_state(&state.y, &state.dy, state.t); let sens = sens_eqn.rhs.sens.borrow(); assert_eq!(sens.nrows(), 3); assert_eq!(sens.ncols(), 1); @@ -400,12 +430,16 @@ mod tests { #[test] fn test_rhs_robertson() { - let (problem, _soln) = robertson_sens::(false); - let sens_eqn = SensEquations::new(&problem.eqn); + let (problem, _soln) = robertson_sens::(); + let mut sens_eqn = SensEquations::new(&problem); let state = SdirkState { t: 0.0, y: Vcpu::from_vec(vec![1.0, 2.0, 3.0]), dy: Vcpu::from_vec(vec![1.0, 1.0, 1.0]), + g: Vcpu::zeros(0), + dg: Vcpu::zeros(0), + sg: Vec::new(), + dsg: Vec::new(), s: Vec::new(), ds: Vec::new(), h: 0.0, @@ -417,7 +451,7 @@ mod tests { // | 0 0 0| // M_p = 0 // so S = f_p - sens_eqn.rhs.update_state(&state.y, &state.dy, state.t); + sens_eqn.update_rhs_out_state(&state.y, &state.dy, state.t); let sens = sens_eqn.rhs.sens.borrow(); assert_eq!(sens.nrows(), 3); assert_eq!(sens.ncols(), 3); diff --git a/src/ode_solver/state.rs b/src/ode_solver/state.rs index 773ce321..8801439f 100644 --- a/src/ode_solver/state.rs +++ b/src/ode_solver/state.rs @@ -3,68 +3,137 @@ use num_traits::{One, Pow, Zero}; use std::rc::Rc; use crate::{ - error::DiffsolError, error::OdeSolverError, nonlinear_solver::NonLinearSolver, - ode_solver_error, scale, solver::SolverProblem, ConstantOp, DefaultSolver, InitOp, - NewtonNonlinearSolver, NonLinearOp, OdeEquations, OdeSolverMethod, OdeSolverProblem, Op, - SensEquations, Vector, + error::{DiffsolError, OdeSolverError}, + nonlinear_solver::NonLinearSolver, + ode_solver_error, scale, AugmentedOdeEquations, AugmentedOdeEquationsImplicit, ConstantOp, + DefaultSolver, InitOp, NewtonNonlinearSolver, NonLinearOp, OdeEquations, OdeEquationsImplicit, + OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, Op, SensEquations, Vector, }; +use super::method::SensitivitiesOdeSolverMethod; + +/// A state holding those variables that are common to all ODE solver states, +/// can be used to create a new state for a specific solver. +pub struct StateCommon { + pub y: V, + pub dy: V, + pub g: V, + pub dg: V, + pub s: Vec, + pub ds: Vec, + pub sg: Vec, + pub dsg: Vec, + pub t: V::T, + pub h: V::T, +} + +/// A reference to the state of the ODE solver, containing: +/// - the current solution `y` +/// - the derivative of the solution wrt time `dy` +/// - the current integral of the output function `g` +/// - the current derivative of the integral of the output function wrt time `dg` +/// - the current time `t` +/// - the current step size `h` +/// - the sensitivity vectors `s` +/// - the derivative of the sensitivity vectors wrt time `ds` +/// - the sensitivity vectors of the output function `sg` +/// - the derivative of the sensitivity vectors of the output function wrt time `dsg` +pub struct StateRef<'a, V: Vector> { + pub y: &'a V, + pub dy: &'a V, + pub g: &'a V, + pub dg: &'a V, + pub s: &'a [V], + pub ds: &'a [V], + pub sg: &'a [V], + pub dsg: &'a [V], + pub t: V::T, + pub h: V::T, +} + +/// A mutable reference to the state of the ODE solver, containing: +/// - the current solution `y` +/// - the derivative of the solution wrt time `dy` +/// - the current integral of the output function `g` +/// - the current derivative of the integral of the output function wrt time `dg` +/// - the current time `t` +/// - the current step size `h` +/// - the sensitivity vectors `s` +/// - the derivative of the sensitivity vectors wrt time `ds` +/// - the sensitivity vectors of the output function `sg` +/// - the derivative of the sensitivity vectors of the output function wrt time `dsg` +pub struct StateRefMut<'a, V: Vector> { + pub y: &'a mut V, + pub dy: &'a mut V, + pub g: &'a mut V, + pub dg: &'a mut V, + pub s: &'a mut [V], + pub ds: &'a mut [V], + pub sg: &'a mut [V], + pub dsg: &'a mut [V], + pub t: &'a mut V::T, + pub h: &'a mut V::T, +} + /// State for the ODE solver, containing: /// - the current solution `y` /// - the derivative of the solution wrt time `dy` +/// - the current integral of the output function `g` +/// - the current derivative of the integral of the output function wrt time `dg` /// - the current time `t` /// - the current step size `h`, /// - the sensitivity vectors `s` /// - the derivative of the sensitivity vectors wrt time `ds` /// pub trait OdeSolverState: Clone + Sized { - fn y(&self) -> &V; - fn y_mut(&mut self) -> &mut V; - fn dy(&self) -> &V; - fn dy_mut(&mut self) -> &mut V; - fn y_dy_mut(&mut self) -> (&mut V, &mut V); - fn s(&self) -> &[V]; - fn s_mut(&mut self) -> &mut [V]; - fn ds(&self) -> &[V]; - fn ds_mut(&mut self) -> &mut [V]; - fn s_ds_mut(&mut self) -> (&mut [V], &mut [V]); - fn t(&self) -> V::T; - fn t_mut(&mut self) -> &mut V::T; - fn h(&self) -> V::T; - fn h_mut(&mut self) -> &mut V::T; - fn new_internal_state(y: V, dy: V, s: Vec, ds: Vec, t: ::T, h: ::T) -> Self; + fn as_ref(&self) -> StateRef; + fn as_mut(&mut self) -> StateRefMut; + fn into_common(self) -> StateCommon; + fn new_from_common(state: StateCommon) -> Self; + fn set_problem( &mut self, ode_problem: &OdeSolverProblem, ) -> Result<(), DiffsolError>; + fn set_augmented_problem>( + &mut self, + ode_problem: &OdeSolverProblem, + augmented_eqn: &AugmentedEqn, + ) -> Result<(), DiffsolError>; + fn check_consistent_with_problem( &self, problem: &OdeSolverProblem, ) -> Result<(), DiffsolError> { - if self.y().len() != problem.eqn.rhs().nstates() { + if self.as_ref().y.len() != problem.eqn.rhs().nstates() { return Err(ode_solver_error!(StateProblemMismatch)); } - if self.dy().len() != problem.eqn.rhs().nstates() { + if self.as_ref().dy.len() != problem.eqn.rhs().nstates() { return Err(ode_solver_error!(StateProblemMismatch)); } Ok(()) } - fn check_sens_consistent_with_problem( + fn check_sens_consistent_with_problem< + Eqn: OdeEquations, + AugmentedEqn: AugmentedOdeEquations, + >( &self, problem: &OdeSolverProblem, + augmented_eqn: &AugmentedEqn, ) -> Result<(), DiffsolError> { - if self.s().len() != problem.eqn_sens.as_ref().unwrap().rhs().nparams() { + let state = self.as_ref(); + if state.s.len() != augmented_eqn.max_index() { return Err(ode_solver_error!(StateProblemMismatch)); } - if !self.s().is_empty() && self.s()[0].len() != problem.eqn.rhs().nstates() { + if !state.s.is_empty() && state.s[0].len() != problem.eqn.rhs().nstates() { return Err(ode_solver_error!(StateProblemMismatch)); } - if self.ds().len() != problem.eqn_sens.as_ref().unwrap().rhs().nparams() { + if state.ds.len() != augmented_eqn.max_index() { return Err(ode_solver_error!(StateProblemMismatch)); } - if !self.ds().is_empty() && self.ds()[0].len() != problem.eqn.rhs().nstates() { + if !state.ds.is_empty() && state.ds[0].len() != problem.eqn.rhs().nstates() { return Err(ode_solver_error!(StateProblemMismatch)); } Ok(()) @@ -77,26 +146,61 @@ pub trait OdeSolverState: Clone + Sized { /// You can then use [Self::set_consistent] and [Self::set_step_size] to set the state up if you need to. fn new(ode_problem: &OdeSolverProblem, solver: &S) -> Result where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, + Eqn::M: DefaultSolver, + S: OdeSolverMethod, + { + let mut ret = Self::new_without_initialise(ode_problem)?; + let mut root_solver = + NewtonNonlinearSolver::new(::default_solver()); + ret.set_consistent(ode_problem, &mut root_solver)?; + ret.set_step_size(ode_problem, solver.order()); + Ok(ret) + } + + fn new_with_sensitivities( + ode_problem: &OdeSolverProblem, + solver: &S, + ) -> Result + where + Eqn: OdeEquationsSens, + Eqn::M: DefaultSolver, + S: SensitivitiesOdeSolverMethod, + { + let augmented_eqn = SensEquations::new(ode_problem); + Self::new_with_augmented(ode_problem, augmented_eqn, solver).map(|(state, _)| state) + } + + fn new_with_augmented( + ode_problem: &OdeSolverProblem, + mut augmented_eqn: AugmentedEqn, + solver: &S, + ) -> Result<(Self, AugmentedEqn), DiffsolError> + where + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquationsImplicit + std::fmt::Debug, Eqn::M: DefaultSolver, S: OdeSolverMethod, { - let mut ret = Self::new_without_initialise(ode_problem); + let mut ret = Self::new_without_initialise_augmented(ode_problem, &mut augmented_eqn)?; let mut root_solver = NewtonNonlinearSolver::new(::default_solver()); ret.set_consistent(ode_problem, &mut root_solver)?; let mut root_solver_sens = NewtonNonlinearSolver::new(::default_solver()); - ret.set_consistent_sens(ode_problem, &mut root_solver_sens)?; + let augmented_eqn = + ret.set_consistent_augmented(ode_problem, augmented_eqn, &mut root_solver_sens)?; ret.set_step_size(ode_problem, solver.order()); - Ok(ret) + Ok((ret, augmented_eqn)) } /// Create a new solver state from an ODE problem, without any initialisation apart from setting the initial time state vector y, /// and if applicable the sensitivity vectors s. /// This is useful if you want to set up the state yourself, or if you want to use a different nonlinear solver to make the state consistent, /// or if you want to set the step size yourself or based on the exact order of the solver. - fn new_without_initialise(ode_problem: &OdeSolverProblem) -> Self + fn new_without_initialise( + ode_problem: &OdeSolverProblem, + ) -> Result where Eqn: OdeEquations, { @@ -104,24 +208,74 @@ pub trait OdeSolverState: Clone + Sized { let h = ode_problem.h0; let y = ode_problem.eqn.init().call(t); let dy = V::zeros(y.len()); - let nparams = ode_problem.eqn.rhs().nparams(); - let (s, ds) = if ode_problem.eqn_sens.is_none() { - (vec![], vec![]) + let (s, ds) = (vec![], vec![]); + let (dg, g) = if ode_problem.integrate_out { + let out = ode_problem + .eqn + .out() + .ok_or(ode_solver_error!(StateProblemMismatch))?; + (out.call(&y, t), V::zeros(out.nout())) } else { - let eqn_sens = ode_problem.eqn_sens.as_ref().unwrap(); - eqn_sens.init().update_state(t); - let mut s = Vec::with_capacity(nparams); - let mut ds = Vec::with_capacity(nparams); - for i in 0..nparams { - eqn_sens.init().set_param_index(i); - let si = eqn_sens.init().call(t); - let dsi = V::zeros(y.len()); - s.push(si); - ds.push(dsi); + (V::zeros(0), V::zeros(0)) + }; + let (sg, dsg) = (vec![], vec![]); + let state = StateCommon { + y, + dy, + g, + dg, + s, + ds, + sg, + dsg, + t, + h, + }; + Ok(Self::new_from_common(state)) + } + + fn new_without_initialise_augmented( + ode_problem: &OdeSolverProblem, + augmented_eqn: &mut AugmentedEqn, + ) -> Result + where + Eqn: OdeEquations, + AugmentedEqn: AugmentedOdeEquations, + { + let mut state = Self::new_without_initialise(ode_problem)?.into_common(); + let naug = augmented_eqn.max_index(); + let mut s = Vec::with_capacity(naug); + let mut ds = Vec::with_capacity(naug); + let nstates = augmented_eqn.rhs().nstates(); + for i in 0..naug { + augmented_eqn.set_index(i); + let si = augmented_eqn.init().call(state.t); + let dsi = V::zeros(nstates); + s.push(si); + ds.push(dsi); + } + state.s = s; + state.ds = ds; + let (dsg, sg) = if augmented_eqn.out().is_some() { + let mut sg = Vec::with_capacity(naug); + let mut dsg = Vec::with_capacity(naug); + for i in 0..naug { + augmented_eqn.set_index(i); + let out = augmented_eqn + .out() + .ok_or(ode_solver_error!(StateProblemMismatch))?; + let dsgi = out.call(&state.s[i], state.t); + let sgi = V::zeros(out.nout()); + sg.push(sgi); + dsg.push(dsgi); } - (s, ds) + (dsg, sg) + } else { + (vec![], vec![]) }; - Self::new_internal_state(y, dy, s, ds, t, h) + state.sg = sg; + state.dsg = dsg; + Ok(Self::new_from_common(state)) } /// Calculate a consistent state and time derivative of the state, based on the equations of the problem. @@ -131,76 +285,73 @@ pub trait OdeSolverState: Clone + Sized { root_solver: &mut S, ) -> Result<(), DiffsolError> where - Eqn: OdeEquations, - S: NonLinearSolver> + ?Sized, + Eqn: OdeEquationsImplicit, + S: NonLinearSolver, { - let t = self.t(); - let (y, dy) = self.y_dy_mut(); - ode_problem.eqn.rhs().call_inplace(y, t, dy); + let state = self.as_mut(); + ode_problem + .eqn + .rhs() + .call_inplace(state.y, *state.t, state.dy); if ode_problem.eqn.mass().is_none() { return Ok(()); } - let f = Rc::new(InitOp::new(&ode_problem.eqn, ode_problem.t0, y)); + let f = InitOp::new(&ode_problem.eqn, ode_problem.t0, state.y); let rtol = ode_problem.rtol; let atol = ode_problem.atol.clone(); - let init_problem = SolverProblem::new(f.clone(), atol, rtol); - root_solver.set_problem(&init_problem); - let mut y_tmp = dy.clone(); - y_tmp.copy_from_indices(y, &init_problem.f.algebraic_indices); + root_solver.set_problem(&f, rtol, atol); + let mut y_tmp = state.dy.clone(); + y_tmp.copy_from_indices(state.y, &f.algebraic_indices); let yerr = y_tmp.clone(); - root_solver.solve_in_place(&mut y_tmp, t, &yerr)?; - f.scatter_soln(&y_tmp, y, dy); + root_solver.reset_jacobian(&f, &y_tmp, *state.t); + root_solver.solve_in_place(&f, &mut y_tmp, *state.t, &yerr)?; + f.scatter_soln(&y_tmp, state.y, state.dy); Ok(()) } /// Calculate the initial sensitivity vectors and their time derivatives, based on the equations of the problem. /// Note that this function assumes that the state is already consistent with the algebraic constraints /// (either via [Self::set_consistent] or by setting the state up manually). - fn set_consistent_sens( + fn set_consistent_augmented( &mut self, ode_problem: &OdeSolverProblem, + mut augmented_eqn: AugmentedEqn, root_solver: &mut S, - ) -> Result<(), DiffsolError> + ) -> Result where - Eqn: OdeEquations, - S: NonLinearSolver>> + ?Sized, + Eqn: OdeEquationsImplicit, + AugmentedEqn: AugmentedOdeEquationsImplicit + std::fmt::Debug, + S: NonLinearSolver, { - if ode_problem.eqn_sens.is_none() { - return Ok(()); - } - - let eqn_sens = ode_problem.eqn_sens.as_ref().unwrap(); - eqn_sens.rhs().update_state(self.y(), self.dy(), self.t()); - let t = self.t(); - let (s, ds) = self.s_ds_mut(); - for i in 0..ode_problem.eqn.rhs().nparams() { - eqn_sens.init().set_param_index(i); - eqn_sens.rhs().set_param_index(i); - eqn_sens.rhs().call_inplace(&s[i], t, &mut ds[i]); + let state = self.as_mut(); + augmented_eqn.update_rhs_out_state(state.y, state.dy, *state.t); + let naug = augmented_eqn.max_index(); + for i in 0..naug { + augmented_eqn.set_index(i); + augmented_eqn + .rhs() + .call_inplace(&state.s[i], *state.t, &mut state.ds[i]); } if ode_problem.eqn.mass().is_none() { - return Ok(()); + return Ok(augmented_eqn); } - for i in 0..ode_problem.eqn.rhs().nparams() { - eqn_sens.init().set_param_index(i); - eqn_sens.rhs().set_param_index(i); - let f = Rc::new(InitOp::new(eqn_sens, ode_problem.t0, &self.s()[i])); - root_solver.set_problem(&SolverProblem::new( - f.clone(), - ode_problem.atol.clone(), - ode_problem.rtol, - )); + let mut augmented_eqn_rc = Rc::new(augmented_eqn); + + for i in 0..naug { + Rc::get_mut(&mut augmented_eqn_rc).unwrap().set_index(i); + let f = InitOp::new(&augmented_eqn_rc, ode_problem.t0, &state.s[i]); + root_solver.set_problem(&f, ode_problem.rtol, ode_problem.atol.clone()); - let mut y = self.ds()[i].clone(); - y.copy_from_indices(self.y(), &f.algebraic_indices); + let mut y = state.ds[i].clone(); + y.copy_from_indices(state.y, &f.algebraic_indices); let yerr = y.clone(); - root_solver.solve_in_place(&mut y, self.t(), &yerr)?; - let (s, ds) = self.s_ds_mut(); - f.scatter_soln(&y, &mut s[i], &mut ds[i]); + root_solver.reset_jacobian(&f, &y, *state.t); + root_solver.solve_in_place(&f, &mut y, *state.t, &yerr)?; + f.scatter_soln(&y, &mut state.s[i], &mut state.ds[i]); } - Ok(()) + Ok(Rc::try_unwrap(augmented_eqn_rc).unwrap()) } /// compute size of first step based on alg in Hairer, Norsett, Wanner @@ -212,61 +363,65 @@ pub trait OdeSolverState: Clone + Sized { where Eqn: OdeEquations, { - let y0 = self.y(); - let t0 = self.t(); - let f0 = self.dy(); - - let rtol = ode_problem.rtol; - let atol = ode_problem.atol.as_ref(); - - let d0 = y0.squared_norm(y0, atol, rtol).sqrt(); - let d1 = f0.squared_norm(y0, atol, rtol).sqrt(); - - let h0 = if d0 < Eqn::T::from(1e-5) || d1 < Eqn::T::from(1e-5) { - Eqn::T::from(1e-6) - } else { - Eqn::T::from(0.01) * (d0 / d1) - }; - - // make sure we preserve the sign of h0 let is_neg_h = ode_problem.h0 < Eqn::T::zero(); + let (h0, h1) = { + let state = self.as_ref(); + let y0 = state.y; + let t0 = state.t; + let f0 = state.dy; - let f1 = if is_neg_h { - let y1 = f0.clone() * scale(-h0) + y0; - let t1 = t0 - h0; - ode_problem.eqn.rhs().call(&y1, t1) - } else { - let y1 = f0.clone() * scale(h0) + y0; - let t1 = t0 + h0; - ode_problem.eqn.rhs().call(&y1, t1) - }; + let rtol = ode_problem.rtol; + let atol = ode_problem.atol.as_ref(); - let df = f1 - f0; - let d2 = df.squared_norm(y0, atol, rtol).sqrt() / h0.abs(); + let d0 = y0.squared_norm(y0, atol, rtol).sqrt(); + let d1 = f0.squared_norm(y0, atol, rtol).sqrt(); - let mut max_d = d2; - if max_d < d1 { - max_d = d1; - } - let h1 = if max_d < Eqn::T::from(1e-15) { - let h1 = h0 * Eqn::T::from(1e-3); - if h1 < Eqn::T::from(1e-6) { + let h0 = if d0 < Eqn::T::from(1e-5) || d1 < Eqn::T::from(1e-5) { Eqn::T::from(1e-6) } else { - h1 + Eqn::T::from(0.01) * (d0 / d1) + }; + + // make sure we preserve the sign of h0 + let f1 = if is_neg_h { + let y1 = f0.clone() * scale(-h0) + y0; + let t1 = t0 - h0; + ode_problem.eqn.rhs().call(&y1, t1) + } else { + let y1 = f0.clone() * scale(h0) + y0; + let t1 = t0 + h0; + ode_problem.eqn.rhs().call(&y1, t1) + }; + + let df = f1 - f0; + let d2 = df.squared_norm(y0, atol, rtol).sqrt() / h0.abs(); + + let mut max_d = d2; + if max_d < d1 { + max_d = d1; } - } else { - (Eqn::T::from(0.01) / max_d) - .pow(Eqn::T::one() / Eqn::T::from(1.0 + solver_order as f64)) + let h1 = if max_d < Eqn::T::from(1e-15) { + let h1 = h0 * Eqn::T::from(1e-3); + if h1 < Eqn::T::from(1e-6) { + Eqn::T::from(1e-6) + } else { + h1 + } + } else { + (Eqn::T::from(0.01) / max_d) + .pow(Eqn::T::one() / Eqn::T::from(1.0 + solver_order as f64)) + }; + (h0, h1) }; - *self.h_mut() = Eqn::T::from(100.0) * h0; - if self.h() > h1 { - *self.h_mut() = h1; + let state = self.as_mut(); + *state.h = Eqn::T::from(100.0) * h0; + if *state.h > h1 { + *state.h = h1; } if is_neg_h { - *self.h_mut() = -self.h(); + *state.h = -*state.h; } } } diff --git a/src/ode_solver/sundials.rs b/src/ode_solver/sundials.rs index 4c2e996c..3379a092 100644 --- a/src/ode_solver/sundials.rs +++ b/src/ode_solver/sundials.rs @@ -9,7 +9,7 @@ use crate::{ IDA_RES_FAIL, IDA_ROOT_RETURN, IDA_RTFUNC_FAIL, IDA_SUCCESS, IDA_TOO_MUCH_ACC, IDA_TOO_MUCH_WORK, IDA_TSTOP_RETURN, IDA_YA_YDP_INIT, }, - SdirkState, + SdirkState, StateRef, StateRefMut, }; use num_traits::Zero; use serde::Serialize; @@ -20,8 +20,8 @@ use std::{ use crate::{ error::*, matrix::sparsity::MatrixSparsityRef, ode_solver_error, scale, LinearOp, Matrix, - NonLinearOp, OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverStopReason, Op, - SundialsMatrix, SundialsVector, Vector, + NonLinearOp, NonLinearOpJacobian, OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, + OdeSolverState, OdeSolverStopReason, Op, SundialsMatrix, SundialsVector, Vector, }; #[cfg(not(sundials_version_major = "5"))] @@ -102,7 +102,7 @@ impl SundialsStatistics { struct SundialsData where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, { eqn: Rc, rhs_jac: SundialsMatrix, @@ -111,7 +111,7 @@ where impl SundialsData where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, { fn new(eqn: Rc) -> Self { let n = eqn.rhs().nstates(); @@ -131,7 +131,7 @@ where pub struct SundialsIda where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, { ida_mem: *mut c_void, linear_solver: SUNLinearSolver, @@ -146,7 +146,7 @@ where impl SundialsIda where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, { extern "C" fn residual( t: realtype, @@ -263,7 +263,7 @@ where impl Default for SundialsIda where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, { fn default() -> Self { Self::new() @@ -272,7 +272,7 @@ where impl Drop for SundialsIda where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, { fn drop(&mut self) { if !self.linear_solver.is_null() { @@ -284,7 +284,7 @@ where impl OdeSolverMethod for SundialsIda where - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, { type State = SdirkState; @@ -299,17 +299,17 @@ where self.problem.as_ref() } - fn state(&self) -> Option<&Self::State> { - self.state.as_ref() + fn state(&self) -> Option> { + self.state.as_ref().map(|s| s.as_ref()) } fn order(&self) -> usize { 1 } - fn state_mut(&mut self) -> Option<&mut Self::State> { + fn state_mut(&mut self) -> Option> { self.is_state_modified = true; - self.state.as_mut() + self.state.as_mut().map(|s| s.as_mut()) } fn take_state(&mut self) -> Option { @@ -378,10 +378,6 @@ where // set jacobian function Self::check(unsafe { IDASetJacFn(ida_mem, Some(Self::jacobian)) }).unwrap(); - // sensitivities - if self.problem.as_ref().unwrap().eqn_sens.is_some() { - panic!("Sensitivities not implemented for sundials solver"); - } Ok(()) } @@ -455,11 +451,12 @@ where Ok(ret) } - fn interpolate_sens( - &self, - _t: ::T, - ) -> Result::V>, DiffsolError> { - Ok(vec![]) + fn interpolate_out(&self, _t: Eqn::T) -> Result { + unimplemented!() + } + + fn interpolate_sens(&self, _t: Eqn::T) -> Result, DiffsolError> { + unimplemented!() } } @@ -474,7 +471,9 @@ mod test { heat2d::head2d_problem, robertson::robertson, }, - tests::{test_interpolate, test_no_set_problem, test_ode_solver, test_state_mut}, + tests::{ + test_interpolate, test_no_set_problem, test_ode_solver_no_sens, test_state_mut, + }, }, OdeEquations, Op, SundialsIda, SundialsMatrix, }; @@ -497,7 +496,7 @@ mod test { fn test_sundials_exponential_decay() { let mut s = crate::SundialsIda::default(); let (problem, soln) = exponential_decay_problem::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 18 @@ -511,6 +510,7 @@ mod test { number_of_calls: 65 number_of_jac_muls: 36 number_of_matrix_evals: 18 + number_of_jac_adj_muls: 0 "###); } @@ -518,7 +518,7 @@ mod test { fn test_sundials_robertson() { let mut s = crate::SundialsIda::default(); let (problem, soln) = robertson::(false); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 59 @@ -532,6 +532,7 @@ mod test { number_of_calls: 509 number_of_jac_muls: 180 number_of_matrix_evals: 60 + number_of_jac_adj_muls: 0 "###); } @@ -540,7 +541,7 @@ mod test { let foodweb_context = FoodWebContext::default(); let mut s = crate::SundialsIda::default(); let (problem, soln) = foodweb_problem::(&foodweb_context); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 42 @@ -554,7 +555,7 @@ mod test { fn test_sundials_heat2d() { let mut s = crate::SundialsIda::default(); let (problem, soln) = head2d_problem::(); - test_ode_solver(&mut s, &problem, soln, None, false); + test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- number_of_linear_solver_setups: 42 @@ -568,6 +569,7 @@ mod test { number_of_calls: 217 number_of_jac_muls: 4300 number_of_matrix_evals: 43 + number_of_jac_adj_muls: 0 "###); } } diff --git a/src/ode_solver/test.rs b/src/ode_solver/test.rs new file mode 100644 index 00000000..4b8c1fc8 --- /dev/null +++ b/src/ode_solver/test.rs @@ -0,0 +1,58 @@ +type V = Vec; + +struct Data { + y: V, + g: V +} + +struct DataThingWrapperMut<'a> { + y: &'a mut V, + g: &'a mut V +} + +trait DataThing { + fn get_mut(&mut self) -> DataThingWrapperMut; +} + +impl DataThing for Data { + fn get_mut(&mut self) -> DataThingWrapperMut { + DataThingWrapperMut { + y: &mut self.y, + g: &mut self.g + } + } +} + +fn test1(a: &V) { +} + +fn test2(b: &mut V) { +} + +fn test3(a: &V, b: &mut V) { +} + +fn make_data() -> Data { + Data{ + y: vec![1, 2, 3], + g: vec![] + } +} + +fn main() { + { + let d1 = make_data(); + test1(&d1.get_mut().y); + } + + { + let mut d2 = make_data(); + test2(&mut d2.get_mut().g); + } + + { + let mut d3 = make_data(); + let mut d3_mut = d3.get_mut(); + test3(&d3_mut.y, &mut d3_mut.g); + } +} \ No newline at end of file diff --git a/src/ode_solver/test_models/dydt_y2.rs b/src/ode_solver/test_models/dydt_y2.rs index 7ed70154..6149ce29 100644 --- a/src/ode_solver/test_models/dydt_y2.rs +++ b/src/ode_solver/test_models/dydt_y2.rs @@ -1,6 +1,6 @@ use crate::{ - ode_solver::problem::OdeSolverSolution, scalar::scale, DenseMatrix, OdeBuilder, OdeEquations, - OdeSolverProblem, Vector, + ode_solver::problem::OdeSolverSolution, scalar::scale, DenseMatrix, OdeBuilder, + OdeEquationsImplicit, OdeSolverProblem, Vector, }; use num_traits::One; use std::ops::MulAssign; @@ -18,11 +18,12 @@ fn rhs_jac(x: &M::V, _p: &M::V, _t: M::T, v: &M::V, y: &mut M::V y.mul_assign(scale(M::T::from(2.))); } +#[allow(clippy::type_complexity)] pub fn dydt_y2_problem( use_coloring: bool, size: usize, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let size2 = size; diff --git a/src/ode_solver/test_models/exponential_decay.rs b/src/ode_solver/test_models/exponential_decay.rs index de0fc660..74cbf59e 100644 --- a/src/ode_solver/test_models/exponential_decay.rs +++ b/src/ode_solver/test_models/exponential_decay.rs @@ -1,46 +1,141 @@ use crate::{ - matrix::Matrix, ode_solver::problem::OdeSolverSolution, scalar::scale, ConstantOp, OdeBuilder, - OdeEquations, OdeSolverProblem, Vector, + matrix::Matrix, ode_solver::problem::OdeSolverSolution, + op::closure_with_adjoint::ClosureWithAdjoint, scalar::scale, ConstantClosureWithAdjoint, + ConstantOp, OdeBuilder, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, + OdeEquationsSens, OdeSolverEquations, OdeSolverProblem, UnitCallable, Vector, }; use nalgebra::ComplexField; -use num_traits::Zero; -use std::ops::MulAssign; +use num_traits::{One, Zero}; +use std::{ops::MulAssign, rc::Rc}; // exponential decay problem -// dy/dt = -ay (p = [a]) +// dy/dt = -ay (p = [a, y0]) fn exponential_decay(x: &M::V, p: &M::V, _t: M::T, y: &mut M::V) { y.copy_from(x); y.mul_assign(scale(-p[0])); } -// df/dp v = -yv (p = [a]) +// df/dp v = -yv (p = [a, y0]) +// df/dp = | -y 0 | +// | -y 0 | +// df/dp v = | -y 0 | |v_1| = |-yv_1| +// | -y 0 | |v_2| |-yv_1 | fn exponential_decay_sens(x: &M::V, _p: &M::V, _t: M::T, v: &M::V, y: &mut M::V) { y.copy_from(x); y.mul_assign(scale(-v[0])); } +// df/dp^T v = | -y -y | |v_1| = |-yv_1 - yv_2| +// | 0 0 | |v_2| | 0 | +fn exponential_decay_sens_transpose( + x: &M::V, + _p: &M::V, + _t: M::T, + v: &M::V, + y: &mut M::V, +) { + y[0] = x[0] * v[0] + x[1] * v[1]; + y[1] = M::T::zero(); +} + +// J = | -a 0 | +// | 0 -a | // Jv = -av fn exponential_decay_jacobian(_x: &M::V, p: &M::V, _t: M::T, v: &M::V, y: &mut M::V) { y.copy_from(v); y.mul_assign(scale(-p[0])); } +// -J^Tv = av +fn exponential_decay_jacobian_adjoint( + _x: &M::V, + p: &M::V, + _t: M::T, + v: &M::V, + y: &mut M::V, +) { + y.copy_from(v); + y.mul_assign(scale(p[0])); +} + fn exponential_decay_init(p: &M::V, _t: M::T) -> M::V { M::V::from_vec(vec![p[1], p[1]]) } -fn exponential_decay_init_sens(_p: &M::V, _t: M::T, _v: &M::V, y: &mut M::V) { - y.fill(M::T::zero()); +// dy0/dp = | 0 1 | +// | 0 1 | +// dy0/dp v = | 0 1 | |v_1| = |v_2| +// | 0 1 | |v_2| |v_2| +fn exponential_decay_init_sens(_p: &M::V, _t: M::T, v: &M::V, y: &mut M::V) { + y[0] = v[1]; + y[1] = v[1]; +} + +// dy0/dp^T v = | 0 0 | |v_1| = |0 | +// | 1 1 | |v_2| |v_1 + v_2| +fn exponential_decay_init_sens_adjoint(_p: &M::V, _t: M::T, v: &M::V, y: &mut M::V) { + y[0] = M::T::zero(); + y[1] = -v[0] - v[1]; } fn exponential_decay_root(x: &M::V, _p: &M::V, _t: M::T, y: &mut M::V) { y[0] = x[0] - M::T::from(0.6); } +/// g_1 = 1 * x_1 + 2 * x_2 +/// g_2 = 3 * x_1 + 4 * x_2 +fn exponential_decay_out(x: &M::V, _p: &M::V, _t: M::T, y: &mut M::V) { + y[0] = M::T::from(1.0) * x[0] + M::T::from(2.0) * x[1]; + y[1] = M::T::from(3.0) * x[0] + M::T::from(4.0) * x[1]; +} + +/// J = |1 2| +/// |3 4| +/// J v = |1 2| |v_1| = |v_1 + 2v_2| +/// |3 4| |v_2| |3v_1 + 4v_2| +fn exponential_decay_out_jac_mul( + _x: &M::V, + _p: &M::V, + _t: M::T, + v: &M::V, + y: &mut M::V, +) { + y[0] = v[0] + M::T::from(2.0) * v[1]; + y[1] = M::T::from(3.0) * v[0] + M::T::from(4.0) * v[1]; +} + +/// J = |1 2| +/// |3 4| +/// -J^T v = |-1 -3| |v_1| = |-v_1 - 3v_2| +/// |-2 -4| |v_2| |-2v_1 - 4v_2| +fn exponential_decay_out_adj_mul( + _x: &M::V, + _p: &M::V, + _t: M::T, + v: &M::V, + y: &mut M::V, +) { + y[0] = -v[0] - M::T::from(3.0) * v[1]; + y[1] = -M::T::from(2.0) * v[0] - M::T::from(4.0) * v[1]; +} + +/// J = |0 0| +/// |0 0| +fn exponential_decay_out_sens_adj( + _x: &M::V, + _p: &M::V, + _t: M::T, + _v: &M::V, + y: &mut M::V, +) { + y.fill(M::T::zero()); +} + +#[allow(clippy::type_complexity)] pub fn negative_exponential_decay_problem( use_coloring: bool, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let h = -1.0; @@ -70,10 +165,11 @@ pub fn negative_exponential_decay_problem( (problem, soln) } +#[allow(clippy::type_complexity)] pub fn exponential_decay_problem( use_coloring: bool, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let h = 1.0; @@ -100,10 +196,11 @@ pub fn exponential_decay_problem( (problem, soln) } +#[allow(clippy::type_complexity)] pub fn exponential_decay_problem_with_root( use_coloring: bool, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let k = 0.1; @@ -130,18 +227,133 @@ pub fn exponential_decay_problem_with_root( (problem, soln) } +#[allow(clippy::type_complexity)] +pub fn exponential_decay_problem_adjoint() -> ( + OdeSolverProblem>, + OdeSolverSolution, +) { + let k = M::T::from(0.1); + let y0 = M::T::from(1.0); + let t0 = M::T::from(0.0); + let h0 = M::T::from(1.0); + let p = Rc::new(M::V::from_vec(vec![k, y0])); + let init = exponential_decay_init::; + let y0 = init(&p, t0); + let nstates = y0.len(); + let rhs = exponential_decay::; + let rhs_jac = exponential_decay_jacobian::; + let rhs_adj_jac = exponential_decay_jacobian_adjoint::; + let rhs_sens_adj = exponential_decay_sens_transpose::; + let mut rhs = ClosureWithAdjoint::new( + rhs, + rhs_jac, + rhs_adj_jac, + rhs_sens_adj, + nstates, + nstates, + p.clone(), + ); + let nout = 2; + let out = exponential_decay_out::; + let out_jac = exponential_decay_out_jac_mul::; + let out_jac_adj = exponential_decay_out_adj_mul::; + let out_sens_adj = exponential_decay_out_sens_adj::; + let out = ClosureWithAdjoint::new( + out, + out_jac, + out_jac_adj, + out_sens_adj, + nstates, + nout, + p.clone(), + ); + let init = ConstantClosureWithAdjoint::new( + exponential_decay_init::, + exponential_decay_init_sens_adjoint::, + p.clone(), + ); + if M::is_sparse() { + rhs.calculate_jacobian_sparsity(&y0, t0); + rhs.calculate_adjoint_sparsity(&y0, t0); + } + let rhs = Rc::new(rhs); + let init = Rc::new(init); + let out = Some(Rc::new(out)); + let mass: Option>> = None; + let root: Option>> = None; + let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); + let rtol = M::T::from(1e-6); + let atol = M::V::from_element(nstates, M::T::from(1e-6)); + let out_rtol = Some(M::T::from(1e-6)); + let out_atol = Some(M::V::from_element(nout, M::T::from(1e-6))); + let param_rtol = Some(M::T::from(1e-6)); + let param_atol = Some(M::V::from_element(p.len(), M::T::from(1e-6))); + let sens_rtol = Some(M::T::from(1e-6)); + let sens_atol = Some(M::V::from_element(nstates, M::T::from(1e-6))); + let integrate_out = true; + let problem = OdeSolverProblem::new( + eqn, + rtol, + atol, + sens_rtol, + sens_atol, + out_rtol, + out_atol, + param_rtol, + param_atol, + t0, + h0, + integrate_out, + ) + .unwrap(); + let mut soln = OdeSolverSolution { + atol: problem.atol.as_ref().clone(), + rtol: problem.rtol, + ..Default::default() + }; + let t0 = M::T::from(0.0); + let t1 = M::T::from(9.0); + for i in 0..10 { + let t = M::T::from(i as f64); + let y0: M::V = problem.eqn.init().call(M::T::zero()); + let g = y0.clone() * scale((M::T::exp(-p[0] * t0) - M::T::exp(-p[0] * t)) / p[0]); + let g = M::V::from_vec(vec![ + g[0] + M::T::from(2.0) * g[1], + M::T::from(3.0) * g[0] + M::T::from(4.0) * g[1], + ]); + let dydk = y0.clone() + * scale( + M::T::exp(-p[0] * (t1 + t0)) + * (M::T::exp(t0 * p[0]) * (p[0] * t1 + M::T::one()) + - M::T::exp(t1 * p[0]) * (p[0] * t0 + M::T::one())) + / (p[0] * p[0]), + ); + let dydy0 = (M::T::exp(-p[0] * t0) - M::T::exp(-p[0] * t1)) / p[0]; + let dg1dk = dydk[0] + M::T::from(2.0) * dydk[1]; + let dg2dk = M::T::from(3.0) * dydk[0] + M::T::from(4.0) * dydk[1]; + let dg1dy0 = dydy0 + M::T::from(2.0) * dydy0; + let dg2dy0 = M::T::from(3.0) * dydy0 + M::T::from(4.0) * dydy0; + let dg1 = M::V::from_vec(vec![dg1dk, dg1dy0]); + let dg2 = M::V::from_vec(vec![dg2dk, dg2dy0]); + soln.push_sens(g, t, &[dg1, dg2]); + } + (problem, soln) +} + +#[allow(clippy::type_complexity)] pub fn exponential_decay_problem_sens( use_coloring: bool, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let k = 0.1; let y0 = 1.0; let problem = OdeBuilder::new() .p([k, y0]) + .sens_rtol(Some(1e-6)) + .sens_atol(Some([1e-6, 1e-6])) .use_coloring(use_coloring) - .sensitivities_error_control(true) .build_ode_with_sens( exponential_decay::, exponential_decay_jacobian::, diff --git a/src/ode_solver/test_models/exponential_decay_with_algebraic.rs b/src/ode_solver/test_models/exponential_decay_with_algebraic.rs index 08ddd67b..2e118d67 100644 --- a/src/ode_solver/test_models/exponential_decay_with_algebraic.rs +++ b/src/ode_solver/test_models/exponential_decay_with_algebraic.rs @@ -1,10 +1,18 @@ use crate::{ - matrix::Matrix, ode_solver::problem::OdeSolverSolution, scalar::scale, OdeBuilder, - OdeEquations, OdeSolverProblem, Vector, + matrix::Matrix, + ode_solver::problem::OdeSolverSolution, + op::{ + closure_with_sens::ClosureWithSens, constant_closure_with_sens::ConstantClosureWithSens, + linear_closure_with_adjoint::LinearClosureWithAdjoint, + }, + scalar::scale, + ClosureWithAdjoint, ConstantClosureWithAdjoint, ConstantOp, LinearClosure, OdeBuilder, + OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsSens, OdeSolverEquations, + OdeSolverProblem, UnitCallable, Vector, }; use nalgebra::ComplexField; use num_traits::{One, Zero}; -use std::ops::MulAssign; +use std::{ops::MulAssign, rc::Rc}; // exponential decay problem with algebraic constraint // dy/dt = -ay @@ -18,6 +26,12 @@ fn exponential_decay_with_algebraic(x: &M::V, p: &M::V, _t: M::T, mut y[nstates - 1] = x[nstates - 1] - x[nstates - 2]; } +// J = | -y[0] | +// | -y[1] | +// | 0 | +// Jv = | -y[0]v[0] | +// | -y[1]v[1] | +// | 0 | #[allow(unused_mut)] fn exponential_decay_with_algebraic_sens( x: &M::V, @@ -32,7 +46,24 @@ fn exponential_decay_with_algebraic_sens( y[nstates - 1] = M::T::zero(); } -// Jv = [[-av, 0], [-1, 1]]v = [-av, -v[0] + v[1]] +// -J^Tv = | y[0]v[0] + y[1]v[1] + 0 | +#[allow(unused_mut)] +fn exponential_decay_with_algebraic_sens_adjoint( + x: &M::V, + _p: &M::V, + _t: M::T, + v: &M::V, + mut y: &mut M::V, +) { + y[0] = x[0] * v[0] + x[1] * v[1]; +} + +// J = | -a, 0, 0 | +// | 0, -a, 0 | +// | 0, -1, 1 | +// Jv = | -av[0] | +// | -av[1] | +// | v[2] - v[1] | #[allow(unused_mut)] fn exponential_decay_with_algebraic_jacobian( _x: &M::V, @@ -47,6 +78,21 @@ fn exponential_decay_with_algebraic_jacobian( y[nstates - 1] = v[nstates - 1] - v[nstates - 2]; } +// -J^T v = | av[0] | +// | av[1] + v[2] | +// | -v[2] | +fn exponential_decay_with_algebraic_adjoint( + _x: &M::V, + p: &M::V, + _t: M::T, + v: &M::V, + y: &mut M::V, +) { + y[0] = p[0] * v[0]; + y[1] = p[0] * v[1] + v[2]; + y[2] = -v[2]; +} + // y = Mx + beta * y = | 1 0 | | x[0] | + beta | y[0] | // | 0 0 | | x[1] | | y[1] | fn exponential_decay_with_algebraic_mass( @@ -62,14 +108,19 @@ fn exponential_decay_with_algebraic_mass( y[nstates - 1] = yn; } -fn exponential_decay_with_algebraic_mass_sens( - _x: &M::V, +// y = M^T x + beta * y = | 1 0 | | x[0] | + beta | y[0] | +// | 0 0 | | x[1] | | y[1] | +fn exponential_decay_with_algebraic_mass_transpose( + x: &M::V, _p: &M::V, _t: M::T, - _v: &M::V, + beta: M::T, y: &mut M::V, ) { - y.fill(M::T::zero()); + let nstates = y.len(); + let yn = beta * y[nstates - 1]; + y.axpy(M::T::one(), x, beta); + y[nstates - 1] = yn; } fn exponential_decay_with_algebraic_init(_p: &M::V, _t: M::T) -> M::V { @@ -85,10 +136,75 @@ fn exponential_decay_with_algebraic_init_sens( y.fill(M::T::zero()); } +fn exponential_decay_with_algebraic_init_sens_adjoint( + _p: &M::V, + _t: M::T, + _v: &M::V, + y: &mut M::V, +) { + y.fill(M::T::zero()); +} + +// out(x) = | a * x[2] | +fn exponential_decay_with_algebraic_out(x: &M::V, p: &M::V, _t: M::T, y: &mut M::V) { + y[0] = p[0] * x[2]; +} + +// J = | 0 0 a | +// Jv = | a * v[2] | +fn exponential_decay_with_algebraic_out_jac_mul( + _x: &M::V, + p: &M::V, + _t: M::T, + v: &M::V, + y: &mut M::V, +) { + y[0] = p[0] * v[2]; +} + +// J = | 0 0 a | +// -J^T v = | 0 | +// | 0 | +// | -a * v[0] | +fn exponential_decay_with_algebraic_out_jac_adj_mul( + _x: &M::V, + p: &M::V, + _t: M::T, + v: &M::V, + y: &mut M::V, +) { + y.fill(M::T::zero()); + y[2] = -p[0] * v[0]; +} + +// J = | x[2] | +// Jv = | x[2]v[0] | +//fn exponential_decay_with_algebraic_out_sens( +// x: &M::V, +// _p: &M::V, +// _t: M::T, +// v: &M::V, +// y: &mut M::V, +//) { +// y[0] = x[1] * v[1]; +//} + +// -J^T v = | -x[2]v[2] | +fn exponential_decay_with_algebraic_out_sens_adj( + x: &M::V, + _p: &M::V, + _t: M::T, + v: &M::V, + y: &mut M::V, +) { + y[0] = -x[2] * v[0]; +} + +#[allow(clippy::type_complexity)] pub fn exponential_decay_with_algebraic_problem( use_coloring: bool, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let p = M::V::from_vec(vec![0.1.into()]); @@ -113,27 +229,172 @@ pub fn exponential_decay_with_algebraic_problem( (problem, soln) } -pub fn exponential_decay_with_algebraic_problem_sens( - use_coloring: bool, -) -> ( - OdeSolverProblem>, +#[allow(clippy::type_complexity)] +pub fn exponential_decay_with_algebraic_adjoint_problem() -> ( + OdeSolverProblem>, OdeSolverSolution, ) { - let p = M::V::from_vec(vec![0.1.into()]); - let problem = OdeBuilder::new() - .p([0.1]) - .use_coloring(use_coloring) - .sensitivities_error_control(true) - .build_ode_with_mass_and_sens( - exponential_decay_with_algebraic::, - exponential_decay_with_algebraic_jacobian::, - exponential_decay_with_algebraic_sens::, - exponential_decay_with_algebraic_mass::, - exponential_decay_with_algebraic_mass_sens::, - exponential_decay_with_algebraic_init::, - exponential_decay_with_algebraic_init_sens::, - ) - .unwrap(); + let a = M::T::from(0.1); + let t0 = M::T::from(0.0); + let h0 = M::T::from(1.0); + let p = Rc::new(M::V::from_vec(vec![a])); + let init = exponential_decay_with_algebraic_init::; + let y0 = init(&p, t0); + let nstates = y0.len(); + let rhs = exponential_decay_with_algebraic::; + let rhs_jac = exponential_decay_with_algebraic_jacobian::; + let rhs_adj_jac = exponential_decay_with_algebraic_adjoint::; + let rhs_sens_adj = exponential_decay_with_algebraic_sens_adjoint::; + let mut rhs = ClosureWithAdjoint::new( + rhs, + rhs_jac, + rhs_adj_jac, + rhs_sens_adj, + nstates, + nstates, + p.clone(), + ); + let nout = 1; + let out = exponential_decay_with_algebraic_out::; + let out_jac = exponential_decay_with_algebraic_out_jac_mul::; + let out_jac_adj = exponential_decay_with_algebraic_out_jac_adj_mul::; + let out_sens_adj = exponential_decay_with_algebraic_out_sens_adj::; + let out = ClosureWithAdjoint::new( + out, + out_jac, + out_jac_adj, + out_sens_adj, + nstates, + nout, + p.clone(), + ); + let init = ConstantClosureWithAdjoint::new( + exponential_decay_with_algebraic_init::, + exponential_decay_with_algebraic_init_sens_adjoint::, + p.clone(), + ); + let mut mass = LinearClosureWithAdjoint::new( + exponential_decay_with_algebraic_mass::, + exponential_decay_with_algebraic_mass_transpose::, + nstates, + nstates, + p.clone(), + ); + if M::is_sparse() { + rhs.calculate_jacobian_sparsity(&y0, t0); + rhs.calculate_adjoint_sparsity(&y0, t0); + mass.calculate_sparsity(t0); + mass.calculate_adjoint_sparsity(t0); + } + let rhs = Rc::new(rhs); + let init = Rc::new(init); + let out = Some(Rc::new(out)); + + let root: Option>> = None; + let mass = Some(Rc::new(mass)); + let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); + let rtol = M::T::from(1e-6); + let atol = M::V::from_element(nstates, M::T::from(1e-6)); + let out_rtol = Some(M::T::from(1e-6)); + let out_atol = Some(M::V::from_element(nout, M::T::from(1e-6))); + let param_rtol = Some(M::T::from(1e-6)); + let param_atol = Some(M::V::from_element(1, M::T::from(1e-6))); + let sens_atol = Some(M::V::from_element(nstates, M::T::from(1e-6))); + let sens_rtol = Some(M::T::from(1e-6)); + let integrate_out = true; + let problem = OdeSolverProblem::new( + eqn, + rtol, + atol, + sens_rtol, + sens_atol, + out_rtol, + out_atol, + param_rtol, + param_atol, + t0, + h0, + integrate_out, + ) + .unwrap(); + let atol_out = M::V::from_element(nout, M::T::from(1e-6)); + let mut soln = OdeSolverSolution { + atol: atol_out, + rtol: problem.rtol, + ..Default::default() + }; + let t0 = M::T::from(0.0); + let t1 = M::T::from(9.0); + for i in 0..10 { + let t = M::T::from(i as f64); + let y0 = M::V::from_vec(vec![1.0.into(), 1.0.into(), 1.0.into()]); + let g = y0.clone() * scale((M::T::exp(-p[0] * t0) - M::T::exp(-p[0] * t)) / p[0]); + let g = M::V::from_vec(vec![p[0] * g[2]]); + let dgdk = t1 * M::T::exp(-p[0] * t1); + let dg = M::V::from_vec(vec![dgdk]); + soln.push_sens(g, t, &[dg]); + } + (problem, soln) +} + +#[allow(clippy::type_complexity)] +pub fn exponential_decay_with_algebraic_problem_sens() -> ( + OdeSolverProblem>, + OdeSolverSolution, +) { + let p = Rc::new(M::V::from_vec(vec![0.1.into()])); + let mut rhs = ClosureWithSens::new( + exponential_decay_with_algebraic::, + exponential_decay_with_algebraic_jacobian::, + exponential_decay_with_algebraic_sens::, + 3, + 3, + p.clone(), + ); + let mut mass = LinearClosure::new(exponential_decay_with_algebraic_mass::, 3, 3, p.clone()); + let init = ConstantClosureWithSens::new( + exponential_decay_with_algebraic_init::, + exponential_decay_with_algebraic_init_sens::, + 3, + 3, + p.clone(), + ); + let t0 = M::T::zero(); + + if M::is_sparse() { + let y0 = init.call(t0); + rhs.calculate_jacobian_sparsity(&y0, t0); + rhs.calculate_sens_sparsity(&y0, t0); + mass.calculate_sparsity(t0); + } + + let out: Option>> = None; + let root: Option>> = None; + let eqn = OdeSolverEquations::new( + Rc::new(rhs), + Some(Rc::new(mass)), + root, + Rc::new(init), + out, + p.clone(), + ); + let sens_rtol = Some(M::T::from(1e-6)); + let sens_atol = Some(M::V::from_element(3, M::T::from(1e-6))); + let problem = OdeSolverProblem::new( + eqn, + M::T::from(1e-6), + M::V::from_element(3, M::T::from(1e-6)), + sens_rtol, + sens_atol, + None, + None, + None, + None, + t0, + M::T::from(1.0), + false, + ) + .unwrap(); let mut soln = OdeSolverSolution::default(); for i in 0..10 { diff --git a/src/ode_solver/test_models/foodweb.rs b/src/ode_solver/test_models/foodweb.rs index 3d8ce8dc..8a5cd09b 100644 --- a/src/ode_solver/test_models/foodweb.rs +++ b/src/ode_solver/test_models/foodweb.rs @@ -1,9 +1,10 @@ use std::rc::Rc; use crate::{ - find_non_zeros_linear, find_non_zeros_nonlinear, ode_solver::problem::OdeSolverSolution, - ConstantOp, JacobianColoring, LinearOp, Matrix, MatrixSparsity, NonLinearOp, OdeEquations, - OdeSolverProblem, Op, UnitCallable, Vector, + find_jacobian_non_zeros, find_matrix_non_zeros, ode_solver::problem::OdeSolverSolution, + ConstantOp, JacobianColoring, LinearOp, Matrix, MatrixSparsity, NonLinearOp, + NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit, OdeSolverProblem, Op, UnitCallable, + Vector, }; use num_traits::Zero; @@ -136,11 +137,12 @@ pub fn foodweb_diffsl_compile( diffsl_context.recompile(code.as_str()).unwrap(); } +#[allow(clippy::type_complexity)] #[cfg(feature = "diffsl")] pub fn foodweb_diffsl_problem( diffsl_context: &crate::DiffSlContext, ) -> ( - OdeSolverProblem + '_>, + OdeSolverProblem + '_>, OdeSolverSolution, ) where @@ -341,7 +343,7 @@ macro_rules! impl_op { context_consts!(FoodWebInit); impl_op!(FoodWebInit); -impl<'a, M, const NX: usize> ConstantOp for FoodWebInit<'a, M, NX> +impl ConstantOp for FoodWebInit<'_, M, NX> where M: Matrix, { @@ -392,7 +394,7 @@ where sparsity: None, coloring: None, }; - let non_zeros = find_non_zeros_nonlinear(&ret, y0, t0); + let non_zeros = find_jacobian_non_zeros(&ret, y0, t0); ret.sparsity = Some( MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()).unwrap(), ); @@ -401,7 +403,7 @@ where } } -impl<'a, M, const NX: usize> Op for FoodWebRhs<'a, M, NX> +impl Op for FoodWebRhs<'_, M, NX> where M: Matrix, { @@ -423,7 +425,7 @@ where } } -impl<'a, M, const NX: usize> NonLinearOp for FoodWebRhs<'a, M, NX> +impl NonLinearOp for FoodWebRhs<'_, M, NX> where M: Matrix, { @@ -509,7 +511,12 @@ where } } } +} +impl NonLinearOpJacobian for FoodWebRhs<'_, M, NX> +where + M: Matrix, +{ #[allow(unused_mut)] fn jac_mul_inplace(&self, x: &M::V, _t: M::T, v: &M::V, mut y: &mut M::V) { let nsmx: usize = NUM_SPECIES * NX; @@ -619,7 +626,7 @@ where sparsity: None, coloring: None, }; - let non_zeros = find_non_zeros_linear(&ret, t0); + let non_zeros = find_matrix_non_zeros(&ret, t0); ret.sparsity = Some( MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()).unwrap(), ); @@ -628,7 +635,7 @@ where } } -impl<'a, M, const NX: usize> Op for FoodWebMass<'a, M, NX> +impl Op for FoodWebMass<'_, M, NX> where M: Matrix, { @@ -650,7 +657,7 @@ where } } -impl<'a, M, const NX: usize> LinearOp for FoodWebMass<'a, M, NX> +impl LinearOp for FoodWebMass<'_, M, NX> where M: Matrix, { @@ -684,7 +691,7 @@ where context_consts!(FoodWebOut); -impl<'a, M, const NX: usize> Op for FoodWebOut<'a, M, NX> +impl Op for FoodWebOut<'_, M, NX> where M: Matrix, { @@ -703,7 +710,7 @@ where } } -impl<'a, M, const NX: usize> NonLinearOp for FoodWebOut<'a, M, NX> +impl NonLinearOp for FoodWebOut<'_, M, NX> where M: Matrix, { @@ -721,7 +728,12 @@ where y[2 * is + 1] = x[loc_br + is]; } } +} +impl NonLinearOpJacobian for FoodWebOut<'_, M, NX> +where + M: Matrix, +{ #[allow(unused_mut)] fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, mut y: &mut Self::V) { let nsmx: usize = NUM_SPECIES * NX; @@ -817,7 +829,7 @@ where { pub fn new(y0: &M::V, t0: M::T) -> Self { let mut ret = Self { sparsity: None }; - let non_zeros = find_non_zeros_nonlinear(&ret, y0, t0); + let non_zeros = find_jacobian_non_zeros(&ret, y0, t0); ret.sparsity = Some( MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()).unwrap(), ); @@ -889,7 +901,13 @@ where } } } +} +#[cfg(feature = "diffsl")] +impl NonLinearOpJacobian for FoodWebDiff +where + M: Matrix, +{ #[allow(unused_mut)] fn jac_mul_inplace(&self, _x: &M::V, _t: M::T, v: &M::V, mut y: &mut M::V) { let nsmx: usize = NX; @@ -1001,10 +1019,11 @@ fn soln() -> OdeSolverSolution { soln } +#[allow(clippy::type_complexity)] pub fn foodweb_problem( context: &FoodWebContext, ) -> ( - OdeSolverProblem + '_>, + OdeSolverProblem + '_>, OdeSolverSolution, ) where @@ -1015,7 +1034,10 @@ where let t0 = M::T::zero(); let h0 = M::T::from(1.0); let eqn = FoodWeb::new(context, t0); - let problem = OdeSolverProblem::new(eqn, rtol, atol, t0, h0, false, false).unwrap(); + let problem = OdeSolverProblem::new( + eqn, rtol, atol, None, None, None, None, None, None, t0, h0, false, + ) + .unwrap(); let soln = soln::(); (problem, soln) } diff --git a/src/ode_solver/test_models/gaussian_decay.rs b/src/ode_solver/test_models/gaussian_decay.rs index ea55e303..42664ab3 100644 --- a/src/ode_solver/test_models/gaussian_decay.rs +++ b/src/ode_solver/test_models/gaussian_decay.rs @@ -1,6 +1,8 @@ use crate::ode_solver::problem::OdeSolverSolution; use crate::OdeSolverProblem; -use crate::{scalar::scale, ConstantOp, DenseMatrix, OdeBuilder, OdeEquations, Vector}; +use crate::{ + scalar::scale, ConstantOp, DenseMatrix, OdeBuilder, OdeEquations, OdeEquationsImplicit, Vector, +}; use nalgebra::ComplexField; use num_traits::Pow; use num_traits::Zero; @@ -20,11 +22,12 @@ fn gaussian_decay_jacobian(_x: &M::V, p: &M::V, t: M::T, v: &M:: y.mul_assign(scale(-t)); } +#[allow(clippy::type_complexity)] pub fn gaussian_decay_problem( use_coloring: bool, size: usize, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let size2 = size; diff --git a/src/ode_solver/test_models/heat2d.rs b/src/ode_solver/test_models/heat2d.rs index 2606cd56..f8b76578 100644 --- a/src/ode_solver/test_models/heat2d.rs +++ b/src/ode_solver/test_models/heat2d.rs @@ -7,14 +7,14 @@ //while for each boundary point, it is res_i = u_i. use crate::{ - ode_solver::problem::OdeSolverSolution, scalar::Scalar, Matrix, OdeBuilder, OdeEquations, - OdeSolverProblem, Vector, + ode_solver::problem::OdeSolverSolution, scalar::Scalar, Matrix, OdeBuilder, + OdeEquationsImplicit, OdeSolverProblem, Vector, }; use nalgebra::ComplexField; use num_traits::{One, Zero}; #[cfg(feature = "diffsl")] -use crate::{ConstantOp, LinearOp, NonLinearOp}; +use crate::{ConstantOp, LinearOp, NonLinearOpJacobian, OdeEquations}; #[cfg(feature = "diffsl")] pub fn heat2d_diffsl_compile< @@ -90,6 +90,7 @@ pub fn heat2d_diffsl_compile< context.recompile(code.as_str()).unwrap(); } +#[allow(clippy::type_complexity)] #[cfg(feature = "diffsl")] pub fn heat2d_diffsl_problem< M: Matrix + 'static, @@ -97,7 +98,7 @@ pub fn heat2d_diffsl_problem< >( context: &crate::DiffSlContext, ) -> ( - OdeSolverProblem + '_>, + OdeSolverProblem + '_>, OdeSolverSolution, ) { let problem = OdeBuilder::new() @@ -252,8 +253,9 @@ fn _pde_solution(x: T, y: T, t: T, max_terms: usize) -> T { u } +#[allow(clippy::type_complexity)] pub fn head2d_problem() -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let problem = OdeBuilder::new() @@ -305,7 +307,7 @@ fn soln() -> OdeSolverSolution { #[cfg(test)] mod tests { - use crate::{ConstantOp, LinearOp, NonLinearOp}; + use crate::{ConstantOp, LinearOp, NonLinearOpJacobian, OdeEquations}; use super::*; diff --git a/src/ode_solver/test_models/mod.rs b/src/ode_solver/test_models/mod.rs index 6c69a4a8..1693d99b 100644 --- a/src/ode_solver/test_models/mod.rs +++ b/src/ode_solver/test_models/mod.rs @@ -7,4 +7,3 @@ pub mod heat2d; pub mod robertson; pub mod robertson_ode; pub mod robertson_ode_with_sens; -pub mod robertson_sens; diff --git a/src/ode_solver/test_models/robertson.rs b/src/ode_solver/test_models/robertson.rs index b5d8613e..f11b82ad 100644 --- a/src/ode_solver/test_models/robertson.rs +++ b/src/ode_solver/test_models/robertson.rs @@ -1,7 +1,13 @@ +use std::rc::Rc; + use crate::{ - matrix::Matrix, ode_solver::problem::OdeSolverSolution, OdeBuilder, OdeEquations, - OdeSolverProblem, Vector, + matrix::Matrix, + ode_solver::problem::OdeSolverSolution, + op::{closure_with_sens::ClosureWithSens, constant_closure_with_sens::ConstantClosureWithSens}, + ConstantOp, LinearClosure, OdeBuilder, OdeEquationsImplicit, OdeEquationsSens, + OdeSolverEquations, OdeSolverProblem, UnitCallable, Vector, }; +use num_traits::Zero; #[cfg(feature = "diffsl")] pub fn robertson_diffsl_compile< @@ -44,6 +50,7 @@ pub fn robertson_diffsl_compile< context.recompile(code).unwrap(); } +#[allow(clippy::type_complexity)] #[cfg(feature = "diffsl")] pub fn robertson_diffsl_problem< M: Matrix + 'static, @@ -52,7 +59,7 @@ pub fn robertson_diffsl_problem< context: &crate::DiffSlContext, use_coloring: bool, ) -> ( - OdeSolverProblem + '_>, + OdeSolverProblem + '_>, OdeSolverSolution, ) { let problem = OdeBuilder::new() @@ -68,10 +75,48 @@ pub fn robertson_diffsl_problem< (problem, soln) } +//* dy1/dt = -.04*y1 + 1.e4*y2*y3 +//* dy2/dt = .04*y1 - 1.e4*y2*y3 - 3.e7*y2**2 +//* 0 = y1 + y2 + y3 - 1 +fn robertson_rhs(x: &M::V, p: &M::V, _t: M::T, y: &mut M::V) { + y[0] = -p[0] * x[0] + p[1] * x[1] * x[2]; + y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1]; + y[2] = x[0] + x[1] + x[2] - M::T::from(1.0); +} +fn robertson_jac_mul(x: &M::V, p: &M::V, _t: M::T, v: &M::V, y: &mut M::V) { + y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2]; + y[1] = p[0] * v[0] + - p[1] * v[1] * x[2] + - p[1] * x[1] * v[2] + - M::T::from(2.0) * p[2] * x[1] * v[1]; + y[2] = v[0] + v[1] + v[2]; +} + +fn robertson_sens_mul(x: &M::V, _p: &M::V, _t: M::T, v: &M::V, y: &mut M::V) { + y[0] = -v[0] * x[0] + v[1] * x[1] * x[2]; + y[1] = v[0] * x[0] - v[1] * x[1] * x[2] - v[2] * x[1] * x[1]; + y[2] = M::T::zero(); +} + +fn robertson_mass(x: &M::V, _p: &M::V, _t: M::T, beta: M::T, y: &mut M::V) { + y[0] = x[0] + beta * y[0]; + y[1] = x[1] + beta * y[1]; + y[2] = beta * y[2]; +} + +fn robertson_init(_p: &M::V, _t: M::T) -> M::V { + M::V::from_vec(vec![1.0.into(), 0.0.into(), 0.0.into()]) +} + +fn robertson_init_sens(_p: &M::V, _t: M::T, _v: &M::V, y: &mut M::V) { + y.fill(M::T::zero()); +} + +#[allow(clippy::type_complexity)] pub fn robertson( use_coloring: bool, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let problem = OdeBuilder::new() @@ -80,28 +125,10 @@ pub fn robertson( .atol([1.0e-8, 1.0e-6, 1.0e-6]) .use_coloring(use_coloring) .build_ode_with_mass( - //* dy1/dt = -.04*y1 + 1.e4*y2*y3 - //* dy2/dt = .04*y1 - 1.e4*y2*y3 - 3.e7*y2**2 - //* 0 = y1 + y2 + y3 - 1 - |x: &M::V, p: &M::V, _t: M::T, y: &mut M::V| { - y[0] = -p[0] * x[0] + p[1] * x[1] * x[2]; - y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1]; - y[2] = x[0] + x[1] + x[2] - M::T::from(1.0); - }, - |x: &M::V, p: &M::V, _t: M::T, v: &M::V, y: &mut M::V| { - y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2]; - y[1] = p[0] * v[0] - - p[1] * v[1] * x[2] - - p[1] * x[1] * v[2] - - M::T::from(2.0) * p[2] * x[1] * v[1]; - y[2] = v[0] + v[1] + v[2]; - }, - |x: &M::V, _p: &M::V, _t: M::T, beta: M::T, y: &mut M::V| { - y[0] = x[0] + beta * y[0]; - y[1] = x[1] + beta * y[1]; - y[2] = beta * y[2]; - }, - |_p: &M::V, _t: M::T| M::V::from_vec(vec![1.0.into(), 0.0.into(), 0.0.into()]), + robertson_rhs::, + robertson_jac_mul::, + robertson_mass::, + robertson_init::, ) .unwrap(); @@ -135,6 +162,96 @@ fn soln() -> OdeSolverSolution { soln } +#[allow(clippy::type_complexity)] +pub fn robertson_sens() -> ( + OdeSolverProblem>, + OdeSolverSolution, +) { + let p = Rc::new(M::V::from_vec(vec![ + M::T::from(0.04), + M::T::from(1.0e4), + M::T::from(3.0e7), + ])); + let mut rhs = ClosureWithSens::new( + robertson_rhs::, + robertson_jac_mul::, + robertson_sens_mul::, + 3, + 3, + p.clone(), + ); + let mut mass = LinearClosure::new(robertson_mass::, 3, 3, p.clone()); + let init = ConstantClosureWithSens::new( + robertson_init::, + robertson_init_sens::, + 3, + 3, + p.clone(), + ); + let t0 = M::T::zero(); + + if M::is_sparse() { + let y0 = init.call(t0); + rhs.calculate_jacobian_sparsity(&y0, t0); + rhs.calculate_sens_sparsity(&y0, t0); + mass.calculate_sparsity(t0); + } + + let out: Option>> = None; + let root: Option>> = None; + let eqn = OdeSolverEquations::new( + Rc::new(rhs), + Some(Rc::new(mass)), + root, + Rc::new(init), + out, + p.clone(), + ); + let rtol = M::T::from(1e-4); + let atol = M::V::from_vec(vec![M::T::from(1e-8), M::T::from(1e-6), M::T::from(1e-6)]); + let problem = OdeSolverProblem::new( + eqn, + rtol, + atol, + None, + None, + None, + None, + None, + None, + t0, + M::T::from(1.0), + false, + ) + .unwrap(); + + let mut soln = OdeSolverSolution::default(); + let data = vec![ + (vec![1.0, 0.0, 0.0], 0.0), + (vec![9.8517e-01, 3.3864e-05, 1.4794e-02], 0.4), + (vec![9.0553e-01, 2.2406e-05, 9.4452e-02], 4.0), + (vec![7.1579e-01, 9.1838e-06, 2.8420e-01], 40.0), + (vec![4.5044e-01, 3.2218e-06, 5.4956e-01], 400.0), + (vec![1.8320e-01, 8.9444e-07, 8.1680e-01], 4000.0), + (vec![3.8992e-02, 1.6221e-07, 9.6101e-01], 40000.0), + (vec![4.9369e-03, 1.9842e-08, 9.9506e-01], 400000.0), + (vec![5.1674e-04, 2.0684e-09, 9.9948e-01], 4000000.0), + (vec![5.2009e-05, 2.0805e-10, 9.9995e-01], 4.0000e+07), + (vec![5.2012e-06, 2.0805e-11, 9.9999e-01], 4.0000e+08), + (vec![5.1850e-07, 2.0740e-12, 1.0000e+00], 4.0000e+09), + (vec![4.8641e-08, 1.9456e-13, 1.0000e+00], 4.0000e+10), + ]; + + for (values, time) in data { + soln.push( + M::V::from_vec(values.into_iter().map(|v| v.into()).collect()), + time.into(), + ); + } + + (problem, soln) +} + /* ----------------------------------------------------------------- * Programmer(s): Allan Taylor, Alan Hindmarsh and * Radu Serban @ LLNL diff --git a/src/ode_solver/test_models/robertson_ode.rs b/src/ode_solver/test_models/robertson_ode.rs index 51429d9e..b653d215 100644 --- a/src/ode_solver/test_models/robertson_ode.rs +++ b/src/ode_solver/test_models/robertson_ode.rs @@ -1,14 +1,15 @@ use crate::{ - ode_solver::problem::OdeSolverSolution, Matrix, OdeBuilder, OdeEquations, OdeSolverProblem, - Vector, + ode_solver::problem::OdeSolverSolution, Matrix, OdeBuilder, OdeEquationsImplicit, + OdeSolverProblem, Vector, }; use num_traits::{One, Zero}; +#[allow(clippy::type_complexity)] pub fn robertson_ode( use_coloring: bool, ngroups: usize, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { const N: usize = 3; diff --git a/src/ode_solver/test_models/robertson_ode_with_sens.rs b/src/ode_solver/test_models/robertson_ode_with_sens.rs index 57f9cef1..6b0b58e3 100644 --- a/src/ode_solver/test_models/robertson_ode_with_sens.rs +++ b/src/ode_solver/test_models/robertson_ode_with_sens.rs @@ -1,13 +1,14 @@ use crate::{ - ode_solver::problem::OdeSolverSolution, Matrix, OdeBuilder, OdeEquations, OdeSolverProblem, + ode_solver::problem::OdeSolverSolution, Matrix, OdeBuilder, OdeEquationsSens, OdeSolverProblem, Vector, }; use num_traits::Zero; +#[allow(clippy::type_complexity)] pub fn robertson_ode_with_sens( use_coloring: bool, ) -> ( - OdeSolverProblem>, + OdeSolverProblem>, OdeSolverSolution, ) { let problem = OdeBuilder::new() diff --git a/src/ode_solver/test_models/robertson_sens.rs b/src/ode_solver/test_models/robertson_sens.rs deleted file mode 100644 index d55c2752..00000000 --- a/src/ode_solver/test_models/robertson_sens.rs +++ /dev/null @@ -1,174 +0,0 @@ -use crate::{ - matrix::Matrix, ode_solver::problem::OdeSolverSolution, OdeBuilder, OdeEquations, - OdeSolverProblem, Vector, -}; -use num_traits::Zero; - -pub fn robertson_sens( - use_coloring: bool, -) -> ( - OdeSolverProblem>, - OdeSolverSolution, -) { - let problem = OdeBuilder::new() - .p([0.04, 1.0e4, 3.0e7]) - .rtol(1e-4) - .atol([1.0e-8, 1.0e-6, 1.0e-6]) - .use_coloring(use_coloring) - .sensitivities_error_control(true) - .build_ode_with_mass_and_sens( - //* dy1/dt = -.04*y1 + 1.e4*y2*y3 - //* dy2/dt = .04*y1 - 1.e4*y2*y3 - 3.e7*y2**2 - //* 0 = y1 + y2 + y3 - 1 - |x: &M::V, p: &M::V, _t: M::T, y: &mut M::V| { - y[0] = -p[0] * x[0] + p[1] * x[1] * x[2]; - y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1]; - y[2] = x[0] + x[1] + x[2] - M::T::from(1.0); - }, - |x: &M::V, p: &M::V, _t: M::T, v: &M::V, y: &mut M::V| { - y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2]; - y[1] = p[0] * v[0] - - p[1] * v[1] * x[2] - - p[1] * x[1] * v[2] - - M::T::from(2.0) * p[2] * x[1] * v[1]; - y[2] = v[0] + v[1] + v[2]; - }, - |x: &M::V, _p: &M::V, _t: M::T, v: &M::V, y: &mut M::V| { - y[0] = -v[0] * x[0] + v[1] * x[1] * x[2]; - y[1] = v[0] * x[0] - v[1] * x[1] * x[2] - v[2] * x[1] * x[1]; - y[2] = M::T::zero(); - }, - |x: &M::V, _p: &M::V, _t: M::T, beta: M::T, y: &mut M::V| { - y[0] = x[0] + beta * y[0]; - y[1] = x[1] + beta * y[1]; - y[2] = beta * y[2]; - }, - |_x: &M::V, _p: &M::V, _t: M::T, _v: &M::V, y: &mut M::V| { - y.fill(M::T::zero()); - }, - |_p: &M::V, _t: M::T| M::V::from_vec(vec![1.0.into(), 0.0.into(), 0.0.into()]), - |_p: &M::V, _t: M::T, _v: &M::V, y: &mut M::V| { - y.fill(M::T::zero()); - }, - ) - .unwrap(); - - let mut soln = OdeSolverSolution::default(); - let data = vec![ - (vec![1.0, 0.0, 0.0], 0.0), - (vec![9.8517e-01, 3.3864e-05, 1.4794e-02], 0.4), - (vec![9.0553e-01, 2.2406e-05, 9.4452e-02], 4.0), - (vec![7.1579e-01, 9.1838e-06, 2.8420e-01], 40.0), - (vec![4.5044e-01, 3.2218e-06, 5.4956e-01], 400.0), - (vec![1.8320e-01, 8.9444e-07, 8.1680e-01], 4000.0), - (vec![3.8992e-02, 1.6221e-07, 9.6101e-01], 40000.0), - (vec![4.9369e-03, 1.9842e-08, 9.9506e-01], 400000.0), - (vec![5.1674e-04, 2.0684e-09, 9.9948e-01], 4000000.0), - (vec![5.2009e-05, 2.0805e-10, 9.9995e-01], 4.0000e+07), - (vec![5.2012e-06, 2.0805e-11, 9.9999e-01], 4.0000e+08), - (vec![5.1850e-07, 2.0740e-12, 1.0000e+00], 4.0000e+09), - (vec![4.8641e-08, 1.9456e-13, 1.0000e+00], 4.0000e+10), - ]; - - for (values, time) in data { - soln.push( - M::V::from_vec(values.into_iter().map(|v| v.into()).collect()), - time.into(), - ); - } - - (problem, soln) -} - -/* ----------------------------------------------------------------- - * Programmer(s): Allan Taylor, Alan Hindmarsh and - * Radu Serban @ LLNL - * ----------------------------------------------------------------- - * SUNDIALS Copyright Start - * Copyright (c) 2002-2023, Lawrence Livermore National Security - * and Southern Methodist University. - * All rights reserved. - * - * See the top-level LICENSE and NOTICE files for details. - * - * SPDX-License-Identifier: BSD-3-Clause - * SUNDIALS Copyright End - * ----------------------------------------------------------------- - * This simple example problem for IDA, due to Robertson, - * is from chemical kinetics, and consists of the following three - * equations: - * - * dy1/dt = -.04*y1 + 1.e4*y2*y3 - * dy2/dt = .04*y1 - 1.e4*y2*y3 - 3.e7*y2**2 - * 0 = y1 + y2 + y3 - 1 - * - * on the interval from t = 0.0 to t = 4.e10, with initial - * conditions: y1 = 1, y2 = y3 = 0. - * - * While integrating the system, we also use the rootfinding - * feature to find the points at which y1 = 1e-4 or at which - * y3 = 0.01. - * - * The problem is solved with IDA using the DENSE linear - * solver, with a user-supplied Jacobian. Output is printed at - * t = .4, 4, 40, ..., 4e10. - * -----------------------------------------------------------------*/ - -// Output from Sundials IDA serial example problem for Robertson kinetics: -// -//idaRoberts_dns: Robertson kinetics DAE serial example problem for IDA -// Three equation chemical kinetics problem. -// -//Linear solver: DENSE, with user-supplied Jacobian. -//Tolerance parameters: rtol = 0.0001 atol = 1e-08 1e-06 1e-06 -//Initial conditions y0 = (1 0 0) -//Constraints and id not used. -// -//----------------------------------------------------------------------- -// t y1 y2 y3 | nst k h -//----------------------------------------------------------------------- -//2.6402e-01 9.8997e-01 3.4706e-05 1.0000e-02 | 27 2 4.4012e-02 -// rootsfound[] = 0 1 -//4.0000e-01 9.8517e-01 3.3864e-05 1.4794e-02 | 29 3 8.8024e-02 -//4.0000e+00 9.0553e-01 2.2406e-05 9.4452e-02 | 43 4 6.3377e-01 -//4.0000e+01 7.1579e-01 9.1838e-06 2.8420e-01 | 68 4 3.1932e+00 -//4.0000e+02 4.5044e-01 3.2218e-06 5.4956e-01 | 95 4 3.3201e+01 -//4.0000e+03 1.8320e-01 8.9444e-07 8.1680e-01 | 126 3 3.1458e+02 -//4.0000e+04 3.8992e-02 1.6221e-07 9.6101e-01 | 161 5 2.5058e+03 -//4.0000e+05 4.9369e-03 1.9842e-08 9.9506e-01 | 202 3 2.6371e+04 -//4.0000e+06 5.1674e-04 2.0684e-09 9.9948e-01 | 250 3 1.7187e+05 -//2.0788e+07 1.0000e-04 4.0004e-10 9.9990e-01 | 280 5 1.0513e+06 -// rootsfound[] = -1 0 -//4.0000e+07 5.2009e-05 2.0805e-10 9.9995e-01 | 293 4 2.3655e+06 -//4.0000e+08 5.2012e-06 2.0805e-11 9.9999e-01 | 325 4 2.6808e+07 -//4.0000e+09 5.1850e-07 2.0740e-12 1.0000e+00 | 348 3 7.4305e+08 -//4.0000e+10 4.8641e-08 1.9456e-13 1.0000e+00 | 362 2 7.5480e+09 -// -//Final Statistics: -//Current time = 41226212070.53522 -//Steps = 362 -//Error test fails = 15 -//NLS step fails = 0 -//Initial step size = 2.164955286048077e-05 -//Last step size = 7548045540.281308 -//Current step size = 7548045540.281308 -//Last method order = 2 -//Current method order = 2 -//Residual fn evals = 537 -//IC linesearch backtrack ops = 0 -//NLS iters = 537 -//NLS fails = 5 -//NLS iters per step = 1.483425414364641 -//LS setups = 60 -//Jac fn evals = 60 -//LS residual fn evals = 0 -//Prec setup evals = 0 -//Prec solves = 0 -//LS iters = 0 -//LS fails = 0 -//Jac-times setups = 0 -//Jac-times evals = 0 -//LS iters per NLS iter = 0 -//Jac evals per NLS iter = 0.111731843575419 -//Prec evals per NLS iter = 0 -//Root fn evals = 404 diff --git a/src/op/bdf.rs b/src/op/bdf.rs index 951b1cb6..d211d310 100644 --- a/src/op/bdf.rs +++ b/src/op/bdf.rs @@ -1,6 +1,7 @@ use crate::{ - matrix::DenseMatrix, ode_solver::equations::OdeEquations, scale, LinearOp, Matrix, MatrixRef, - MatrixSparsity, MatrixSparsityRef, OdeSolverProblem, Vector, VectorRef, + matrix::DenseMatrix, ode_solver::equations::OdeEquationsImplicit, scale, LinearOp, Matrix, + MatrixRef, MatrixSparsity, MatrixSparsityRef, NonLinearOp, NonLinearOpJacobian, + OdeSolverProblem, Op, Vector, VectorRef, }; use num_traits::{One, Zero}; use std::ops::MulAssign; @@ -10,10 +11,8 @@ use std::{ rc::Rc, }; -use super::{NonLinearOp, Op}; - // callable to solve for F(y) = M (y' + psi) - c * f(y) = 0 -pub struct BdfCallable { +pub struct BdfCallable { eqn: Rc, psi_neg_y0: RefCell, c: RefCell, @@ -25,8 +24,26 @@ pub struct BdfCallable { sparsity: Option<::Sparsity>, } -impl BdfCallable { - pub fn from_eqn(eqn: &Rc) -> Self { +impl BdfCallable { + // F(y) = M (y - y0 + psi) - c * f(y) = 0 + // M = I + // dg = f(y) + // g - y0 + psi = c * dg + // g - y0 = c * dg - psi + pub fn integrate_out>( + &self, + dg: &Eqn::V, + diff: &M, + gamma: &[Eqn::T], + alpha: &[Eqn::T], + order: usize, + d: &mut Eqn::V, + ) { + self.set_psi(diff, gamma, alpha, order, d); + let c = self.c.borrow(); + d.axpy(*c, dg, -Eqn::T::one()); + } + pub fn from_sensitivity_eqn(eqn: &Rc) -> Self { let eqn = eqn.clone(); let n = eqn.rhs().nstates(); let c = RefCell::new(Eqn::T::zero()); @@ -49,6 +66,9 @@ impl BdfCallable { sparsity, } } + pub fn eqn_mut(&mut self) -> &mut Rc { + &mut self.eqn + } pub fn eqn(&self) -> &Rc { &self.eqn } @@ -133,21 +153,31 @@ impl BdfCallable { { self.c.replace(h * alpha); } - pub fn set_psi_and_y0>( + fn set_psi>( &self, diff: &M, gamma: &[Eqn::T], alpha: &[Eqn::T], order: usize, - y0: &Eqn::V, + psi: &mut Eqn::V, ) { // update psi term as defined in second equation on page 9 of [1] - let mut psi = self.psi_neg_y0.borrow_mut(); psi.axpy_v(gamma[1], &diff.column(1), Eqn::T::zero()); for (i, &gamma_i) in gamma.iter().enumerate().take(order + 1).skip(2) { psi.axpy_v(gamma_i, &diff.column(i), Eqn::T::one()); } psi.mul_assign(scale(alpha[order])); + } + pub fn set_psi_and_y0>( + &self, + diff: &M, + gamma: &[Eqn::T], + alpha: &[Eqn::T], + order: usize, + y0: &Eqn::V, + ) { + let mut psi = self.psi_neg_y0.borrow_mut(); + self.set_psi(diff, gamma, alpha, order, &mut psi); // now negate y0 psi.sub_assign(y0); @@ -157,7 +187,7 @@ impl BdfCallable { } } -impl Op for BdfCallable { +impl Op for BdfCallable { type V = Eqn::V; type T = Eqn::T; type M = Eqn::M; @@ -178,7 +208,7 @@ impl Op for BdfCallable { // dF(y)/dp = dM/dp (y - y0 + psi) + Ms - c * df(y)/dp - c df(y)/dy s = 0 // jac is M - c * df(y)/dy, same // callable to solve for F(y) = M (y' + psi) - f(y) = 0 -impl NonLinearOp for BdfCallable +impl NonLinearOp for BdfCallable where for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, @@ -201,6 +231,13 @@ where y.axpy(Eqn::T::one(), &tmp, -c); } } +} + +impl NonLinearOpJacobian for BdfCallable +where + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ // (M - c * f'(y)) v fn jac_mul_inplace(&self, x: &Eqn::V, t: Eqn::T, v: &Eqn::V, y: &mut Eqn::V) { self.eqn.rhs().jac_mul_inplace(x, t, v, y); @@ -244,8 +281,8 @@ where #[cfg(test)] mod tests { use crate::ode_solver::test_models::exponential_decay::exponential_decay_problem; - use crate::op::NonLinearOp; use crate::vector::Vector; + use crate::{NonLinearOp, NonLinearOpJacobian}; use super::BdfCallable; type Mcpu = nalgebra::DMatrix; diff --git a/src/op/closure.rs b/src/op/closure.rs index c3307102..04a9dcf8 100644 --- a/src/op/closure.rs +++ b/src/op/closure.rs @@ -1,11 +1,11 @@ use std::{cell::RefCell, rc::Rc}; use crate::{ - jacobian::{find_non_zeros_nonlinear, JacobianColoring}, - Matrix, MatrixSparsity, Vector, + find_jacobian_non_zeros, jacobian::JacobianColoring, Matrix, MatrixSparsity, NonLinearOp, + NonLinearOpJacobian, Op, Vector, }; -use super::{NonLinearOp, Op, OpStatistics}; +use super::OpStatistics; pub struct Closure where @@ -46,7 +46,7 @@ where } pub fn calculate_sparsity(&mut self, y0: &M::V, t0: M::T) { - let non_zeros = find_non_zeros_nonlinear(self, y0, t0); + let non_zeros = find_jacobian_non_zeros(self, y0, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -95,6 +95,14 @@ where self.statistics.borrow_mut().increment_call(); (self.func)(x, self.p.as_ref(), t, y) } +} + +impl NonLinearOpJacobian for Closure +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) { self.statistics.borrow_mut().increment_jac_mul(); (self.jacobian_action)(x, self.p.as_ref(), t, v, y) diff --git a/src/op/closure_no_jac.rs b/src/op/closure_no_jac.rs index a75623b4..08cb94b6 100644 --- a/src/op/closure_no_jac.rs +++ b/src/op/closure_no_jac.rs @@ -1,8 +1,8 @@ use std::{cell::RefCell, rc::Rc}; -use crate::{Matrix, Vector}; +use crate::{Matrix, NonLinearOp, Op, Vector}; -use super::{NonLinearOp, Op, OpStatistics}; +use super::OpStatistics; pub struct ClosureNoJac where @@ -70,7 +70,4 @@ where self.statistics.borrow_mut().increment_call(); (self.func)(x, self.p.as_ref(), t, y) } - fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, _y: &mut Self::V) { - unimplemented!() - } } diff --git a/src/op/closure_with_adjoint.rs b/src/op/closure_with_adjoint.rs new file mode 100644 index 00000000..e93c811e --- /dev/null +++ b/src/op/closure_with_adjoint.rs @@ -0,0 +1,218 @@ +use std::{cell::RefCell, rc::Rc}; + +use crate::{ + jacobian::{ + find_adjoint_non_zeros, find_jacobian_non_zeros, find_sens_adjoint_non_zeros, + JacobianColoring, + }, + Matrix, MatrixSparsity, NonLinearOp, NonLinearOpAdjoint, NonLinearOpJacobian, + NonLinearOpSensAdjoint, Op, Vector, +}; + +use super::OpStatistics; + +pub struct ClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + func: F, + jacobian_action: G, + jacobian_adjoint_action: H, + sens_adjoint_action: I, + nstates: usize, + nout: usize, + nparams: usize, + p: Rc, + coloring: Option>, + sparsity: Option, + sparsity_adjoint: Option, + coloring_adjoint: Option>, + sens_sparsity: Option, + coloring_sens_adjoint: Option>, + statistics: RefCell, +} + +impl ClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + pub fn new( + func: F, + jacobian_action: G, + jacobian_adjoint_action: H, + sens_adjoint_action: I, + nstates: usize, + nout: usize, + p: Rc, + ) -> Self { + let nparams = p.len(); + Self { + func, + jacobian_action, + jacobian_adjoint_action, + sens_adjoint_action, + nstates, + nout, + nparams, + p, + statistics: RefCell::new(OpStatistics::default()), + coloring: None, + sparsity: None, + sparsity_adjoint: None, + coloring_adjoint: None, + sens_sparsity: None, + coloring_sens_adjoint: None, + } + } + + pub fn calculate_jacobian_sparsity(&mut self, y0: &M::V, t0: M::T) { + let non_zeros = find_jacobian_non_zeros(self, y0, t0); + self.sparsity = Some( + MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) + .expect("invalid sparsity pattern"), + ); + self.coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + } + + pub fn calculate_adjoint_sparsity(&mut self, y0: &M::V, t0: M::T) { + let non_zeros = find_adjoint_non_zeros(self, y0, t0); + self.sparsity_adjoint = Some( + MatrixSparsity::try_from_indices(self.nstates, self.nout, non_zeros.clone()) + .expect("invalid sparsity pattern"), + ); + self.coloring_adjoint = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + } + + pub fn calculate_sens_adjoint_sparsity(&mut self, y0: &M::V, t0: M::T) { + let non_zeros = find_sens_adjoint_non_zeros(self, y0, t0); + self.sens_sparsity = Some( + MatrixSparsity::try_from_indices(self.nstates, self.nparams, non_zeros.clone()) + .expect("invalid sparsity pattern"), + ); + self.coloring_sens_adjoint = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + } +} + +impl Op for ClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + type V = M::V; + type T = M::T; + type M = M; + fn nstates(&self) -> usize { + self.nstates + } + fn nout(&self) -> usize { + self.nout + } + fn nparams(&self) -> usize { + self.nparams + } + fn set_params(&mut self, p: Rc) { + assert_eq!(p.len(), self.nparams); + self.p = p; + } + fn sparsity(&self) -> Option<::SparsityRef<'_>> { + self.sparsity.as_ref().map(|s| s.as_ref()) + } + fn sparsity_adjoint(&self) -> Option<::SparsityRef<'_>> { + self.sparsity_adjoint.as_ref().map(|s| s.as_ref()) + } + fn sparsity_sens_adjoint(&self) -> Option<::SparsityRef<'_>> { + self.sens_sparsity.as_ref().map(|s| s.as_ref()) + } + fn statistics(&self) -> OpStatistics { + self.statistics.borrow().clone() + } +} + +impl NonLinearOp for ClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) { + self.statistics.borrow_mut().increment_call(); + (self.func)(x, self.p.as_ref(), t, y) + } +} + +impl NonLinearOpJacobian for ClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) { + self.statistics.borrow_mut().increment_jac_mul(); + (self.jacobian_action)(x, self.p.as_ref(), t, v, y) + } + fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + self.statistics.borrow_mut().increment_matrix(); + if let Some(coloring) = self.coloring.as_ref() { + coloring.jacobian_inplace(self, x, t, y); + } else { + self._default_jacobian_inplace(x, t, y); + } + } +} + +impl NonLinearOpAdjoint for ClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { + self.statistics.borrow_mut().increment_jac_adj_mul(); + (self.jacobian_adjoint_action)(x, self.p.as_ref(), t, v, y); + } + + fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + if let Some(coloring) = self.coloring_adjoint.as_ref() { + coloring.adjoint_inplace(self, x, t, y); + } else { + self._default_adjoint_inplace(x, t, y); + } + } +} + +impl NonLinearOpSensAdjoint for ClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + fn sens_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { + (self.sens_adjoint_action)(_x, self.p.as_ref(), _t, _v, y); + } + fn sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + if let Some(coloring) = self.coloring_sens_adjoint.as_ref() { + coloring.sens_adjoint_inplace(self, x, t, y); + } else { + self._default_sens_adjoint_inplace(x, t, y); + } + } +} diff --git a/src/op/closure_with_sens.rs b/src/op/closure_with_sens.rs index 2b4389e0..acd79cb3 100644 --- a/src/op/closure_with_sens.rs +++ b/src/op/closure_with_sens.rs @@ -1,11 +1,11 @@ use std::{cell::RefCell, rc::Rc}; use crate::{ - jacobian::{find_non_zeros_nonlinear, JacobianColoring}, - Matrix, MatrixSparsity, Vector, + jacobian::{find_jacobian_non_zeros, find_sens_non_zeros, JacobianColoring}, + Matrix, MatrixSparsity, NonLinearOp, NonLinearOpJacobian, NonLinearOpSens, Op, Vector, }; -use super::{NonLinearOp, Op, OpStatistics}; +use super::OpStatistics; pub struct ClosureWithSens where @@ -22,7 +22,9 @@ where nparams: usize, p: Rc, coloring: Option>, + sens_coloring: Option>, sparsity: Option, + sens_sparsity: Option, statistics: RefCell, } @@ -53,17 +55,27 @@ where statistics: RefCell::new(OpStatistics::default()), coloring: None, sparsity: None, + sens_coloring: None, + sens_sparsity: None, } } - pub fn calculate_sparsity(&mut self, y0: &M::V, t0: M::T) { - let non_zeros = find_non_zeros_nonlinear(self, y0, t0); + pub fn calculate_jacobian_sparsity(&mut self, y0: &M::V, t0: M::T) { + let non_zeros = find_jacobian_non_zeros(self, y0, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), ); self.coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); } + pub fn calculate_sens_sparsity(&mut self, y0: &M::V, t0: M::T) { + let non_zeros = find_sens_non_zeros(self, y0, t0); + self.sens_sparsity = Some( + MatrixSparsity::try_from_indices(self.nout(), self.nparams, non_zeros.clone()) + .expect("invalid sparsity pattern"), + ); + self.sens_coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + } } impl Op for ClosureWithSens @@ -92,6 +104,9 @@ where fn sparsity(&self) -> Option<::SparsityRef<'_>> { self.sparsity.as_ref().map(|x| x.as_ref()) } + fn sparsity_sens(&self) -> Option<::SparsityRef<'_>> { + self.sens_sparsity.as_ref().map(|x| x.as_ref()) + } fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } @@ -108,16 +123,19 @@ where self.statistics.borrow_mut().increment_call(); (self.func)(x, self.p.as_ref(), t, y) } +} + +impl NonLinearOpJacobian for ClosureWithSens +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) { self.statistics.borrow_mut().increment_jac_mul(); (self.jacobian_action)(x, self.p.as_ref(), t, v, y) } - fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { - (self.sens_action)(x, self.p.as_ref(), t, v, y); - } - fn has_sens(&self) -> bool { - true - } fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { self.statistics.borrow_mut().increment_matrix(); if let Some(coloring) = self.coloring.as_ref() { @@ -127,3 +145,23 @@ where } } } + +impl NonLinearOpSens for ClosureWithSens +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), +{ + fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { + (self.sens_action)(x, self.p.as_ref(), t, v, y); + } + + fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + if let Some(coloring) = self.sens_coloring.as_ref() { + coloring.jacobian_inplace(self, x, t, y); + } else { + self._default_sens_inplace(x, t, y); + } + } +} diff --git a/src/op/constant_closure.rs b/src/op/constant_closure.rs index 1e426c44..2f0dd43a 100644 --- a/src/op/constant_closure.rs +++ b/src/op/constant_closure.rs @@ -1,9 +1,7 @@ use num_traits::Zero; use std::rc::Rc; -use crate::{Matrix, Vector}; - -use super::{ConstantOp, Op}; +use crate::{ConstantOp, Matrix, Op, Vector}; pub struct ConstantClosure where diff --git a/src/op/constant_closure_with_adjoint.rs b/src/op/constant_closure_with_adjoint.rs new file mode 100644 index 00000000..0369ad37 --- /dev/null +++ b/src/op/constant_closure_with_adjoint.rs @@ -0,0 +1,89 @@ +use num_traits::Zero; +use std::rc::Rc; + +use crate::{ConstantOp, ConstantOpSensAdjoint, Matrix, Op, Vector}; + +pub struct ConstantClosureWithAdjoint +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, + J: Fn(&M::V, M::T, &M::V, &mut M::V), +{ + func: I, + func_sens_adjoint: J, + nstates: usize, + nout: usize, + nparams: usize, + p: Rc, +} + +impl ConstantClosureWithAdjoint +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, + J: Fn(&M::V, M::T, &M::V, &mut M::V), +{ + pub fn new(func: I, func_sens_adjoint: J, p: Rc) -> Self { + let nparams = p.len(); + let y0 = (func)(p.as_ref(), M::T::zero()); + let nstates = y0.len(); + let nout = nstates; + Self { + func, + func_sens_adjoint, + nstates, + nout, + nparams, + p, + } + } +} + +impl Op for ConstantClosureWithAdjoint +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, + J: Fn(&M::V, M::T, &M::V, &mut M::V), +{ + type V = M::V; + type T = M::T; + type M = M; + fn nstates(&self) -> usize { + self.nstates + } + fn nout(&self) -> usize { + self.nout + } + fn nparams(&self) -> usize { + self.nparams + } + fn set_params(&mut self, p: Rc) { + assert_eq!(p.len(), self.nparams); + self.p = p; + } +} + +impl ConstantOp for ConstantClosureWithAdjoint +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, + J: Fn(&M::V, M::T, &M::V, &mut M::V), +{ + fn call_inplace(&self, t: Self::T, y: &mut Self::V) { + y.copy_from(&(self.func)(self.p.as_ref(), t)); + } + fn call(&self, t: Self::T) -> Self::V { + (self.func)(self.p.as_ref(), t) + } +} + +impl ConstantOpSensAdjoint for ConstantClosureWithAdjoint +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, + J: Fn(&M::V, M::T, &M::V, &mut M::V), +{ + fn sens_mul_transpose_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) { + (self.func_sens_adjoint)(self.p.as_ref(), t, v, y); + } +} diff --git a/src/op/constant_closure_with_sens.rs b/src/op/constant_closure_with_sens.rs index ccce9f23..45fd3c76 100644 --- a/src/op/constant_closure_with_sens.rs +++ b/src/op/constant_closure_with_sens.rs @@ -1,8 +1,6 @@ use std::rc::Rc; -use crate::{Matrix, Vector}; - -use super::{ConstantOp, Op}; +use crate::{ConstantOp, ConstantOpSens, Matrix, Op, Vector}; pub struct ConstantClosureWithSens where @@ -73,10 +71,15 @@ where fn call(&self, t: Self::T) -> Self::V { (self.func)(self.p.as_ref(), t) } +} + +impl ConstantOpSens for ConstantClosureWithSens +where + M: Matrix, + I: Fn(&M::V, M::T) -> M::V, + J: Fn(&M::V, M::T, &M::V, &mut M::V), +{ fn sens_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) { (self.func_sens)(self.p.as_ref(), t, v, y); } - fn has_sens(&self) -> bool { - true - } } diff --git a/src/op/constant_op.rs b/src/op/constant_op.rs new file mode 100644 index 00000000..e3b1513f --- /dev/null +++ b/src/op/constant_op.rs @@ -0,0 +1,54 @@ +use super::Op; +use crate::{Matrix, MatrixSparsityRef, Vector}; +use num_traits::{One, Zero}; + +pub trait ConstantOp: Op { + fn call_inplace(&self, t: Self::T, y: &mut Self::V); + fn call(&self, t: Self::T) -> Self::V { + let mut y = Self::V::zeros(self.nout()); + self.call_inplace(t, &mut y); + y + } +} + +pub trait ConstantOpSens: ConstantOp { + /// Compute the product of the gradient of F wrt a parameter vector p with a given vector `J_p(x, t) * v`. + /// Note that the vector v is of size nparams() and the result is of size nstates(). + fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, _y: &mut Self::V); + + /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. + /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], + /// but it can be overriden for more efficient implementations. + fn sens_inplace(&self, t: Self::T, y: &mut Self::M) { + self._default_sens_inplace(t, y); + } + + /// Default implementation of the gradient computation (this is the default for [Self::sens_inplace]). + fn _default_sens_inplace(&self, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nparams()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nparams() { + v[j] = Self::T::one(); + self.sens_mul_inplace(t, &v, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } + + /// Compute the gradient of the operator wrt a parameter vector p and return it. + /// See [Self::sens_inplace] for a non-allocating version. + fn sens(&self, t: Self::T) -> Self::M { + let n = self.nstates(); + let m = self.nparams(); + let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); + self.sens_inplace(t, &mut y); + y + } +} + +pub trait ConstantOpSensAdjoint: ConstantOp { + /// Compute the product of the transpose of the gradient of F wrt a parameter vector p with a given vector `-J_p^T(x, t) * v`. + /// Note that the vector v is of size nstates() and the result is of size nparam(). + fn sens_mul_transpose_inplace(&self, _t: Self::T, _v: &Self::V, _y: &mut Self::V); +} diff --git a/src/op/init.rs b/src/op/init.rs index 8065b47e..0cf83b61 100644 --- a/src/op/init.rs +++ b/src/op/init.rs @@ -1,4 +1,6 @@ -use crate::{ode_solver::equations::OdeEquations, scale, LinearOp, Matrix, Vector, VectorIndex}; +use crate::{ + scale, LinearOp, Matrix, NonLinearOpJacobian, OdeEquationsImplicit, Vector, VectorIndex, +}; use num_traits::{One, Zero}; use std::{cell::RefCell, rc::Rc}; @@ -8,7 +10,7 @@ use super::{NonLinearOp, Op}; /// /// We calculate consistent initial conditions following the approach of /// Brown, P. N., Hindmarsh, A. C., & Petzold, L. R. (1998). Consistent initial condition calculation for differential-algebraic systems. SIAM Journal on Scientific Computing, 19(5), 1495-1512. -pub struct InitOp { +pub struct InitOp { eqn: Rc, jac: Eqn::M, pub y0: RefCell, @@ -16,7 +18,7 @@ pub struct InitOp { neg_mass: Eqn::M, } -impl InitOp { +impl InitOp { pub fn new(eqn: &Rc, t0: Eqn::T, y0: &Eqn::V) -> Self { let eqn = eqn.clone(); let n = eqn.rhs().nstates(); @@ -68,7 +70,7 @@ impl InitOp { } } -impl Op for InitOp { +impl Op for InitOp { type V = Eqn::V; type T = Eqn::T; type M = Eqn::M; @@ -86,7 +88,7 @@ impl Op for InitOp { } } -impl NonLinearOp for InitOp { +impl NonLinearOp for InitOp { // -M_u du + f(u, v) // g(t, u, v) fn call_inplace(&self, x: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) { @@ -101,7 +103,9 @@ impl NonLinearOp for InitOp { // y = -M x + y self.neg_mass.gemv(Eqn::T::one(), x, Eqn::T::one(), y); } +} +impl NonLinearOpJacobian for InitOp { // J v fn jac_mul_inplace(&self, _x: &Eqn::V, _t: Eqn::T, v: &Eqn::V, y: &mut Eqn::V) { self.jac.gemv(Eqn::T::one(), v, Eqn::T::one(), y); @@ -118,8 +122,8 @@ mod tests { use crate::ode_solver::test_models::exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem; use crate::op::init::InitOp; - use crate::op::NonLinearOp; use crate::vector::Vector; + use crate::{NonLinearOp, NonLinearOpJacobian}; type Mcpu = nalgebra::DMatrix; type Vcpu = nalgebra::DVector; diff --git a/src/op/linear_closure.rs b/src/op/linear_closure.rs index 97089117..6b46bcf6 100644 --- a/src/op/linear_closure.rs +++ b/src/op/linear_closure.rs @@ -1,12 +1,11 @@ use std::{cell::RefCell, rc::Rc}; use crate::{ - jacobian::{find_non_zeros_linear, JacobianColoring}, - matrix::sparsity::MatrixSparsity, - Matrix, Vector, + find_matrix_non_zeros, jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity, LinearOp, + Matrix, Op, Vector, }; -use super::{LinearOp, Op, OpStatistics}; +use super::OpStatistics; pub struct LinearClosure where @@ -43,7 +42,7 @@ where } pub fn calculate_sparsity(&mut self, t0: M::T) { - let non_zeros = find_non_zeros_linear(self, t0); + let non_zeros = find_matrix_non_zeros(self, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), @@ -91,6 +90,7 @@ where self.statistics.borrow_mut().increment_call(); (self.func)(x, self.p.as_ref(), t, beta, y) } + fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { self.statistics.borrow_mut().increment_matrix(); if let Some(coloring) = &self.coloring { diff --git a/src/op/linear_closure_with_sens.rs b/src/op/linear_closure_with_adjoint.rs similarity index 51% rename from src/op/linear_closure_with_sens.rs rename to src/op/linear_closure_with_adjoint.rs index 0d9f92dc..594fdbd7 100644 --- a/src/op/linear_closure_with_sens.rs +++ b/src/op/linear_closure_with_adjoint.rs @@ -1,41 +1,42 @@ use std::{cell::RefCell, rc::Rc}; use crate::{ - jacobian::{find_non_zeros_linear, JacobianColoring}, - matrix::sparsity::MatrixSparsity, - Matrix, Vector, + find_matrix_non_zeros, find_transpose_non_zeros, jacobian::JacobianColoring, + matrix::sparsity::MatrixSparsity, LinearOp, LinearOpTranspose, Matrix, Op, Vector, }; -use super::{LinearOp, Op, OpStatistics}; +use super::OpStatistics; -pub struct LinearClosureWithSens +pub struct LinearClosureWithAdjoint where M: Matrix, F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), - H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { func: F, - func_sens: H, + func_adjoint: G, nstates: usize, nout: usize, nparams: usize, p: Rc, coloring: Option>, sparsity: Option, + coloring_adjoint: Option>, + sparsity_adjoint: Option, statistics: RefCell, } -impl LinearClosureWithSens +impl LinearClosureWithAdjoint where M: Matrix, F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), - H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { - pub fn new(func: F, func_sens: H, nstates: usize, nout: usize, p: Rc) -> Self { + pub fn new(func: F, func_adjoint: G, nstates: usize, nout: usize, p: Rc) -> Self { let nparams = p.len(); Self { func, - func_sens, + func_adjoint, nstates, statistics: RefCell::new(OpStatistics::default()), nout, @@ -43,24 +44,34 @@ where p, coloring: None, sparsity: None, + coloring_adjoint: None, + sparsity_adjoint: None, } } pub fn calculate_sparsity(&mut self, t0: M::T) { - let non_zeros = find_non_zeros_linear(self, t0); + let non_zeros = find_matrix_non_zeros(self, t0); self.sparsity = Some( MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), ); self.coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); } + pub fn calculate_adjoint_sparsity(&mut self, t0: M::T) { + let non_zeros = find_transpose_non_zeros(self, t0); + self.sparsity_adjoint = Some( + MatrixSparsity::try_from_indices(self.nstates, self.nout, non_zeros.clone()) + .expect("invalid sparsity pattern"), + ); + self.coloring_adjoint = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + } } -impl Op for LinearClosureWithSens +impl Op for LinearClosureWithAdjoint where M: Matrix, F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), - H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { type V = M::V; type T = M::T; @@ -82,21 +93,25 @@ where fn sparsity(&self) -> Option<::SparsityRef<'_>> { self.sparsity.as_ref().map(|s| s.as_ref()) } + fn sparsity_adjoint(&self) -> Option<::SparsityRef<'_>> { + self.sparsity_adjoint.as_ref().map(|s| s.as_ref()) + } fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } } -impl LinearOp for LinearClosureWithSens +impl LinearOp for LinearClosureWithAdjoint where M: Matrix, F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), - H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), { fn gemv_inplace(&self, x: &M::V, t: M::T, beta: M::T, y: &mut M::V) { self.statistics.borrow_mut().increment_call(); (self.func)(x, self.p.as_ref(), t, beta, y) } + fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { self.statistics.borrow_mut().increment_matrix(); if let Some(coloring) = &self.coloring { @@ -105,10 +120,22 @@ where self._default_matrix_inplace(t, y); } } - fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { - (self.func_sens)(self.p.as_ref(), x, t, v, y) +} + +impl LinearOpTranspose for LinearClosureWithAdjoint +where + M: Matrix, + F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V), +{ + fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) { + (self.func_adjoint)(x, self.p.as_ref(), t, beta, y) } - fn has_sens(&self) -> bool { - true + fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) { + if let Some(coloring) = &self.coloring_adjoint { + coloring.matrix_inplace(self, t, y); + } else { + self._default_transpose_inplace(t, y); + } } } diff --git a/src/op/linear_op.rs b/src/op/linear_op.rs new file mode 100644 index 00000000..43cd2709 --- /dev/null +++ b/src/op/linear_op.rs @@ -0,0 +1,124 @@ +use super::Op; +use crate::{Matrix, MatrixSparsityRef, Vector}; +use num_traits::{One, Zero}; + +/// LinearOp is a trait for linear operators (i.e. they only depend linearly on the input `x`), see [crate::NonLinearOp] for a non-linear op. +/// +/// An example of a linear operator is a matrix-vector product `y = A(t) * x`, where `A(t)` is a matrix. +/// It extends the [Op] trait with methods for calling the operator via a GEMV-like operation (i.e. `y = t * A * x + beta * y`), and for computing the matrix representation of the operator. +pub trait LinearOp: Op { + /// Compute the operator `y = A(t) * x` at a given state and time, the default implementation uses [Self::gemv_inplace]. + fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { + let beta = Self::T::zero(); + self.gemv_inplace(x, t, beta, y); + } + + /// Compute the operator via a GEMV operation (i.e. `y = A(t) * x + beta * y`) + fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V); + + /// Compute the matrix representation of the operator `A(t)` and return it. + /// See [Self::matrix_inplace] for a non-allocating version. + fn matrix(&self, t: Self::T) -> Self::M { + let mut y = Self::M::new_from_sparsity( + self.nstates(), + self.nstates(), + self.sparsity().map(|s| s.to_owned()), + ); + self.matrix_inplace(t, &mut y); + y + } + + /// Compute the matrix representation of the operator `A(t)` and store it in the matrix `y`. + /// The default implementation of this method computes the matrix using [Self::gemv_inplace], + /// but it can be overriden for more efficient implementations. + fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { + self._default_matrix_inplace(t, y); + } + + /// Default implementation of the matrix computation, see [Self::matrix_inplace]. + fn _default_matrix_inplace(&self, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nstates()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nstates() { + v[j] = Self::T::one(); + self.call_inplace(&v, t, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } +} + +pub trait LinearOpTranspose: LinearOp { + /// Compute the transpose of the operator via a GEMV operation (i.e. `y = A(t)^T * x + beta * y`) + fn gemv_transpose_inplace(&self, _x: &Self::V, _t: Self::T, _beta: Self::T, _y: &mut Self::V); + + /// Compute the transpose of the operator `y = A(t)^T * x` at a given state and time, the default implementation uses [Self::gemv_transpose_inplace]. + fn call_transpose_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { + let beta = Self::T::zero(); + self.gemv_transpose_inplace(x, t, beta, y); + } + + /// Compute the matrix representation of the transpose of the operator `A(t)^T` and store it in the matrix `y`. + /// The default implementation of this method computes the matrix using [Self::gemv_transpose_inplace], + /// but it can be overriden for more efficient implementations. + fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) { + self._default_transpose_inplace(t, y); + } + + /// Default implementation of the tranpose computation, see [Self::transpose_inplace]. + fn _default_transpose_inplace(&self, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nstates()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nstates() { + v[j] = Self::T::one(); + self.call_transpose_inplace(&v, t, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } +} + +pub trait LinearOpSens: LinearOp { + /// Compute the product of the gradient of F wrt a parameter vector p with a given vector `J_p(t) * x * v`. + /// Note that the vector v is of size nparams() and the result is of size nstates(). + /// Default implementation returns zero and panics if nparams() is not zero. + fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, _y: &mut Self::V); + + /// Compute the product of the partial gradient of F wrt a parameter vector p with a given vector `\parial F/\partial p(x, t) * v`, and return the result. + /// Use `[Self::sens_mul_inplace]` to for a non-allocating version. + fn sens_mul(&self, x: &Self::V, t: Self::T, v: &Self::V) -> Self::V { + let mut y = Self::V::zeros(self.nstates()); + self.sens_mul_inplace(x, t, v, &mut y); + y + } + + /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. + /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], + /// but it can be overriden for more efficient implementations. + fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + self._default_sens_inplace(x, t, y); + } + + /// Default implementation of the gradient computation (this is the default for [Self::sens_inplace]). + fn _default_sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nparams()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nparams() { + v[j] = Self::T::one(); + self.sens_mul_inplace(x, t, &v, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } + + /// Compute the gradient of the operator wrt a parameter vector p and return it. + /// See [Self::sens_inplace] for a non-allocating version. + fn sens(&self, x: &Self::V, t: Self::T) -> Self::M { + let n = self.nstates(); + let m = self.nparams(); + let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); + self.sens_inplace(x, t, &mut y); + y + } +} diff --git a/src/op/linearise.rs b/src/op/linearise.rs index 7c2f3960..de00db5c 100644 --- a/src/op/linearise.rs +++ b/src/op/linearise.rs @@ -1,18 +1,18 @@ use num_traits::One; use std::{cell::RefCell, rc::Rc}; -use crate::{Matrix, Vector}; +use crate::{LinearOp, Matrix, Op, Vector}; -use super::{LinearOp, NonLinearOp, Op}; +use super::nonlinear_op::NonLinearOpJacobian; -pub struct LinearisedOp { +pub struct LinearisedOp { callable: Rc, x: C::V, tmp: RefCell, x_is_set: bool, } -impl LinearisedOp { +impl LinearisedOp { pub fn new(callable: Rc) -> Self { let x = C::V::zeros(callable.nstates()); let tmp = RefCell::new(C::V::zeros(callable.nstates())); @@ -38,7 +38,7 @@ impl LinearisedOp { } } -impl Op for LinearisedOp { +impl Op for LinearisedOp { type V = C::V; type T = C::T; type M = C::M; @@ -56,7 +56,7 @@ impl Op for LinearisedOp { } } -impl LinearOp for LinearisedOp { +impl LinearOp for LinearisedOp { fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { self.callable.jac_mul_inplace(&self.x, t, x, y); } diff --git a/src/op/matrix.rs b/src/op/matrix.rs index 31736a93..92ea40a1 100644 --- a/src/op/matrix.rs +++ b/src/op/matrix.rs @@ -1,6 +1,4 @@ -use crate::matrix::Matrix; - -use super::{LinearOp, Op}; +use crate::{LinearOp, Matrix, Op}; pub struct MatrixOp { m: M, diff --git a/src/op/mod.rs b/src/op/mod.rs index 7889297e..0c377193 100644 --- a/src/op/mod.rs +++ b/src/op/mod.rs @@ -1,21 +1,26 @@ use std::rc::Rc; -use crate::{Matrix, MatrixSparsityRef, Scalar, Vector}; +use crate::{LinearOp, Matrix, NonLinearOp, Scalar, Vector}; -use num_traits::{One, Zero}; +use nonlinear_op::NonLinearOpJacobian; use serde::Serialize; pub mod bdf; pub mod closure; pub mod closure_no_jac; +pub mod closure_with_adjoint; pub mod closure_with_sens; pub mod constant_closure; +pub mod constant_closure_with_adjoint; pub mod constant_closure_with_sens; +pub mod constant_op; pub mod init; pub mod linear_closure; -pub mod linear_closure_with_sens; +pub mod linear_closure_with_adjoint; +pub mod linear_op; pub mod linearise; pub mod matrix; +pub mod nonlinear_op; pub mod sdirk; pub mod unit; @@ -50,11 +55,21 @@ pub trait Op { None } + /// Return sparsity information for the jacobian or matrix (if available) + fn sparsity_adjoint(&self) -> Option<::SparsityRef<'_>> { + None + } + /// Return sparsity information for the sensitivity of the operator wrt a parameter vector p (if available) fn sparsity_sens(&self) -> Option<::SparsityRef<'_>> { None } + /// Return sparsity information for the sensitivity of the operator wrt a parameter vector p (if available) + fn sparsity_sens_adjoint(&self) -> Option<::SparsityRef<'_>> { + None + } + /// Return statistics about the operator (e.g. how many times it was called, how many times the jacobian was computed, etc.) fn statistics(&self) -> OpStatistics { OpStatistics::default() @@ -66,6 +81,7 @@ pub struct OpStatistics { pub number_of_calls: usize, pub number_of_jac_muls: usize, pub number_of_matrix_evals: usize, + pub number_of_jac_adj_muls: usize, } impl OpStatistics { @@ -74,6 +90,7 @@ impl OpStatistics { number_of_jac_muls: 0, number_of_calls: 0, number_of_matrix_evals: 0, + number_of_jac_adj_muls: 0, } } @@ -85,269 +102,12 @@ impl OpStatistics { self.number_of_jac_muls += 1; } - pub fn increment_matrix(&mut self) { - self.number_of_matrix_evals += 1; - } -} - -// NonLinearOp is a trait that defines a nonlinear operator or function `F` that maps an input vector `x` to an output vector `y`, (i.e. `y = F(x, t)`). -// It extends the [Op] trait with methods for computing the operator and its Jacobian. -// -// The operator is defined by the [Self::call_inplace] method, which computes the function `F(x, t)` at a given state and time. -// The Jacobian is defined by the [Self::jac_mul_inplace] method, which computes the product of the Jacobian with a given vector `J(x, t) * v`. -pub trait NonLinearOp: Op { - /// Compute the operator `F(x, t)` at a given state and time. - fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V); - - /// Compute the product of the Jacobian with a given vector `J(x, t) * v`. - fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V); - - /// Compute the product of the gradient of F wrt a parameter vector p with a given vector `J_p(x, t) * v`. - /// Note that the vector v is of size nparams() and the result is of size nstates(). - /// Default implementation returns zero and panics if nparams() is not zero. - fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { - if self.nparams() != 0 { - panic!("sens_mul_inplace not implemented for non-zero parameters"); - } - y.fill(Self::T::zero()); - } - - fn has_sens(&self) -> bool { - false - } - - /// Compute the operator `F(x, t)` at a given state and time, and return the result. - /// Use `[Self::call_inplace]` to for a non-allocating version. - fn call(&self, x: &Self::V, t: Self::T) -> Self::V { - let mut y = Self::V::zeros(self.nout()); - self.call_inplace(x, t, &mut y); - y - } - - /// Compute the product of the Jacobian with a given vector `J(x, t) * v`, and return the result. - /// Use `[Self::jac_mul_inplace]` to for a non-allocating version. - fn jac_mul(&self, x: &Self::V, t: Self::T, v: &Self::V) -> Self::V { - let mut y = Self::V::zeros(self.nstates()); - self.jac_mul_inplace(x, t, v, &mut y); - y - } - - /// Compute the product of the partial gradient of F wrt a parameter vector p with a given vector `\parial F/\partial p(x, t) * v`, and return the result. - /// Use `[Self::sens_mul_inplace]` to for a non-allocating version. - fn sens_mul(&self, x: &Self::V, t: Self::T, v: &Self::V) -> Self::V { - let mut y = Self::V::zeros(self.nstates()); - self.sens_mul_inplace(x, t, v, &mut y); - y - } - - /// Compute the Jacobian matrix `J(x, t)` of the operator and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. - /// The default implementation of this method computes the Jacobian using [Self::jac_mul_inplace], - /// but it can be overriden for more efficient implementations. - fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - self._default_jacobian_inplace(x, t, y); - } - - /// Default implementation of the Jacobian computation (this is the default for [Self::jacobian_inplace]). - fn _default_jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - let mut v = Self::V::zeros(self.nstates()); - let mut col = Self::V::zeros(self.nout()); - for j in 0..self.nstates() { - v[j] = Self::T::one(); - self.jac_mul_inplace(x, t, &v, &mut col); - y.set_column(j, &col); - v[j] = Self::T::zero(); - } - } - - /// Compute the Jacobian matrix `J(x, t)` of the operator and return it. - /// See [Self::jacobian_inplace] for a non-allocating version. - fn jacobian(&self, x: &Self::V, t: Self::T) -> Self::M { - let n = self.nstates(); - let mut y = Self::M::new_from_sparsity(n, n, self.sparsity().map(|s| s.to_owned())); - self.jacobian_inplace(x, t, &mut y); - y + pub fn increment_jac_adj_mul(&mut self) { + self.number_of_jac_adj_muls += 1; } - /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. - /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], - /// but it can be overriden for more efficient implementations. - fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - self._default_sens_inplace(x, t, y); - } - - /// Default implementation of the gradient computation (this is the default for [Self::sens_inplace]). - fn _default_sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - let mut v = Self::V::zeros(self.nparams()); - let mut col = Self::V::zeros(self.nout()); - for j in 0..self.nparams() { - v[j] = Self::T::one(); - self.sens_mul_inplace(x, t, &v, &mut col); - y.set_column(j, &col); - v[j] = Self::T::zero(); - } - } - - /// Compute the gradient of the operator wrt a parameter vector p and return it. - /// See [Self::sens_inplace] for a non-allocating version. - fn sens(&self, x: &Self::V, t: Self::T) -> Self::M { - let n = self.nstates(); - let m = self.nparams(); - let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); - self.sens_inplace(x, t, &mut y); - y - } -} - -/// LinearOp is a trait for linear operators (i.e. they only depend linearly on the input `x`), see [NonLinearOp] for a non-linear op. -/// -/// An example of a linear operator is a matrix-vector product `y = A(t) * x`, where `A(t)` is a matrix. -/// It extends the [Op] trait with methods for calling the operator via a GEMV-like operation (i.e. `y = t * A * x + beta * y`), and for computing the matrix representation of the operator. -pub trait LinearOp: Op { - /// Compute the operator `y = A(t) * x` at a given state and time, the default implementation uses [Self::gemv_inplace]. - fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { - let beta = Self::T::zero(); - self.gemv_inplace(x, t, beta, y); - } - - fn has_sens(&self) -> bool { - false - } - - /// Compute the operator via a GEMV operation (i.e. `y = A(t) * x + beta * y`) - fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V); - - /// Compute the product of the gradient of F wrt a parameter vector p with a given vector `J_p(t) * x * v`. - /// Note that the vector v is of size nparams() and the result is of size nstates(). - /// Default implementation returns zero and panics if nparams() is not zero. - fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { - if self.nparams() != 0 { - panic!("sens_mul_inplace not implemented for non-zero parameters"); - } - y.fill(Self::T::zero()); - } - - /// Compute the product of the partial gradient of F wrt a parameter vector p with a given vector `\parial F/\partial p(x, t) * v`, and return the result. - /// Use `[Self::sens_mul_inplace]` to for a non-allocating version. - fn sens_mul(&self, x: &Self::V, t: Self::T, v: &Self::V) -> Self::V { - let mut y = Self::V::zeros(self.nstates()); - self.sens_mul_inplace(x, t, v, &mut y); - y - } - - /// Compute the matrix representation of the operator `A(t)` and return it. - /// See [Self::matrix_inplace] for a non-allocating version. - fn matrix(&self, t: Self::T) -> Self::M { - let mut y = Self::M::new_from_sparsity( - self.nstates(), - self.nstates(), - self.sparsity().map(|s| s.to_owned()), - ); - self.matrix_inplace(t, &mut y); - y - } - - /// Compute the matrix representation of the operator `A(t)` and store it in the matrix `y`. - /// The default implementation of this method computes the matrix using [Self::gemv_inplace], - /// but it can be overriden for more efficient implementations. - fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { - self._default_matrix_inplace(t, y); - } - - /// Default implementation of the matrix computation, see [Self::matrix_inplace]. - fn _default_matrix_inplace(&self, t: Self::T, y: &mut Self::M) { - let mut v = Self::V::zeros(self.nstates()); - let mut col = Self::V::zeros(self.nout()); - for j in 0..self.nstates() { - v[j] = Self::T::one(); - self.call_inplace(&v, t, &mut col); - y.set_column(j, &col); - v[j] = Self::T::zero(); - } - } - - /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. - /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], - /// but it can be overriden for more efficient implementations. - fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - self._default_sens_inplace(x, t, y); - } - - /// Default implementation of the gradient computation (this is the default for [Self::sens_inplace]). - fn _default_sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - let mut v = Self::V::zeros(self.nparams()); - let mut col = Self::V::zeros(self.nout()); - for j in 0..self.nparams() { - v[j] = Self::T::one(); - self.sens_mul_inplace(x, t, &v, &mut col); - y.set_column(j, &col); - v[j] = Self::T::zero(); - } - } - - /// Compute the gradient of the operator wrt a parameter vector p and return it. - /// See [Self::sens_inplace] for a non-allocating version. - fn sens(&self, x: &Self::V, t: Self::T) -> Self::M { - let n = self.nstates(); - let m = self.nparams(); - let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); - self.sens_inplace(x, t, &mut y); - y - } -} - -pub trait ConstantOp: Op { - fn call_inplace(&self, t: Self::T, y: &mut Self::V); - fn call(&self, t: Self::T) -> Self::V { - let mut y = Self::V::zeros(self.nout()); - self.call_inplace(t, &mut y); - y - } - - fn has_sens(&self) -> bool { - false - } - - /// Compute the product of the gradient of F wrt a parameter vector p with a given vector `J_p(x, t) * v`. - /// Note that the vector v is of size nparams() and the result is of size nstates(). - /// Default implementation returns zero and panics if nparams() is not zero. - fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, y: &mut Self::V) { - if self.nparams() != 0 { - panic!("sens_mul_inplace not implemented for non-zero parameters"); - } - y.fill(Self::T::zero()); - } - - /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. - /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], - /// but it can be overriden for more efficient implementations. - fn sens_inplace(&self, t: Self::T, y: &mut Self::M) { - self._default_sens_inplace(t, y); - } - - /// Default implementation of the gradient computation (this is the default for [Self::sens_inplace]). - fn _default_sens_inplace(&self, t: Self::T, y: &mut Self::M) { - let mut v = Self::V::zeros(self.nparams()); - let mut col = Self::V::zeros(self.nout()); - for j in 0..self.nparams() { - v[j] = Self::T::one(); - self.sens_mul_inplace(t, &v, &mut col); - y.set_column(j, &col); - v[j] = Self::T::zero(); - } - } - - /// Compute the gradient of the operator wrt a parameter vector p and return it. - /// See [Self::sens_inplace] for a non-allocating version. - fn sens(&self, t: Self::T) -> Self::M { - let n = self.nstates(); - let m = self.nparams(); - let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); - self.sens_inplace(t, &mut y); - y + pub fn increment_matrix(&mut self) { + self.number_of_matrix_evals += 1; } } @@ -370,6 +130,9 @@ impl NonLinearOp for &C { fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { C::call_inplace(*self, x, t, y) } +} + +impl NonLinearOpJacobian for &C { fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { C::jac_mul_inplace(*self, x, t, v, y) } diff --git a/src/op/nonlinear_op.rs b/src/op/nonlinear_op.rs new file mode 100644 index 00000000..068567bd --- /dev/null +++ b/src/op/nonlinear_op.rs @@ -0,0 +1,175 @@ +use super::Op; +use crate::{Matrix, MatrixSparsityRef, Vector}; +use num_traits::{One, Zero}; + +// NonLinearOp is a trait that defines a nonlinear operator or function `F` that maps an input vector `x` to an output vector `y`, (i.e. `y = F(x, t)`). +// It extends the [Op] trait with methods for computing the operator and its Jacobian. +// +// The operator is defined by the [Self::call_inplace] method, which computes the function `F(x, t)` at a given state and time. +// The Jacobian is defined by the [Self::jac_mul_inplace] method, which computes the product of the Jacobian with a given vector `J(x, t) * v`. +pub trait NonLinearOp: Op { + /// Compute the operator `F(x, t)` at a given state and time. + fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V); + + /// Compute the operator `F(x, t)` at a given state and time, and return the result. + /// Use `[Self::call_inplace]` to for a non-allocating version. + fn call(&self, x: &Self::V, t: Self::T) -> Self::V { + let mut y = Self::V::zeros(self.nout()); + self.call_inplace(x, t, &mut y); + y + } +} + +pub trait NonLinearOpSens: NonLinearOp { + /// Compute the product of the gradient of F wrt a parameter vector p with a given vector `J_p(x, t) * v`. + /// Note that the vector v is of size nparams() and the result is of size nstates(). + fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, _y: &mut Self::V); + + /// Compute the product of the partial gradient of F wrt a parameter vector p with a given vector `\parial F/\partial p(x, t) * v`, and return the result. + /// Use `[Self::sens_mul_inplace]` to for a non-allocating version. + fn sens_mul(&self, x: &Self::V, t: Self::T, v: &Self::V) -> Self::V { + let mut y = Self::V::zeros(self.nstates()); + self.sens_mul_inplace(x, t, v, &mut y); + y + } + + /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. + /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], + /// but it can be overriden for more efficient implementations. + fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + self._default_sens_inplace(x, t, y); + } + + /// Default implementation of the gradient computation (this is the default for [Self::sens_inplace]). + fn _default_sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nparams()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nparams() { + v[j] = Self::T::one(); + self.sens_mul_inplace(x, t, &v, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } + + /// Compute the gradient of the operator wrt a parameter vector p and return it. + /// See [Self::sens_inplace] for a non-allocating version. + fn sens(&self, x: &Self::V, t: Self::T) -> Self::M { + let n = self.nstates(); + let m = self.nparams(); + let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); + self.sens_inplace(x, t, &mut y); + y + } +} +pub trait NonLinearOpSensAdjoint: NonLinearOp { + /// Compute the product of the negative tramspose of the gradient of F wrt a parameter vector p with a given vector `-J_p(x, t)^T * v`. + fn sens_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, _y: &mut Self::V); + + /// Compute the negative transpose of the gradient of the operator wrt a parameter vector p and return it. + /// See [Self::sens_adjoint_inplace] for a non-allocating version. + fn sens_adjoint(&self, x: &Self::V, t: Self::T) -> Self::M { + let n = self.nstates(); + let mut y = + Self::M::new_from_sparsity(n, n, self.sparsity_sens_adjoint().map(|s| s.to_owned())); + self.sens_adjoint_inplace(x, t, &mut y); + y + } + + /// Compute the negative transpose of the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. + /// `y` should have been previously initialised using the output of [`Op::sparsity_sens_adjoint`]. + /// The default implementation of this method computes the gradient using [Self::sens_transpose_mul_inplace], + /// but it can be overriden for more efficient implementations. + fn sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + self._default_sens_adjoint_inplace(x, t, y); + } + + /// Default implementation of the gradient computation (this is the default for [Self::sens_adjoint_inplace]). + fn _default_sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nstates()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nstates() { + v[j] = Self::T::one(); + self.sens_transpose_mul_inplace(x, t, &v, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } +} +pub trait NonLinearOpAdjoint: NonLinearOp { + /// Compute the product of the transpose of the Jacobian with a given vector `-J(x, t)^T * v`. + /// The default implementation fails with a panic, as this method is not implemented by default + /// and should be implemented by the user if needed. + fn jac_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, _y: &mut Self::V); + + /// Compute the Adjoint matrix `-J^T(x, t)` of the operator and store it in the matrix `y`. + /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// The default implementation of this method computes the Jacobian using [Self::jac_transpose_mul_inplace], + /// but it can be overriden for more efficient implementations. + fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + self._default_adjoint_inplace(x, t, y); + } + + /// Default implementation of the Adjoint computation (this is the default for [Self::adjoint_inplace]). + fn _default_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nstates()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nstates() { + v[j] = Self::T::one(); + self.jac_transpose_mul_inplace(x, t, &v, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } + + /// Compute the Adjoint matrix `-J^T(x, t)` of the operator and return it. + /// See [Self::adjoint_inplace] for a non-allocating version. + fn adjoint(&self, x: &Self::V, t: Self::T) -> Self::M { + let n = self.nstates(); + let mut y = Self::M::new_from_sparsity(n, n, self.sparsity_adjoint().map(|s| s.to_owned())); + self.adjoint_inplace(x, t, &mut y); + y + } +} +pub trait NonLinearOpJacobian: NonLinearOp { + /// Compute the product of the Jacobian with a given vector `J(x, t) * v`. + fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V); + + /// Compute the product of the Jacobian with a given vector `J(x, t) * v`, and return the result. + /// Use `[Self::jac_mul_inplace]` to for a non-allocating version. + fn jac_mul(&self, x: &Self::V, t: Self::T, v: &Self::V) -> Self::V { + let mut y = Self::V::zeros(self.nstates()); + self.jac_mul_inplace(x, t, v, &mut y); + y + } + + /// Compute the Jacobian matrix `J(x, t)` of the operator and return it. + /// See [Self::jacobian_inplace] for a non-allocating version. + fn jacobian(&self, x: &Self::V, t: Self::T) -> Self::M { + let n = self.nstates(); + let mut y = Self::M::new_from_sparsity(n, n, self.sparsity().map(|s| s.to_owned())); + self.jacobian_inplace(x, t, &mut y); + y + } + + /// Compute the Jacobian matrix `J(x, t)` of the operator and store it in the matrix `y`. + /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// The default implementation of this method computes the Jacobian using [Self::jac_mul_inplace], + /// but it can be overriden for more efficient implementations. + fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + self._default_jacobian_inplace(x, t, y); + } + + /// Default implementation of the Jacobian computation (this is the default for [Self::jacobian_inplace]). + fn _default_jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nstates()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nstates() { + v[j] = Self::T::one(); + self.jac_mul_inplace(x, t, &v, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } +} diff --git a/src/op/sdirk.rs b/src/op/sdirk.rs index 726ddadf..5ceec27c 100644 --- a/src/op/sdirk.rs +++ b/src/op/sdirk.rs @@ -1,12 +1,14 @@ use crate::{ matrix::{MatrixRef, MatrixView}, ode_solver::equations::OdeEquations, - LinearOp, Matrix, MatrixSparsity, MatrixSparsityRef, OdeSolverProblem, Vector, VectorRef, + scale, LinearOp, Matrix, MatrixSparsity, MatrixSparsityRef, NonLinearOpJacobian, + OdeEquationsImplicit, OdeSolverProblem, Vector, VectorRef, }; use num_traits::{One, Zero}; use std::{ cell::{Ref, RefCell}, ops::Deref, + ops::MulAssign, rc::Rc, }; @@ -27,6 +29,11 @@ pub struct SdirkCallable { } impl SdirkCallable { + // y = h g(phi + c * y_s) + pub fn integrate_out(&self, ys: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) { + self.eqn.out().unwrap().call_inplace(ys, t, y); + y.mul_assign(scale(*(self.h.borrow()))); + } pub fn from_eqn(eqn: Rc, c: Eqn::T) -> Self { let n = eqn.rhs().nstates(); let h = RefCell::new(Eqn::T::zero()); @@ -51,6 +58,10 @@ impl SdirkCallable { } } + pub fn eqn_mut(&mut self) -> &mut Rc { + &mut self.eqn + } + pub fn new(ode_problem: &OdeSolverProblem, c: Eqn::T) -> Self { let eqn = ode_problem.eqn.clone(); let n = ode_problem.eqn.rhs().nstates(); @@ -193,6 +204,13 @@ where y.axpy(Eqn::T::one(), x, -h); } } +} + +impl NonLinearOpJacobian for SdirkCallable +where + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ // (M - c * h * f'(phi + c * y)) v fn jac_mul_inplace(&self, x: &Eqn::V, t: Eqn::T, v: &Eqn::V, y: &mut Eqn::V) { self.set_tmp(x); @@ -245,9 +263,9 @@ where mod tests { use crate::ode_solver::test_models::exponential_decay::exponential_decay_problem; use crate::ode_solver::test_models::robertson::robertson; - use crate::op::NonLinearOp; use crate::vector::Vector; use crate::Matrix; + use crate::{NonLinearOp, NonLinearOpJacobian}; use super::SdirkCallable; type Mcpu = nalgebra::DMatrix; diff --git a/src/op/unit.rs b/src/op/unit.rs index 53a64dff..2056cc40 100644 --- a/src/op/unit.rs +++ b/src/op/unit.rs @@ -1,9 +1,10 @@ // unit is a callable that returns returns the input vector -use crate::{Matrix, Vector}; -use num_traits::One; - -use super::{LinearOp, NonLinearOp, Op}; +use crate::{ + LinearOp, LinearOpSens, LinearOpTranspose, Matrix, NonLinearOp, NonLinearOpAdjoint, + NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Op, Vector, +}; +use num_traits::{One, Zero}; /// A dummy operator that returns the input vector. Can be used either as a [NonLinearOp] or [LinearOp]. pub struct UnitCallable { @@ -51,7 +52,40 @@ impl NonLinearOp for UnitCallable { fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) { y.copy_from(x); } +} + +impl NonLinearOpJacobian for UnitCallable { fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) { y.copy_from(v); } } + +impl NonLinearOpAdjoint for UnitCallable { + fn jac_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) { + y.copy_from(v); + } +} + +impl NonLinearOpSens for UnitCallable { + fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { + y.fill(Self::T::zero()); + } +} + +impl NonLinearOpSensAdjoint for UnitCallable { + fn sens_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { + y.fill(Self::T::zero()); + } +} + +impl LinearOpSens for UnitCallable { + fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) { + y.fill(Self::T::zero()); + } +} + +impl LinearOpTranspose for UnitCallable { + fn gemv_transpose_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) { + y.axpy(Self::T::one(), x, beta); + } +} diff --git a/src/solver/mod.rs b/src/solver/mod.rs index e77395c6..e8da84d2 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -1,63 +1,6 @@ -use std::rc::Rc; - -use crate::{ - op::{linearise::LinearisedOp, Op}, - IndexType, NonLinearOp, OdeEquations, OdeSolverProblem, -}; +use crate::IndexType; pub struct SolverStatistics { pub niter: IndexType, pub nmaxiter: IndexType, } - -/// A generic linear or nonlinear solver problem, containing the function to solve $f(t, y)$, the current time $t$, and the relative and absolute tolerances. -pub struct SolverProblem { - pub f: Rc, - pub atol: Rc, - pub rtol: C::T, -} - -impl Clone for SolverProblem { - fn clone(&self) -> Self { - Self { - f: self.f.clone(), - atol: self.atol.clone(), - rtol: self.rtol, - } - } -} - -impl SolverProblem { - pub fn new(f: Rc, atol: Rc, rtol: C::T) -> Self { - Self { f, rtol, atol } - } - pub fn new_from_ode_problem( - f: Rc, - other: &OdeSolverProblem>, - ) -> Self { - Self { - f, - rtol: other.rtol, - atol: other.atol.clone(), - } - } - pub fn new_from_problem(f: Rc, other: &SolverProblem) -> Self - where - C2: Op, - { - Self { - f, - rtol: other.rtol, - atol: other.atol.clone(), - } - } -} - -impl SolverProblem { - /// Create a new solver problem from a nonlinear operator that solves for the linearised operator. - /// That is, if the original function is $f(t, y)$, this function creates a new problem $f'$ that solves $f' = J(x) v$, where $J(x)$ is the Jacobian of $f$ at $x$. - pub fn linearise(&self) -> SolverProblem> { - let linearised_f = Rc::new(LinearisedOp::new(self.f.clone())); - SolverProblem::new_from_problem(linearised_f, self) - } -} diff --git a/src/vector/faer_serial.rs b/src/vector/faer_serial.rs index 264dfc24..1ab676d3 100644 --- a/src/vector/faer_serial.rs +++ b/src/vector/faer_serial.rs @@ -233,6 +233,10 @@ impl<'a, T: Scalar> VectorViewMut<'a> for ColMut<'a, T> { fn copy_from_view(&mut self, other: &Self::View) { self.copy_from(other); } + fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T) { + zipped!(self.as_mut(), x.as_view()) + .for_each(|unzipped!(mut si, xi)| si.write(si.read() * beta + xi.read() * alpha)); + } } // tests diff --git a/src/vector/mod.rs b/src/vector/mod.rs index 69e222e4..2354bc17 100644 --- a/src/vector/mod.rs +++ b/src/vector/mod.rs @@ -27,14 +27,14 @@ pub trait VectorCommon: Sized + Debug { type T: Scalar; } -impl<'a, V> VectorCommon for &'a V +impl VectorCommon for &V where V: VectorCommon, { type T = V::T; } -impl<'a, V> VectorCommon for &'a mut V +impl VectorCommon for &mut V where V: VectorCommon, { @@ -86,6 +86,7 @@ pub trait VectorViewMut<'a>: type View; fn copy_from(&mut self, other: &Self::Owned); fn copy_from_view(&mut self, other: &Self::View); + fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T); } pub trait VectorView<'a>: diff --git a/src/vector/nalgebra_serial.rs b/src/vector/nalgebra_serial.rs index d8b45089..2b5c4402 100644 --- a/src/vector/nalgebra_serial.rs +++ b/src/vector/nalgebra_serial.rs @@ -115,6 +115,9 @@ impl<'a, T: Scalar> VectorViewMut<'a> for DVectorViewMut<'a, T> { fn copy_from_view(&mut self, other: &Self::View) { self.copy_from(other); } + fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T) { + self.axpy(alpha, x, beta); + } } impl Div> for DVector { diff --git a/src/vector/sundials.rs b/src/vector/sundials.rs index 365b1c6c..45fb8e87 100644 --- a/src/vector/sundials.rs +++ b/src/vector/sundials.rs @@ -44,7 +44,7 @@ impl SundialsVector { #[cfg(not(sundials_version_major = "5"))] let nv = { let ctx = get_suncontext(); - unsafe { N_VNew_Serial(len as i64, *ctx) } + unsafe { N_VNew_Serial(len as i32, *ctx) } }; #[cfg(sundials_version_major = "5")] @@ -83,7 +83,7 @@ impl Drop for SundialsVector { #[derive(Debug)] pub struct SundialsVectorViewMut<'a>(&'a mut SundialsVector); -impl<'a> SundialsVectorViewMut<'a> { +impl SundialsVectorViewMut<'_> { fn sundials_vector(&self) -> N_Vector { self.0.sundials_vector() } @@ -95,7 +95,7 @@ impl<'a> SundialsVectorViewMut<'a> { #[derive(Debug)] pub struct SundialsVectorView<'a>(&'a SundialsVector); -impl<'a> SundialsVectorView<'a> { +impl SundialsVectorView<'_> { fn sundials_vector(&self) -> N_Vector { self.0.sundials_vector() } @@ -157,11 +157,11 @@ impl VectorCommon for SundialsVector { type T = realtype; } -impl<'a> VectorCommon for SundialsVectorView<'a> { +impl VectorCommon for SundialsVectorView<'_> { type T = realtype; } -impl<'a> VectorCommon for SundialsVectorViewMut<'a> { +impl VectorCommon for SundialsVectorViewMut<'_> { type T = realtype; } @@ -415,6 +415,17 @@ impl<'a> VectorViewMut<'a> for SundialsVectorViewMut<'a> { fn copy_from_view(&mut self, other: &Self::View) { unsafe { N_VScale(1.0, other.sundials_vector(), self.sundials_vector()) } } + fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T) { + unsafe { + N_VLinearSum( + alpha, + x.sundials_vector(), + beta, + self.sundials_vector(), + self.sundials_vector(), + ) + }; + } } impl<'a> VectorView<'a> for SundialsVectorView<'a> {