Skip to content

Commit

Permalink
Add arbitrary precision evaluation to Python API
Browse files Browse the repository at this point in the history
- The output is given as a Decimal
  • Loading branch information
benruijl committed Jun 6, 2024
1 parent 09ed8ac commit 2f28123
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 2 deletions.
166 changes: 164 additions & 2 deletions src/api/python.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Borrow,
f64::consts::LOG2_10,
fs::File,
hash::{Hash, Hasher},
io::BufWriter,
Expand All @@ -14,10 +15,12 @@ use pyo3::{
pyclass,
pyclass::CompareOp,
pyfunction, pymethods, pymodule,
sync::GILOnceCell,
types::{PyBytes, PyComplex, PyLong, PyModule, PyTuple, PyType},
wrap_pyfunction, FromPyObject, IntoPy, PyErr, PyObject, PyRef, PyResult, Python,
wrap_pyfunction, FromPyObject, IntoPy, Py, PyErr, PyObject, PyRef, PyResult, Python,
ToPyObject,
};
use rug::Complete;
use rug::{ops::CompleteRound, Complete};
use self_cell::self_cell;
use smallvec::SmallVec;
use smartstring::{LazyCompact, SmartString};
Expand Down Expand Up @@ -1191,6 +1194,61 @@ impl ConvertibleToExpression {
}
}

pub struct PythonMultiPrecisionFloat(rug::Float);

impl From<rug::Float> for PythonMultiPrecisionFloat {
fn from(f: rug::Float) -> Self {
PythonMultiPrecisionFloat(f)
}
}

static PYDECIMAL: GILOnceCell<Py<PyType>> = GILOnceCell::new();

fn get_decimal(py: Python) -> &Py<PyType> {
PYDECIMAL.get_or_init(py, || {
py.import("decimal")
.unwrap()
.getattr("Decimal")
.unwrap()
.extract()
.unwrap()
})
}

impl ToPyObject for PythonMultiPrecisionFloat {
fn to_object(&self, py: Python) -> PyObject {
get_decimal(py)
.as_ref(py)
.call1((self.0.to_string(),))
.expect("failed to call decimal.Decimal(value)")
.to_object(py)
}
}

impl<'a> FromPyObject<'a> for PythonMultiPrecisionFloat {
fn extract(ob: &'a pyo3::PyAny) -> PyResult<Self> {
if ob.is_instance(get_decimal(ob.py()).as_ref(ob.py()))? {
let a = ob.call_method0("__str__").unwrap().extract::<&str>()?;
Ok(rug::Float::parse(a)
.unwrap()
.complete((a.len() as f64 * LOG2_10).ceil() as u32)
.into())
} else if let Ok(a) = ob.extract::<&str>() {
// convert without loss of precision by setting the precision to the string length
Ok(rug::Float::parse(a)
.unwrap()
.complete((a.len() as f64 * LOG2_10).ceil() as u32)
.into())
} else if let Ok(a) = ob.extract::<f64>() {
Ok(rug::Float::with_val(53, a).into())
} else {
Err(exceptions::PyValueError::new_err(
"Not a valid multi-precision float",
))
}
}
}

