Skip to content

Commit

Permalink
feat: mass matrix optional (#54)
Browse files Browse the repository at this point in the history
* update diffsl to detect if mass matrix is provided, start to make mass optional

* finish setting mass optional in solvers, fix tests

* cargo fmt
  • Loading branch information
martinjrobins authored May 15, 2024
1 parent 2c78ba0 commit 3e89fb2
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 74 deletions.
28 changes: 14 additions & 14 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,20 @@ anyhow = ">=1.0.77"
num-traits = "0.2.17"
ouroboros = "0.18.2"
serde = { version = "1.0.196", features = ["derive"] }
diffsl4-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm4-0"], optional = true }
diffsl5-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm5-0"], optional = true }
diffsl6-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm6-0"], optional = true }
diffsl7-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm7-0"], optional = true }
diffsl8-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm8-0"], optional = true }
diffsl9-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm9-0"], optional = true }
diffsl10-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm10-0"], optional = true }
diffsl11-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm11-0"], optional = true }
diffsl12-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm12-0"], optional = true }
diffsl13-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm13-0"], optional = true }
diffsl14-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm14-0"], optional = true }
diffsl15-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm15-0"], optional = true }
diffsl16-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm16-0"], optional = true }
diffsl17-0 = { package = "diffsl", version = ">=0.1.2", features = ["llvm17-0"], optional = true }
diffsl4-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm4-0"], optional = true }
diffsl5-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm5-0"], optional = true }
diffsl6-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm6-0"], optional = true }
diffsl7-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm7-0"], optional = true }
diffsl8-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm8-0"], optional = true }
diffsl9-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm9-0"], optional = true }
diffsl10-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm10-0"], optional = true }
diffsl11-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm11-0"], optional = true }
diffsl12-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm12-0"], optional = true }
diffsl13-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm13-0"], optional = true }
diffsl14-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm14-0"], optional = true }
diffsl15-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm15-0"], optional = true }
diffsl16-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm16-0"], optional = true }
diffsl17-0 = { package = "diffsl", version = ">=0.1.4", features = ["llvm17-0"], optional = true }
petgraph = "0.6.4"
faer = "0.18.2"
sundials-sys = { version = "0.4.0", features = ["ida", "static_libraries"], optional = true }
Expand Down
8 changes: 3 additions & 5 deletions src/ode_solver/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ impl OdeBuilder {
rhs.calculate_sparsity(&y0, t0);
mass.calculate_sparsity(t0);
}
let mass = Rc::new(mass);
let mass = Some(Rc::new(mass));
let rhs = Rc::new(rhs);
let eqn = OdeSolverEquations::new(rhs, mass, None, init, p, self.constant_mass);
let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?;
Expand Down Expand Up @@ -288,12 +288,11 @@ impl OdeBuilder {
let y0 = init(&p, t0);
let nstates = y0.len();
let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone());
let mass = Rc::new(UnitCallable::new(nstates));
if self.use_coloring {
rhs.calculate_sparsity(&y0, t0);
}
let rhs = Rc::new(rhs);
let eqn = OdeSolverEquations::new(rhs, mass, None, init, p, self.use_coloring);
let eqn = OdeSolverEquations::new(rhs, None, None, init, p, self.use_coloring);
let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?;
Ok(OdeSolverProblem::new(
eqn,
Expand Down Expand Up @@ -366,13 +365,12 @@ impl OdeBuilder {
let y0 = init(&p, t0);
let nstates = y0.len();
let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone());
let mass = Rc::new(UnitCallable::new(nstates));
let root = Rc::new(ClosureNoJac::new(root, nstates, nroots, p.clone()));
if self.use_coloring {
rhs.calculate_sparsity(&y0, t0);
}
let rhs = Rc::new(rhs);
let eqn = OdeSolverEquations::new(rhs, mass, Some(root), init, p, self.use_coloring);
let eqn = OdeSolverEquations::new(rhs, None, Some(root), init, p, self.use_coloring);
let atol = Self::build_atol(self.atol, eqn.rhs().nstates())?;
Ok(OdeSolverProblem::new(
eqn,
Expand Down
17 changes: 10 additions & 7 deletions src/ode_solver/diffsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ impl DiffSlContext {
pub struct DiffSl<'a> {
context: &'a DiffSlContext,
rhs: Rc<DiffSlRhs<'a>>,
mass: Rc<DiffSlMass<'a>>,
mass: Option<Rc<DiffSlMass<'a>>>,
root: Rc<DiffSlRoot<'a>>,
}

impl<'a> DiffSl<'a> {
pub fn new(context: &'a DiffSlContext, use_coloring: bool) -> Self {
let rhs = Rc::new(DiffSlRhs::new(context, use_coloring));
let mass = Rc::new(DiffSlMass::new(context, use_coloring));
let mass = DiffSlMass::new(context, use_coloring).map(Rc::new);
let root = Rc::new(DiffSlRoot::new(context));
Self {
context,
Expand Down Expand Up @@ -140,7 +140,10 @@ impl<'a> DiffSlRhs<'a> {
}

impl<'a> DiffSlMass<'a> {
pub fn new(context: &'a DiffSlContext, use_coloring: bool) -> Self {
pub fn new(context: &'a DiffSlContext, use_coloring: bool) -> Option<Self> {
if !context.compiler.has_mass() {
return None;
}
let mut ret = Self {
context,
coloring: None,
Expand All @@ -151,7 +154,7 @@ impl<'a> DiffSlMass<'a> {
let non_zeros = find_non_zeros_linear(&ret, t0);
ret.coloring = Some(JacobianColoring::new_from_non_zeros(&ret, non_zeros));
}
ret
Some(ret)
}
}

Expand Down Expand Up @@ -276,8 +279,8 @@ impl<'a> OdeEquations for DiffSl<'a> {
&self.rhs
}

fn mass(&self) -> &Rc<Self::Mass> {
&self.mass
fn mass(&self) -> Option<&Rc<Self::Mass>> {
self.mass.as_ref()
}

fn root(&self) -> Option<&Rc<Self::Root>> {
Expand Down Expand Up @@ -374,7 +377,7 @@ mod tests {
rhs_jac.assert_eq_st(&rhs_jac_expect, 1e-10);
let mut mass_y = DVector::from_vec(vec![0.0, 0.0]);
let v = DVector::from_vec(vec![1.0, 1.0]);
eqn.mass().call_inplace(&v, 0.0, &mut mass_y);
eqn.mass().unwrap().call_inplace(&v, 0.0, &mut mass_y);
let mass_y_expect = DVector::from_vec(vec![1.0, 0.0]);
mass_y.assert_eq_st(&mass_y_expect, 1e-10);

Expand Down
28 changes: 12 additions & 16 deletions src/ode_solver/equations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub trait OdeEquations {
fn rhs(&self) -> &Rc<Self::Rhs>;

/// returns the mass matrix `M` as a [LinearOp]
fn mass(&self) -> &Rc<Self::Mass>;
fn mass(&self) -> Option<&Rc<Self::Mass>>;

fn root(&self) -> Option<&Rc<Self::Root>> {
None
Expand Down Expand Up @@ -110,8 +110,8 @@ pub trait OdeEquations {
///
/// let rhs = Rc::new(MyProblem);
///
/// // we don't have a mass matrix, so we can use a unit operator which does nothing
/// let mass = Rc::new(UnitCallable::new(1));
/// // we don't have a mass matrix or root function, so we can set to None
/// let mass: Option<Rc<UnitCallable<M>>> = None;
/// let root: Option<Rc<UnitCallable<M>>> = None;
/// let init = |p: &V, _t: f64| V::from_vec(vec![1.0]);
/// let p = Rc::new(V::from_vec(vec![]));
Expand Down Expand Up @@ -143,7 +143,7 @@ where
I: Fn(&M::V, M::T) -> M::V,
{
rhs: Rc<Rhs>,
mass: Rc<Mass>,
mass: Option<Rc<Mass>>,
root: Option<Rc<Root>>,
init: I,
p: Rc<M::V>,
Expand All @@ -161,7 +161,7 @@ where
#[allow(clippy::too_many_arguments)]
pub fn new(
rhs: Rc<Rhs>,
mass: Rc<Mass>,
mass: Option<Rc<Mass>>,
root: Option<Rc<Root>>,
init: I,
p: Rc<M::V>,
Expand Down Expand Up @@ -196,8 +196,8 @@ where
fn rhs(&self) -> &Rc<Self::Rhs> {
&self.rhs
}
fn mass(&self) -> &Rc<Self::Mass> {
&self.mass
fn mass(&self) -> Option<&Rc<Self::Mass>> {
self.mass.as_ref()
}
fn root(&self) -> Option<&Rc<Self::Root>> {
self.root.as_ref()
Expand All @@ -215,9 +215,9 @@ where
Rc::<Rhs>::get_mut(&mut self.rhs)
.unwrap()
.set_params(self.p.clone());
Rc::<Mass>::get_mut(&mut self.mass)
.unwrap()
.set_params(self.p.clone());
if let Some(m) = self.mass.as_mut() {
Rc::<Mass>::get_mut(m).unwrap().set_params(self.p.clone());
}
if let Some(r) = self.root.as_mut() {
Rc::<Root>::get_mut(r).unwrap().set_params(self.p.clone())
}
Expand Down Expand Up @@ -248,11 +248,7 @@ mod tests {
let jac_rhs_y = problem.eqn.rhs().jac_mul(&y, 0.0, &y);
let expect_jac_rhs_y = Vcpu::from_vec(vec![-0.1, -0.1]);
jac_rhs_y.assert_eq_st(&expect_jac_rhs_y, 1e-10);
let mass = problem.eqn.mass().matrix(0.0);
assert_eq!(mass[(0, 0)], 1.0);
assert_eq!(mass[(1, 1)], 1.0);
assert_eq!(mass[(0, 1)], 0.);
assert_eq!(mass[(1, 0)], 0.);
assert!(problem.eqn.mass().is_none());
let jac = problem.eqn.rhs().jacobian(&y, 0.0);
assert_eq!(jac[(0, 0)], -0.1);
assert_eq!(jac[(1, 1)], -0.1);
Expand All @@ -270,7 +266,7 @@ mod tests {
let jac_rhs_y = problem.eqn.rhs().jac_mul(&y, 0.0, &y);
let expect_jac_rhs_y = Vcpu::from_vec(vec![-0.1, -0.1, 0.0]);
jac_rhs_y.assert_eq_st(&expect_jac_rhs_y, 1e-10);
let mass = problem.eqn.mass().matrix(0.0);
let mass = problem.eqn.mass().unwrap().matrix(0.0);
assert_eq!(mass[(0, 0)], 1.);
assert_eq!(mass[(1, 1)], 1.);
assert_eq!(mass[(2, 2)], 0.);
Expand Down
5 changes: 4 additions & 1 deletion src/ode_solver/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ impl<V: Vector> OdeSolverState<V> {
Eqn: OdeEquations<T = V::T, V = V>,
S: NonLinearSolver<FilterCallable<Eqn::Rhs>> + ?Sized,
{
let mass_diagonal = ode_problem.eqn.mass().matrix(self.t).diagonal();
if ode_problem.eqn.mass().is_none() {
return Ok(());
}
let mass_diagonal = ode_problem.eqn.mass().unwrap().matrix(self.t).diagonal();
let indices = mass_diagonal.filter_indices(|x| x == Eqn::T::zero());
if indices.len() == 0 {
return Ok(());
Expand Down
6 changes: 2 additions & 4 deletions src/ode_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ mod tests {

pub struct TestEqn<M: Matrix> {
rhs: Rc<TestEqnRhs<M>>,
mass: Rc<UnitCallable<M>>,
}

impl<M: Matrix> TestEqn<M> {
Expand All @@ -139,7 +138,6 @@ mod tests {
rhs: Rc::new(TestEqnRhs {
_m: std::marker::PhantomData,
}),
mass: Rc::new(UnitCallable::new(1)),
}
}
}
Expand All @@ -158,8 +156,8 @@ mod tests {
&self.rhs
}

fn mass(&self) -> &Rc<Self::Mass> {
&self.mass
fn mass(&self) -> Option<&Rc<Self::Mass>> {
None
}

fn root(&self) -> Option<&Rc<Self::Root>> {
Expand Down
23 changes: 19 additions & 4 deletions src/ode_solver/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,13 @@ where
let rhs = eqn.rhs();
let rhs_jac_sparsity = rhs.sparsity();
let rhs_jac = SundialsMatrix::new_from_sparsity(n, n, rhs_jac_sparsity);
let mass_sparsity = eqn.mass().sparsity();
let mass = SundialsMatrix::new_from_sparsity(n, n, mass_sparsity);
let mass = if let Some(mass) = eqn.mass() {
let mass_sparsity = mass.sparsity();
SundialsMatrix::new_from_sparsity(n, n, mass_sparsity)
} else {
let ones = SundialsVector::from_element(n, 1.0);
SundialsMatrix::from_diagonal(&ones)
};
Self { eqn, rhs_jac, mass }
}
}
Expand Down Expand Up @@ -157,7 +162,11 @@ where
// rr = f(t, y)
data.eqn.rhs().call_inplace(&y, t, &mut rr);
// rr = M y' - rr
data.eqn.mass().gemv_inplace(&yp, t, -1.0, &mut rr);
if let Some(mass) = data.eqn.mass() {
mass.gemv_inplace(&yp, t, -1.0, &mut rr);
} else {
rr.axpy(1.0, &yp, -1.0);
}
0
}

Expand All @@ -179,7 +188,9 @@ where
// jac = c_j * M - rhs_jac
let y = SundialsVector::new_not_owned(y);
let mut jac = SundialsMatrix::new_not_owned(jac);
eqn.mass().matrix_inplace(t, &mut data.mass);
if let Some(mass) = eqn.mass() {
mass.matrix_inplace(t, &mut data.mass);
}
eqn.rhs().jacobian_inplace(&y, t, &mut data.rhs_jac);
data.rhs_jac *= scale(-1.0);
jac.scale_add_and_assign(&data.rhs_jac, c_j, &data.mass);
Expand Down Expand Up @@ -217,12 +228,16 @@ where
if self.problem.is_none() {
return Err(anyhow!("Problem not set"));
}
if self.problem.as_ref().unwrap().eqn.mass().is_none() {
return Ok(());
}
let diag = self
.problem
.as_ref()
.unwrap()
.eqn
.mass()
.unwrap()
.matrix(t)
.diagonal();
let id = diag.filter_indices(|x| x == Eqn::T::zero());
Expand Down
Loading

0 comments on commit 3e89fb2

Please sign in to comment.