Skip to content

Commit

Permalink
feat: forward sensitivity analysis (#53)
Browse files Browse the repository at this point in the history
* feat: add forward sensitivity analysis for bdf and sdirk solvers (*not* for sundials solver)
* bug: add new initial condition solver so that differential mass matrix is taken into account
* optimisation: removed some memory allocations
  • Loading branch information
martinjrobins authored May 28, 2024
1 parent 3e89fb2 commit 99f7fa0
Show file tree
Hide file tree
Showing 42 changed files with 3,732 additions and 997 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ Cargo.lock
/target
/Cargo.lock

.vscode
.vscode
*pending-snap
13 changes: 3 additions & 10 deletions benches/solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,20 @@ mod robertson_ode {
}

mod robertson {
use diffsol::{
linear_solver::NalgebraLU, ode_solver::test_models::robertson::robertson, Bdf,
NewtonNonlinearSolver, OdeSolverMethod,
};
use diffsol::{ode_solver::test_models::robertson::robertson, Bdf, OdeSolverMethod};

#[divan::bench]
fn bdf() {
let mut s = Bdf::default();
let (problem, _soln) = robertson::<nalgebra::DMatrix<f64>>(false);
let mut root = NewtonNonlinearSolver::new(NalgebraLU::default());
let _y = s.make_consistent_and_solve(&problem, 4.0000e+10, &mut root);
let _y = s.solve(&problem, 4.0000e+10);
}

#[cfg(feature = "sundials")]
#[divan::bench]
fn sundials() {
use diffsol::SundialsLinearSolver;

let mut s = diffsol::SundialsIda::default();
let (problem, _soln) = robertson::<diffsol::SundialsMatrix>(false);
let mut root = NewtonNonlinearSolver::new(SundialsLinearSolver::new_dense());
let _y = s.make_consistent_and_solve(&problem, 4.0000e+10, &mut root);
let _y = s.solve(&problem, 4.0000e+10);
}
}
121 changes: 23 additions & 98 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! - Use the [OdeSolverMethod::step] method to step the solution forward in time with an internal time step chosen by the solver to meet the error tolerances.
//! - Use the [OdeSolverMethod::interpolate] method to interpolate the solution between the last two time steps.
//! - Use the [OdeSolverMethod::set_stop_time] method to stop the solver at a specific time (i.e. this will override the internal time step so that the solver stops at the specified time).
//! - Alternatively, use the convenience functions [OdeSolverMethod::solve] and [OdeSolverMethod::make_consistent_and_solve] that will both initialise the problem and solve the problem up to a specific time.
//! - Alternatively, use the convenience function [OdeSolverMethod::solve] that will both initialise the problem and solve the problem up to a specific time.
//!
//! ## DiffSL
//!
Expand Down Expand Up @@ -63,6 +63,13 @@
//! DiffSol provides a simple way to detect user-provided events during the integration of the ODEs. You can use this by providing a closure that has a zero-crossing at the event you want to detect, using the [OdeBuilder::build_ode_with_root] builder,
//! or by providing a [NonLinearOp] that has a zero-crossing at the event you want to detect. To use the root finding feature while integrating with the solver, you can use the return value of [OdeSolverMethod::step] to check if an event has been detected.
//!
//! ## Forward Sensitivity Analysis
//!
//! DiffSol provides a way to compute the forward sensitivity of the solution with respect to the parameters. You can use this by using the [OdeBuilder::build_ode_with_sens] or [OdeBuilder::build_ode_with_mass_and_sens] builder functions.
//! Note that by default the sensitivity equations are not included in the error control for the solvers, you can change this by using the [OdeBuilder::sensitivities_error_control] method.
//!
//! To obtain the sensitivity solution via interpolation, you can use the [OdeSolverMethod::interpolate_sens] method. Otherwise the sensitivity vectors are stored in the [OdeSolverState] struct.
//!
//! ## Nonlinear and linear solvers
//!
//! DiffSol provides generic nonlinear and linear solvers that are used internally by the ODE solver. You can use the solvers provided by DiffSol, or implement your own following the provided traits.
Expand Down Expand Up @@ -148,109 +155,27 @@ pub use ode_solver::diffsl::DiffSlContext;
pub use matrix::default_solver::DefaultSolver;
use matrix::{DenseMatrix, Matrix, MatrixCommon, MatrixSparsity, MatrixView, MatrixViewMut};
pub use nonlinear_solver::newton::NewtonNonlinearSolver;
use nonlinear_solver::{root::RootFinder, NonLinearSolver};
use nonlinear_solver::{
convergence::Convergence, convergence::ConvergenceStatus, newton::newton_iteration,
root::RootFinder, NonLinearSolver,
};
pub use ode_solver::{
bdf::Bdf, builder::OdeBuilder, equations::OdeEquations, equations::OdeSolverEquations,
method::OdeSolverMethod, method::OdeSolverState, method::OdeSolverStopReason,
problem::OdeSolverProblem, sdirk::Sdirk, tableau::Tableau,
problem::OdeSolverProblem, sdirk::Sdirk, sens_equations::SensEquations,
sens_equations::SensInit, sens_equations::SensRhs, tableau::Tableau,
};
pub use op::{
closure::Closure, constant_closure::ConstantClosure, linear_closure::LinearClosure,
unit::UnitCallable, ConstantOp, LinearOp, NonLinearOp, Op,
};
use op::{
closure_no_jac::ClosureNoJac, closure_with_sens::ClosureWithSens,
constant_closure_with_sens::ConstantClosureWithSens, init::InitOp,
linear_closure_with_sens::LinearClosureWithSens,
};
use op::{closure::Closure, closure_no_jac::ClosureNoJac, linear_closure::LinearClosure};
pub use op::{unit::UnitCallable, LinearOp, NonLinearOp, Op};
use scalar::{IndexType, Scalar, Scale};
use solver::SolverProblem;
use vector::{Vector, VectorCommon, VectorIndex, VectorRef, VectorView, VectorViewMut};

pub use scalar::scale;

#[cfg(test)]
mod tests {

use crate::{
ode_solver::builder::OdeBuilder, vector::Vector, Bdf, OdeSolverMethod, OdeSolverState,
};

// WARNING: if this test fails and you make a change to the code, you should update the README.md file as well!!!
#[test]
fn test_readme() {
type T = f64;
type V = nalgebra::DVector<T>;
let problem = OdeBuilder::new()
.p([0.04, 1.0e4, 3.0e7])
.rtol(1e-4)
.atol([1.0e-8, 1.0e-6, 1.0e-6])
.build_ode_dense(
|x: &V, p: &V, _t: T, y: &mut V| {
y[0] = -p[0] * x[0] + p[1] * x[1] * x[2];
y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1];
y[2] = p[2] * x[1] * x[1];
},
|x: &V, p: &V, _t: T, v: &V, y: &mut V| {
y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2];
y[1] = p[0] * v[0]
- p[1] * v[1] * x[2]
- p[1] * x[1] * v[2]
- 2.0 * p[2] * x[1] * v[1];
y[2] = 2.0 * p[2] * x[1] * v[1];
},
|_p: &V, _t: T| V::from_vec(vec![1.0, 0.0, 0.0]),
)
.unwrap();

let mut solver = Bdf::default();

let t = 0.4;
let y = solver.solve(&problem, t).unwrap();

let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem);
while solver.state().unwrap().t <= t {
solver.step().unwrap();
}
let y2 = solver.interpolate(t).unwrap();