impl<'a> FromPyObject<'a> for Complex<f64> {
fn extract(ob: &'a pyo3::PyAny) -> PyResult<Self> {
if let Ok(a) = ob.extract::<f64>() {
Expand All @@ -1205,6 +1263,24 @@ impl<'a> FromPyObject<'a> for Complex<f64> {
}
}

impl<'a> FromPyObject<'a> for Complex<rug::Float> {
fn extract(ob: &'a pyo3::PyAny) -> PyResult<Self> {
if let Ok(a) = ob.extract::<PythonMultiPrecisionFloat>() {
let zero = rug::Float::new(a.0.prec());
Ok(Complex::new(a.0, zero))
} else if let Ok(a) = ob.extract::<&PyComplex>() {
Ok(Complex::new(
rug::Float::with_val(53, a.real()).into(),
rug::Float::with_val(53, a.imag()).into(),
))
} else {
Err(exceptions::PyValueError::new_err(
"Not a valid complex number",
))
}
}
}

macro_rules! req_cmp {
($self:ident,$num:ident,$cmp_any_atom:ident,$c:ident) => {{
let num = $num.to_expression();
Expand Down Expand Up @@ -3156,6 +3232,92 @@ impl PythonExpression {
.evaluate(|x| x.into(), &constants, &functions, &mut cache))
}

/// Evaluate the expression, using a map of all the constants and
/// user functions using arbitrary precision arithmetic.
/// The user has to specify the number of decimal digits of precision
/// and provide all input numbers as floats, strings or `decimal`.
///
/// Examples
/// --------
/// >>> from symbolica import *
/// >>> from decimal import Decimal, getcontext
/// >>> x = Expression.symbols('x', 'f')
/// >>> e = Expression.parse('cos(x)')*3 + f(x, 2)
/// >>> getcontext().prec = 100
/// >>> a = e.evaluate_with_prec({x: Decimal('1.123456789')}, {
/// >>> f: lambda args: args[0] + args[1]}, 100)
pub fn evaluate_with_prec(
&self,
constants: HashMap<PythonExpression, PythonMultiPrecisionFloat>,
functions: HashMap<Variable, PyObject>,
decimal_digit_precision: u32,
py: Python,
) -> PyResult<PyObject> {
let prec = (decimal_digit_precision as f64 * std::f64::consts::LOG2_10).ceil() as u32;

let mut cache = HashMap::default();

let constants: HashMap<AtomView, rug::Float> = constants
.iter()
.map(|(k, v)| {
Ok((k.expr.as_view(), {
let mut vv = v.0.clone();
vv.set_prec(prec);
vv
}))
})
.collect::<PyResult<_>>()?;

let functions = functions
.into_iter()
.map(|(k, v)| {
let id = if let Variable::Symbol(v) = k {
v
} else {
Err(exceptions::PyValueError::new_err(format!(
"Expected function name instead of {}",
k
)))?
};

Ok((
id,
EvaluationFn::new(Box::new(move |args: &[rug::Float], _, _, _| {
Python::with_gil(|py| {
let mut vv = v
.call(
py,
(args
.iter()
.map(|x| PythonMultiPrecisionFloat(x.clone()).to_object(py))
.collect::<Vec<_>>(),),
None,
)
.expect("Bad callback function")
.extract::<PythonMultiPrecisionFloat>(py)
.expect("Function does not return a string")
.0;
vv.set_prec(prec);
vv
})
})),
))
})
.collect::<PyResult<_>>()?;

let a: PythonMultiPrecisionFloat = self
.expr
.evaluate(
|x| x.to_multi_prec_float(prec),
&constants,
&functions,
&mut cache,
)
.into();

Ok(a.to_object(py))
}

/// Evaluate the expression, using a map of all the variables and
/// user functions to a complex number.
///
Expand Down
23 changes: 23 additions & 0 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Symbolica Python API.
from __future__ import annotations
from enum import Enum
from typing import Any, Callable, overload, Iterator, Optional, Sequence, Tuple, List
from decimal import Decimal


def get_version() -> str:
Expand Down Expand Up @@ -986,6 +987,28 @@ class Expression:
>>> print(e.evaluate({x: 1}, {f: lambda args: args[0]+args[1]}))
"""

def evaluate_with_prec(
self,
constants: dict[Expression, float | str | Decimal],
funs: dict[Expression, Callable[[Sequence[Decimal]], float | str | Decimal]],
decimal_digit_precision: int
) -> Decimal:
"""Evaluate the expression, using a map of all the constants and
user functions using arbitrary precision arithmetic.
The user has to specify the number of decimal digits of precision
and provide all input numbers as floats, strings or `Decimal`.
Examples
--------
>>> from symbolica import *
>>> from decimal import Decimal, getcontext
>>> x = Expression.symbols('x', 'f')
>>> e = Expression.parse('cos(x)')*3 + f(x, 2)
>>> getcontext().prec = 100
>>> a = e.evaluate_with_prec({x: Decimal('1.123456789')}, {
>>> f: lambda args: args[0] + args[1]}, 100)
"""

def evaluate_complex(
self, constants: dict[Expression, float | complex], funs: dict[Expression, Callable[[Sequence[complex]], float | complex]]
) -> complex:
Expand Down

0 comments on commit 2f28123

Please sign in to comment.