diff --git a/benches/ode_solvers.rs b/benches/ode_solvers.rs index 5f21027..b6c9525 100644 --- a/benches/ode_solvers.rs +++ b/benches/ode_solvers.rs @@ -10,9 +10,6 @@ use diffsol::{ FaerLU, FaerSparseLU, NalgebraLU, SparseColMat, }; -#[cfg(feature = "sundials")] -use diffsol::SundialsLinearSolver; - #[cfg(feature = "suitesparse")] use diffsol::KLU; @@ -55,15 +52,6 @@ fn criterion_benchmark(c: &mut Criterion) { exponential_decay_problem, nalgebra::DMatrix ); - #[cfg(feature = "sundials")] - bench!( - sundials_exponential_decay, - sundials, - SundialsLinearSolver, - exponential_decay, - exponential_decay_problem, - diffsol::SundialsMatrix - ); bench!( nalgebra_bdf_robertson, bdf, @@ -88,15 +76,6 @@ fn criterion_benchmark(c: &mut Criterion) { robertson, nalgebra::DMatrix ); - #[cfg(feature = "sundials")] - bench!( - sundials_robertson, - sundials, - SundialsLinearSolver, - robertson, - robertson, - diffsol::SundialsMatrix - ); bench!( faer_bdf_exponential_decay, bdf, @@ -627,18 +606,4 @@ mod benchmarks { let mut s = Sdirk::new(tableau, linear_solver); let _y = s.solve(problem, t); } - - #[cfg(feature = "sundials")] - pub fn sundials( - problem: &OdeSolverProblem, - t: Eqn::T, - _ls: impl LinearSolver, - ) where - Eqn: OdeEquations, - { - use diffsol::SundialsIda; - - let mut s = SundialsIda::default(); - let _y = s.solve(problem, t); - } } diff --git a/src/lib.rs b/src/lib.rs index 0a0e7c0..d2d898b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,8 +91,6 @@ //! - For vectors: [Vector], [VectorIndex], [VectorView], [VectorViewMut], and [VectorCommon]. //! -#[cfg(feature = "diffsl-cranelift")] -pub extern crate diffsl; #[cfg(feature = "diffsl-llvm13")] pub extern crate diffsl13_0 as diffsl; #[cfg(feature = "diffsl-llvm14")] @@ -103,6 +101,8 @@ pub extern crate diffsl15_0 as diffsl; pub extern crate diffsl16_0 as diffsl; #[cfg(feature = "diffsl-llvm17")] pub extern crate diffsl17_0 as diffsl; +#[cfg(feature = "diffsl-cranelift")] +pub extern crate diffsl_no_llvm as diffsl; pub mod jacobian; pub mod linear_solver; diff --git a/src/ode_solver/diffsl.rs b/src/ode_solver/diffsl.rs index e2ab51e..75166ed 100644 --- a/src/ode_solver/diffsl.rs +++ b/src/ode_solver/diffsl.rs @@ -532,21 +532,21 @@ mod tests { let mut solver = Bdf::default(); let t = 1.0; let (ys, ts) = solver.solve(&problem, t).unwrap(); - for (y, t) in ys.iter().zip(ts.iter()) { + for (i, t) in ts.iter().enumerate() { let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0); let z_expect = 2.0 * y_expect; let expected_out = DVector::from_vec(vec![3.0 * y_expect, 4.0 * z_expect]); - y.assert_eq_st(&expected_out, 1e-4); + ys.column(i).into_owned().assert_eq_st(&expected_out, 1e-4); } // do it again with some explicit t_evals let t_evals = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0]; let ys = solver.solve_dense(&problem, &t_evals).unwrap(); - for (y, t) in ys.iter().zip(t_evals.iter()) { + for (i, t) in t_evals.iter().enumerate() { let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0); let z_expect = 2.0 * y_expect; let expected_out = DVector::from_vec(vec![3.0 * y_expect, 4.0 * z_expect]); - y.assert_eq_st(&expected_out, 1e-4); + ys.column(i).into_owned().assert_eq_st(&expected_out, 1e-4); } } } diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index ebbbe02..f314fa7 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -7,8 +7,9 @@ use crate::{ matrix::default_solver::DefaultSolver, ode_solver_error, scalar::Scalar, - scale, ConstantOp, InitOp, NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquations, - OdeSolverProblem, Op, SensEquations, SolverProblem, Vector, + scale, ConstantOp, DefaultDenseMatrix, DenseMatrix, InitOp, Matrix, MatrixCommon, + NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquations, OdeSolverProblem, Op, + SensEquations, SolverProblem, Vector, VectorViewMut, }; #[derive(Debug, PartialEq)] @@ -92,35 +93,69 @@ pub trait OdeSolverMethod { &mut self, problem: &OdeSolverProblem, final_time: Eqn::T, - ) -> Result<(Vec, Vec), DiffsolError> + ) -> Result<(::M, Vec), DiffsolError> where Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, Self: Sized, { let state = OdeSolverState::new(problem, self)?; self.set_problem(state, problem); let mut t = vec![self.state().unwrap().t]; - let mut y = vec![]; - match problem.eqn.out() { - Some(out) => y.push(out.call(&self.state().unwrap().y, self.state().unwrap().t)), - None => y.push(self.state().unwrap().y.clone()), + let nstates = problem.eqn.rhs().nstates(); + let ntimes_guess = std::cmp::max( + 10, + ((final_time - self.state().unwrap().t).abs() / self.state().unwrap().h) + .into() + .ceil() as usize, + ); + let mut y = <::M as Matrix>::zeros(nstates, ntimes_guess); + { + let mut y_i = y.column_mut(0); + match problem.eqn.out() { + Some(out) => { + y_i.copy_from(&out.call(&self.state().unwrap().y, self.state().unwrap().t)) + } + None => y_i.copy_from(&self.state().unwrap().y), + } } self.set_stop_time(final_time)?; while self.step()? != OdeSolverStopReason::TstopReached { t.push(self.state().unwrap().t); + let mut y_i = { + let max_i = y.ncols(); + let curr_i = t.len() - 1; + if curr_i >= max_i { + y = <::M as Matrix>::zeros(nstates, max_i * 2); + } + y.column_mut(curr_i) + }; match problem.eqn.out() { - Some(out) => y.push(out.call(&self.state().unwrap().y, self.state().unwrap().t)), - None => y.push(self.state().unwrap().y.clone()), + Some(out) => { + y_i.copy_from(&out.call(&self.state().unwrap().y, self.state().unwrap().t)) + } + None => y_i.copy_from(&self.state().unwrap().y), } } // store the final step t.push(self.state().unwrap().t); - match problem.eqn.out() { - Some(out) => y.push(out.call(&self.state().unwrap().y, self.state().unwrap().t)), - None => y.push(self.state().unwrap().y.clone()), + { + let mut y_i = { + let max_i = y.ncols(); + let curr_i = t.len() - 1; + if curr_i >= max_i { + y = <::M as Matrix>::zeros(nstates, max_i + 1); + } + y.column_mut(curr_i) + }; + match problem.eqn.out() { + Some(out) => { + y_i.copy_from(&out.call(&self.state().unwrap().y, self.state().unwrap().t)) + } + None => y_i.copy_from(&self.state().unwrap().y), + } } - Ok((y, t)) } @@ -131,14 +166,16 @@ pub trait OdeSolverMethod { &mut self, problem: &OdeSolverProblem, t_eval: &[Eqn::T], - ) -> Result, DiffsolError> + ) -> Result<::M, DiffsolError> where Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, Self: Sized, { let state = OdeSolverState::new(problem, self)?; self.set_problem(state, problem); - let mut ret = vec![]; + let nstates = problem.eqn.rhs().nstates(); + let mut ret = <::M as Matrix>::zeros(nstates, t_eval.len()); // check t_eval is increasing and all values are greater than or equal to the current time let t0 = self.state().unwrap().t; @@ -149,14 +186,15 @@ pub trait OdeSolverMethod { // do loop self.set_stop_time(t_eval[t_eval.len() - 1])?; let mut step_reason = OdeSolverStopReason::InternalTimestep; - for t in t_eval.iter().take(t_eval.len() - 1) { + for (i, t) in t_eval.iter().take(t_eval.len() - 1).enumerate() { while self.state().unwrap().t < *t { step_reason = self.step()?; } let y = self.interpolate(*t)?; + let mut y_out = ret.column_mut(i); match problem.eqn.out() { - Some(out) => ret.push(out.call(&y, *t)), - None => ret.push(y), + Some(out) => y_out.copy_from(&out.call(&y, *t)), + None => y_out.copy_from(&y), } } @@ -164,9 +202,14 @@ pub trait OdeSolverMethod { while step_reason != OdeSolverStopReason::TstopReached { step_reason = self.step()?; } - match problem.eqn.out() { - Some(out) => ret.push(out.call(&self.state().unwrap().y, self.state().unwrap().t)), - None => ret.push(self.state().unwrap().y.clone()), + { + let mut y_out = ret.column_mut(t_eval.len() - 1); + match problem.eqn.out() { + Some(out) => { + y_out.copy_from(&out.call(&self.state().unwrap().y, self.state().unwrap().t)) + } + None => y_out.copy_from(&self.state().unwrap().y), + } } Ok(ret) } diff --git a/src/ode_solver/mod.rs b/src/ode_solver/mod.rs index 9db1876..88f6e51 100644 --- a/src/ode_solver/mod.rs +++ b/src/ode_solver/mod.rs @@ -26,7 +26,7 @@ mod tests { use crate::matrix::Matrix; use crate::op::unit::UnitCallable; use crate::op::{NonLinearOp, Op}; - use crate::{ConstantOp, DefaultSolver, Vector}; + use crate::{ConstantOp, DefaultDenseMatrix, DefaultSolver, Vector}; use crate::{ OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, }; @@ -296,6 +296,7 @@ mod tests { Eqn: OdeEquations, Method: OdeSolverMethod, Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, { // solve for a little bit s.solve(&problem, Eqn::T::from(1.0)).unwrap();