y2.assert_eq_st(&y, 1e-6);
}
#[test]
fn test_readme_faer() {
type T = f64;
type V = faer::Col<f64>;
type M = faer::Mat<f64>;
let problem = OdeBuilder::new()
.p([0.04, 1.0e4, 3.0e7])
.rtol(1e-4)
.atol([1.0e-8, 1.0e-6, 1.0e-6])
.build_ode_dense(
|x: &V, p: &V, _t: T, y: &mut V| {
y[0] = -p[0] * x[0] + p[1] * x[1] * x[2];
y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1];
y[2] = p[2] * x[1] * x[1];
},
|x: &V, p: &V, _t: T, v: &V, y: &mut V| {
y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2];
y[1] = p[0] * v[0]
- p[1] * v[1] * x[2]
- p[1] * x[1] * v[2]
- 2.0 * p[2] * x[1] * v[1];
y[2] = 2.0 * p[2] * x[1] * v[1];
},
|_p: &V, _t: T| V::from_vec(vec![1.0, 0.0, 0.0]),
)
.unwrap();

let mut solver = Bdf::<M, _, _>::default();

let t = 0.4;
let y = solver.solve(&problem, t).unwrap();

let state = OdeSolverState::new(&problem, &solver).unwrap();
solver.set_problem(state, &problem);
while solver.state().unwrap().t <= t {
solver.step().unwrap();
}
let y2 = solver.interpolate(t).unwrap();

