Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solvers #31

Merged
merged 11 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benches/solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ mod robertson_ode {

mod robertson {
use diffsol::{
ode_solver::test_models::robertson::robertson, Bdf, NewtonNonlinearSolver, OdeSolverMethod,
LU,
linear_solver::NalgebraLU, ode_solver::test_models::robertson::robertson, Bdf,
NewtonNonlinearSolver, OdeSolverMethod,
};

#[divan::bench]
fn bdf() {
let mut s = Bdf::default();
let (problem, _soln) = robertson::<nalgebra::DMatrix<f64>>(false);
let mut root = NewtonNonlinearSolver::new(LU::default());
let mut root = NewtonNonlinearSolver::new(NalgebraLU::default());
let _y = s.make_consistent_and_solve(&problem, 4.0000e+10, &mut root);
}

Expand Down
86 changes: 44 additions & 42 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
//! The linear solver trait is [LinearSolver], and the nonlinear solver trait is [NonLinearSolver]. The [SolverProblem] struct is used to define the problem to solve.
//!
//! The provided linear solvers are:
//! - [LU]: a direct solver that uses the LU decomposition implemented in the [nalgebra](https://nalgebra.org) library.
//! - [NalgebraLU]: a direct solver that uses the LU decomposition implemented in the [nalgebra](https://nalgebra.org) library.
//! - [SundialsLinearSolver]: a linear solver that uses the [sundials](https://computation.llnl.gov/projects/sundials) library (requires the `sundials` feature).
//!
//! The provided nonlinear solvers are:
Expand Down Expand Up @@ -135,8 +135,7 @@ pub mod scalar;
pub mod solver;
pub mod vector;

pub use linear_solver::lu::LU;
use linear_solver::LinearSolver;
pub use linear_solver::{LinearSolver, NalgebraLU};

#[cfg(feature = "sundials")]
pub use matrix::sundials::SundialsMatrix;
Expand Down Expand Up @@ -212,43 +211,46 @@ mod tests {

y2.assert_eq(&y, 1e-6);
}
// #[test]
// fn test_readme_faer() {
// type T = f64;
// type V = faer::Col<f64>;
// let problem = OdeBuilder::new()
// .p([0.04, 1.0e4, 3.0e7])
// .rtol(1e-4)
// .atol([1.0e-8, 1.0e-6, 1.0e-6])
// .build_ode(
// |x: &V, p: &V, _t: T, y: &mut V| {
// y[0] = -p[0] * x[0] + p[1] * x[1] * x[2];
// y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1];
// y[2] = p[2] * x[1] * x[1];
// },
// |x: &V, p: &V, _t: T, v: &V, y: &mut V| {
// y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2];
// y[1] = p[0] * v[0]
// - p[1] * v[1] * x[2]
// - p[1] * x[1] * v[2]
// - 2.0 * p[2] * x[1] * v[1];
// y[2] = 2.0 * p[2] * x[1] * v[1];
// },
// |_p: &V, _t: T| V::from_vec(vec![1.0, 0.0, 0.0]),
// );

// let mut solver = Bdf::default();

// let t = 0.4;
// let y = solver.solve(&problem.unwrap(), t).unwrap();

// let mut state = OdeSolverState::new(&problem);
// solver.set_problem(&mut state, &problem);
// while state.t <= t {
// solver.step(&mut state).unwrap();
// }
// let y2 = solver.interpolate(&state, t);

// y2.assert_eq(&y, 1e-6);
// }
#[test]
fn test_readme_faer() {
type T = f64;
type V = faer::Col<f64>;
let problem = OdeBuilder::new()
.p([0.04, 1.0e4, 3.0e7])
.rtol(1e-4)
.atol([1.0e-8, 1.0e-6, 1.0e-6])
.build_ode(
|x: &V, p: &V, _t: T, y: &mut V| {
y[0] = -p[0] * x[0] + p[1] * x[1] * x[2];
y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1];
y[2] = p[2] * x[1] * x[1];
},
|x: &V, p: &V, _t: T, v: &V, y: &mut V| {
y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2];
y[1] = p[0] * v[0]
- p[1] * v[1] * x[2]
- p[1] * x[1] * v[2]
- 2.0 * p[2] * x[1] * v[1];
y[2] = 2.0 * p[2] * x[1] * v[1];
},
|_p: &V, _t: T| V::from_vec(vec![1.0, 0.0, 0.0]),
)
.unwrap();

let mut solver = Bdf::default();

let t = 0.4;
let y = solver.solve(&problem, t).unwrap();

let state = OdeSolverState::new(&problem);
solver.set_problem(state, &problem);
while solver.state().unwrap().t <= t {
solver.step().unwrap();
}
let y2 = solver.interpolate(t).unwrap();

y2.assert_eq(&y, 1e-6);
}

