Skip to content

Commit

Permalink
refactor: solve and solve_dense use a dense matrix type for state ret…
Browse files Browse the repository at this point in the history
…urn (#90)
  • Loading branch information
martinjrobins authored Sep 8, 2024
1 parent 8d19b87 commit b7f0589
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 63 deletions.
35 changes: 0 additions & 35 deletions benches/ode_solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ use diffsol::{
FaerLU, FaerSparseLU, NalgebraLU, SparseColMat,
};

#[cfg(feature = "sundials")]
use diffsol::SundialsLinearSolver;

#[cfg(feature = "suitesparse")]
use diffsol::KLU;

Expand Down Expand Up @@ -55,15 +52,6 @@ fn criterion_benchmark(c: &mut Criterion) {
exponential_decay_problem,
nalgebra::DMatrix<f64>
);
#[cfg(feature = "sundials")]
bench!(
sundials_exponential_decay,
sundials,
SundialsLinearSolver,
exponential_decay,
exponential_decay_problem,
diffsol::SundialsMatrix
);
bench!(
nalgebra_bdf_robertson,
bdf,
Expand All @@ -88,15 +76,6 @@ fn criterion_benchmark(c: &mut Criterion) {
robertson,
nalgebra::DMatrix<f64>
);
#[cfg(feature = "sundials")]
bench!(
sundials_robertson,
sundials,
SundialsLinearSolver,
robertson,
robertson,
diffsol::SundialsMatrix
);
bench!(
faer_bdf_exponential_decay,
bdf,
Expand Down Expand Up @@ -627,18 +606,4 @@ mod benchmarks {
let mut s = Sdirk::new(tableau, linear_solver);
let _y = s.solve(problem, t);
}

#[cfg(feature = "sundials")]
pub fn sundials<Eqn>(
problem: &OdeSolverProblem<Eqn>,
t: Eqn::T,
_ls: impl LinearSolver<Eqn::Rhs>,
) where
Eqn: OdeEquations<M = diffsol::SundialsMatrix, V = diffsol::SundialsVector, T = f64>,
{
use diffsol::SundialsIda;

let mut s = SundialsIda::default();
let _y = s.solve(problem, t);
}
}
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@
//! - For vectors: [Vector], [VectorIndex], [VectorView], [VectorViewMut], and [VectorCommon].
//!
#[cfg(feature = "diffsl-cranelift")]
pub extern crate diffsl;
#[cfg(feature = "diffsl-llvm13")]
pub extern crate diffsl13_0 as diffsl;
#[cfg(feature = "diffsl-llvm14")]
Expand All @@ -103,6 +101,8 @@ pub extern crate diffsl15_0 as diffsl;
pub extern crate diffsl16_0 as diffsl;
#[cfg(feature = "diffsl-llvm17")]
pub extern crate diffsl17_0 as diffsl;
#[cfg(feature = "diffsl-cranelift")]
pub extern crate diffsl_no_llvm as diffsl;

pub mod jacobian;
pub mod linear_solver;
Expand Down
8 changes: 4 additions & 4 deletions src/ode_solver/diffsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,21 +532,21 @@ mod tests {
let mut solver = Bdf::default();
let t = 1.0;
let (ys, ts) = solver.solve(&problem, t).unwrap();
for (y, t) in ys.iter().zip(ts.iter()) {
for (i, t) in ts.iter().enumerate() {
let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0);
let z_expect = 2.0 * y_expect;
let expected_out = DVector::from_vec(vec![3.0 * y_expect, 4.0 * z_expect]);
y.assert_eq_st(&expected_out, 1e-4);
ys.column(i).into_owned().assert_eq_st(&expected_out, 1e-4);
}

// do it again with some explicit t_evals
let t_evals = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0];
let ys = solver.solve_dense(&problem, &t_evals).unwrap();
for (y, t) in ys.iter().zip(t_evals.iter()) {
for (i, t) in t_evals.iter().enumerate() {
let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0);
let z_expect = 2.0 * y_expect;
let expected_out = DVector::from_vec(vec![3.0 * y_expect, 4.0 * z_expect]);
y.assert_eq_st(&expected_out, 1e-4);
ys.column(i).into_owned().assert_eq_st(&expected_out, 1e-4);
}
}
}
85 changes: 64 additions & 21 deletions src/ode_solver/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use crate::{
matrix::default_solver::DefaultSolver,
ode_solver_error,
scalar::Scalar,
scale, ConstantOp, InitOp, NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquations,
OdeSolverProblem, Op, SensEquations, SolverProblem, Vector,
scale, ConstantOp, DefaultDenseMatrix, DenseMatrix, InitOp, Matrix, MatrixCommon,
NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquations, OdeSolverProblem, Op,
SensEquations, SolverProblem, Vector, VectorViewMut,
};