y2.assert_eq_st(&y, 1e-6);
}

// y2.assert_eq(&y, 1e-6);
}
10 changes: 9 additions & 1 deletion src/matrix/dense_faer_serial.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Mul, MulAssign};
use std::ops::{AddAssign, Mul, MulAssign};

use super::default_solver::DefaultSolver;
use super::{Dense, DenseMatrix, Matrix, MatrixCommon, MatrixSparsity, MatrixView, MatrixViewMut};
Expand Down Expand Up @@ -144,6 +144,14 @@ impl<T: Scalar> Matrix for Mat<T> {
}
}

fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V) {
v.add_assign(&self.column(j));
}

fn triplet_iter(&self) -> impl Iterator<Item = (IndexType, IndexType, &Self::T)> {
(0..self.nrows()).flat_map(move |i| (0..self.ncols()).map(move |j| (i, j, &self[(i, j)])))
}

fn try_from_triplets(
nrows: IndexType,
ncols: IndexType,
Expand Down
10 changes: 10 additions & 0 deletions src/matrix/dense_nalgebra_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ impl<T: Scalar> Matrix for DMatrix<T> {
}
}

fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V) {
v.add_assign(&self.column(j));
}

fn triplet_iter(&self) -> impl Iterator<Item = (IndexType, IndexType, &Self::T)> {
let n = self.ncols();
let m = self.nrows();
(0..m).flat_map(move |i| (0..n).map(move |j| (i, j, &self[(i, j)])))
}

fn try_from_triplets(
nrows: IndexType,
ncols: IndexType,
Expand Down
91 changes: 90 additions & 1 deletion src/matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fmt::Debug;
use std::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};

use crate::scalar::Scale;
use crate::{IndexType, Scalar, Vector};
use crate::{IndexType, Scalar, Vector, VectorIndex};
use anyhow::Result;
use num_traits::{One, Zero};

Expand Down Expand Up @@ -206,6 +206,91 @@ pub trait Matrix:
None
}

/// Split the current matrix into four submatrices at the given indices
fn split_at_indices(
&self,
indices: &<Self::V as crate::vector::Vector>::Index,
) -> (Self, Self, Self, Self) {
let n = self.nrows();
if n != self.ncols() {
panic!("Matrix must be square");
}
let ni = indices.len();
let nni = n - ni;
let mut indices = indices.clone_as_vec();
indices.sort();
let cat = (0..n)
.map(|i| indices.as_slice().binary_search(&i).is_ok())
.collect::<Vec<_>>();
let mut ur_triplets = Vec::new();
let mut ul_triplets = Vec::new();
let mut lr_triplets = Vec::new();
let mut ll_triplets = Vec::new();
for (i, j, &v) in self.triplet_iter() {
if !cat[i] && !cat[j] {
ul_triplets.push((i, j, v));
} else if !cat[i] && cat[j] {
ur_triplets.push((i, j - nni, v));
} else if cat[i] && !cat[j] {
ll_triplets.push((i - nni, j, v));
} else {
lr_triplets.push((i - nni, j - nni, v));
}
}
(
Self::try_from_triplets(nni, nni, ul_triplets).unwrap(),
Self::try_from_triplets(nni, ni, ur_triplets).unwrap(),
Self::try_from_triplets(ni, nni, ll_triplets).unwrap(),
Self::try_from_triplets(ni, ni, lr_triplets).unwrap(),
)
}

