Skip to content

Commit

Permalink
refactor: improvements to data ownership and lifetimes (#109)
Browse files Browse the repository at this point in the history
* solvers now created from a problem, solvers have a lifetime linked to the problem
* data ownership now much clearer across the structs: 
   * equation struct owned by problem
   * tolerances owned by problem
   * parameter vector owned by equation
* equation trait refactored to allow data to be shared across op structs (owned by struct implementing the equation trait)
* builder pattern extended to generic parameters, rhs, root, mass etc can be set independently and adjoint and sensitivitity equations supported
  • Loading branch information
martinjrobins authored Nov 30, 2024
1 parent 76636c7 commit 8bcfc60
Show file tree
Hide file tree
Showing 78 changed files with 4,361 additions and 4,454 deletions.
74 changes: 32 additions & 42 deletions benches/ode_solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ fn criterion_benchmark(c: &mut Criterion) {
($name:ident, $solver:ident, $linear_solver:ident, $model:ident, $model_problem:ident, $matrix:ty) => {
c.bench_function(stringify!($name), |b| {
b.iter(|| {
let ls = $linear_solver::default();
let (problem, soln) = $model_problem::<$matrix>(false);
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls);
benchmarks::$solver::<_, $linear_solver<_>>(
&problem,
soln.solution_points.last().unwrap().t,
);
})
});
};
Expand Down Expand Up @@ -126,9 +128,8 @@ fn criterion_benchmark(c: &mut Criterion) {
($name:ident, $solver:ident, $linear_solver:ident, $model:ident, $model_problem:ident, $matrix:ty, $($N:expr),+) => {
$(c.bench_function(concat!(stringify!($name), "_", $N), |b| {
b.iter(|| {
let ls = $linear_solver::default();
let (problem, soln) = $model_problem::<$matrix>(false, $N);
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls);
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t);
})
});)+
};
Expand Down Expand Up @@ -219,12 +220,14 @@ fn criterion_benchmark(c: &mut Criterion) {
($name:ident, $solver:ident, $linear_solver:ident, $matrix:ty) => {
#[cfg(feature = "diffsl-llvm")]
c.bench_function(stringify!($name), |b| {
use diffsol::diffsl::LlvmModule;
use diffsol::ode_solver::test_models::robertson::*;
use diffsol::LlvmModule;
b.iter(|| {
let (problem, soln) = robertson_diffsl_problem::<$matrix, LlvmModule>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
benchmarks::$solver::<_, $linear_solver<_>>(
&problem,
soln.solution_points.last().unwrap().t,
)
})
});
};
Expand All @@ -242,8 +245,7 @@ fn criterion_benchmark(c: &mut Criterion) {
$(c.bench_function(concat!(stringify!($name), "_", $N), |b| {
b.iter(|| {
let (problem, soln) = $model_problem::<$matrix, $N>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
})
});)+
};
Expand Down Expand Up @@ -334,8 +336,7 @@ fn criterion_benchmark(c: &mut Criterion) {
$(c.bench_function(concat!(stringify!($name), "_", $N), |b| {
b.iter(|| {
let (problem, soln) = $model_problem::<$matrix, $N>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
})
});)+
};
Expand Down Expand Up @@ -424,11 +425,10 @@ fn criterion_benchmark(c: &mut Criterion) {
$(#[cfg(feature = "diffsl-llvm")]
c.bench_function(concat!(stringify!($name), "_", $N), |b| {
use diffsol::ode_solver::test_models::heat2d::*;
use diffsol::diffsl::LlvmModule;
use diffsol::LlvmModule;
b.iter(|| {
let (problem, soln) = heat2d_diffsl_problem::<$matrix, LlvmModule, $N>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
})
});)+
};
Expand Down Expand Up @@ -499,11 +499,10 @@ fn criterion_benchmark(c: &mut Criterion) {
$(#[cfg(feature = "diffsl-llvm")]
c.bench_function(concat!(stringify!($name), "_", $N), |b| {
use diffsol::ode_solver::test_models::foodweb::*;
use diffsol::diffsl::LlvmModule;
use diffsol::LlvmModule;
b.iter(|| {
let (problem, soln) = foodweb_diffsl_problem::<$matrix, LlvmModule, $N>();
let ls = $linear_solver::default();
benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls)
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
})
});)+

Expand Down Expand Up @@ -542,56 +541,47 @@ mod benchmarks {
use diffsol::vector::VectorRef;
use diffsol::LinearSolver;
use diffsol::{
Bdf, DefaultDenseMatrix, DefaultSolver, Matrix, NewtonNonlinearSolver,
OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Sdirk, Tableau,
DefaultDenseMatrix, DefaultSolver, Matrix, OdeEquationsImplicit, OdeSolverMethod,
OdeSolverProblem,
};

// bdf
pub fn bdf<Eqn>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T, ls: impl LinearSolver<Eqn::M>)
pub fn bdf<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T)
where
Eqn: OdeEquationsImplicit,
Eqn::M: Matrix + DefaultSolver,
Eqn::V: DefaultDenseMatrix,
LS: LinearSolver<Eqn::M>,
for<'a> &'a Eqn::V: VectorRef<Eqn::V>,
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
{
let nls = NewtonNonlinearSolver::new(ls);
let mut s = Bdf::<<Eqn::V as DefaultDenseMatrix>::M, _, _>::new(nls);
let state = OdeSolverState::new(problem, &s).unwrap();
let _y = s.solve(problem, state, t);
let mut s = problem.bdf::<LS>().unwrap();
let _y = s.solve(t);
}

