Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I32-clone #36

Merged
merged 2 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ pub mod solver;
pub mod vector;

use linear_solver::LinearSolver;
pub use linear_solver::NalgebraLU;
pub use linear_solver::{FaerLU, NalgebraLU};

#[cfg(feature = "sundials")]
pub use matrix::sundials::SundialsMatrix;
Expand Down Expand Up @@ -181,7 +181,7 @@ mod tests {
.p([0.04, 1.0e4, 3.0e7])
.rtol(1e-4)
.atol([1.0e-8, 1.0e-6, 1.0e-6])
.build_ode(
.build_ode_dense(
|x: &V, p: &V, _t: T, y: &mut V| {
y[0] = -p[0] * x[0] + p[1] * x[1] * x[2];
y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1];
Expand Down Expand Up @@ -217,11 +217,12 @@ mod tests {
fn test_readme_faer() {
type T = f64;
type V = faer::Col<f64>;
type M = faer::Mat<f64>;
let problem = OdeBuilder::new()
.p([0.04, 1.0e4, 3.0e7])
.rtol(1e-4)
.atol([1.0e-8, 1.0e-6, 1.0e-6])
.build_ode(
.build_ode_dense(
|x: &V, p: &V, _t: T, y: &mut V| {
y[0] = -p[0] * x[0] + p[1] * x[1] * x[2];
y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1];
Expand All @@ -239,7 +240,7 @@ mod tests {
)
.unwrap();

let mut solver = Bdf::default();
let mut solver = Bdf::<M, _>::default();

let t = 0.4;
let y = solver.solve(&problem, t).unwrap();
Expand Down
9 changes: 9 additions & 0 deletions src/linear_solver/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ where
matrix: SundialsMatrix,
}

impl<Op> Default for SundialsLinearSolver<Op>
where
Op: LinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
fn default() -> Self {
Self::new_dense()
}
}

impl<Op> SundialsLinearSolver<Op>
where
Op: LinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
Expand Down
10 changes: 10 additions & 0 deletions src/matrix/default_solver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use crate::{linear_solver::LinearSolver, op::LinearOp};

use super::Matrix;

pub trait DefaultSolver: Matrix {
type LS<C: LinearOp<M = Self, V = Self::V, T = Self::T>>: LinearSolver<C> + Default;
fn default_solver<C: LinearOp<M = Self, V = Self::V, T = Self::T>>() -> Self::LS<C> {
Self::LS::default()
}
}
6 changes: 6 additions & 0 deletions src/matrix/dense_faer_serial.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use std::ops::{Mul, MulAssign};

use super::default_solver::DefaultSolver;
use super::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut};
use crate::scalar::{IndexType, Scalar, Scale};
use crate::{op::LinearOp, FaerLU};
use anyhow::Result;
use faer::{linalg::matmul::matmul, Col, ColMut, ColRef, Mat, MatMut, MatRef, Parallelism};

impl<T: Scalar> DefaultSolver for Mat<T> {
type LS<C: LinearOp<M = Mat<T>, V = Col<T>, T = T>> = FaerLU<T, C>;
}

macro_rules! impl_matrix_common {
($mat_type:ty) => {
impl<'a, T: Scalar> MatrixCommon for $mat_type {
Expand Down
9 changes: 8 additions & 1 deletion src/matrix/dense_nalgebra_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@ use std::ops::{Mul, MulAssign};
use anyhow::Result;
use nalgebra::{DMatrix, DMatrixView, DMatrixViewMut, DVector, DVectorView, DVectorViewMut};

use crate::op::LinearOp;
use crate::{scalar::Scale, IndexType, Scalar};

use crate::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut};
use crate::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut, NalgebraLU};

use super::default_solver::DefaultSolver;

impl<T: Scalar> DefaultSolver for DMatrix<T> {
type LS<C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>> = NalgebraLU<T, C>;
}

macro_rules! impl_matrix_common {
($matrix_type:ty) => {
Expand Down
2 changes: 2 additions & 0 deletions src/matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ mod dense_nalgebra_serial;
#[cfg(feature = "faer")]
mod dense_faer_serial;

pub mod default_solver;
mod sparse_serial;

#[cfg(feature = "sundials")]
pub mod sundials;

Expand Down
10 changes: 8 additions & 2 deletions src/matrix/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ use sundials_sys::{

use crate::{
ode_solver::sundials::sundials_check,
op::LinearOp,
scalar::scale,
vector::sundials::{get_suncontext, SundialsVector},
IndexType, Scale, Vector,
IndexType, Scale, SundialsLinearSolver, Vector,
};

use super::{Matrix, MatrixCommon};
use super::{default_solver::DefaultSolver, Matrix, MatrixCommon};
use anyhow::anyhow;

#[derive(Debug)]
Expand Down Expand Up @@ -79,6 +80,11 @@ impl Display for SundialsMatrix {
}
}

impl DefaultSolver for SundialsMatrix {
type LS<C: LinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>> =
SundialsLinearSolver<C>;
}

impl MatrixCommon for SundialsMatrix {
type V = SundialsVector;
type T = realtype;
Expand Down
70 changes: 3 additions & 67 deletions src/ode_solver/bdf/faer.rs
Original file line number Diff line number Diff line change
@@ -1,67 +1,3 @@
use faer::{Col, Mat};

use crate::{
linear_solver::FaerLU, op::bdf::BdfCallable, Bdf, NewtonNonlinearSolver, NonLinearSolver,
OdeEquations, Scalar, VectorRef,
};

impl<T: Scalar, Eqn: OdeEquations<T = T, V = Col<T>, M = Mat<T>> + 'static> Default
for Bdf<Mat<T>, Eqn>
{
fn default() -> Self {
let n = 1;
let linear_solver = FaerLU::default();
let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::<BdfCallable<Eqn>>::new(
linear_solver,
));
nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER);
Self {
ode_problem: None,
nonlinear_solver,
order: 1,
n_equal_steps: 0,
diff: Mat::zeros(n, Self::MAX_ORDER + 3),
diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3),
gamma: vec![T::from(1.0); Self::MAX_ORDER + 1],
alpha: vec![T::from(1.0); Self::MAX_ORDER + 1],
error_const: vec![T::from(1.0); Self::MAX_ORDER + 1],
u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: super::BdfStatistics::default(),
state: None,
}
}
}

// implement clone for bdf
impl<T: Scalar, Eqn: OdeEquations<T = T, V = Col<T>, M = Mat<T>> + 'static> Clone
for Bdf<Mat<T>, Eqn>
where
for<'b> &'b Col<T>: VectorRef<Col<T>>,
{
fn clone(&self) -> Self {
let n = self.diff.nrows();
let linear_solver = FaerLU::default();
let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::<BdfCallable<Eqn>>::new(
linear_solver,
));
nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER);
Self {
ode_problem: self.ode_problem.clone(),
nonlinear_solver,
order: self.order,
n_equal_steps: self.n_equal_steps,
diff: Mat::zeros(n, Self::MAX_ORDER + 3),
diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3),
gamma: self.gamma.clone(),
alpha: self.alpha.clone(),
error_const: self.error_const.clone(),
u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: self.statistics.clone(),
state: self.state.clone(),
}
}
}

#[cfg(test)]
mod test {
use crate::{
Expand All @@ -72,14 +8,14 @@ mod test {
type M = faer::Mat<f64>;
#[test]
fn bdf_no_set_problem() {
test_no_set_problem::<M, _>(Bdf::default())
test_no_set_problem::<M, _>(Bdf::<M, _>::default())
}
#[test]
fn bdf_take_state() {
test_take_state::<M, _>(Bdf::default())
test_take_state::<M, _>(Bdf::<M, _>::default())
}
#[test]
fn bdf_test_interpolate() {
test_interpolate::<M, _>(Bdf::default())
test_interpolate::<M, _>(Bdf::<M, _>::default())
}
}
105 changes: 75 additions & 30 deletions src/ode_solver/bdf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ use num_traits::{One, Pow, Zero};
use serde::Serialize;

use crate::{
matrix::MatrixRef, op::bdf::BdfCallable, scalar::scale, DenseMatrix, IndexType, MatrixViewMut,
NonLinearSolver, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Scalar, SolverProblem,
Vector, VectorRef, VectorView, VectorViewMut,
matrix::{default_solver::DefaultSolver, Matrix, MatrixRef},
op::bdf::BdfCallable,
scalar::scale,
vector::DefaultDenseMatrix,
DenseMatrix, IndexType, MatrixViewMut, NewtonNonlinearSolver, NonLinearSolver, OdeSolverMethod,
OdeSolverProblem, OdeSolverState, Scalar, SolverProblem, Vector, VectorRef, VectorView,
VectorViewMut,
};

pub mod faer;
Expand Down Expand Up @@ -69,32 +73,73 @@ pub struct Bdf<M: DenseMatrix<T = Eqn::T, V = Eqn::V>, Eqn: OdeEquations> {
gamma: Vec<Eqn::T>,
error_const: Vec<Eqn::T>,
statistics: BdfStatistics<Eqn::T>,
state: Option<OdeSolverState<Eqn::M>>,
state: Option<OdeSolverState<Eqn::V>>,
}

// impl<Eqn: OdeEquations<T = f64, V = Col<f64>, M = faer::Mat<f64>>> Default for Bdf<Mat<f64>, Eqn> {
// fn default() -> Self {
// let n = 1;
// let linear_solver = LU::default();
// let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::<BdfCallable<Eqn>>::new(
// linear_solver,
// ));
// nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER);
// Self {
// ode_problem: None,
// nonlinear_solver,
// order: 1,
// n_equal_steps: 0,
// diff: Mat::zeros(n, Self::MAX_ORDER + 3), //DMatrix::<T>::zeros(n, Self::MAX_ORDER + 3),
// diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3),
// gamma: vec![f64::from(1.0); Self::MAX_ORDER + 1],
// alpha: vec![f64::from(1.0); Self::MAX_ORDER + 1],
// error_const: vec![f64::from(1.0); Self::MAX_ORDER + 1],
// u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
// statistics: BdfStatistics::default(),
// }
// }
// }
impl<Eqn> Default for Bdf<<Eqn::V as DefaultDenseMatrix>::M, Eqn>
where
Eqn: OdeEquations + 'static,
Eqn::M: DefaultSolver,
Eqn::V: DefaultDenseMatrix,
for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
{
fn default() -> Self {
let n = 1;
let linear_solver = Eqn::M::default_solver();
let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::<BdfCallable<Eqn>>::new(
linear_solver,
));
nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER);
type M<V> = <V as DefaultDenseMatrix>::M;
Self {
ode_problem: None,
nonlinear_solver,
order: 1,
n_equal_steps: 0,
diff: <M<Eqn::V> as Matrix>::zeros(n, Self::MAX_ORDER + 3), //DMatrix::<T>::zeros(n, Self::MAX_ORDER + 3),
diff_tmp: <M<Eqn::V> as Matrix>::zeros(n, Self::MAX_ORDER + 3),
gamma: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1],
alpha: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1],
error_const: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1],
u: <M<Eqn::V> as Matrix>::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: BdfStatistics::default(),
state: None,
}
}
}

