diff --git a/src/lib.rs b/src/lib.rs index 981e1f00..c12599e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,7 +137,7 @@ pub mod solver; pub mod vector; use linear_solver::LinearSolver; -pub use linear_solver::NalgebraLU; +pub use linear_solver::{FaerLU, NalgebraLU}; #[cfg(feature = "sundials")] pub use matrix::sundials::SundialsMatrix; @@ -181,7 +181,7 @@ mod tests { .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) - .build_ode( + .build_ode_dense( |x: &V, p: &V, _t: T, y: &mut V| { y[0] = -p[0] * x[0] + p[1] * x[1] * x[2]; y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1]; @@ -217,11 +217,12 @@ mod tests { fn test_readme_faer() { type T = f64; type V = faer::Col; + type M = faer::Mat; let problem = OdeBuilder::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) - .build_ode( + .build_ode_dense( |x: &V, p: &V, _t: T, y: &mut V| { y[0] = -p[0] * x[0] + p[1] * x[1] * x[2]; y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1]; @@ -239,7 +240,7 @@ mod tests { ) .unwrap(); - let mut solver = Bdf::default(); + let mut solver = Bdf::::default(); let t = 0.4; let y = solver.solve(&problem, t).unwrap(); diff --git a/src/linear_solver/sundials.rs b/src/linear_solver/sundials.rs index a964999d..175af9e9 100644 --- a/src/linear_solver/sundials.rs +++ b/src/linear_solver/sundials.rs @@ -23,6 +23,15 @@ where matrix: SundialsMatrix, } +impl Default for SundialsLinearSolver +where + Op: LinearOp, +{ + fn default() -> Self { + Self::new_dense() + } +} + impl SundialsLinearSolver where Op: LinearOp, diff --git a/src/matrix/default_solver.rs b/src/matrix/default_solver.rs new file mode 100644 index 00000000..3855ae2f --- /dev/null +++ b/src/matrix/default_solver.rs @@ -0,0 +1,10 @@ +use crate::{linear_solver::LinearSolver, op::LinearOp}; + +use super::Matrix; + +pub trait DefaultSolver: Matrix { + type LS>: LinearSolver + Default; + fn default_solver>() -> Self::LS { + Self::LS::default() + } +} diff --git a/src/matrix/dense_faer_serial.rs b/src/matrix/dense_faer_serial.rs index a450da26..fa816988 100644 --- a/src/matrix/dense_faer_serial.rs +++ b/src/matrix/dense_faer_serial.rs @@ -1,10 +1,16 @@ use std::ops::{Mul, MulAssign}; +use super::default_solver::DefaultSolver; use super::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut}; use crate::scalar::{IndexType, Scalar, Scale}; +use crate::{op::LinearOp, FaerLU}; use anyhow::Result; use faer::{linalg::matmul::matmul, Col, ColMut, ColRef, Mat, MatMut, MatRef, Parallelism}; +impl DefaultSolver for Mat { + type LS, V = Col, T = T>> = FaerLU; +} + macro_rules! impl_matrix_common { ($mat_type:ty) => { impl<'a, T: Scalar> MatrixCommon for $mat_type { diff --git a/src/matrix/dense_nalgebra_serial.rs b/src/matrix/dense_nalgebra_serial.rs index ddc881f2..029f37da 100644 --- a/src/matrix/dense_nalgebra_serial.rs +++ b/src/matrix/dense_nalgebra_serial.rs @@ -3,9 +3,16 @@ use std::ops::{Mul, MulAssign}; use anyhow::Result; use nalgebra::{DMatrix, DMatrixView, DMatrixViewMut, DVector, DVectorView, DVectorViewMut}; +use crate::op::LinearOp; use crate::{scalar::Scale, IndexType, Scalar}; -use crate::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut}; +use crate::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut, NalgebraLU}; + +use super::default_solver::DefaultSolver; + +impl DefaultSolver for DMatrix { + type LS, V = DVector, T = T>> = NalgebraLU; +} macro_rules! impl_matrix_common { ($matrix_type:ty) => { diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs index 280bb3cd..6f2fbc3a 100644 --- a/src/matrix/mod.rs +++ b/src/matrix/mod.rs @@ -12,7 +12,9 @@ mod dense_nalgebra_serial; #[cfg(feature = "faer")] mod dense_faer_serial; +pub mod default_solver; mod sparse_serial; + #[cfg(feature = "sundials")] pub mod sundials; diff --git a/src/matrix/sundials.rs b/src/matrix/sundials.rs index 54b8253a..e55fb3c3 100644 --- a/src/matrix/sundials.rs +++ b/src/matrix/sundials.rs @@ -12,12 +12,13 @@ use sundials_sys::{ use crate::{ ode_solver::sundials::sundials_check, + op::LinearOp, scalar::scale, vector::sundials::{get_suncontext, SundialsVector}, - IndexType, Scale, Vector, + IndexType, Scale, SundialsLinearSolver, Vector, }; -use super::{Matrix, MatrixCommon}; +use super::{default_solver::DefaultSolver, Matrix, MatrixCommon}; use anyhow::anyhow; #[derive(Debug)] @@ -79,6 +80,11 @@ impl Display for SundialsMatrix { } } +impl DefaultSolver for SundialsMatrix { + type LS> = + SundialsLinearSolver; +} + impl MatrixCommon for SundialsMatrix { type V = SundialsVector; type T = realtype; diff --git a/src/ode_solver/bdf/faer.rs b/src/ode_solver/bdf/faer.rs index 57beac46..83531a10 100644 --- a/src/ode_solver/bdf/faer.rs +++ b/src/ode_solver/bdf/faer.rs @@ -1,67 +1,3 @@ -use faer::{Col, Mat}; - -use crate::{ - linear_solver::FaerLU, op::bdf::BdfCallable, Bdf, NewtonNonlinearSolver, NonLinearSolver, - OdeEquations, Scalar, VectorRef, -}; - -impl, M = Mat> + 'static> Default - for Bdf, Eqn> -{ - fn default() -> Self { - let n = 1; - let linear_solver = FaerLU::default(); - let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( - linear_solver, - )); - nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); - Self { - ode_problem: None, - nonlinear_solver, - order: 1, - n_equal_steps: 0, - diff: Mat::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3), - gamma: vec![T::from(1.0); Self::MAX_ORDER + 1], - alpha: vec![T::from(1.0); Self::MAX_ORDER + 1], - error_const: vec![T::from(1.0); Self::MAX_ORDER + 1], - u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), - statistics: super::BdfStatistics::default(), - state: None, - } - } -} - -// implement clone for bdf -impl, M = Mat> + 'static> Clone - for Bdf, Eqn> -where - for<'b> &'b Col: VectorRef>, -{ - fn clone(&self) -> Self { - let n = self.diff.nrows(); - let linear_solver = FaerLU::default(); - let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( - linear_solver, - )); - nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); - Self { - ode_problem: self.ode_problem.clone(), - nonlinear_solver, - order: self.order, - n_equal_steps: self.n_equal_steps, - diff: Mat::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3), - gamma: self.gamma.clone(), - alpha: self.alpha.clone(), - error_const: self.error_const.clone(), - u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), - statistics: self.statistics.clone(), - state: self.state.clone(), - } - } -} - #[cfg(test)] mod test { use crate::{ @@ -72,14 +8,14 @@ mod test { type M = faer::Mat; #[test] fn bdf_no_set_problem() { - test_no_set_problem::(Bdf::default()) + test_no_set_problem::(Bdf::::default()) } #[test] fn bdf_take_state() { - test_take_state::(Bdf::default()) + test_take_state::(Bdf::::default()) } #[test] fn bdf_test_interpolate() { - test_interpolate::(Bdf::default()) + test_interpolate::(Bdf::::default()) } } diff --git a/src/ode_solver/bdf/mod.rs b/src/ode_solver/bdf/mod.rs index f41aac11..0e287fca 100644 --- a/src/ode_solver/bdf/mod.rs +++ b/src/ode_solver/bdf/mod.rs @@ -7,9 +7,13 @@ use num_traits::{One, Pow, Zero}; use serde::Serialize; use crate::{ - matrix::MatrixRef, op::bdf::BdfCallable, scalar::scale, DenseMatrix, IndexType, MatrixViewMut, - NonLinearSolver, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Scalar, SolverProblem, - Vector, VectorRef, VectorView, VectorViewMut, + matrix::{default_solver::DefaultSolver, Matrix, MatrixRef}, + op::bdf::BdfCallable, + scalar::scale, + vector::DefaultDenseMatrix, + DenseMatrix, IndexType, MatrixViewMut, NewtonNonlinearSolver, NonLinearSolver, OdeSolverMethod, + OdeSolverProblem, OdeSolverState, Scalar, SolverProblem, Vector, VectorRef, VectorView, + VectorViewMut, }; pub mod faer; @@ -69,32 +73,73 @@ pub struct Bdf, Eqn: OdeEquations> { gamma: Vec, error_const: Vec, statistics: BdfStatistics, - state: Option>, + state: Option>, } -// impl, M = faer::Mat>> Default for Bdf, Eqn> { -// fn default() -> Self { -// let n = 1; -// let linear_solver = LU::default(); -// let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( -// linear_solver, -// )); -// nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); -// Self { -// ode_problem: None, -// nonlinear_solver, -// order: 1, -// n_equal_steps: 0, -// diff: Mat::zeros(n, Self::MAX_ORDER + 3), //DMatrix::::zeros(n, Self::MAX_ORDER + 3), -// diff_tmp: Mat::zeros(n, Self::MAX_ORDER + 3), -// gamma: vec![f64::from(1.0); Self::MAX_ORDER + 1], -// alpha: vec![f64::from(1.0); Self::MAX_ORDER + 1], -// error_const: vec![f64::from(1.0); Self::MAX_ORDER + 1], -// u: Mat::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), -// statistics: BdfStatistics::default(), -// } -// } -// } +impl Default for Bdf<::M, Eqn> +where + Eqn: OdeEquations + 'static, + Eqn::M: DefaultSolver, + Eqn::V: DefaultDenseMatrix, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + fn default() -> Self { + let n = 1; + let linear_solver = Eqn::M::default_solver(); + let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( + linear_solver, + )); + nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); + type M = ::M; + Self { + ode_problem: None, + nonlinear_solver, + order: 1, + n_equal_steps: 0, + diff: as Matrix>::zeros(n, Self::MAX_ORDER + 3), //DMatrix::::zeros(n, Self::MAX_ORDER + 3), + diff_tmp: as Matrix>::zeros(n, Self::MAX_ORDER + 3), + gamma: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], + alpha: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], + error_const: vec![Eqn::T::from(1.0); Self::MAX_ORDER + 1], + u: as Matrix>::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), + statistics: BdfStatistics::default(), + state: None, + } + } +} + +impl Clone for Bdf +where + M: DenseMatrix, + Eqn: OdeEquations + 'static, + Eqn::M: DefaultSolver, + for<'b> &'b Eqn::V: VectorRef, + for<'b> &'b Eqn::M: MatrixRef, +{ + fn clone(&self) -> Self { + let n = self.diff.nrows(); + let linear_solver = Eqn::M::default_solver(); + let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( + linear_solver, + )); + nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); + Self { + ode_problem: self.ode_problem.clone(), + nonlinear_solver, + order: self.order, + n_equal_steps: self.n_equal_steps, + diff: M::zeros(n, Self::MAX_ORDER + 3), + diff_tmp: M::zeros(n, Self::MAX_ORDER + 3), + gamma: self.gamma.clone(), + alpha: self.alpha.clone(), + error_const: self.error_const.clone(), + u: M::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), + statistics: self.statistics.clone(), + state: self.state.clone(), + } + } +} impl, Eqn: OdeEquations> Bdf where @@ -261,15 +306,15 @@ where self.ode_problem.as_ref() } - fn state(&self) -> Option<&OdeSolverState> { + fn state(&self) -> Option<&OdeSolverState> { self.state.as_ref() } - fn take_state(&mut self) -> Option::M>> { + fn take_state(&mut self) -> Option::V>> { Option::take(&mut self.state) } - fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { + fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { let mut state = state; self.ode_problem = Some(problem.clone()); let nstates = problem.eqn.nstates(); diff --git a/src/ode_solver/bdf/nalgebra.rs b/src/ode_solver/bdf/nalgebra.rs index 32647597..ed53a9b1 100644 --- a/src/ode_solver/bdf/nalgebra.rs +++ b/src/ode_solver/bdf/nalgebra.rs @@ -1,67 +1,3 @@ -use crate::{ - linear_solver::NalgebraLU, op::bdf::BdfCallable, Bdf, NewtonNonlinearSolver, NonLinearSolver, - OdeEquations, Scalar, VectorRef, -}; - -use nalgebra::{DMatrix, DVector}; - -impl, M = DMatrix> + 'static> Default - for Bdf, Eqn> -{ - fn default() -> Self { - let n = 1; - let linear_solver = NalgebraLU::default(); - let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( - linear_solver, - )); - nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); - Self { - ode_problem: None, - nonlinear_solver, - order: 1, - n_equal_steps: 0, - diff: DMatrix::::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: DMatrix::::zeros(n, Self::MAX_ORDER + 3), - gamma: vec![T::from(1.0); Self::MAX_ORDER + 1], - alpha: vec![T::from(1.0); Self::MAX_ORDER + 1], - error_const: vec![T::from(1.0); Self::MAX_ORDER + 1], - u: DMatrix::::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), - statistics: super::BdfStatistics::default(), - state: None, - } - } -} - -// implement clone for bdf -impl, M = DMatrix> + 'static> Clone - for Bdf, Eqn> -where - for<'b> &'b DVector: VectorRef>, -{ - fn clone(&self) -> Self { - let n = self.diff.nrows(); - let linear_solver = NalgebraLU::default(); - let mut nonlinear_solver = Box::new(NewtonNonlinearSolver::>::new( - linear_solver, - )); - nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER); - Self { - ode_problem: self.ode_problem.clone(), - nonlinear_solver, - order: self.order, - n_equal_steps: self.n_equal_steps, - diff: DMatrix::zeros(n, Self::MAX_ORDER + 3), - diff_tmp: DMatrix::zeros(n, Self::MAX_ORDER + 3), - gamma: self.gamma.clone(), - alpha: self.alpha.clone(), - error_const: self.error_const.clone(), - u: DMatrix::zeros(Self::MAX_ORDER + 1, Self::MAX_ORDER + 1), - statistics: self.statistics.clone(), - state: self.state.clone(), - } - } -} - #[cfg(test)] mod test { use crate::{ diff --git a/src/ode_solver/builder.rs b/src/ode_solver/builder.rs index 89855260..775405d2 100644 --- a/src/ode_solver/builder.rs +++ b/src/ode_solver/builder.rs @@ -1,4 +1,4 @@ -use crate::{op::Op, Matrix, OdeSolverProblem, Vector}; +use crate::{op::Op, vector::DefaultDenseMatrix, Matrix, OdeSolverProblem, Vector}; use anyhow::Result; use super::equations::{OdeSolverEquations, OdeSolverEquationsMassI}; @@ -245,6 +245,23 @@ impl OdeBuilder { )) } + /// Build an ODE problem using the default dense matrix (see [Self::build_ode]). + #[allow(clippy::type_complexity)] + pub fn build_ode_dense( + self, + rhs: F, + rhs_jac: G, + init: I, + ) -> Result>> + where + V: Vector + DefaultDenseMatrix, + F: Fn(&V, &V, V::T, &mut V), + G: Fn(&V, &V, V::T, &V, &mut V), + I: Fn(&V, V::T) -> V, + { + self.build_ode(rhs, rhs_jac, init) + } + /// Build an ODE problem using the DiffSL language (requires the `diffsl` feature). /// The source code is provided as a string, please see the [DiffSL documentation](https://martinjrobins.github.io/diffsl/) for more information. #[cfg(feature = "diffsl")] diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index d9946cf8..d8af5603 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use crate::{ op::{filter::FilterCallable, ode_rhs::OdeRhs}, - Matrix, NonLinearSolver, OdeEquations, OdeSolverProblem, SolverProblem, Vector, VectorIndex, + NonLinearSolver, OdeEquations, OdeSolverProblem, SolverProblem, Vector, VectorIndex, }; /// Trait for ODE solver methods. This is the main user interface for the ODE solvers. @@ -31,7 +31,7 @@ pub trait OdeSolverMethod { /// Set the problem to solve, this performs any initialisation required by the solver. Call this before calling `step` or `solve`. /// The solver takes ownership of the initial state given by `state`, this is assumed to be consistent with any algebraic constraints. - fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem); + fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem); /// Step the solution forward by one step, altering the internal state of the solver. fn step(&mut self) -> Result<()>; @@ -40,12 +40,12 @@ pub trait OdeSolverMethod { fn interpolate(&self, t: Eqn::T) -> Result; /// Get the current state of the solver, if it exists - fn state(&self) -> Option<&OdeSolverState>; + fn state(&self) -> Option<&OdeSolverState>; /// Take the current state of the solver, if it exists, returning it to the user. This is useful if you want to use this /// state in another solver or problem. Note that this will unset the current problem and solver state, so you will need to call /// `set_problem` again before calling `step` or `solve`. - fn take_state(&mut self) -> Option>; + fn take_state(&mut self) -> Option>; /// Reinitialise the solver state and solve the problem up to time `t` fn solve(&mut self, problem: &OdeSolverProblem, t: Eqn::T) -> Result { @@ -75,29 +75,23 @@ pub trait OdeSolverMethod { /// State for the ODE solver, containing the current solution `y`, the current time `t`, and the current step size `h`. #[derive(Clone)] -pub struct OdeSolverState { - pub y: M::V, - pub t: M::T, - pub h: M::T, - _phantom: std::marker::PhantomData, +pub struct OdeSolverState { + pub y: V, + pub t: V::T, + pub h: V::T, } -impl OdeSolverState { +impl OdeSolverState { /// Create a new solver state from an ODE problem. Note that this does not make the state consistent with the algebraic constraints. /// If you need to make the state consistent, use `new_consistent` instead. pub fn new(ode_problem: &OdeSolverProblem) -> Self where - Eqn: OdeEquations, + Eqn: OdeEquations, { let t = ode_problem.t0; let h = ode_problem.h0; let y = ode_problem.eqn.init(t); - Self { - y, - t, - h, - _phantom: std::marker::PhantomData, - } + Self { y, t, h } } /// Create a new solver state from an ODE problem, making the state consistent with the algebraic constraints. @@ -106,7 +100,7 @@ impl OdeSolverState { root_solver: &mut S, ) -> Result where - Eqn: OdeEquations, + Eqn: OdeEquations, S: NonLinearSolver>> + ?Sized, { let t = ode_problem.t0; @@ -114,12 +108,7 @@ impl OdeSolverState { let indices = ode_problem.eqn.algebraic_indices(); let mut y = ode_problem.eqn.init(t); if indices.len() == 0 { - return Ok(Self { - y, - t, - h, - _phantom: std::marker::PhantomData, - }); + return Ok(Self { y, t, h }); } let mut y_filtered = y.filter(&indices); let atol = Rc::new(ode_problem.atol.as_ref().filter(&indices)); @@ -132,11 +121,6 @@ impl OdeSolverState { let init_problem = root_solver.problem().unwrap(); let indices = init_problem.f.indices(); y.scatter_from(&y_filtered, indices); - Ok(Self { - y, - t, - h, - _phantom: std::marker::PhantomData, - }) + Ok(Self { y, t, h }) } } diff --git a/src/ode_solver/sdirk.rs b/src/ode_solver/sdirk.rs index 56f2dd66..533ec7c9 100644 --- a/src/ode_solver/sdirk.rs +++ b/src/ode_solver/sdirk.rs @@ -36,7 +36,7 @@ where tableau: Tableau, problem: Option>, nonlinear_solver: NewtonNonlinearSolver>, - state: Option>, + state: Option>, diff: M, gamma: Eqn::T, is_sdirk: bool, @@ -175,7 +175,7 @@ where fn set_problem( &mut self, - mut state: OdeSolverState<::M>, + mut state: OdeSolverState<::V>, problem: &OdeSolverProblem, ) { // update initial step size based on function @@ -490,11 +490,11 @@ where } } - fn state(&self) -> Option<&OdeSolverState<::M>> { + fn state(&self) -> Option<&OdeSolverState<::V>> { self.state.as_ref() } - fn take_state(&mut self) -> Option::M>> { + fn take_state(&mut self) -> Option::V>> { Option::take(&mut self.state) } } diff --git a/src/ode_solver/sundials.rs b/src/ode_solver/sundials.rs index 2aad73ce..8f7f43e7 100644 --- a/src/ode_solver/sundials.rs +++ b/src/ode_solver/sundials.rs @@ -123,7 +123,7 @@ where yp: SundialsVector, jacobian: SundialsMatrix, statistics: SundialsStatistics, - state: Option>, + state: Option>, } impl SundialsIda @@ -251,15 +251,15 @@ where self.problem.as_ref() } - fn state(&self) -> Option<&OdeSolverState<::M>> { + fn state(&self) -> Option<&OdeSolverState<::V>> { self.state.as_ref() } - fn take_state(&mut self) -> Option::M>> { + fn take_state(&mut self) -> Option::V>> { Option::take(&mut self.state) } - fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { + fn set_problem(&mut self, state: OdeSolverState, problem: &OdeSolverProblem) { self.state = Some(state); let state = self.state.as_ref().unwrap(); self.problem = Some(problem.clone()); diff --git a/src/vector/faer_serial.rs b/src/vector/faer_serial.rs index 6739c1b6..b5d00b23 100644 --- a/src/vector/faer_serial.rs +++ b/src/vector/faer_serial.rs @@ -1,11 +1,17 @@ use std::ops::{Div, Mul, MulAssign}; -use faer::{unzipped, zipped, Col, ColMut, ColRef}; +use faer::{unzipped, zipped, Col, ColMut, ColRef, Mat}; use crate::{scalar::Scale, IndexType, Scalar}; use crate::{Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut}; +use super::DefaultDenseMatrix; + +impl DefaultDenseMatrix for Col { + type M = Mat; +} + macro_rules! impl_op_for_faer_struct { ($struct:ident, $trait_name:ident, $func_name:ident) => { impl<'a, T: Scalar> $trait_name> for $struct<'a, T> { diff --git a/src/vector/mod.rs b/src/vector/mod.rs index a2a09b6e..51219703 100644 --- a/src/vector/mod.rs +++ b/src/vector/mod.rs @@ -1,3 +1,4 @@ +use crate::matrix::DenseMatrix; use crate::scalar::Scale; use crate::{IndexType, Scalar}; use num_traits::Zero; @@ -206,3 +207,7 @@ pub trait Vector: } } } + +pub trait DefaultDenseMatrix: Vector { + type M: DenseMatrix; +} diff --git a/src/vector/nalgebra_serial.rs b/src/vector/nalgebra_serial.rs index de5de1ca..681fd78c 100644 --- a/src/vector/nalgebra_serial.rs +++ b/src/vector/nalgebra_serial.rs @@ -1,10 +1,14 @@ use std::ops::{Div, Mul, MulAssign}; -use nalgebra::{DVector, DVectorView, DVectorViewMut}; +use nalgebra::{DMatrix, DVector, DVectorView, DVectorViewMut}; use crate::{IndexType, Scalar, Scale}; -use super::{Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut}; +use super::{DefaultDenseMatrix, Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut}; + +impl DefaultDenseMatrix for DVector { + type M = DMatrix; +} macro_rules! impl_op_for_dvector_struct { ($struct:ident, $trait_name:ident, $func_name:ident) => {