/// Combine four matrices into a single matrix at the given indices
fn combine_at_indices(
ul: &Self,
ur: &Self,
ll: &Self,
lr: &Self,
indices: &<Self::V as Vector>::Index,
) -> Self {
let n = ul.nrows() + ll.nrows();
let m = ul.ncols() + ur.ncols();
if ul.ncols() != ll.ncols()
|| ur.ncols() != lr.ncols()
|| ul.nrows() != ur.nrows()
|| ll.nrows() != lr.nrows()
{
panic!("Matrices must have the same shape");
}
let mut triplets = Vec::new();
let mut indices = indices.clone_as_vec();
indices.sort();
let cat = (0..n)
.map(|i| indices.as_slice().binary_search(&i).is_ok())
.collect::<Vec<_>>();
for (i, j, &v) in ul.triplet_iter() {
if !cat[i] && !cat[j] {
triplets.push((i, j, v));
}
}
for (i, j, &v) in ur.triplet_iter() {
if !cat[i] && cat[j + ul.ncols()] {
triplets.push((i, j + ul.ncols(), v));
}
}
for (i, j, &v) in ll.triplet_iter() {
if cat[i + ul.nrows()] && !cat[j] {
triplets.push((i + ul.nrows(), j, v));
}
}
for (i, j, &v) in lr.triplet_iter() {
if cat[i + ul.nrows()] && cat[j + ul.ncols()] {
triplets.push((i + ul.nrows(), j + ul.ncols(), v));
}
}
Self::try_from_triplets(n, m, triplets).unwrap()
}

/// Extract the diagonal of the matrix as an owned vector
fn diagonal(&self) -> Self::V;

Expand All @@ -230,6 +315,8 @@ pub trait Matrix:

fn set_column(&mut self, j: IndexType, v: &Self::V);

fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V);

fn set_data_with_indices(
&mut self,
dst_indices: &<Self::Sparsity as MatrixSparsity>::Index,
Expand All @@ -241,6 +328,8 @@ pub trait Matrix:
/// Panics if the sparsity of self, x, and y do not match (i.e. sparsity of self must be the union of the sparsity of x and y)
fn scale_add_and_assign(&mut self, x: &Self, beta: Self::T, y: &Self);

fn triplet_iter(&self) -> impl Iterator<Item = (IndexType, IndexType, &Self::T)>;

/// Create a new matrix from a vector of triplets (i, j, value) where i and j are the row and column indices of the value
fn try_from_triplets(
nrows: IndexType,
Expand Down
11 changes: 11 additions & 0 deletions src/matrix/sparse_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ impl<T: Scalar> Matrix for CscMatrix<T> {
}
}

fn triplet_iter(&self) -> impl Iterator<Item = (IndexType, IndexType, &Self::T)> {
self.triplet_iter()
}

fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V) {
let col = self.col(j);
for (&i, &val) in col.row_indices().iter().zip(col.values().iter()) {
v[i] += val;
}
}

fn try_from_triplets(
nrows: IndexType,
ncols: IndexType,
Expand Down
13 changes: 13 additions & 0 deletions src/matrix/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,19 @@ impl Matrix for SundialsMatrix {
}
}

fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V) {
let n = self.nrows();
for i in 0..n {
v[i] += self[(i, j)];
}
}

fn triplet_iter(&self) -> impl Iterator<Item = (IndexType, IndexType, &Self::T)> {
let n = self.ncols();
let m = self.nrows();
(0..m).flat_map(move |i| (0..n).map(move |j| (i, j, &self[(i, j)])))
}

fn diagonal(&self) -> Self::V {
let n = min(self.nrows(), self.ncols());
let mut v = SundialsVector::new_serial(n);
Expand Down
Loading

0 comments on commit 99f7fa0

Please sign in to comment.