impl<M, Eqn> Clone for Bdf<M, Eqn>
where
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
Eqn: OdeEquations + 'static,
Eqn::M: DefaultSolver,
for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
{
fn clone(&self) -> Self {
let n = self.diff.nrows();
let linear_solver = Eqn::M::default_solver();
let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::<BdfCallable<Eqn>>::new(
linear_solver,
));
nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER);
Self {
ode_problem: self.ode_problem.clone(),
nonlinear_solver,
order: self.order,
n_equal_steps: self.n_equal_steps,
diff: M::zeros(n, Self::MAX_ORDER + 3),
diff_tmp: M::zeros(n, Self::MAX_ORDER + 3),
gamma: self.gamma.clone(),
alpha: self.alpha.clone(),
error_const: self.error_const.clone(),
u: M::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1),
statistics: self.statistics.clone(),
state: self.state.clone(),
}
}
}

impl<M: DenseMatrix<T = Eqn::T, V = Eqn::V>, Eqn: OdeEquations> Bdf<M, Eqn>
where
Expand Down Expand Up @@ -261,15 +306,15 @@ where
self.ode_problem.as_ref()
}

fn state(&self) -> Option<&OdeSolverState<Eqn::M>> {
fn state(&self) -> Option<&OdeSolverState<Eqn::V>> {
self.state.as_ref()
}

fn take_state(&mut self) -> Option<OdeSolverState<<Eqn>::M>> {
fn take_state(&mut self) -> Option<OdeSolverState<<Eqn>::V>> {
Option::take(&mut self.state)
}

fn set_problem(&mut self, state: OdeSolverState<Eqn::M>, problem: &OdeSolverProblem<Eqn>) {
fn set_problem(&mut self, state: OdeSolverState<Eqn::V>, problem: &OdeSolverProblem<Eqn>) {
let mut state = state;
self.ode_problem = Some(problem.clone());
let nstates = problem.eqn.nstates();
Expand Down
Loading
Loading