Skip to content

Commit

Permalink
Function name wildcards may now be used as variables in the rhs
Browse files Browse the repository at this point in the history
- Add contains to Python API
- Evaluate now returns a Result
- Fix setting normalized flag
  • Loading branch information
benruijl committed Aug 21, 2024
1 parent fa978ac commit b19bcb5
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 104 deletions.
30 changes: 26 additions & 4 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2078,6 +2078,20 @@ impl PythonExpression {
}
}

/// Returns true iff `self` contains `a` literally.
///
/// Examples
/// --------
/// >>> from symbolica import *
/// >>> x, y, z = Expression.symbols('x', 'y', 'z')
/// >>> e = x * y * z
/// >>> e.contains(x) # True
/// >>> e.contains(x*y*z) # True
/// >>> e.contains(x*y) # False
pub fn contains(&self, s: ConvertibleToExpression) -> bool {
self.expr.contains(s.to_expression().expr.as_view())
}

/// Convert all coefficients to floats with a given precision `decimal_prec``.
/// The precision of floating point coefficients in the input will be truncated to `decimal_prec`.
pub fn coefficients_to_float(&self, decimal_prec: u32) -> PythonExpression {
Expand Down Expand Up @@ -3489,9 +3503,11 @@ impl PythonExpression {
})
.collect::<PyResult<_>>()?;

Ok(self
.expr
.evaluate(|x| x.into(), &constants, &functions, &mut cache))
self.expr
.evaluate(|x| x.into(), &constants, &functions, &mut cache)
.map_err(|e| {
exceptions::PyValueError::new_err(format!("Could not evaluate expression: {}", e))
})
}

/// Evaluate the expression, using a map of all the constants and
Expand Down Expand Up @@ -3575,6 +3591,9 @@ impl PythonExpression {
&functions,
&mut cache,
)
.map_err(|e| {
exceptions::PyValueError::new_err(format!("Could not evaluate expression: {}", e))
})?
.into();

Ok(a.to_object(py))
Expand Down Expand Up @@ -3637,7 +3656,10 @@ impl PythonExpression {

let r = self
.expr
.evaluate(|x| x.into(), &constants, &functions, &mut cache);
.evaluate(|x| x.into(), &constants, &functions, &mut cache)
.map_err(|e| {
exceptions::PyValueError::new_err(format!("Could not evaluate expression: {}", e))
})?;
Ok(PyComplex::from_doubles(py, r.re, r.im))
}

Expand Down
13 changes: 12 additions & 1 deletion src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ impl<'a> AtomOrView<'a> {
/// A trait for any type that can be converted into an `AtomView`.
/// To be used for functions that accept any argument that can be
/// converted to an `AtomView`.
pub trait AsAtomView<'a>: Sized {
pub trait AsAtomView<'a>: Copy + Sized {
fn as_atom_view(self) -> AtomView<'a>;
}

Expand Down Expand Up @@ -743,6 +743,17 @@ impl FunctionBuilder {
self
}

/// Add multiple arguments to the function.
pub fn add_args<'b, T: AsAtomView<'b>>(mut self, args: &[T]) -> FunctionBuilder {
if let Atom::Fun(f) = self.handle.deref_mut() {
for a in args {
f.add_arg(a.as_atom_view());
}
}

self
}

/// Finish the function construction and return an `Atom`.
pub fn finish(self) -> Atom {
Workspace::get_local().with(|ws| {
Expand Down
135 changes: 49 additions & 86 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl Atom {
const_map: &HashMap<AtomView<'_>, T>,
function_map: &HashMap<Symbol, EvaluationFn<T>>,
cache: &mut HashMap<AtomView<'b>, T>,
) -> T {
) -> Result<T, String> {
self.as_view()
.evaluate(coeff_map, const_map, function_map, cache)
}
Expand Down Expand Up @@ -1999,42 +1999,6 @@ impl<T: Clone + PartialEq> SplitExpression<T> {
subexpressions: self.subexpressions.iter().map(|x| x.map_coeff(f)).collect(),
}
}

pub fn unnest(&mut self, max_depth: usize) {
// TODO: also unnest subexpressions
for t in &mut self.tree {
Self::unnest_impl(t, &mut self.subexpressions, 0, max_depth);
}
}

fn unnest_impl(
expr: &mut Expression<T>,
subs: &mut Vec<Expression<T>>,
depth: usize,
max_depth: usize,
) {
match expr {
Expression::Add(a) | Expression::Mul(a) => {
if depth == max_depth {
// split off into new subexpression

Self::unnest_impl(expr, subs, 0, max_depth);

let mut r = Expression::SubExpression(subs.len());
std::mem::swap(expr, &mut r);
subs.push(r);
return;
}

for x in a {
Self::unnest_impl(x, subs, depth + 1, max_depth);
}
}
Expression::Eval(_, _) => {} // TODO: count the arg evals! always bring to base level?
Expression::BuiltinFun(_, _) => {}
_ => {} // TODO: count pow levels too?
}
}
}

impl<T: Clone + PartialEq> Expression<T> {
Expand Down Expand Up @@ -2129,14 +2093,6 @@ impl<T: Clone + PartialEq> EvalTree<T> {
param_count: self.param_count,
}
}

pub fn unnest(&mut self, max_depth: usize) {
for (_, _, e) in &mut self.functions {
e.unnest(max_depth);
}

self.expressions.unnest(max_depth);
}
}

