Skip to content

Commit

Permalink
Add Newton's method for solving systems
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Sep 30, 2024
1 parent 7da76d0 commit b63b960
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 1 deletion.
31 changes: 31 additions & 0 deletions src/coefficient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,37 @@ impl<'a> TryFrom<AtomView<'a>> for Rational {
}
}

impl TryFrom<Atom> for Float {
type Error = &'static str;

fn try_from(value: Atom) -> Result<Self, Self::Error> {
value.as_view().try_into()
}
}

impl TryFrom<&Atom> for Float {
type Error = &'static str;

fn try_from(value: &Atom) -> Result<Self, Self::Error> {
value.as_view().try_into()
}
}

impl<'a> TryFrom<AtomView<'a>> for Float {
type Error = &'static str;

fn try_from(value: AtomView<'a>) -> Result<Self, Self::Error> {
if let AtomView::Num(n) = value {
match n.get_coeff_view() {
CoefficientView::Float(f) => Ok(f.to_float()),
_ => Err("Not a float"),
}
} else {
Err("Not a number")
}
}
}

impl Atom {
/// Set the coefficient ring to the multivariate rational polynomial with `vars` variables.
pub fn set_coefficient_ring(&self, vars: &Arc<Vec<Variable>>) -> Atom {
Expand Down
6 changes: 6 additions & 0 deletions src/domains/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,12 @@ impl Div<Rational> for Float {
}
}

impl From<f64> for Float {
fn from(value: f64) -> Self {
Float::with_val(53, value)
}
}

impl Float {
pub fn new(prec: u32) -> Self {
Float(MultiPrecisionFloat::new(prec))
Expand Down
154 changes: 153 additions & 1 deletion src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,138 @@ use std::{ops::Neg, sync::Arc};
use crate::{
atom::{Atom, AtomView, Symbol},
domains::{
atom::AtomField,
float::{Float, Real, SingleFloat},
integer::{IntegerRing, Z},
rational::Q,
rational::{Rational, Q},
rational_polynomial::{RationalPolynomial, RationalPolynomialField},
},
evaluate::FunctionMap,
poly::{Exponent, Variable},
tensors::matrix::Matrix,
};

impl<'a> AtomView<'a> {
/// Find the root of a function in `x` numerically over the reals using Newton's method.
pub fn nsolve<N: SingleFloat + Real + PartialOrd + From<Rational>>(
&self,
x: Symbol,
init: N,
prec: N,
max_iterations: usize,
) -> Result<N, String> {
let v = Atom::new_var(x);
let f = self
.to_evaluation_tree(&FunctionMap::new(), std::slice::from_ref(&v))
.unwrap()
.optimize(0, 0, None, false);
let df = self
.derivative(x)
.to_evaluation_tree(&FunctionMap::new(), std::slice::from_ref(&v))
.unwrap()
.optimize(0, 0, None, false);

let mut f_e = f.map_coeff(&|x| x.clone().into());
let mut df_e = df.map_coeff(&|x| x.clone().into());

let mut cur = init.clone();

for _ in 0..max_iterations {
let df_val = df_e.evaluate_single(std::slice::from_ref(&cur));
let f_val = f_e.evaluate_single(std::slice::from_ref(&cur));

if !df_val.is_finite() || df_val.is_zero() {
return Err("Derivative is zero".to_owned());
}

cur = cur - f_val.clone() / df_val;
if f_val.norm() < prec {
return Ok(cur);
}
}

Err("Did not converge".to_owned())
}

/// Solve a non-linear system numerically over the reals using Newton's method.
pub fn nsolve_system(
system: &[AtomView],
vars: &[Symbol],
init: &[Float],
prec: Float,
max_iterations: usize,
) -> Result<Vec<Float>, String> {
if system.len() != vars.len() {
Err("System must have same number of equations as there are unknowns".to_owned())?;
}

let avars = vars.iter().map(|v| Atom::new_var(*v)).collect::<Vec<_>>();

let mut fs = system
.iter()
.map(|a| {
a.to_evaluation_tree(&FunctionMap::new(), &avars)
.unwrap()
.optimize(0, 0, None, false)
.map_coeff(&|x| x.to_multi_prec_float(init[0].prec()))
})
.collect::<Vec<_>>();

let mut jacobian = Vec::with_capacity(vars.len() * system.len());
for a in system {
let mut row = Vec::with_capacity(vars.len());
for v in vars {
let deriv = a.derivative(*v);

let a = deriv
.to_evaluation_tree(&FunctionMap::new(), &avars)
.unwrap()
.optimize(0, 0, None, false)
.map_coeff(&|x| x.to_multi_prec_float(init[0].prec()));

row.push(a);
}
jacobian.extend_from_slice(&row);
}

let mut cur = init.to_vec();
let mut ci = Matrix::new_vec(
init.iter().map(|i| Atom::new_num(i.clone())).collect(),
AtomField::new(),
);

for _ in 0..max_iterations {
let f = fs
.iter_mut()
.map(|a| Atom::new_num(a.evaluate_single(&cur)))
.collect::<Vec<_>>();
let f = Matrix::new_vec(f, AtomField::new());

let df = jacobian
.iter_mut()
.map(|a| Atom::new_num(a.evaluate_single(&cur)))
.collect::<Vec<_>>();

let df =
Matrix::from_linear(df, system.len() as u32, vars.len() as u32, AtomField::new())
.unwrap();

let Ok(i) = df.inv() else {
return Err("Could not invert Jacobian".to_owned());
};

ci = &ci - &(&i * &f);

cur = ci.data.iter().map(|x| x.try_into().unwrap()).collect();

if f.data.iter().all(|x| Float::try_from(x).unwrap() < prec) {
return Ok(cur);
}
}

Err("Did not converge".to_owned())
}

/// Solve a system that is linear in `vars`, if possible.
/// Each expression in `system` is understood to yield 0.
pub fn solve_linear_system<E: Exponent>(
Expand Down Expand Up @@ -95,6 +218,7 @@ mod test {
use crate::{
atom::{Atom, AtomView},
domains::{
float::{Float, Real},
integer::Z,
rational::Q,
rational_polynomial::{RationalPolynomial, RationalPolynomialField},
Expand Down Expand Up @@ -194,4 +318,32 @@ mod test {

assert_eq!(sol.data, res);
}

#[test]
fn find_root() {
let x = State::get_symbol("x");
let a = Atom::parse("x^2 - 2").unwrap();
let a = a.as_view();

let root = a.nsolve(x, 1.0, 1e-10, 1000).unwrap();
assert!((root - 2f64.sqrt()).abs() < 1e-10);
}

#[test]
fn solve_system_newton() {
let a = Atom::parse("5x^2+x*y^2+sin(2y)^2 - 2").unwrap();
let b = Atom::parse("exp(2x-y)+4y - 3").unwrap();

let r = AtomView::nsolve_system(
&[a.as_view(), b.as_view()],
&[State::get_symbol("x"), State::get_symbol("y")],
&[Float::with_val(53, 1.), Float::with_val(53, 1.)],
Float::with_val(53, 1e-10),
100,
)
.unwrap();

assert!((r[0].clone() - Float::from(-5.0533137563948738e-1)).norm() < 1e-10.into());
assert!((r[1].clone() - Float::from(7.0504070797868401e-1)).norm() < 1e-10.into());
}
}

0 comments on commit b63b960

Please sign in to comment.