Skip to content

Commit

Permalink
Add nested expression evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jun 26, 2024
1 parent 72bafce commit 248650d
Show file tree
Hide file tree
Showing 3 changed files with 528 additions and 11 deletions.
47 changes: 47 additions & 0 deletions examples/nested_evaluation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use ahash::HashMap;
use symbolica::{atom::Atom, evaluate::ConstOrExpr, state::State};

fn main() {
let e = Atom::parse("x + cos(x) + f(g(x+1),h(x*2)) + p(1)").unwrap();
let f = Atom::parse("y^2 + z^2").unwrap(); // f(y,z) = y^2+z^2
let g = Atom::parse("i(y+7)").unwrap(); // g(y) = i(y+7)
let h = Atom::parse("y + 3").unwrap(); // h(y) = y+3
let i = Atom::parse("y * 2").unwrap(); // i(y) = y*2
let k = Atom::parse("x+8").unwrap(); // p(1) = x + 8

let mut const_map = HashMap::default();

let p1 = Atom::parse("p(1)").unwrap();
let f_s = Atom::new_var(State::get_symbol("f"));
let g_s = Atom::new_var(State::get_symbol("g"));
let h_s = Atom::new_var(State::get_symbol("h"));
let i_s = Atom::new_var(State::get_symbol("i"));

const_map.insert(p1.into(), ConstOrExpr::Expr(vec![], k.as_view()));

const_map.insert(
f_s.into(),
ConstOrExpr::Expr(
vec![State::get_symbol("y"), State::get_symbol("z")],
f.as_view(),
),
);
const_map.insert(
g_s.into(),
ConstOrExpr::Expr(vec![State::get_symbol("y")], g.as_view()),
);
const_map.insert(
h_s.into(),
ConstOrExpr::Expr(vec![State::get_symbol("y")], h.as_view()),
);
const_map.insert(
i_s.into(),
ConstOrExpr::Expr(vec![State::get_symbol("y")], i.as_view()),
);

let params = vec![Atom::parse("x").unwrap()];

let mut evaluator = e.as_view().evaluator(|r| r.into(), &const_map, &params);

println!("{}", evaluator.evaluate(&[5.]));
}
91 changes: 81 additions & 10 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ pub struct Symbol {
is_linear: bool,
}

impl std::fmt::Debug for Symbol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", self.id))?;
for _ in 0..self.wildcard_level {
f.write_str("_")?;
}
Ok(())
}
}

impl Symbol {
/// Create a new variable symbol. This constructor should be used with care as there are no checks
/// about the validity of the identifier.
Expand Down Expand Up @@ -81,16 +91,6 @@ impl Symbol {
}
}

impl std::fmt::Debug for Symbol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", self.id))?;
for _ in 0..self.wildcard_level {
f.write_str("_")?;
}
Ok(())
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AtomType {
Num,
Expand Down Expand Up @@ -192,6 +192,77 @@ impl<'a> From<AddView<'a>> for AtomView<'a> {
}
}

/// A copy-on-write structure for `Atom` and `AtomView`.
pub enum AtomOrView<'a> {
Atom(Atom),
View(AtomView<'a>),
}

impl<'a> PartialEq for AtomOrView<'a> {
#[inline]
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(AtomOrView::Atom(a), AtomOrView::Atom(b)) => a == b,
(AtomOrView::View(a), AtomOrView::View(b)) => a == b,
_ => self.as_view() == other.as_view(),
}
}
}

impl Eq for AtomOrView<'_> {}

impl Hash for AtomOrView<'_> {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
AtomOrView::Atom(a) => a.as_view().hash(state),
AtomOrView::View(a) => a.hash(state),
}
}
}

impl<'a> From<Atom> for AtomOrView<'a> {
fn from(a: Atom) -> AtomOrView<'a> {
AtomOrView::Atom(a)
}
}

impl<'a> From<AtomView<'a>> for AtomOrView<'a> {
fn from(a: AtomView<'a>) -> AtomOrView<'a> {
AtomOrView::View(a)
}
}

impl<'a> From<&AtomView<'a>> for AtomOrView<'a> {
fn from(a: &AtomView<'a>) -> AtomOrView<'a> {
AtomOrView::View(*a)
}
}

impl<'a> AtomOrView<'a> {
pub fn as_view(&'a self) -> AtomView<'a> {
match self {
AtomOrView::Atom(a) => a.as_view(),
AtomOrView::View(a) => *a,
}
}

pub fn as_mut(&mut self) -> &mut Atom {
match self {
AtomOrView::Atom(a) => a,
AtomOrView::View(a) => {
let mut oa = Atom::default();
oa.set_from_view(a);
*self = AtomOrView::Atom(oa);
match self {
AtomOrView::Atom(a) => a,
_ => unreachable!(),
}
}
}
}
}

/// A trait for any type that can be converted into an `AtomView`.
/// To be used for functions that accept any argument that can be
/// converted to an `AtomView`.
Expand Down
Loading

0 comments on commit 248650d

Please sign in to comment.