Skip to content

Commit

Permalink
Support tagged functions
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jul 3, 2024
1 parent 390c4eb commit 57ce732
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 83 deletions.
82 changes: 42 additions & 40 deletions examples/nested_evaluation.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,75 @@
use std::{process::Command, time::Instant};

use ahash::HashMap;
use symbolica::{
atom::{Atom, AtomView},
evaluate::{ConstOrExpr, ExpressionEvaluator},
domains::rational::Rational,
evaluate::{ExpressionEvaluator, FunctionMap},
state::State,
};

fn main() {
let e1 = Atom::parse("x + cos(x) + f(g(x+1),h(x*2)) + p(1)").unwrap();
let e1 = Atom::parse("x + pi + cos(x) + f(g(x+1),h(x*2)) + p(1,x)").unwrap();
let e2 = Atom::parse("x + h(x*2) + cos(x)").unwrap();
let f = Atom::parse("y^2 + z^2*y^2").unwrap();
let g = Atom::parse("i(y+7)+x*i(y+7)*(y-1)").unwrap();
let h = Atom::parse("y*(1+x*(1+x^2)) + y^2*(1+x*(1+x^2))^2 + 3*(1+x^2)").unwrap();
let i = Atom::parse("y - 1").unwrap();
let k = Atom::parse("3*x^3 + 4*x^2 + 6*x +8").unwrap();
let p1 = Atom::parse("3*z^3 + 4*z^2 + 6*z +8").unwrap();

let mut const_map = HashMap::default();
let mut fn_map = FunctionMap::new();

let p1 = Atom::parse("p(1)").unwrap();
let f_s = Atom::new_var(State::get_symbol("f"));
let g_s = Atom::new_var(State::get_symbol("g"));
let h_s = Atom::new_var(State::get_symbol("h"));
let i_s = Atom::new_var(State::get_symbol("i"));

const_map.insert(
p1.into(),
ConstOrExpr::Expr(State::get_symbol("p1"), vec![], k.as_view()),
fn_map.add_constant(
Atom::new_var(State::get_symbol("pi")).into(),
Rational::from((22, 7)).into(),
);

const_map.insert(
f_s.into(),
ConstOrExpr::Expr(
fn_map
.add_tagged_function(
State::get_symbol("p"),
vec![Atom::new_num(1).into()],
"p1".to_string(),
vec![State::get_symbol("z")],
p1.as_view(),
)
.unwrap();
fn_map
.add_function(
State::get_symbol("f"),
"f".to_string(),
vec![State::get_symbol("y"), State::get_symbol("z")],
f.as_view(),
),
);
const_map.insert(
g_s.into(),
ConstOrExpr::Expr(
)
.unwrap();
fn_map
.add_function(
State::get_symbol("g"),
"g".to_string(),
vec![State::get_symbol("y")],
g.as_view(),
),
);
const_map.insert(
h_s.into(),
ConstOrExpr::Expr(
)
.unwrap();
fn_map
.add_function(
State::get_symbol("h"),
"h".to_string(),
vec![State::get_symbol("y")],
h.as_view(),
),
);
const_map.insert(
i_s.into(),
ConstOrExpr::Expr(
)
.unwrap();
fn_map
.add_function(
State::get_symbol("i"),
"i".to_string(),
vec![State::get_symbol("y")],
i.as_view(),
),
);
)
.unwrap();

let params = vec![Atom::parse("x").unwrap()];

let mut tree = AtomView::to_eval_tree_multiple(
&[e1.as_view(), e2.as_view()],
|r| r.clone(),
&const_map,
&fn_map,
&params,
);

Expand All @@ -87,7 +89,7 @@ fn main() {

std::fs::write("nested_evaluation.cpp", cpp).unwrap();

Command::new("g++")
let r = Command::new("g++")
.arg("-shared")
.arg("-fPIC")
.arg("-O3")
Expand All @@ -97,12 +99,12 @@ fn main() {
.arg("nested_evaluation.cpp")
.output()
.unwrap();
println!("Compilation {}", r.status);

unsafe {
let lib = libloading::Library::new("./libneval.so").unwrap();
let func: libloading::Symbol<
unsafe extern "C" fn(params: *const f64, out: *mut f64) -> f64,
> = lib.get(b"eval_double").unwrap();
let func: libloading::Symbol<unsafe extern "C" fn(params: *const f64, out: *mut f64)> =
lib.get(b"eval_double").unwrap();

let params = vec![5.];
let mut out = vec![0., 0.];
Expand Down
1 change: 1 addition & 0 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ impl<'a> From<AddView<'a>> for AtomView<'a> {
}

/// A copy-on-write structure for `Atom` and `AtomView`.
#[derive(Clone)]
pub enum AtomOrView<'a> {
Atom(Atom),
View(AtomView<'a>),
Expand Down
Loading

0 comments on commit 57ce732

Please sign in to comment.