impl<T: Clone + Default + PartialEq> EvalTree<T> {
Expand Down Expand Up @@ -3935,104 +3891,107 @@ impl<'a> AtomView<'a> {
const_map: &HashMap<AtomView<'_>, T>,
function_map: &HashMap<Symbol, EvaluationFn<T>>,
cache: &mut HashMap<AtomView<'a>, T>,
) -> T {
) -> Result<T, String> {
if let Some(c) = const_map.get(self) {
return c.clone();
return Ok(c.clone());
}

match self {
AtomView::Num(n) => match n.get_coeff_view() {
CoefficientView::Natural(n, d) => coeff_map(&Rational::from_unchecked(n, d)),
CoefficientView::Large(l) => coeff_map(&l.to_rat()),
CoefficientView::Natural(n, d) => Ok(coeff_map(&Rational::from_unchecked(n, d))),
CoefficientView::Large(l) => Ok(coeff_map(&l.to_rat())),
CoefficientView::Float(f) => {
// TODO: converting back to rational is slow
coeff_map(&f.to_float().to_rational())
Ok(coeff_map(&f.to_float().to_rational()))
}
CoefficientView::FiniteField(_, _) => {
unimplemented!("Finite field not yet supported for evaluation")
Err("Finite field not yet supported for evaluation".to_string())
}
CoefficientView::RationalPolynomial(_) => unimplemented!(
"Rational polynomial coefficient not yet supported for evaluation"
CoefficientView::RationalPolynomial(_) => Err(
"Rational polynomial coefficient not yet supported for evaluation".to_string(),
),
},
AtomView::Var(v) => panic!(
AtomView::Var(v) => Err(format!(
"Variable {} not in constant map",
State::get_name(v.get_symbol())
),
)),
AtomView::Fun(f) => {
let name = f.get_symbol();
if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) {
assert!(f.get_nargs() == 1);
let arg = f.iter().next().unwrap();
let arg_eval = arg.evaluate(coeff_map, const_map, function_map, cache);
let arg_eval = arg.evaluate(coeff_map, const_map, function_map, cache)?;

return match f.get_symbol() {
return Ok(match f.get_symbol() {
State::EXP => arg_eval.exp(),
State::LOG => arg_eval.log(),
State::SIN => arg_eval.sin(),
State::COS => arg_eval.cos(),
State::SQRT => arg_eval.sqrt(),
_ => unreachable!(),
};
});
}

if let Some(eval) = cache.get(self) {
return eval.clone();
return Ok(eval.clone());
}

let mut args = Vec::with_capacity(f.get_nargs());
for arg in f {
args.push(arg.evaluate(coeff_map, const_map, function_map, cache));
args.push(arg.evaluate(coeff_map, const_map, function_map, cache)?);
}

let Some(fun) = function_map.get(&f.get_symbol()) else {
panic!("Missing function {}", State::get_name(f.get_symbol()));
Err(format!(
"Missing function {}",
State::get_name(f.get_symbol())
))?
};
let eval = fun.get()(&args, const_map, function_map, cache);

cache.insert(*self, eval.clone());
eval
Ok(eval)
}
AtomView::Pow(p) => {
let (b, e) = p.get_base_exp();
let b_eval = b.evaluate(coeff_map, const_map, function_map, cache);
let b_eval = b.evaluate(coeff_map, const_map, function_map, cache)?;

if let AtomView::Num(n) = e {
if let CoefficientView::Natural(num, den) = n.get_coeff_view() {
if den == 1 {
if num >= 0 {
return b_eval.pow(num as u64);
return Ok(b_eval.pow(num as u64));
} else {
return b_eval.pow(num.unsigned_abs()).inv();
return Ok(b_eval.pow(num.unsigned_abs()).inv());
}
}
}
}

let e_eval = e.evaluate(coeff_map, const_map, function_map, cache);
b_eval.powf(&e_eval)
let e_eval = e.evaluate(coeff_map, const_map, function_map, cache)?;
Ok(b_eval.powf(&e_eval))
}
AtomView::Mul(m) => {
let mut it = m.iter();
let mut r = it
.next()
.unwrap()
.evaluate(coeff_map, const_map, function_map, cache);
let mut r =
it.next()
.unwrap()
.evaluate(coeff_map, const_map, function_map, cache)?;
for arg in it {
r *= arg.evaluate(coeff_map, const_map, function_map, cache);
r *= arg.evaluate(coeff_map, const_map, function_map, cache)?;
}
r
Ok(r)
}
AtomView::Add(a) => {
let mut it = a.iter();
let mut r = it
.next()
.unwrap()
.evaluate(coeff_map, const_map, function_map, cache);
let mut r =
it.next()
.unwrap()
.evaluate(coeff_map, const_map, function_map, cache)?;
for arg in it {
r += arg.evaluate(coeff_map, const_map, function_map, cache);
r += arg.evaluate(coeff_map, const_map, function_map, cache)?;
}
r
Ok(r)
}
}
}
Expand Down Expand Up @@ -4082,7 +4041,9 @@ mod test {
})),
);

