Skip to content

Commit

Permalink
Add for_each transformer
Browse files Browse the repository at this point in the history
- Auto-convert pattern to transformer if it is not one already
- Add more tests
  • Loading branch information
benruijl committed Apr 10, 2024
1 parent eb57ead commit 1393891
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 50 deletions.
44 changes: 41 additions & 3 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,13 @@ macro_rules! append_transformer {
expr: Arc::new(Pattern::Transformer(t)),
})
} else {
return Err(exceptions::PyValueError::new_err(
"Pattern must be a transformer",
));
// pattern is not a transformer yet (but may have subtransformers)
Ok(PythonPattern {
expr: Arc::new(Pattern::Transformer(Box::new((
Some($self.expr.as_ref().clone()),
vec![$t],
)))),
})
}
};
}
Expand Down Expand Up @@ -526,6 +530,40 @@ impl PythonPattern {
return append_transformer!(self, transformer);
}

/// Create a transformer that applies a transformer chain to every argument of the `arg()` function.
/// If the input is not `arg()`, the transformer is applied to the input.
///
/// Examples
/// --------
/// >>> from symbolica import Expression
/// >>> x = Expression.var('x')
/// >>> f = Expression.fun('f')
/// >>> e = (1+x).transform().split().for_each(Transformer().map(f)).execute()
#[pyo3(signature = (*transformers))]
pub fn for_each(&self, transformers: &PyTuple) -> PyResult<PythonPattern> {
let mut rep_chain = vec![];
// fuse all sub-transformers into one chain
for r in transformers {
let p = r.extract::<PythonPattern>()?;

let Pattern::Transformer(t) = p.expr.borrow() else {
return Err(exceptions::PyValueError::new_err(
"Argument must be a transformer",
));
};

if t.0.is_some() {
return Err(exceptions::PyValueError::new_err(
"Transformers in a for_each must be unbound. Use Transformer() to create it.",
));
}

rep_chain.extend_from_slice(&t.1);
}

return append_transformer!(self, Transformer::ForEach(rep_chain));
}

/// Create a transformer that checks for a Python interrupt,
/// such as ctrl-c and aborts the current transformer.
///
Expand Down
37 changes: 37 additions & 0 deletions src/coefficient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,40 @@ impl<'a> AtomView<'a> {
}
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use crate::{domains::rational::Rational, representations::Atom, state::State};

#[test]
fn coefficient_ring() {
let expr = Atom::parse("v1*v3+v1*(v2+2)^-1*(v2+v3+1)").unwrap();

let v2 = State::get_symbol("v2");
let expr_yz =
expr.set_coefficient_ring(&Arc::new(vec![v2.into(), State::get_symbol("v3").into()]));

let a = ((&expr_yz + &Atom::new_num(Rational::new(1, 2)))
* &Atom::new_num(Rational::new(3, 4)))
.expand();

let a = (a / &Atom::new_num(Rational::new(3, 4)) - &Atom::new_num(Rational::new(1, 2)))
.expand();

let a = a.set_coefficient_ring(&Arc::new(vec![]));

let expr = Atom::new_var(v2)
.into_pattern()
.replace_all(expr.as_view(), &Atom::new_num(3).into_pattern(), None, None)
.expand();

let a = Atom::new_var(v2)
.into_pattern()
.replace_all(a.as_view(), &Atom::new_num(3).into_pattern(), None, None)
.expand();

assert_eq!(a, expr);
}
}
45 changes: 0 additions & 45 deletions src/poly/gcd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,51 +263,6 @@ impl<F: Field, E: Exponent> MultivariatePolynomial<F, E> {
res
}

/// Replace all variables except `v` in the polynomial by elements from
/// a finite field of size `p`. The exponent of `v` should be small.
pub fn sample_polynomial_small_exponent(
&self,
v: usize,
r: &[(usize, F::Element)],
cache: &mut [Vec<F::Element>],
tm: &mut [F::Element],
) -> MultivariatePolynomial<F, E> {
for mv in self.into_iter() {
let mut c = mv.coefficient.clone();
for (n, vv) in r {
let exp = mv.exponents[*n].to_u32() as usize;
if exp > 0 {
if exp < cache[*n].len() {
if F::is_zero(&cache[*n][exp]) {
cache[*n][exp] = self.field.pow(vv, exp as u64);
}

self.field.mul_assign(&mut c, &cache[*n][exp]);
} else {
self.field
.mul_assign(&mut c, &self.field.pow(vv, exp as u64));
}
}
}

let expv = mv.exponents[v].to_u32() as usize;
self.field.add_assign(&mut tm[expv], &c);
}

// TODO: add bounds estimate
let mut res = self.zero();
let mut e = vec![E::zero(); self.nvars()];
for (k, c) in tm.iter_mut().enumerate() {
if !F::is_zero(c) {
e[v] = E::from_u32(k as u32);
res.append_monomial_back(mem::replace(c, self.field.zero()), &e);
e[v] = E::zero();
}
}

res
}

