diff --git a/src/api/python.rs b/src/api/python.rs index 93107d5..4050a82 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -53,8 +53,8 @@ use crate::{ graph::Graph, id::{ Condition, ConditionResult, Evaluate, Match, MatchSettings, MatchStack, Pattern, - PatternAtomTreeIterator, PatternOrMap, PatternRestriction, ReplaceIterator, Replacement, - WildcardRestriction, + PatternAtomTreeIterator, PatternOrMap, PatternRestriction, Relation, ReplaceIterator, + Replacement, WildcardRestriction, }, numerical_integration::{ContinuousGrid, DiscreteGrid, Grid, MonteCarloRng, Sample}, parser::Token, @@ -443,6 +443,50 @@ impl PythonTransformer { } } + /// Compare two expressions. If one of the expressions is not a number, an + /// internal ordering will be used. + fn __richcmp__(&self, other: ConvertibleToPattern, op: CompareOp) -> PyResult { + Ok(match op { + CompareOp::Eq => PythonCondition { + condition: Relation::Eq(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Ne => PythonCondition { + condition: Relation::Ne(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Ge => PythonCondition { + condition: Relation::Ge(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Gt => PythonCondition { + condition: Relation::Gt(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Le => PythonCondition { + condition: Relation::Le(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Lt => PythonCondition { + condition: Relation::Lt(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + }) + } + + /// Returns true iff `self` contains `a` literally. + /// + /// Examples + /// -------- + /// >>> from symbolica import * + /// >>> x, y, z = Expression.symbol('x', 'y', 'z') + /// >>> e = x * y * z + /// >>> e.contains(x) # True + /// >>> e.contains(x*y*z) # True + /// >>> e.contains(x*y) # False + pub fn contains(&self, s: ConvertibleToPattern) -> PyResult { + Ok(PythonCondition { + condition: Condition::Yield(Relation::Contains( + self.expr.clone(), + s.to_pattern()?.expr, + )), + }) + } + /// Create a transformer that expands products and powers. /// /// Examples @@ -917,6 +961,119 @@ impl PythonTransformer { return append_transformer!(self, Transformer::Repeat(rep_chain)); } + /// Evaluate the condition and apply the `if_block` if the condition is true, otherwise apply the `else_block`. + /// The expression that is the input of the transformer is the input for the condition, the `if_block` and the `else_block`. + /// + /// Examples + /// -------- + /// >>> t = T.map_terms(T.if_then(T.contains(x), T.print())) + /// >>> t(x + y + 4) + /// + /// prints `x`. + #[pyo3(signature = (condition, if_block, else_block = None))] + pub fn if_then( + &self, + condition: PythonCondition, + if_block: PythonTransformer, + else_block: Option, + ) -> PyResult { + let Pattern::Transformer(t1) = if_block.expr else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + }; + + let t2 = if let Some(e) = else_block { + if let Pattern::Transformer(t2) = e.expr { + t2 + } else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + } + } else { + Box::new((None, vec![])) + }; + + if t1.0.is_some() || t2.0.is_some() { + return Err(exceptions::PyValueError::new_err( + "Transformers in a repeat must be unbound. Use Transformer() to create it.", + )); + } + + return append_transformer!(self, Transformer::IfElse(condition.condition, t1.1, t2.1)); + } + + /// Execute the `condition` transformer. If the result of the `condition` transformer is different from the input expression, + /// apply the `if_block`, otherwise apply the `else_block`. The input expression of the `if_block` is the output + /// of the `condition` transformer. + /// + /// Examples + /// -------- + /// >>> t = T.map_terms(T.if_changed(T.replace_all(x, y), T.print())) + /// >>> print(t(x + y + 4)) + /// + /// prints + /// ```log + /// y + /// 2*y+4 + /// ``` + #[pyo3(signature = (condition, if_block, else_block = None))] + pub fn if_changed( + &self, + condition: PythonTransformer, + if_block: PythonTransformer, + else_block: Option, + ) -> PyResult { + let Pattern::Transformer(t0) = condition.expr else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + }; + + let Pattern::Transformer(t1) = if_block.expr else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + }; + + let t2 = if let Some(e) = else_block { + if let Pattern::Transformer(t2) = e.expr { + t2 + } else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + } + } else { + Box::new((None, vec![])) + }; + + if t0.0.is_some() || t1.0.is_some() || t2.0.is_some() { + return Err(exceptions::PyValueError::new_err( + "Transformers in a repeat must be unbound. Use Transformer() to create it.", + )); + } + + return append_transformer!(self, Transformer::IfChanged(t0.1, t1.1, t2.1)); + } + + /// Break the current chain and all higher-level chains containing `if` transformers. + /// + /// Examples + /// -------- + /// >>> from symbolica import * + /// >>> t = T.map_terms(T.repeat( + /// >>> T.replace_all(y, 4), + /// >>> T.if_changed(T.replace_all(x, y), + /// >>> T.break_chain()), + /// >>> T.print() # print of y is never reached + /// >>> )) + /// >>> print(t(x)) + pub fn break_chain(&self) -> PyResult { + return append_transformer!(self, Transformer::BreakChain); + } + /// Chain several transformers. `chain(A,B,C)` is the same as `A.B.C`, /// where `A`, `B`, `C` are transformers. /// @@ -981,6 +1138,7 @@ impl PythonTransformer { workspace, &mut out, &MatchStack::new(&Condition::default(), &MatchSettings::default()), + None, ) }) .map_err(|e| match e { @@ -1138,7 +1296,7 @@ impl PythonTransformer { /// /// yields /// - /// ``` + /// ```log /// -6*(x-2*y)*(x+y) /// ``` pub fn collect_num(&self) -> PyResult { @@ -1711,58 +1869,6 @@ impl PythonPatternRestriction { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum Relation { - Eq(Atom, Atom), - Ne(Atom, Atom), - Gt(Atom, Atom), - Ge(Atom, Atom), - Lt(Atom, Atom), - Le(Atom, Atom), - Contains(Atom, Atom), - IsType(Atom, AtomType), -} - -impl std::fmt::Display for Relation { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Relation::Eq(a, b) => write!(f, "{} == {}", a, b), - Relation::Ne(a, b) => write!(f, "{} != {}", a, b), - Relation::Gt(a, b) => write!(f, "{} > {}", a, b), - Relation::Ge(a, b) => write!(f, "{} >= {}", a, b), - Relation::Lt(a, b) => write!(f, "{} < {}", a, b), - Relation::Le(a, b) => write!(f, "{} <= {}", a, b), - Relation::Contains(a, b) => write!(f, "{} contains {}", a, b), - Relation::IsType(a, b) => write!(f, "{} is type {:?}", a, b), - } - } -} - -impl Evaluate for Relation { - type State<'a> = (); - - fn evaluate(&self, _state: &()) -> ConditionResult { - match self { - Relation::Eq(a, b) => (a == b).into(), - Relation::Ne(a, b) => (a != b).into(), - Relation::Gt(a, b) => (a > b).into(), - Relation::Ge(a, b) => (a >= b).into(), - Relation::Lt(a, b) => (a < b).into(), - Relation::Le(a, b) => (a <= b).into(), - Relation::Contains(a, b) => (a.contains(b)).into(), - Relation::IsType(a, b) => match a { - Atom::Var(_) => (*b == AtomType::Var).into(), - Atom::Fun(_) => (*b == AtomType::Fun).into(), - Atom::Num(_) => (*b == AtomType::Num).into(), - Atom::Add(_) => (*b == AtomType::Add).into(), - Atom::Mul(_) => (*b == AtomType::Mul).into(), - Atom::Pow(_) => (*b == AtomType::Pow).into(), - Atom::Zero => (*b == AtomType::Num).into(), - }, - } - } -} - /// A restriction on wildcards. #[pyclass(name = "Condition", module = "symbolica")] #[derive(Clone)] @@ -1786,11 +1892,15 @@ impl PythonCondition { format!("{}", self.condition) } - pub fn eval(&self) -> bool { - self.condition.evaluate(&()) == ConditionResult::True + pub fn eval(&self) -> PyResult { + Ok(self + .condition + .evaluate(&None) + .map_err(|e| exceptions::PyValueError::new_err(e))? + == ConditionResult::True) } - pub fn __bool__(&self) -> bool { + pub fn __bool__(&self) -> PyResult { self.eval() } @@ -1821,37 +1931,45 @@ impl PythonCondition { macro_rules! req_cmp_rel { ($self:ident,$num:ident,$cmp_any_atom:ident,$c:ident) => {{ - if !$cmp_any_atom && !matches!($num.as_view(), AtomView::Num(_)) { - return Err("Can only compare to number"); - }; - - match $self.as_view() { - AtomView::Var(v) => { - let name = v.get_symbol(); - if v.get_wildcard_level() == 0 { - return Err("Only wildcards can be restricted."); + let num = if !$cmp_any_atom { + if let Pattern::Literal(a) = $num { + if let AtomView::Num(_) = a.as_view() { + a + } else { + return Err("Can only compare to number"); } + } else { + return Err("Can only compare to number"); + } + } else if let Pattern::Literal(a) = $num { + a + } else { + return Err("Pattern must be literal"); + }; - Ok(PatternRestriction::Wildcard(( - name, - WildcardRestriction::Filter(Box::new(move |v: &Match| { - let k = $num.as_view(); + if let Pattern::Wildcard(name) = $self { + if name.get_wildcard_level() == 0 { + return Err("Only wildcards can be restricted."); + } - if let Match::Single(m) = v { - if !$cmp_any_atom { - if let AtomView::Num(_) = m { - return m.cmp(&k).$c(); - } - } else { - return m.cmp(&k).$c(); + Ok(PatternRestriction::Wildcard(( + name, + WildcardRestriction::Filter(Box::new(move |v: &Match| { + if let Match::Single(m) = v { + if !$cmp_any_atom { + if let AtomView::Num(_) = m { + return m.cmp(&num.as_view()).$c(); } + } else { + return m.cmp(&num.as_view()).$c(); } + } - false - })), - ))) - } - _ => Err("Only wildcards can be restricted."), + false + })), + ))) + } else { + Err("Only wildcards can be restricted.") } }}; } @@ -1880,18 +1998,28 @@ impl TryFrom for PatternRestriction { return req_cmp_rel!(atom, atom1, true, is_le); } Relation::Contains(atom, atom1) => { - if let Atom::Var(v) = atom { - let name = v.get_symbol(); + if let Pattern::Wildcard(name) = atom { if name.get_wildcard_level() == 0 { return Err("Only wildcards can be restricted."); } + if !matches!(&atom1, &Pattern::Literal(_)) { + return Err("Pattern must be literal"); + } + Ok(PatternRestriction::Wildcard(( name, - WildcardRestriction::Filter(Box::new(move |m| match m { - Match::Single(v) => v.contains(atom1.as_view()), - Match::Multiple(_, v) => v.iter().any(|x| x.contains(atom1.as_view())), - Match::FunctionName(_) => false, + WildcardRestriction::Filter(Box::new(move |m| { + let val = if let Pattern::Literal(a) = &atom1 { + a.as_view() + } else { + unreachable!() + }; + match m { + Match::Single(v) => v.contains(val), + Match::Multiple(_, v) => v.iter().any(|x| x.contains(val)), + Match::FunctionName(_) => false, + } })), ))) } else { @@ -1899,9 +2027,9 @@ impl TryFrom for PatternRestriction { } } Relation::IsType(atom, atom_type) => { - if let Atom::Var(v) = atom { + if let Pattern::Wildcard(name) = atom { Ok(PatternRestriction::Wildcard(( - v.get_symbol(), + name, WildcardRestriction::IsAtomType(atom_type), ))) } else { @@ -3056,8 +3184,8 @@ impl PythonExpression { pub fn contains(&self, s: ConvertibleToExpression) -> PythonCondition { PythonCondition { condition: Condition::Yield(Relation::Contains( - self.expr.clone(), - s.to_expression().expr, + self.expr.into_pattern(), + s.to_expression().expr.into_pattern(), )), } } @@ -3235,7 +3363,7 @@ impl PythonExpression { pub fn is_type(&self, atom_type: PythonAtomType) -> PythonCondition { PythonCondition { condition: Condition::Yield(Relation::IsType( - self.expr.clone(), + self.expr.into_pattern(), match atom_type { PythonAtomType::Num => AtomType::Num, PythonAtomType::Var => AtomType::Var, @@ -3245,32 +3373,32 @@ impl PythonExpression { PythonAtomType::Fn => AtomType::Fun, }, )), - } - } + } + } /// Compare two expressions. If one of the expressions is not a number, an /// internal ordering will be used. - fn __richcmp__(&self, other: ConvertibleToExpression, op: CompareOp) -> PythonCondition { - match op { + fn __richcmp__(&self, other: ConvertibleToPattern, op: CompareOp) -> PyResult { + Ok(match op { CompareOp::Eq => PythonCondition { - condition: Relation::Eq(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Eq(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Ne => PythonCondition { - condition: Relation::Ne(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Ne(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Ge => PythonCondition { - condition: Relation::Ge(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Ge(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Gt => PythonCondition { - condition: Relation::Gt(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Gt(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Le => PythonCondition { - condition: Relation::Le(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Le(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Lt => PythonCondition { - condition: Relation::Lt(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Lt(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, - } + }) } /// Create a pattern restriction that passes when the wildcard is smaller than a number `num`. @@ -3719,7 +3847,7 @@ impl PythonExpression { /// /// yields /// - /// ``` + /// ```log /// (3*x+3*y)*(4*x+5*y) /// ``` pub fn expand_num(&self) -> PythonExpression { diff --git a/src/id.rs b/src/id.rs index 41fddf1..180eb2b 100644 --- a/src/id.rs +++ b/src/id.rs @@ -23,6 +23,16 @@ pub enum Pattern { Transformer(Box<(Option, Vec)>), } +impl std::fmt::Display for Pattern { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Ok(a) = self.to_atom() { + a.fmt(f) + } else { + std::fmt::Debug::fmt(self, f) + } + } +} + pub trait MatchMap: Fn(&MatchStack) -> Atom + DynClone + Send + Sync {} dyn_clone::clone_trait_object!(MatchMap); impl Atom> MatchMap for T {} @@ -417,8 +427,13 @@ impl<'a> AtomView<'a> { match r.rhs { PatternOrMap::Pattern(rhs) => { - rhs.substitute_wildcards(workspace, &mut rhs_subs, &match_stack) - .unwrap(); // TODO: escalate? + rhs.substitute_wildcards( + workspace, + &mut rhs_subs, + &match_stack, + None, + ) + .unwrap(); // TODO: escalate? } PatternOrMap::Map(f) => { let mut rhs = f(&match_stack); @@ -938,6 +953,7 @@ impl Pattern { workspace: &Workspace, out: &mut Atom, match_stack: &MatchStack, + transformer_input: Option<&Pattern>, ) -> Result<(), TransformerError> { match self { Pattern::Wildcard(name) => { @@ -1017,7 +1033,12 @@ impl Pattern { } let mut handle = workspace.new_atom(); - arg.substitute_wildcards(workspace, &mut handle, match_stack)?; + arg.substitute_wildcards( + workspace, + &mut handle, + match_stack, + transformer_input, + )?; func.add_arg(handle.as_view()); } @@ -1055,7 +1076,12 @@ impl Pattern { } let mut handle = workspace.new_atom(); - arg.substitute_wildcards(workspace, &mut handle, match_stack)?; + arg.substitute_wildcards( + workspace, + &mut handle, + match_stack, + transformer_input, + )?; out.set_from_view(&handle.as_view()); } @@ -1099,7 +1125,12 @@ impl Pattern { } let mut handle = workspace.new_atom(); - arg.substitute_wildcards(workspace, &mut handle, match_stack)?; + arg.substitute_wildcards( + workspace, + &mut handle, + match_stack, + transformer_input, + )?; mul.extend(handle.as_view()); } mul_h.as_view().normalize(workspace, out); @@ -1140,7 +1171,12 @@ impl Pattern { } let mut handle = workspace.new_atom(); - arg.substitute_wildcards(workspace, &mut handle, match_stack)?; + arg.substitute_wildcards( + workspace, + &mut handle, + match_stack, + transformer_input, + )?; add.extend(handle.as_view()); } add_h.as_view().normalize(workspace, out); @@ -1150,14 +1186,19 @@ impl Pattern { } Pattern::Transformer(p) => { let (pat, ts) = &**p; - let pat = pat.as_ref().ok_or_else(|| { - TransformerError::ValueError( + + let pat = if let Some(p) = pat.as_ref() { + p + } else if let Some(input_p) = transformer_input { + input_p + } else { + Err(TransformerError::ValueError( "Transformer is missing an expression to act on.".to_owned(), - ) - })?; + ))? + }; let mut handle = workspace.new_atom(); - pat.substitute_wildcards(workspace, &mut handle, match_stack)?; + pat.substitute_wildcards(workspace, &mut handle, match_stack, transformer_input)?; Transformer::execute_chain(handle.as_view(), ts, workspace, out)?; } @@ -1362,21 +1403,21 @@ pub trait Evaluate { type State<'a>; /// Evaluate a condition. - fn evaluate<'a>(&self, state: &Self::State<'a>) -> ConditionResult; + fn evaluate<'a>(&self, state: &Self::State<'a>) -> Result; } impl Evaluate for Condition { type State<'a> = T::State<'a>; - fn evaluate(&self, state: &T::State<'_>) -> ConditionResult { - match self { - Condition::And(a) => a.0.evaluate(state) & a.1.evaluate(state), - Condition::Or(o) => o.0.evaluate(state) | o.1.evaluate(state), - Condition::Not(n) => !n.evaluate(state), + fn evaluate(&self, state: &T::State<'_>) -> Result { + Ok(match self { + Condition::And(a) => a.0.evaluate(state)? & a.1.evaluate(state)?, + Condition::Or(o) => o.0.evaluate(state)? | o.1.evaluate(state)?, + Condition::Not(n) => !n.evaluate(state)?, Condition::True => ConditionResult::True, Condition::False => ConditionResult::False, - Condition::Yield(t) => t.evaluate(state), - } + Condition::Yield(t) => t.evaluate(state)?, + }) } } @@ -1465,14 +1506,120 @@ impl From for ConditionResult { } } +impl ConditionResult { + pub fn is_true(&self) -> bool { + matches!(self, ConditionResult::True) + } + + pub fn is_false(&self) -> bool { + matches!(self, ConditionResult::False) + } + + pub fn is_inconclusive(&self) -> bool { + matches!(self, ConditionResult::Inconclusive) + } +} + +#[derive(Clone, Debug)] +pub enum Relation { + Eq(Pattern, Pattern), + Ne(Pattern, Pattern), + Gt(Pattern, Pattern), + Ge(Pattern, Pattern), + Lt(Pattern, Pattern), + Le(Pattern, Pattern), + Contains(Pattern, Pattern), + IsType(Pattern, AtomType), +} + +impl std::fmt::Display for Relation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Relation::Eq(a, b) => write!(f, "{} == {}", a, b), + Relation::Ne(a, b) => write!(f, "{} != {}", a, b), + Relation::Gt(a, b) => write!(f, "{} > {}", a, b), + Relation::Ge(a, b) => write!(f, "{} >= {}", a, b), + Relation::Lt(a, b) => write!(f, "{} < {}", a, b), + Relation::Le(a, b) => write!(f, "{} <= {}", a, b), + Relation::Contains(a, b) => write!(f, "{} contains {}", a, b), + Relation::IsType(a, b) => write!(f, "{} is type {:?}", a, b), + } + } +} + +impl Evaluate for Relation { + type State<'a> = Option>; + + fn evaluate(&self, state: &Option) -> Result { + Workspace::get_local().with(|ws| { + let mut out1 = ws.new_atom(); + let mut out2 = ws.new_atom(); + let c = Condition::default(); + let s = MatchSettings::default(); + let m = MatchStack::new(&c, &s); + let pat = state.map(|x| x.into_pattern()); + + Ok(match self { + Relation::Eq(a, b) + | Relation::Ne(a, b) + | Relation::Gt(a, b) + | Relation::Ge(a, b) + | Relation::Lt(a, b) + | Relation::Le(a, b) + | Relation::Contains(a, b) => { + a.substitute_wildcards(ws, &mut out1, &m, pat.as_ref()) + .map_err(|e| match e { + TransformerError::Interrupt => "Interrupted by user".into(), + TransformerError::ValueError(v) => v, + })?; + b.substitute_wildcards(ws, &mut out2, &m, pat.as_ref()) + .map_err(|e| match e { + TransformerError::Interrupt => "Interrupted by user".into(), + TransformerError::ValueError(v) => v, + })?; + + match self { + Relation::Eq(_, _) => out1 == out2, + Relation::Ne(_, _) => out1 != out2, + Relation::Gt(_, _) => out1.as_view() > out2.as_view(), + Relation::Ge(_, _) => out1.as_view() >= out2.as_view(), + Relation::Lt(_, _) => out1.as_view() < out2.as_view(), + Relation::Le(_, _) => out1.as_view() <= out2.as_view(), + Relation::Contains(_, _) => out1.contains(out2.as_view()), + _ => unreachable!(), + } + } + Relation::IsType(a, b) => { + a.substitute_wildcards(ws, &mut out1, &m, pat.as_ref()) + .map_err(|e| match e { + TransformerError::Interrupt => "Interrupted by user".into(), + TransformerError::ValueError(v) => v, + })?; + + match out1.as_ref() { + Atom::Var(_) => (*b == AtomType::Var).into(), + Atom::Fun(_) => (*b == AtomType::Fun).into(), + Atom::Num(_) => (*b == AtomType::Num).into(), + Atom::Add(_) => (*b == AtomType::Add).into(), + Atom::Mul(_) => (*b == AtomType::Mul).into(), + Atom::Pow(_) => (*b == AtomType::Pow).into(), + Atom::Zero => (*b == AtomType::Num).into(), + } + } + } + .into()) + }) + } +} + impl Evaluate for Condition { type State<'a> = MatchStack<'a, 'a>; - fn evaluate(&self, state: &MatchStack) -> ConditionResult { - match self { - Condition::And(a) => a.0.evaluate(state) & a.1.evaluate(state), - Condition::Or(o) => o.0.evaluate(state) | o.1.evaluate(state), - Condition::Not(n) => !n.evaluate(state), + fn evaluate(&self, state: &MatchStack) -> Result { + Ok(match self { + Condition::And(a) => a.0.evaluate(state)? & a.1.evaluate(state)?, + Condition::Or(o) => o.0.evaluate(state)? | o.1.evaluate(state)?, + Condition::Not(n) => !n.evaluate(state)?, Condition::True => ConditionResult::True, Condition::False => ConditionResult::False, Condition::Yield(t) => match t { @@ -1507,7 +1654,7 @@ impl Evaluate for Condition { { f(value, value2) } else { - return ConditionResult::Inconclusive; + return Ok(ConditionResult::Inconclusive); } } WildcardRestriction::NotGreedy => true, @@ -1519,7 +1666,7 @@ impl Evaluate for Condition { } PatternRestriction::MatchStack(mf) => mf(state), }, - } + }) } } @@ -3125,7 +3272,7 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> { match self.rhs { PatternOrMap::Pattern(p) => { - p.substitute_wildcards(ws, &mut new_rhs, pattern_match.match_stack) + p.substitute_wildcards(ws, &mut new_rhs, pattern_match.match_stack, None) .unwrap(); // TODO: escalate? } PatternOrMap::Map(f) => { diff --git a/src/transformer.rs b/src/transformer.rs index e0a152c..e6c2979 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -1,11 +1,14 @@ -use std::{sync::Arc, time::Instant}; +use std::{ops::ControlFlow, sync::Arc, time::Instant}; use crate::{ atom::{representation::FunView, Atom, AtomOrView, AtomView, Fun, Symbol}, coefficient::{Coefficient, CoefficientView}, combinatorics::{partitions, unique_permutations}, domains::rational::Rational, - id::{Condition, MatchSettings, Pattern, PatternOrMap, PatternRestriction, Replacement}, + id::{ + Condition, Evaluate, MatchSettings, Pattern, PatternOrMap, PatternRestriction, Relation, + Replacement, + }, printer::{AtomPrinter, PrintOptions}, state::{RecycledAtom, State, Workspace}, }; @@ -106,6 +109,9 @@ pub enum TransformerError { /// Operations that take a pattern as the input and produce an expression #[derive(Clone)] pub enum Transformer { + IfElse(Condition, Vec, Vec), + IfChanged(Vec, Vec, Vec), + BreakChain, /// Expand the rhs. Expand(Option, bool), /// Distribute numbers. @@ -168,6 +174,9 @@ pub enum Transformer { impl std::fmt::Debug for Transformer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Transformer::IfElse(_, _, _) => f.debug_tuple("IfElse").finish(), + Transformer::IfChanged(_, _, _) => f.debug_tuple("IfChanged").finish(), + Transformer::BreakChain => f.debug_tuple("BreakChain").finish(), Transformer::Expand(s, _) => f.debug_tuple("Expand").field(s).finish(), Transformer::ExpandNum => f.debug_tuple("ExpandNum").finish(), Transformer::Derivative(x) => f.debug_tuple("Derivative").field(x).finish(), @@ -415,17 +424,17 @@ impl Transformer { input: AtomView<'_>, workspace: &Workspace, out: &mut Atom, - ) -> Result<(), TransformerError> { + ) -> Result, TransformerError> { Transformer::execute_chain(input, std::slice::from_ref(self), workspace, out) } - /// Apply a chain of transformers to `orig_input`. + /// Apply a chain of transformers to `input`. pub fn execute_chain( input: AtomView<'_>, chain: &[Transformer], workspace: &Workspace, out: &mut Atom, - ) -> Result<(), TransformerError> { + ) -> Result, TransformerError> { out.set_from_view(&input); let mut tmp = workspace.new_atom(); for t in chain { @@ -433,6 +442,39 @@ impl Transformer { let cur_input = tmp.as_view(); match t { + Transformer::IfElse(cond, t1, t2) => { + if cond + .evaluate(&Some(cur_input)) + .map_err(|e| TransformerError::ValueError(e))? + .is_true() + { + if Transformer::execute_chain(cur_input, t1, workspace, out)?.is_break() { + return Ok(ControlFlow::Break(())); + } + } else if Transformer::execute_chain(cur_input, t2, workspace, out)?.is_break() + { + return Ok(ControlFlow::Break(())); + } + } + Transformer::IfChanged(cond, t1, t2) => { + Transformer::execute_chain(cur_input, cond, workspace, out)?; + std::mem::swap(out, &mut tmp); + + if tmp.as_view() != out.as_view() { + if Transformer::execute_chain(tmp.as_view(), t1, workspace, out)?.is_break() + { + return Ok(ControlFlow::Break(())); + } + } else if Transformer::execute_chain(tmp.as_view(), t2, workspace, out)? + .is_break() + { + return Ok(ControlFlow::Break(())); + } + } + Transformer::BreakChain => { + std::mem::swap(out, &mut tmp); + return Ok(ControlFlow::Break(())); + } Transformer::Map(f) => { f(cur_input, out)?; } @@ -504,7 +546,7 @@ impl Transformer { let key_map = key_map.clone(); Some(Box::new(move |i, o| { Workspace::get_local() - .with(|ws| Self::execute_chain(i, &key_map, ws, o).unwrap()) + .with(|ws| Self::execute_chain(i, &key_map, ws, o).unwrap()); })) }, if coeff_map.is_empty() { @@ -513,7 +555,7 @@ impl Transformer { let coeff_map = coeff_map.clone(); Some(Box::new(move |i, o| { Workspace::get_local() - .with(|ws| Self::execute_chain(i, &coeff_map, ws, o).unwrap()) + .with(|ws| Self::execute_chain(i, &coeff_map, ws, o).unwrap()); })) }, out, @@ -832,7 +874,7 @@ impl Transformer { } } - Ok(()) + Ok(ControlFlow::Continue(())) } } diff --git a/symbolica.pyi b/symbolica.pyi index 43ce8da..8036805 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -1576,6 +1576,85 @@ class Transformer: >>> e = Transformer().expand()((1+x)**2) """ + def __eq__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. + """ + + def __neq__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. + """ + + def __lt__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. + """ + + def __le__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. + """ + + def __gt__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. + """ + + def __ge__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. + """ + + def contains(self, element: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Create a transformer that checks if the expression contains the given `element`. + """ + + def if_then(self, condition: Condition, if_block: Transformer, else_block: Optional[Transformer] = None) -> Transformer: + """Evaluate the condition and apply the `if_block` if the condition is true, otherwise apply the `else_block`. + The expression that is the input of the transformer is the input for the condition, the `if_block` and the `else_block`. + + Examples + -------- + >>> t = T.map_terms(T.if_then(T.contains(x), T.print())) + >>> t(x + y + 4) + + prints `x`. + """ + + def if_changed(self, condition: Transformer, if_block: Transformer, else_block: Optional[Transformer] = None) -> Transformer: + """Execute the `condition` transformer. If the result of the `condition` transformer is different from the input expression, + apply the `if_block`, otherwise apply the `else_block`. The input expression of the `if_block` is the output + of the `condition` transformer. + + Examples + -------- + >>> t = T.map_terms(T.if_changed(T.replace_all(x, y), T.print())) + >>> print(t(x + y + 4)) + + prints + ``` + y + 2*y+4 + ``` + """ + + def break_chain(self) -> Transformer: + """Break the current chain and all higher-level chains containing `if` transformers. + + Examples + -------- + >>> from symbolica import * + >>> t = T.map_terms(T.repeat( + >>> T.replace_all(y, 4), + >>> T.if_changed(T.replace_all(x, y), + >>> T.break_chain()), + >>> T.print() # print of y is never reached + >>> )) + >>> print(t(x)) + """ + def expand(self, var: Optional[Expression] = None, via_poly: Optional[bool] = None) -> Transformer: """Create a transformer that expands products and powers. Optionally, expand in `var` only.