// y2.assert_eq(&y, 1e-6);
}
52 changes: 52 additions & 0 deletions src/linear_solver/faer/lu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use crate::{linear_solver::LinearSolver, op::LinearOp, solver::SolverProblem, Scalar};
use anyhow::Result;
use faer::{linalg::solvers::FullPivLu, solvers::SpSolver, Col, Mat};
/// A [LinearSolver] that uses the LU decomposition in the [`faer`](https://github.com/sarah-ek/faer-rs) library to solve the linear system.
pub struct LU<T, C>
where
T: Scalar,
C: LinearOp<M = Mat<T>, V = Col<T>, T = T>,
{
lu: Option<FullPivLu<T>>,
problem: Option<SolverProblem<C>>,
}

impl<T, C> Default for LU<T, C>
where
T: Scalar,
C: LinearOp<M = Mat<T>, V = Col<T>, T = T>,
{
fn default() -> Self {
Self {
lu: None,
problem: None,
}
}
}

impl<T: Scalar, C: LinearOp<M = Mat<T>, V = Col<T>, T = T>> LinearSolver<C> for LU<T, C> {
fn problem(&self) -> Option<&SolverProblem<C>> {
self.problem.as_ref()
}
fn problem_mut(&mut self) -> Option<&mut SolverProblem<C>> {
self.problem.as_mut()
}
fn take_problem(&mut self) -> Option<SolverProblem<C>> {
self.lu = None;
Option::take(&mut self.problem)
}

fn solve_in_place(&mut self, state: &mut C::V) -> Result<()> {
if self.lu.is_none() {
return Err(anyhow::anyhow!("LU not initialized"));
}
let lu = self.lu.as_ref().unwrap();
lu.solve_in_place(state);
Ok(())
}

fn set_problem(&mut self, problem: SolverProblem<C>) {
self.lu = Some(problem.f.jacobian(problem.t).full_piv_lu());
self.problem = Some(problem);
}
}
1 change: 1 addition & 0 deletions src/linear_solver/faer/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod lu;
29 changes: 23 additions & 6 deletions src/linear_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@ use crate::{op::Op, solver::SolverProblem};
use anyhow::Result;

pub mod gmres;
pub mod lu;
#[cfg(feature = "nalgebra")]
pub mod nalgebra;

#[cfg(feature = "faer")]
pub mod faer;

#[cfg(feature = "sundials")]
pub mod sundials;

pub use faer::lu::LU as FaerLU;
pub use nalgebra::lu::LU as NalgebraLU;