#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -92,35 +93,69 @@ pub trait OdeSolverMethod<Eqn: OdeEquations> {
&mut self,
problem: &OdeSolverProblem<Eqn>,
final_time: Eqn::T,
) -> Result<(Vec<Eqn::V>, Vec<Eqn::T>), DiffsolError>
) -> Result<(<Eqn::V as DefaultDenseMatrix>::M, Vec<Eqn::T>), DiffsolError>
where
Eqn::M: DefaultSolver,
Eqn::V: DefaultDenseMatrix,
Self: Sized,
{
let state = OdeSolverState::new(problem, self)?;
self.set_problem(state, problem);
let mut t = vec![self.state().unwrap().t];
let mut y = vec![];
match problem.eqn.out() {
Some(out) => y.push(out.call(&self.state().unwrap().y, self.state().unwrap().t)),
None => y.push(self.state().unwrap().y.clone()),
let nstates = problem.eqn.rhs().nstates();
let ntimes_guess = std::cmp::max(
10,
((final_time - self.state().unwrap().t).abs() / self.state().unwrap().h)
.into()
.ceil() as usize,
);
let mut y = <<Eqn::V as DefaultDenseMatrix>::M as Matrix>::zeros(nstates, ntimes_guess);
{
let mut y_i = y.column_mut(0);
match problem.eqn.out() {
Some(out) => {
y_i.copy_from(&out.call(&self.state().unwrap().y, self.state().unwrap().t))
}
None => y_i.copy_from(&self.state().unwrap().y),
}
}
self.set_stop_time(final_time)?;
while self.step()? != OdeSolverStopReason::TstopReached {
t.push(self.state().unwrap().t);
let mut y_i = {
let max_i = y.ncols();
let curr_i = t.len() - 1;
if curr_i >= max_i {
y = <<Eqn::V as DefaultDenseMatrix>::M as Matrix>::zeros(nstates, max_i * 2);
}
y.column_mut(curr_i)
};
match problem.eqn.out() {
Some(out) => y.push(out.call(&self.state().unwrap().y, self.state().unwrap().t)),
None => y.push(self.state().unwrap().y.clone()),
Some(out) => {
y_i.copy_from(&out.call(&self.state().unwrap().y, self.state().unwrap().t))
}
None => y_i.copy_from(&self.state().unwrap().y),
}
}

// store the final step
t.push(self.state().unwrap().t);
match problem.eqn.out() {
Some(out) => y.push(out.call(&self.state().unwrap().y, self.state().unwrap().t)),
None => y.push(self.state().unwrap().y.clone()),
{
let mut y_i = {
let max_i = y.ncols();
let curr_i = t.len() - 1;
if curr_i >= max_i {
y = <<Eqn::V as DefaultDenseMatrix>::M as Matrix>::zeros(nstates, max_i + 1);
}
y.column_mut(curr_i)
};
match problem.eqn.out() {
Some(out) => {
y_i.copy_from(&out.call(&self.state().unwrap().y, self.state().unwrap().t))
}
None => y_i.copy_from(&self.state().unwrap().y),
}
}

Ok((y, t))
}

Expand All @@ -131,14 +166,16 @@ pub trait OdeSolverMethod<Eqn: OdeEquations> {
&mut self,
problem: &OdeSolverProblem<Eqn>,
t_eval: &[Eqn::T],
) -> Result<Vec<Eqn::V>, DiffsolError>
) -> Result<<Eqn::V as DefaultDenseMatrix>::M, DiffsolError>
where
Eqn::M: DefaultSolver,
Eqn::V: DefaultDenseMatrix,
Self: Sized,
{
let state = OdeSolverState::new(problem, self)?;
self.set_problem(state, problem);
let mut ret = vec![];
let nstates = problem.eqn.rhs().nstates();
let mut ret = <<Eqn::V as DefaultDenseMatrix>::M as Matrix>::zeros(nstates, t_eval.len());

// check t_eval is increasing and all values are greater than or equal to the current time
let t0 = self.state().unwrap().t;
Expand All @@ -149,24 +186,30 @@ pub trait OdeSolverMethod<Eqn: OdeEquations> {
// do loop
self.set_stop_time(t_eval[t_eval.len() - 1])?;
let mut step_reason = OdeSolverStopReason::InternalTimestep;
for t in t_eval.iter().take(t_eval.len() - 1) {
for (i, t) in t_eval.iter().take(t_eval.len() - 1).enumerate() {
while self.state().unwrap().t < *t {
step_reason = self.step()?;
}
let y = self.interpolate(*t)?;
let mut y_out = ret.column_mut(i);
match problem.eqn.out() {
Some(out) => ret.push(out.call(&y, *t)),
None => ret.push(y),
Some(out) => y_out.copy_from(&out.call(&y, *t)),
None => y_out.copy_from(&y),
}
}

// do final step
while step_reason != OdeSolverStopReason::TstopReached {
step_reason = self.step()?;
}
match problem.eqn.out() {
Some(out) => ret.push(out.call(&self.state().unwrap().y, self.state().unwrap().t)),
None => ret.push(self.state().unwrap().y.clone()),
{
let mut y_out = ret.column_mut(t_eval.len() - 1);
match problem.eqn.out() {
Some(out) => {
y_out.copy_from(&out.call(&self.state().unwrap().y, self.state().unwrap().t))
}
None => y_out.copy_from(&self.state().unwrap().y),
}
}
Ok(ret)
}
Expand Down
3 changes: 2 additions & 1 deletion src/ode_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mod tests {
use crate::matrix::Matrix;
use crate::op::unit::UnitCallable;
use crate::op::{NonLinearOp, Op};
use crate::{ConstantOp, DefaultSolver, Vector};
use crate::{ConstantOp, DefaultDenseMatrix, DefaultSolver, Vector};
use crate::{
OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason,
};
Expand Down Expand Up @@ -296,6 +296,7 @@ mod tests {
Eqn: OdeEquations,
Method: OdeSolverMethod<Eqn>,
Eqn::M: DefaultSolver,
Eqn::V: DefaultDenseMatrix,
{
// solve for a little bit
s.solve(&problem, Eqn::T::from(1.0)).unwrap();
Expand Down

0 comments on commit b7f0589

Please sign in to comment.