From a523350c28ab483d50eef4435d9af38f344f10d8 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sat, 20 Apr 2024 08:07:40 +0000 Subject: [PATCH 1/2] #32 make clone and default for generic bdf, implement some default traits --- src/lib.rs | 12 +- src/linear_solver/sundials.rs | 9 ++ src/matrix/default_solver.rs | 10 ++ src/matrix/dense_faer_serial.rs | 6 + src/matrix/dense_nalgebra_serial.rs | 9 +- src/matrix/mod.rs | 2 + src/matrix/sundials.rs | 10 +- src/ode_solver/bdf/faer.rs | 70 +----------- src/ode_solver/bdf/mod.rs | 103 +++++++++++++----- src/ode_solver/bdf/nalgebra.rs | 64 ----------- src/ode_solver/builder.rs | 22 ++-- src/ode_solver/method.rs | 42 +++---- src/ode_solver/sdirk.rs | 8 +- src/ode_solver/sundials.rs | 8 +- src/ode_solver/test_models/dydt_y2.rs | 2 +- .../test_models/exponential_decay.rs | 2 +- src/ode_solver/test_models/gaussian_decay.rs | 2 +- src/ode_solver/test_models/robertson_ode.rs | 2 +- src/vector/mod.rs | 5 + 19 files changed, 168 insertions(+), 220 deletions(-) create mode 100644 src/matrix/default_solver.rs diff --git a/src/lib.rs b/src/lib.rs index 981e1f00..c2e76287 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,7 +137,7 @@ pub mod solver; pub mod vector; use linear_solver::LinearSolver; -pub use linear_solver::NalgebraLU; +pub use linear_solver::{FaerLU, NalgebraLU}; #[cfg(feature = "sundials")] pub use matrix::sundials::SundialsMatrix; @@ -177,11 +177,12 @@ mod tests { fn test_readme() { type T = f64; type V = nalgebra::DVector; + type M = nalgebra::DMatrix; let problem = OdeBuilder::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) - .build_ode( + .build_ode_dense::( |x: &V, p: &V, _t: T, y: &mut V| { y[0] = -p[0] * x[0] + p[1] * x[1] * x[2]; y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1]; @@ -199,7 +200,7 @@ mod tests { ) .unwrap(); - let mut solver = Bdf::default(); + let mut solver = Bdf::::default(); let t = 0.4; let y = solver.solve(&problem, t).unwrap(); @@ -217,11 +218,12 @@ mod tests { fn test_readme_faer() { type T = f64; type V = faer::Col; + type M = faer::Mat; let problem = OdeBuilder::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) - .build_ode( + .build_ode_dense( |x: &V, p: &V, _t: T, y: &mut V| { y[0] = -p[0] * x[0] + p[1] * x[1] * x[2]; y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1]; @@ -239,7 +241,7 @@ mod tests { ) .unwrap(); - let mut solver = Bdf::default(); + let mut solver = Bdf::::default(); let t = 0.4; let y = solver.solve(&problem, t).unwrap(); diff --git a/src/linear_solver/sundials.rs b/src/linear_solver/sundials.rs index a964999d..175af9e9 100644 --- a/src/linear_solver/sundials.rs +++ b/src/linear_solver/sundials.rs @@ -23,6 +23,15 @@ where matrix: SundialsMatrix, } +impl Default for SundialsLinearSolver +where + Op: LinearOp, +{ + fn default() -> Self { + Self::new_dense() + } +} + impl SundialsLinearSolver where Op: LinearOp, diff --git a/src/matrix/default_solver.rs b/src/matrix/default_solver.rs new file mode 100644 index 00000000..3855ae2f --- /dev/null +++ b/src/matrix/default_solver.rs @@ -0,0 +1,10 @@ +use crate::{linear_solver::LinearSolver, op::LinearOp}; + +use super::Matrix; + +pub trait DefaultSolver: Matrix { + type LS>: LinearSolver + Default; + fn default_solver>() -> Self::LS { + Self::LS::default() + } +} diff --git a/src/matrix/dense_faer_serial.rs b/src/matrix/dense_faer_serial.rs index a450da26..fa816988 100644 --- a/src/matrix/dense_faer_serial.rs +++ b/src/matrix/dense_faer_serial.rs @@ -1,10 +1,16 @@ use std::ops::{Mul, MulAssign}; +use super::default_solver::DefaultSolver; use super::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut}; use crate::scalar::{IndexType, Scalar, Scale}; +use crate::{op::LinearOp, FaerLU}; use anyhow::Result; use faer::{linalg::matmul::matmul, Col, ColMut, ColRef, Mat, MatMut, MatRef, Parallelism}; +impl DefaultSolver for Mat { + type LS, V = Col, T = T>> = FaerLU; +} + macro_rules! impl_matrix_common { ($mat_type:ty) => { impl<'a, T: Scalar> MatrixCommon for $mat_type { diff --git a/src/matrix/dense_nalgebra_serial.rs b/src/matrix/dense_nalgebra_serial.rs index ddc881f2..029f37da 100644 --- a/src/matrix/dense_nalgebra_serial.rs +++ b/src/matrix/dense_nalgebra_serial.rs @@ -3,9 +3,16 @@ use std::ops::{Mul, MulAssign}; use anyhow::Result; use nalgebra::{DMatrix, DMatrixView, DMatrixViewMut, DVector, DVectorView, DVectorViewMut}; +use crate::op::LinearOp; use crate::{scalar::Scale, IndexType, Scalar}; -use crate::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut}; +use crate::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut, NalgebraLU}; + +use super::default_solver::DefaultSolver; + +impl DefaultSolver for DMatrix { + type LS, V = DVector, T = T>> = NalgebraLU; +} macro_rules! impl_matrix_common { ($matrix_type:ty) => { diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs index 280bb3cd..6f2fbc3a 100644 --- a/src/matrix/mod.rs +++ b/src/matrix/mod.rs @@ -12,7 +12,9 @@ mod dense_nalgebra_serial; #[cfg(feature = "faer")] mod dense_faer_serial; +pub mod default_solver; mod sparse_serial; + #[cfg(feature = "sundials")] pub mod sundials; diff --git a/src/matrix/sundials.rs b/src/matrix/sundials.rs index 54b8253a..e55fb3c3 100644 --- a/src/matrix/sundials.rs +++ b/src/matrix/sundials.rs @@ -12,12 +12,13 @@ use sundials_sys::{ use crate::{ ode_solver::sundials::sundials_check, + op::LinearOp, scalar::scale, vector::sundials::{get_suncontext, SundialsVector}, - IndexType, Scale, Vector, + IndexType, Scale, SundialsLinearSolver, Vector, }; -use super::{Matrix, MatrixCommon}; +use super::{default_solver::DefaultSolver, Matrix, MatrixCommon}; use anyhow::anyhow; #[derive(Debug)] @@ -79,6 +80,11 @@ impl Display for SundialsMatrix { } } +impl DefaultSolver for SundialsMatrix { + type LS> = + SundialsLinearSolver; +} + impl MatrixCommon for SundialsMatrix { type V = SundialsVector; type T = realtype; diff --git a/src/ode_solver/bdf/faer.rs b/src/ode_solver/bdf/faer.rs index 57beac46..83531a10 100644 --- a/src/ode_solver/bdf/faer.rs +++ b/src/ode_solver/bdf/faer.rs @@ -1,67 +1,3 @@ -use faer::{Col, Mat}; - -use crate::{ - linear_solver::FaerLU, op::bdf::BdfCallable, Bdf, NewtonNonlinearSolver, NonLinearSolver, - OdeEquations, Scalar, VectorRef, -}; - -impl, M = Mat> + 'static> Default - for Bdf, Eqn> -{ - fn default() -> Self { - let n = 1; - let linear_solver = FaerLU::default(); - let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( - linear_solver, - )); - nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); - Self { - ode_problem: None, - nonlinear_solver, - order: 1, - n_equal_steps: 0, - diff: Mat::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3), - gamma: vec![T::from(1.0); Self::MAX_ORDER + 1], - alpha: vec![T::from(1.0); Self::MAX_ORDER + 1], - error_const: vec![T::from(1.0); Self::MAX_ORDER + 1], - u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), - statistics: super::BdfStatistics::default(), - state: None, - } - } -} - -// implement clone for bdf -impl, M = Mat> + 'static> Clone - for Bdf, Eqn> -where - for<'b> &'b Col: VectorRef>, -{ - fn clone(&self) -> Self { - let n = self.diff.nrows(); - let linear_solver = FaerLU::default(); - let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( - linear_solver, - )); - nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); - Self { - ode_problem: self.ode_problem.clone(), - nonlinear_solver, - order: self.order, - n_equal_steps: self.n_equal_steps, - diff: Mat::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3), - gamma: self.gamma.clone(), - alpha: self.alpha.clone(), - error_const: self.error_const.clone(), - u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), - statistics: self.statistics.clone(), - state: self.state.clone(), - } - } -} - #[cfg(test)] mod test { use crate::{ @@ -72,14 +8,14 @@ mod test { type M = faer::Mat; #[test] fn bdf_no_set_problem() { - test_no_set_problem::(Bdf::default()) + test_no_set_problem::(Bdf::::default()) } #[test] fn bdf_take_state() { - test_take_state::(Bdf::default()) + test_take_state::(Bdf::::default()) } #[test] fn bdf_test_interpolate() { - test_interpolate::(Bdf::default()) + test_interpolate::(Bdf::::default()) } } diff --git a/src/ode_solver/bdf/mod.rs b/src/ode_solver/bdf/mod.rs index f41aac11..27d514f3 100644 --- a/src/ode_solver/bdf/mod.rs +++ b/src/ode_solver/bdf/mod.rs @@ -7,9 +7,12 @@ use num_traits::{One, Pow, Zero}; use serde::Serialize; use crate::{ - matrix::MatrixRef, op::bdf::BdfCallable, scalar::scale, DenseMatrix, IndexType, MatrixViewMut, - NonLinearSolver, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Scalar, SolverProblem, - Vector, VectorRef, VectorView, VectorViewMut, + matrix::{default_solver::DefaultSolver, MatrixCommon, MatrixRef}, + op::{bdf::BdfCallable, linearise::LinearisedOp, Op}, + scalar::scale, + DenseMatrix, IndexType, MatrixViewMut, NewtonNonlinearSolver, NonLinearSolver, OdeSolverMethod, + OdeSolverProblem, OdeSolverState, Scalar, SolverProblem, Vector, VectorRef, VectorView, + VectorViewMut, }; pub mod faer; @@ -69,32 +72,72 @@ pub struct Bdf, Eqn: OdeEquations> { gamma: Vec, error_const: Vec, statistics: BdfStatistics, - state: Option>, + state: Option>, } -// impl, M = faer::Mat>> Default for Bdf, Eqn> { -// fn default() -> Self { -// let n = 1; -// let linear_solver = LU::default(); -// let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( -// linear_solver, -// )); -// nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); -// Self { -// ode_problem: None, -// nonlinear_solver, -// order: 1, -// n_equal_steps: 0, -// diff: Mat::zeros(n, Self::MAX_ORDER + 3), //DMatrix::::zeros(n, Self::MAX_ORDER + 3), -// diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3), -// gamma: vec![f64::from(1.0); Self::MAX_ORDER + 1], -// alpha: vec![f64::from(1.0); Self::MAX_ORDER + 1], -// error_const: vec![f64::from(1.0); Self::MAX_ORDER + 1], -// u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), -// statistics: BdfStatistics::default(), -// } -// } -// } +impl Default for Bdf +where + M: DenseMatrix, + Eqn: OdeEquations + 'static, + Eqn::M: DefaultSolver, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + fn default() -> Self { + let n = 1; + let linear_solver = Eqn::M::default_solver(); + let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( + linear_solver, + )); + nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); + Self { + ode_problem: None, + nonlinear_solver, + order: 1, + n_equal_steps: 0, + diff: M::zeros(n, Self::MAX_ORDER + 3), //DMatrix::::zeros(n, Self::MAX_ORDER + 3), + diff_tmp: M::zeros(n, Self::MAX_ORDER + 3), + gamma: vec![M::T::from(1.0); Self::MAX_ORDER + 1], + alpha: vec![M::T::from(1.0); Self::MAX_ORDER + 1], + error_const: vec![M::T::from(1.0); Self::MAX_ORDER + 1], + u: M::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), + statistics: BdfStatistics::default(), + state: None, + } + } +} + +impl Clone for Bdf +where + M: DenseMatrix, + Eqn: OdeEquations + 'static, + Eqn::M: DefaultSolver, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + fn clone(&self) -> Self { + let n = self.diff.nrows(); + let linear_solver = Eqn::M::default_solver(); + let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( + linear_solver, + )); + nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); + Self { + ode_problem: self.ode_problem.clone(), + nonlinear_solver, + order: self.order, + n_equal_steps: self.n_equal_steps, + diff: M::zeros(n, Self::MAX_ORDER + 3), + diff_tmp: M::zeros(n, Self::MAX_ORDER + 3), + gamma: self.gamma.clone(), + alpha: self.alpha.clone(), + error_const: self.error_const.clone(), + u: M::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), + statistics: self.statistics.clone(), + state: self.state.clone(), + } + } +} impl, Eqn: OdeEquations> Bdf where @@ -261,15 +304,15 @@ where self.ode_problem.as_ref() } - fn state(&self) -> Option<&OdeSolverState> { + fn state(&self) -> Option<&OdeSolverState> { self.state.as_ref() } - fn take_state(&mut self) -> Option::M>> { + fn take_state(&mut self) -> Option::V>> { Option::take(&mut self.state) } - fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { + fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { let mut state = state; self.ode_problem = Some(problem.clone()); let nstates = problem.eqn.nstates(); diff --git a/src/ode_solver/bdf/nalgebra.rs b/src/ode_solver/bdf/nalgebra.rs index 32647597..ed53a9b1 100644 --- a/src/ode_solver/bdf/nalgebra.rs +++ b/src/ode_solver/bdf/nalgebra.rs @@ -1,67 +1,3 @@ -use crate::{ - linear_solver::NalgebraLU, op::bdf::BdfCallable, Bdf, NewtonNonlinearSolver, NonLinearSolver, - OdeEquations, Scalar, VectorRef, -}; - -use nalgebra::{DMatrix, DVector}; - -impl, M = DMatrix> + 'static> Default - for Bdf, Eqn> -{ - fn default() -> Self { - let n = 1; - let linear_solver = NalgebraLU::default(); - let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( - linear_solver, - )); - nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); - Self { - ode_problem: None, - nonlinear_solver, - order: 1, - n_equal_steps: 0, - diff: DMatrix::::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: DMatrix::::zeros(n, Self::MAX_ORDER + 3), - gamma: vec![T::from(1.0); Self::MAX_ORDER + 1], - alpha: vec![T::from(1.0); Self::MAX_ORDER + 1], - error_const: vec![T::from(1.0); Self::MAX_ORDER + 1], - u: DMatrix::::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), - statistics: super::BdfStatistics::default(), - state: None, - } - } -} - -// implement clone for bdf -impl, M = DMatrix> + 'static> Clone - for Bdf, Eqn> -where - for<'b> &'b DVector: VectorRef>, -{ - fn clone(&self) -> Self { - let n = self.diff.nrows(); - let linear_solver = NalgebraLU::default(); - let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( - linear_solver, - )); - nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); - Self { - ode_problem: self.ode_problem.clone(), - nonlinear_solver, - order: self.order, - n_equal_steps: self.n_equal_steps, - diff: DMatrix::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: DMatrix::zeros(n, Self::MAX_ORDER + 3), - gamma: self.gamma.clone(), - alpha: self.alpha.clone(), - error_const: self.error_const.clone(), - u: DMatrix::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), - statistics: self.statistics.clone(), - state: self.state.clone(), - } - } -} - #[cfg(test)] mod test { use crate::{ diff --git a/src/ode_solver/builder.rs b/src/ode_solver/builder.rs index 89855260..9ba0ea9e 100644 --- a/src/ode_solver/builder.rs +++ b/src/ode_solver/builder.rs @@ -1,4 +1,6 @@ -use crate::{op::Op, Matrix, OdeSolverProblem, Vector}; +use crate::{ + matrix::DenseMatrix, op::Op, vector::DefaultDenseMatrix, Matrix, OdeSolverProblem, Vector, +}; use anyhow::Result; use super::equations::{OdeSolverEquations, OdeSolverEquationsMassI}; @@ -214,17 +216,17 @@ impl OdeBuilder { /// |p, _t| DVector::from_element(1, 0.1), /// ); /// ``` - pub fn build_ode( + pub fn build_ode_dense( self, rhs: F, rhs_jac: G, init: I, - ) -> Result>> + ) -> Result::M, F, G, I>>> where - M: Matrix, - F: Fn(&M::V, &M::V, M::T, &mut M::V), - G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), - I: Fn(&M::V, M::T) -> M::V, + V: Vector + DefaultDenseMatrix, + F: Fn(&V, &V, V::T, &mut V), + G: Fn(&V, &V, V::T, &V, &mut V), + I: Fn(&V, V::T) -> V, { let p = Self::build_p(self.p); let eqn = OdeSolverEquationsMassI::new_ode( @@ -238,10 +240,10 @@ impl OdeBuilder { let atol = Self::build_atol(self.atol, eqn.nstates())?; Ok(OdeSolverProblem::new( eqn, - M::T::from(self.rtol), + V::T::from(self.rtol), atol, - M::T::from(self.t0), - M::T::from(self.h0), + V::T::from(self.t0), + V::T::from(self.h0), )) } diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index d9946cf8..a79a3353 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -31,7 +31,7 @@ pub trait OdeSolverMethod { /// Set the problem to solve, this performs any initialisation required by the solver. Call this before calling `step` or `solve`. /// The solver takes ownership of the initial state given by `state`, this is assumed to be consistent with any algebraic constraints. - fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem); + fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem); /// Step the solution forward by one step, altering the internal state of the solver. fn step(&mut self) -> Result<()>; @@ -40,12 +40,12 @@ pub trait OdeSolverMethod { fn interpolate(&self, t: Eqn::T) -> Result; /// Get the current state of the solver, if it exists - fn state(&self) -> Option<&OdeSolverState>; + fn state(&self) -> Option<&OdeSolverState>; /// Take the current state of the solver, if it exists, returning it to the user. This is useful if you want to use this /// state in another solver or problem. Note that this will unset the current problem and solver state, so you will need to call /// `set_problem` again before calling `step` or `solve`. - fn take_state(&mut self) -> Option>; + fn take_state(&mut self) -> Option>; /// Reinitialise the solver state and solve the problem up to time `t` fn solve(&mut self, problem: &OdeSolverProblem, t: Eqn::T) -> Result { @@ -75,29 +75,23 @@ pub trait OdeSolverMethod { /// State for the ODE solver, containing the current solution `y`, the current time `t`, and the current step size `h`. #[derive(Clone)] -pub struct OdeSolverState { - pub y: M::V, - pub t: M::T, - pub h: M::T, - _phantom: std::marker::PhantomData, +pub struct OdeSolverState { + pub y: V, + pub t: V::T, + pub h: V::T, } -impl OdeSolverState { +impl OdeSolverState { /// Create a new solver state from an ODE problem. Note that this does not make the state consistent with the algebraic constraints. /// If you need to make the state consistent, use `new_consistent` instead. pub fn new(ode_problem: &OdeSolverProblem) -> Self where - Eqn: OdeEquations, + Eqn: OdeEquations, { let t = ode_problem.t0; let h = ode_problem.h0; let y = ode_problem.eqn.init(t); - Self { - y, - t, - h, - _phantom: std::marker::PhantomData, - } + Self { y, t, h } } /// Create a new solver state from an ODE problem, making the state consistent with the algebraic constraints. @@ -106,7 +100,7 @@ impl OdeSolverState { root_solver: &mut S, ) -> Result where - Eqn: OdeEquations, + Eqn: OdeEquations, S: NonLinearSolver>> + ?Sized, { let t = ode_problem.t0; @@ -114,12 +108,7 @@ impl OdeSolverState { let indices = ode_problem.eqn.algebraic_indices(); let mut y = ode_problem.eqn.init(t); if indices.len() == 0 { - return Ok(Self { - y, - t, - h, - _phantom: std::marker::PhantomData, - }); + return Ok(Self { y, t, h }); } let mut y_filtered = y.filter(&indices); let atol = Rc::new(ode_problem.atol.as_ref().filter(&indices)); @@ -132,11 +121,6 @@ impl OdeSolverState { let init_problem = root_solver.problem().unwrap(); let indices = init_problem.f.indices(); y.scatter_from(&y_filtered, indices); - Ok(Self { - y, - t, - h, - _phantom: std::marker::PhantomData, - }) + Ok(Self { y, t, h }) } } diff --git a/src/ode_solver/sdirk.rs b/src/ode_solver/sdirk.rs index 56f2dd66..533ec7c9 100644 --- a/src/ode_solver/sdirk.rs +++ b/src/ode_solver/sdirk.rs @@ -36,7 +36,7 @@ where tableau: Tableau, problem: Option>, nonlinear_solver: NewtonNonlinearSolver>, - state: Option>, + state: Option>, diff: M, gamma: Eqn::T, is_sdirk: bool, @@ -175,7 +175,7 @@ where fn set_problem( &mut self, - mut state: OdeSolverState<::M>, + mut state: OdeSolverState<::V>, problem: &OdeSolverProblem, ) { // update initial step size based on function @@ -490,11 +490,11 @@ where } } - fn state(&self) -> Option<&OdeSolverState<::M>> { + fn state(&self) -> Option<&OdeSolverState<::V>> { self.state.as_ref() } - fn take_state(&mut self) -> Option::M>> { + fn take_state(&mut self) -> Option::V>> { Option::take(&mut self.state) } } diff --git a/src/ode_solver/sundials.rs b/src/ode_solver/sundials.rs index 2aad73ce..8f7f43e7 100644 --- a/src/ode_solver/sundials.rs +++ b/src/ode_solver/sundials.rs @@ -123,7 +123,7 @@ where yp: SundialsVector, jacobian: SundialsMatrix, statistics: SundialsStatistics, - state: Option>, + state: Option>, } impl SundialsIda @@ -251,15 +251,15 @@ where self.problem.as_ref() } - fn state(&self) -> Option<&OdeSolverState<::M>> { + fn state(&self) -> Option<&OdeSolverState<::V>> { self.state.as_ref() } - fn take_state(&mut self) -> Option::M>> { + fn take_state(&mut self) -> Option::V>> { Option::take(&mut self.state) } - fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { + fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { self.state = Some(state); let state = self.state.as_ref().unwrap(); self.problem = Some(problem.clone()); diff --git a/src/ode_solver/test_models/dydt_y2.rs b/src/ode_solver/test_models/dydt_y2.rs index 7ed70154..3000a35c 100644 --- a/src/ode_solver/test_models/dydt_y2.rs +++ b/src/ode_solver/test_models/dydt_y2.rs @@ -31,7 +31,7 @@ pub fn dydt_y2_problem( let problem = OdeBuilder::new() .use_coloring(use_coloring) .rtol(1e-4) - .build_ode(rhs::, rhs_jac::, move |_p, _t| { + .build_ode_dense(rhs::, rhs_jac::, move |_p, _t| { M::V::from_vec([y0.into()].repeat(size2)) }) .unwrap(); diff --git a/src/ode_solver/test_models/exponential_decay.rs b/src/ode_solver/test_models/exponential_decay.rs index 333a4720..979c83fe 100644 --- a/src/ode_solver/test_models/exponential_decay.rs +++ b/src/ode_solver/test_models/exponential_decay.rs @@ -32,7 +32,7 @@ pub fn exponential_decay_problem( let problem = OdeBuilder::new() .p([0.1]) .use_coloring(use_coloring) - .build_ode( + .build_ode_dense( exponential_decay::, exponential_decay_jacobian::, exponential_decay_init::, diff --git a/src/ode_solver/test_models/gaussian_decay.rs b/src/ode_solver/test_models/gaussian_decay.rs index b2313641..a316e48b 100644 --- a/src/ode_solver/test_models/gaussian_decay.rs +++ b/src/ode_solver/test_models/gaussian_decay.rs @@ -30,7 +30,7 @@ pub fn gaussian_decay_problem( let problem = OdeBuilder::new() .p([0.1].repeat(size)) .use_coloring(use_coloring) - .build_ode( + .build_ode_dense( gaussian_decay::, gaussian_decay_jacobian::, move |_p, _t| M::V::from_vec([1.0.into()].repeat(size2)), diff --git a/src/ode_solver/test_models/robertson_ode.rs b/src/ode_solver/test_models/robertson_ode.rs index eeef3646..0f64dffb 100644 --- a/src/ode_solver/test_models/robertson_ode.rs +++ b/src/ode_solver/test_models/robertson_ode.rs @@ -14,7 +14,7 @@ pub fn robertson_ode( .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) .use_coloring(use_coloring) - .build_ode( + .build_ode_dense( // dy1/dt = -.04*y1 + 1.e4*y2*y3 //* dy2/dt = .04*y1 - 1.e4*y2*y3 - 3.e7*(y2)^2 //* dy3/dt = 3.e7*(y2)^2 diff --git a/src/vector/mod.rs b/src/vector/mod.rs index a2a09b6e..51219703 100644 --- a/src/vector/mod.rs +++ b/src/vector/mod.rs @@ -1,3 +1,4 @@ +use crate::matrix::DenseMatrix; use crate::scalar::Scale; use crate::{IndexType, Scalar}; use num_traits::Zero; @@ -206,3 +207,7 @@ pub trait Vector: } } } + +pub trait DefaultDenseMatrix: Vector { + type M: DenseMatrix; +} From 72551f66685ee628c7c396dd8b819098f9fd3604 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sat, 20 Apr 2024 11:29:55 +0000 Subject: [PATCH 2/2] #32 fix some default type issues, works now --- src/lib.rs | 5 +-- src/ode_solver/bdf/mod.rs | 22 ++++++----- src/ode_solver/builder.rs | 39 +++++++++++++------ src/ode_solver/method.rs | 2 +- src/ode_solver/test_models/dydt_y2.rs | 2 +- .../test_models/exponential_decay.rs | 2 +- src/ode_solver/test_models/gaussian_decay.rs | 2 +- src/ode_solver/test_models/robertson_ode.rs | 2 +- src/vector/faer_serial.rs | 8 +++- src/vector/nalgebra_serial.rs | 8 +++- 10 files changed, 59 insertions(+), 33 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c2e76287..c12599e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -177,12 +177,11 @@ mod tests { fn test_readme() { type T = f64; type V = nalgebra::DVector; - type M = nalgebra::DMatrix; let problem = OdeBuilder::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) - .build_ode_dense::( + .build_ode_dense( |x: &V, p: &V, _t: T, y: &mut V| { y[0] = -p[0] * x[0] + p[1] * x[1] * x[2]; y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1]; @@ -200,7 +199,7 @@ mod tests { ) .unwrap(); - let mut solver = Bdf::::default(); + let mut solver = Bdf::default(); let t = 0.4; let y = solver.solve(&problem, t).unwrap(); diff --git a/src/ode_solver/bdf/mod.rs b/src/ode_solver/bdf/mod.rs index 27d514f3..0e287fca 100644 --- a/src/ode_solver/bdf/mod.rs +++ b/src/ode_solver/bdf/mod.rs @@ -7,9 +7,10 @@ use num_traits::{One, Pow, Zero}; use serde::Serialize; use crate::{ - matrix::{default_solver::DefaultSolver, MatrixCommon, MatrixRef}, - op::{bdf::BdfCallable, linearise::LinearisedOp, Op}, + matrix::{default_solver::DefaultSolver, Matrix, MatrixRef}, + op::bdf::BdfCallable, scalar::scale, + vector::DefaultDenseMatrix, DenseMatrix, IndexType, MatrixViewMut, NewtonNonlinearSolver, NonLinearSolver, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Scalar, SolverProblem, Vector, VectorRef, VectorView, VectorViewMut, @@ -75,11 +76,11 @@ pub struct Bdf, Eqn: OdeEquations> { state: Option>, } -impl Default for Bdf +impl Default for Bdf<::M, Eqn> where - M: DenseMatrix, Eqn: OdeEquations + 'static, Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, { @@ -90,17 +91,18 @@ where linear_solver, )); nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); + type M = ::M; Self { ode_problem: None, nonlinear_solver, order: 1, n_equal_steps: 0, - diff: M::zeros(n, Self::MAX_ORDER + 3), //DMatrix::::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: M::zeros(n, Self::MAX_ORDER + 3), - gamma: vec![M::T::from(1.0); Self::MAX_ORDER + 1], - alpha: vec![M::T::from(1.0); Self::MAX_ORDER + 1], - error_const: vec![M::T::from(1.0); Self::MAX_ORDER + 1], - u: M::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), + diff: as Matrix>::zeros(n, Self::MAX_ORDER + 3), //DMatrix::::zeros(n, Self::MAX_ORDER + 3), + diff_tmp: as Matrix>::zeros(n, Self::MAX_ORDER + 3), + gamma: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], + alpha: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], + error_const: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], + u: as Matrix>::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), statistics: BdfStatistics::default(), state: None, } diff --git a/src/ode_solver/builder.rs b/src/ode_solver/builder.rs index 9ba0ea9e..775405d2 100644 --- a/src/ode_solver/builder.rs +++ b/src/ode_solver/builder.rs @@ -1,6 +1,4 @@ -use crate::{ - matrix::DenseMatrix, op::Op, vector::DefaultDenseMatrix, Matrix, OdeSolverProblem, Vector, -}; +use crate::{op::Op, vector::DefaultDenseMatrix, Matrix, OdeSolverProblem, Vector}; use anyhow::Result; use super::equations::{OdeSolverEquations, OdeSolverEquationsMassI}; @@ -216,17 +214,17 @@ impl OdeBuilder { /// |p, _t| DVector::from_element(1, 0.1), /// ); /// ``` - pub fn build_ode_dense( + pub fn build_ode( self, rhs: F, rhs_jac: G, init: I, - ) -> Result::M, F, G, I>>> + ) -> Result>> where - V: Vector + DefaultDenseMatrix, - F: Fn(&V, &V, V::T, &mut V), - G: Fn(&V, &V, V::T, &V, &mut V), - I: Fn(&V, V::T) -> V, + M: Matrix, + F: Fn(&M::V, &M::V, M::T, &mut M::V), + G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V), + I: Fn(&M::V, M::T) -> M::V, { let p = Self::build_p(self.p); let eqn = OdeSolverEquationsMassI::new_ode( @@ -240,13 +238,30 @@ impl OdeBuilder { let atol = Self::build_atol(self.atol, eqn.nstates())?; Ok(OdeSolverProblem::new( eqn, - V::T::from(self.rtol), + M::T::from(self.rtol), atol, - V::T::from(self.t0), - V::T::from(self.h0), + M::T::from(self.t0), + M::T::from(self.h0), )) } + /// Build an ODE problem using the default dense matrix (see [Self::build_ode]). + #[allow(clippy::type_complexity)] + pub fn build_ode_dense( + self, + rhs: F, + rhs_jac: G, + init: I, + ) -> Result>> + where + V: Vector + DefaultDenseMatrix, + F: Fn(&V, &V, V::T, &mut V), + G: Fn(&V, &V, V::T, &V, &mut V), + I: Fn(&V, V::T) -> V, + { + self.build_ode(rhs, rhs_jac, init) + } + /// Build an ODE problem using the DiffSL language (requires the `diffsl` feature). /// The source code is provided as a string, please see the [DiffSL documentation](https://martinjrobins.github.io/diffsl/) for more information. #[cfg(feature = "diffsl")] diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index a79a3353..d8af5603 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use crate::{ op::{filter::FilterCallable, ode_rhs::OdeRhs}, - Matrix, NonLinearSolver, OdeEquations, OdeSolverProblem, SolverProblem, Vector, VectorIndex, + NonLinearSolver, OdeEquations, OdeSolverProblem, SolverProblem, Vector, VectorIndex, }; /// Trait for ODE solver methods. This is the main user interface for the ODE solvers. diff --git a/src/ode_solver/test_models/dydt_y2.rs b/src/ode_solver/test_models/dydt_y2.rs index 3000a35c..7ed70154 100644 --- a/src/ode_solver/test_models/dydt_y2.rs +++ b/src/ode_solver/test_models/dydt_y2.rs @@ -31,7 +31,7 @@ pub fn dydt_y2_problem( let problem = OdeBuilder::new() .use_coloring(use_coloring) .rtol(1e-4) - .build_ode_dense(rhs::, rhs_jac::, move |_p, _t| { + .build_ode(rhs::, rhs_jac::, move |_p, _t| { M::V::from_vec([y0.into()].repeat(size2)) }) .unwrap(); diff --git a/src/ode_solver/test_models/exponential_decay.rs b/src/ode_solver/test_models/exponential_decay.rs index 979c83fe..333a4720 100644 --- a/src/ode_solver/test_models/exponential_decay.rs +++ b/src/ode_solver/test_models/exponential_decay.rs @@ -32,7 +32,7 @@ pub fn exponential_decay_problem( let problem = OdeBuilder::new() .p([0.1]) .use_coloring(use_coloring) - .build_ode_dense( + .build_ode( exponential_decay::, exponential_decay_jacobian::, exponential_decay_init::, diff --git a/src/ode_solver/test_models/gaussian_decay.rs b/src/ode_solver/test_models/gaussian_decay.rs index a316e48b..b2313641 100644 --- a/src/ode_solver/test_models/gaussian_decay.rs +++ b/src/ode_solver/test_models/gaussian_decay.rs @@ -30,7 +30,7 @@ pub fn gaussian_decay_problem( let problem = OdeBuilder::new() .p([0.1].repeat(size)) .use_coloring(use_coloring) - .build_ode_dense( + .build_ode( gaussian_decay::, gaussian_decay_jacobian::, move |_p, _t| M::V::from_vec([1.0.into()].repeat(size2)), diff --git a/src/ode_solver/test_models/robertson_ode.rs b/src/ode_solver/test_models/robertson_ode.rs index 0f64dffb..eeef3646 100644 --- a/src/ode_solver/test_models/robertson_ode.rs +++ b/src/ode_solver/test_models/robertson_ode.rs @@ -14,7 +14,7 @@ pub fn robertson_ode( .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) .use_coloring(use_coloring) - .build_ode_dense( + .build_ode( // dy1/dt = -.04*y1 + 1.e4*y2*y3 //* dy2/dt = .04*y1 - 1.e4*y2*y3 - 3.e7*(y2)^2 //* dy3/dt = 3.e7*(y2)^2 diff --git a/src/vector/faer_serial.rs b/src/vector/faer_serial.rs index 6739c1b6..b5d00b23 100644 --- a/src/vector/faer_serial.rs +++ b/src/vector/faer_serial.rs @@ -1,11 +1,17 @@ use std::ops::{Div, Mul, MulAssign}; -use faer::{unzipped, zipped, Col, ColMut, ColRef}; +use faer::{unzipped, zipped, Col, ColMut, ColRef, Mat}; use crate::{scalar::Scale, IndexType, Scalar}; use crate::{Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut}; +use super::DefaultDenseMatrix; + +impl DefaultDenseMatrix for Col { + type M = Mat; +} + macro_rules! impl_op_for_faer_struct { ($struct:ident, $trait_name:ident, $func_name:ident) => { impl<'a, T: Scalar> $trait_name> for $struct<'a, T> { diff --git a/src/vector/nalgebra_serial.rs b/src/vector/nalgebra_serial.rs index de5de1ca..681fd78c 100644 --- a/src/vector/nalgebra_serial.rs +++ b/src/vector/nalgebra_serial.rs @@ -1,10 +1,14 @@ use std::ops::{Div, Mul, MulAssign}; -use nalgebra::{DVector, DVectorView, DVectorViewMut}; +use nalgebra::{DMatrix, DVector, DVectorView, DVectorViewMut}; use crate::{IndexType, Scalar, Scale}; -use super::{Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut}; +use super::{DefaultDenseMatrix, Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut}; + +impl DefaultDenseMatrix for DVector { + type M = DMatrix; +} macro_rules! impl_op_for_dvector_struct { ($struct:ident, $trait_name:ident, $func_name:ident) => {