/// Find the upper bound of a variable `var` in the gcd.
/// This is done by computing the univariate gcd by
/// substituting all variables except `var`. This
Expand Down
24 changes: 23 additions & 1 deletion src/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ pub enum Transformer {
ArgCount(bool),
/// Map the rhs with a user-specified function.
Map(Box<dyn Map>),
/// Apply a transformation to each argument of the `arg()` function.
/// If the input is not `arg()`, map the current input.
ForEach(Vec<Transformer>),
/// Split a `Mul` or `Add` into a list of arguments.
Split,
Partition(Vec<(Symbol, usize)>, bool, bool),
Expand All @@ -108,6 +111,7 @@ impl std::fmt::Debug for Transformer {
Transformer::Sum => f.debug_tuple("Sum").finish(),
Transformer::ArgCount(p) => f.debug_tuple("ArgCount").field(p).finish(),
Transformer::Map(_) => f.debug_tuple("Map").finish(),
Transformer::ForEach(t) => f.debug_tuple("ForEach").field(t).finish(),
Transformer::Split => f.debug_tuple("Split").finish(),
Transformer::Partition(g, b1, b2) => f
.debug_tuple("Partition")
Expand Down Expand Up @@ -157,7 +161,6 @@ impl Transformer {
pub fn execute(
orig_input: AtomView<'_>,
chain: &[Transformer],

workspace: &Workspace,
out: &mut Atom,
) -> Result<(), TransformerError> {
Expand All @@ -171,6 +174,25 @@ impl Transformer {
Transformer::Map(f) => {
f(input, out)?;
}
Transformer::ForEach(t) => {
if let AtomView::Fun(f) = input {
if f.get_symbol() == State::ARG {
let mut ff = workspace.new_atom();
let ff = ff.to_fun(State::ARG);

let mut a = workspace.new_atom();
for arg in f.iter() {
Self::execute(arg, t, workspace, &mut a)?;
ff.add_arg(a.as_view());
}

ff.as_view().normalize(workspace, out);
continue;
}
}

Self::execute(input, t, workspace, out)?;
}
Transformer::Expand => {
input.expand_with_ws_into(workspace, out);
}
Expand Down
12 changes: 12 additions & 0 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,18 @@ class Transformer:
>>> print(e)
"""

def for_each(self, *transformers: Transformer) -> Transformer:
"""Create a transformer that applies a transformer chain to every argument of the `arg()` function.
If the input is not `arg()`, the transformer is applied to the input.
Examples
--------
>>> from symbolica import Expression
>>> x = Expression.var('x')
>>> f = Expression.fun('f')
>>> e = (1+x).transform().split().for_each(Transformer().map(f)).execute()
"""

def check_interrupt(self) -> Transformer:
"""Create a transformer that checks for a Python interrupt,
such as ctrl-c and aborts the current transformer.
Expand Down
29 changes: 28 additions & 1 deletion tests/rational_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,34 @@ use symbolica::{
};

#[test]
fn large_gcd() {
fn large_gcd_single_scale() {
let order = Arc::new(vec![
Variable::Symbol(State::get_symbol("x0")),
Variable::Symbol(State::get_symbol("x1")),
Variable::Symbol(State::get_symbol("x2")),
Variable::Symbol(State::get_symbol("x3")),
Variable::Symbol(State::get_symbol("x4")),
]);

let a = Atom::parse("(x0+2*x1+x2+x3-x4^2)^10")
.unwrap()
.to_polynomial::<_, u8>(&Z, Some(order.clone()));

let b = Atom::parse("(x0+2*x1+5+x2+x3-x4^2)^10")
.unwrap()
.to_polynomial::<_, u8>(&Z, Some(order.clone()));

let g = Atom::parse("(-x0+3*x1+x2+5*x3-x4^2)^4")
.unwrap()
.to_polynomial::<_, u8>(&Z, Some(order.clone()));

let gg = (&a * &g).gcd(&(&b * &g));

assert_eq!(gg, g);
}

#[test]
fn large_gcd_multiple_scales() {
let order = Arc::new(vec![
Variable::Symbol(State::get_symbol("x0")),
Variable::Symbol(State::get_symbol("x1")),
Expand Down

0 comments on commit 1393891

Please sign in to comment.