let r = a.evaluate(|x| x.into(), &const_map, &fn_map, &mut cache);
let r = a
.evaluate(|x| x.into(), &const_map, &fn_map, &mut cache)
.unwrap();
assert_eq!(r, 2905.761021719902);
}

Expand All @@ -4096,12 +4057,14 @@ mod test {
let v = Atom::new_var(x);
const_map.insert(v.as_view(), Float::with_val(200, 6));

let r = a.evaluate(
|r| r.to_multi_prec_float(200),
&const_map,
&HashMap::default(),
&mut HashMap::default(),
);
let r = a
.evaluate(
|r| r.to_multi_prec_float(200),
&const_map,
&HashMap::default(),
&mut HashMap::default(),
)
.unwrap();

assert_eq!(
format!("{}", r),
Expand Down
19 changes: 8 additions & 11 deletions src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use dyn_clone::DynClone;

use crate::{
atom::{
representation::ListSlice, AsAtomView, Atom, AtomType, AtomView, Num, SliceType, Symbol,
representation::{InlineVar, ListSlice},
AsAtomView, Atom, AtomType, AtomView, Num, SliceType, Symbol,
},
state::{State, Workspace},
transformer::{Transformer, TransformerError},
Expand Down Expand Up @@ -867,8 +868,8 @@ impl Pattern {
func.add_arg(handle.as_view())
}
},
Match::FunctionName(_) => {
unreachable!("Wildcard cannot be function name")
Match::FunctionName(s) => {
func.add_arg(InlineVar::new(*s).as_view())
}
}
} else if match_stack.settings.allow_new_wildcards_on_rhs {
Expand Down Expand Up @@ -905,8 +906,8 @@ impl Pattern {
w.to_atom_into(&mut handle);
out.set_from_view(&handle.as_view())
}
Match::FunctionName(_) => {
unreachable!("Wildcard cannot be function name")
Match::FunctionName(s) => {
out.set_from_view(&InlineVar::new(*s).as_view())
}
}
} else if match_stack.settings.allow_new_wildcards_on_rhs {
Expand Down Expand Up @@ -951,9 +952,7 @@ impl Pattern {
mul.extend(handle.as_view())
}
},
Match::FunctionName(_) => {
unreachable!("Wildcard cannot be function name")
}
Match::FunctionName(s) => mul.extend(InlineVar::new(*s).as_view()),
}
} else if match_stack.settings.allow_new_wildcards_on_rhs {
mul.extend(workspace.new_var(*w).as_view());
Expand Down Expand Up @@ -994,9 +993,7 @@ impl Pattern {
add.extend(handle.as_view())
}
},
Match::FunctionName(_) => {
unreachable!("Wildcard cannot be function name")
}
Match::FunctionName(s) => add.extend(InlineVar::new(*s).as_view()),
}
} else if match_stack.settings.allow_new_wildcards_on_rhs {
add.extend(workspace.new_var(*w).as_view());
Expand Down
Loading

0 comments on commit b19bcb5

Please sign in to comment.