/// A solver for the linear problem `Ax = b`.
/// The solver is parameterised by the type `C` which is the type of the linear operator `A` (see the [Op] trait for more details).
pub trait LinearSolver<C: Op> {
Expand All @@ -22,6 +29,7 @@ pub trait LinearSolver<C: Op> {

/// Take the current problem, if any, and return it.
fn take_problem(&mut self) -> Option<SolverProblem<C>>;

fn reset(&mut self) {
if let Some(problem) = self.take_problem() {
self.set_problem(problem);
Expand Down Expand Up @@ -54,10 +62,12 @@ pub mod tests {
use std::rc::Rc;

use crate::{
linear_solver::FaerLU,
linear_solver::NalgebraLU,
op::{linear_closure::LinearClosure, LinearOp},
scalar::scale,
vector::VectorRef,
DenseMatrix, LinearSolver, SolverProblem, Vector, LU,
DenseMatrix, LinearSolver, SolverProblem, Vector,
};
use num_traits::{One, Zero};

Expand Down Expand Up @@ -107,12 +117,19 @@ pub mod tests {
}
}

type MCpu = nalgebra::DMatrix<f64>;
type MCpuNalgebra = nalgebra::DMatrix<f64>;
type MCpuFaer = faer::Mat<f64>;

#[test]
fn test_lu() {
let (p, solns) = linear_problem::<MCpu>();
let s = LU::default();
fn test_lu_nalgebra() {
let (p, solns) = linear_problem::<MCpuNalgebra>();
let s = NalgebraLU::default();
test_linear_solver(s, p, solns);
}
#[test]
fn test_lu_faer() {
let (p, solns) = linear_problem::<MCpuFaer>();
let s = FaerLU::default();
test_linear_solver(s, p, solns);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<T: Scalar, C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>> LinearSolver
}
fn take_problem(&mut self) -> Option<SolverProblem<C>> {
self.lu = None;
self.problem.take()
Option::take(&mut self.problem)
}

fn solve_in_place(&mut self, state: &mut C::V) -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions src/linear_solver/nalgebra/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod lu;
2 changes: 1 addition & 1 deletion src/linear_solver/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ where

fn take_problem(&mut self) -> Option<SolverProblem<Op>> {
self.is_setup = false;
self.problem.take()
Option::take(&mut self.problem)
}

fn solve_in_place(&mut self, b: &mut Op::V) -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion src/nonlinear_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ pub mod newton;
pub mod tests {
use self::newton::NewtonNonlinearSolver;
use crate::{
linear_solver::lu::LU,
linear_solver::nalgebra::lu::LU,
matrix::MatrixCommon,
op::{closure::Closure, NonLinearOp},
DenseMatrix,
Expand Down
2 changes: 1 addition & 1 deletion src/nonlinear_solver/newton.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<C: NonLinearOp> NonLinearSolver<C> for NewtonNonlinearSolver<C> {
}

fn take_problem(&mut self) -> Option<SolverProblem<C>> {
self.problem.take()
Option::take(&mut self.problem)
}

fn solve_in_place(&mut self, xn: &mut C::V) -> Result<()> {
Expand Down
85 changes: 85 additions & 0 deletions src/ode_solver/bdf/faer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use faer::{Col, Mat};

use crate::{
linear_solver::FaerLU, op::ode::BdfCallable, Bdf, NewtonNonlinearSolver, NonLinearSolver,
OdeEquations, Scalar, VectorRef,
};

impl<T: Scalar, Eqn: OdeEquations<T = T, V = Col<T>, M = Mat<T>> + 'static> Default
for Bdf<Mat<T>, Eqn>
{
fn default() -> Self {
let n = 1;
let linear_solver = FaerLU::default();
let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::<BdfCallable<Eqn>>::new(
linear_solver,
));
nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER);
Self {
ode_problem: None,
nonlinear_solver,
order: 1,
n_equal_steps: 0,
diff: Mat::zeros(n, Self::MAX_ORDER + 3),
diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3),
gamma: vec![T::from(1.0); Self::MAX_ORDER + 1],
alpha: vec![T::from(1.0); Self::MAX_ORDER + 1],
error_const: vec![T::from(1.0); Self::MAX_ORDER + 1],
u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: super::BdfStatistics::default(),
state: None,
}
}
}

// implement clone for bdf
impl<T: Scalar, Eqn: OdeEquations<T = T, V = Col<T>, M = Mat<T>> + 'static> Clone
for Bdf<Mat<T>, Eqn>
where
for<'b> &'b Col<T>: VectorRef<Col<T>>,
{
fn clone(&self) -> Self {
let n = self.diff.nrows();
let linear_solver = FaerLU::default();
let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::<BdfCallable<Eqn>>::new(
linear_solver,
));
nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER);
Self {
ode_problem: self.ode_problem.clone(),
nonlinear_solver,
order: self.order,
n_equal_steps: self.n_equal_steps,
diff: Mat::zeros(n, Self::MAX_ORDER + 3),
diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3),
gamma: self.gamma.clone(),
alpha: self.alpha.clone(),
error_const: self.error_const.clone(),
u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: self.statistics.clone(),
state: self.state.clone(),
}
}
}

#[cfg(test)]
mod test {
use crate::{
ode_solver::tests::{test_interpolate, test_no_set_problem, test_take_state},
Bdf,
};

type M = faer::Mat<f64>;
#[test]
fn bdf_no_set_problem() {
test_no_set_problem::<M, _>(Bdf::default())
}
#[test]
fn bdf_take_state() {
test_take_state::<M, _>(Bdf::default())
}
#[test]
fn bdf_test_interpolate() {
test_interpolate::<M, _>(Bdf::default())
}
}
Loading
Loading