pub fn esdirk34<Eqn>(
problem: &OdeSolverProblem<Eqn>,
t: Eqn::T,
linear_solver: impl LinearSolver<Eqn::M>,
) where
pub fn esdirk34<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T)
where
Eqn: OdeEquationsImplicit,
Eqn::M: Matrix + DefaultSolver,
Eqn::V: DefaultDenseMatrix,
LS: LinearSolver<Eqn::M>,
for<'a> &'a Eqn::V: VectorRef<Eqn::V>,
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
{
let tableau = Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::esdirk34();
let mut s = Sdirk::new(tableau, linear_solver);
let state = OdeSolverState::new(problem, &s).unwrap();
let _y = s.solve(problem, state, t);
let mut s = problem.esdirk34::<LS>().unwrap();
let _y = s.solve(t);
}

pub fn tr_bdf2<Eqn>(
problem: &OdeSolverProblem<Eqn>,
t: Eqn::T,
linear_solver: impl LinearSolver<Eqn::M>,
) where
pub fn tr_bdf2<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T)
where
Eqn: OdeEquationsImplicit,
Eqn::M: Matrix + DefaultSolver,
Eqn::V: DefaultDenseMatrix,
LS: LinearSolver<Eqn::M>,
for<'a> &'a Eqn::V: VectorRef<Eqn::V>,
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
{
let tableau = Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::tr_bdf2();
let mut s = Sdirk::new(tableau, linear_solver);
let state = OdeSolverState::new(problem, &s).unwrap();
let _y = s.solve(problem, state, t);
let mut s = problem.tr_bdf2::<LS>().unwrap();
let _y = s.solve(t);
}
}
3 changes: 1 addition & 2 deletions book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
- [Non-linear functions](./specify/custom/non_linear_functions.md)
- [Constant functions](./specify/custom/constant_functions.md)
- [Linear functions](./specify/custom/linear_functions.md)
- [Putting it all together](./specify/custom/putting_it_all_together.md)
- [ODE systems](./specify/custom/ode_systems.md)
- [DiffSL](./specify/diffsl.md)
- [Sparse problems](./specify/sparse_problems.md)
- [Choosing a solver](./choosing_a_solver.md)
- [Initialisation](./initialisation.md)
- [Solving the problem](./solving_the_problem.md)
- [Benchmarks](./benchmarks.md)
105 changes: 59 additions & 46 deletions book/src/choosing_a_solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,75 +3,88 @@
Once you have defined the problem, you need to create a solver to solve the problem. The available solvers are:
- [`diffsol::Bdf`](https://docs.rs/diffsol/latest/diffsol/ode_solver/bdf/struct.Bdf.html): A Backwards Difference Formulae solver, suitable for stiff problems and singular mass matrices.
- [`diffsol::Sdirk`](https://docs.rs/diffsol/latest/diffsol/ode_solver/sdirk/struct.Sdirk.html) A Singly Diagonally Implicit Runge-Kutta (SDIRK or ESDIRK) solver. You can define your own butcher tableau using [`Tableau`](https://docs.rs/diffsol/latest/diffsol/ode_solver/tableau/struct.Tableau.html) or use one of the pre-defined tableaues.


Each of these solvers has a number of generic arguments, for example the `Bdf` solver has three generic arguments:
- `M`: The matrix type used to define the problem.
- `Eqn`: The type of the equations struct that defines the problem.
- `Nls`: The type of the non-linear solver used to solve the implicit equations in the solver.
For each solver, you will need to specify the linear solver type to use. The available linear solvers are:
- [`diffsol::NalgebraLU`](https://docs.rs/diffsol/latest/diffsol/linear_solver/nalgebra_lu/struct.NalgebraLU.html): A LU decomposition solver using the [nalgebra](https://nalgebra.org) crate.
- [`diffsol::FaerLU`](https://docs.rs/diffsol/latest/diffsol/linear_solver/faer_lu/struct.FaerLU.html): A LU decomposition solver using the [faer](https://github.com/sarah-ek/faer-rs) crate.
- [`diffsol::FaerSparseLU`](https://docs.rs/diffsol/latest/diffsol/linear_solver/faer_sparse_lu/struct.FaerSparseLU.html): A sparse LU decomposition solver using the `faer` crate.

In normal use cases, Rust can infer these from your code so you don't need to specify these explicitly. The `Bdf` solver implements the `Default` trait so can be easily created using:
Each solver can be created directly, but it generally easier to use the methods on the [`OdeSolverProblem`](https://docs.rs/diffsol/latest/diffsol/ode_solver/problem/struct.OdeSolverProblem.html) struct to create the solver.
For example:

```rust
# use diffsol::OdeBuilder;
# use nalgebra::DVector;
use diffsol::{OdeSolverState, NalgebraLU, BdfState, Tableau, SdirkState};
# type M = nalgebra::DMatrix<f64>;
use diffsol::{Bdf, OdeSolverState, OdeSolverMethod};
type LS = NalgebraLU<f64>;
# fn main() {
#
# let problem = OdeBuilder::new()
# let problem = OdeBuilder::<M>::new()
# .p(vec![1.0, 10.0])
# .build_ode::<M, _, _, _>(
# .rhs_implicit(
# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]),
# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]),
# |_p, _t| DVector::from_element(1, 0.1),
# ).unwrap();
let mut solver = Bdf::default();
let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem);
# }
```
# )
# .init(|_p, _t| DVector::from_element(1, 0.1))
# .build()
# .unwrap();
// Create a BDF solver with an initial state
let solver = problem.bdf::<LS>();

The `Sdirk` solver requires a tableu to be specified so you can use its `new` method to create a new solver, for example using the `tr_bdf2` tableau:
// Create a non-initialised state and manually set the values before
// creating the solver
let state = BdfState::new_without_initialise(&problem).unwrap();
// ... set the state values manually
let solver = problem.bdf_solver::<LS>(state);

```rust
# use diffsol::{OdeBuilder};
# use nalgebra::DVector;
# type M = nalgebra::DMatrix<f64>;
use diffsol::{Sdirk, Tableau, NalgebraLU, OdeSolverState, OdeSolverMethod};
# fn main() {
# let problem = OdeBuilder::new()
# .p(vec![1.0, 10.0])
# .build_ode::<M, _, _, _>(
# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]),
# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]),
# |_p, _t| DVector::from_element(1, 0.1),
# ).unwrap();
let mut solver = Sdirk::new(Tableau::<M>::tr_bdf2(), NalgebraLU::default());
let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem);
// Create a SDIRK solver with a pre-defined tableau
let tableau = Tableau::<M>::tr_bdf2();
let state = problem.sdirk_state::<LS, _>(&tableau).unwrap();
let solver = problem.sdirk_solver::<LS, _>(state, tableau);

// Create a tr_bdf2 or esdirk34 solvers directly (both are SDIRK solvers with different tableaus)
let solver = problem.tr_bdf2::<LS>();
let solver = problem.esdirk34::<LS>();

// Create a non-initialised state and manually set the values before
// creating the solver
let state = SdirkState::new_without_initialise(&problem).unwrap();
// ... set the state values manually
let solver = problem.tr_bdf2_solver::<LS>(state);
# }
```

You can also use one of the helper functions to create a SDIRK solver with a pre-defined tableau, which will create it with the default linear solver:
# Initialisation

Each solver has an internal state that holds information like the current state vector, the gradient of the state vector, the current time, and the current step size. When you create a solver using the `bdf` or `sdirk` methods on the `OdeSolverProblem` struct, the solver will be initialised with an initial state based on the initial conditions of the problem as well as satisfying any algebraic constraints. An initial time step will also be chosen based on your provided equations.

Each solver's state struct implements the [`OdeSolverState`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/trait.OdeSolverState.html) trait, and if you wish to manually create and setup a state, you can use the methods on this trait to do so.

For example, say that you wish to bypass the initialisation of the state as you already have the algebraic constraints and so don't need to solve for them. You can use the `new_without_initialise` method on the `OdeSolverState` trait to create a new state without initialising it. You can then use the `as_mut` method to get a mutable reference to the state and set the values manually.

Note that each state struct has a [`as_ref`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/trait.OdeSolverState.html#tymethod.as_ref) and [`as_mut`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/trait.OdeSolverState.html#tymethod.as_mut) methods that return a [`StateRef`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/struct.StateRef.html) or ['StateRefMut`](https://docs.rs/diffsol/latest/diffsol/ode_solver/state/struct.StateRefMut.html) struct respectively. These structs provide a solver-independent way to access the state values so you can use the same code with different solvers.

```rust
# use diffsol::{OdeBuilder};
# use diffsol::OdeBuilder;
# use nalgebra::DVector;
# type M = nalgebra::DMatrix<f64>;
use diffsol::{Sdirk, Tableau, NalgebraLU, OdeSolverState, OdeSolverMethod};
use diffsol::{OdeSolverState, NalgebraLU, BdfState};
type LS = NalgebraLU<f64>;

# fn main() {
# let problem = OdeBuilder::new()
#
# let problem = OdeBuilder::<M>::new()
# .p(vec![1.0, 10.0])
# .build_ode::<M, _, _, _>(
# .rhs_implicit(
# |x, p, _t, y| y[0] = p[0] * x[0] * (1.0 - x[0] / p[1]),
# |x, p, _t, v , y| y[0] = p[0] * v[0] * (1.0 - 2.0 * x[0] / p[1]),
# |_p, _t| DVector::from_element(1, 0.1),
# ).unwrap();
let mut solver = Sdirk::tr_bdf2();
let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem);
# )
# .init(|_p, _t| DVector::from_element(1, 0.1))
# .build()
# .unwrap();
let mut state = BdfState::new_without_initialise(&problem).unwrap();
state.as_mut().y[0] = 0.1;
let mut solver = problem.bdf_solver::<LS>(state);
# }
```


```
34 changes: 0 additions & 34 deletions book/src/initialisation.md

This file was deleted.

1 change: 1 addition & 0 deletions book/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading

0 comments on commit 8bcfc60

Please sign in to comment.