diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 58fb8a9c..06babe19 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -1,6 +1,6 @@ use symbolica::{ atom::{Atom, AtomView}, - id::{Match, Pattern, PatternRestriction}, + id::{Match, Pattern, WildcardRestriction}, state::{RecycledAtom, State}, }; @@ -15,7 +15,7 @@ fn main() { // prepare the pattern restriction `x_ > 1` let restrictions = ( State::get_symbol("x_"), - PatternRestriction::Filter(Box::new(|v: &Match| match v { + WildcardRestriction::Filter(Box::new(|v: &Match| match v { Match::Single(AtomView::Num(n)) => !n.is_one() && !n.is_zero(), _ => false, })), diff --git a/examples/pattern_restrictions.rs b/examples/pattern_restrictions.rs index ad4a5287..243f44f8 100644 --- a/examples/pattern_restrictions.rs +++ b/examples/pattern_restrictions.rs @@ -2,7 +2,7 @@ use symbolica::{ atom::{Atom, AtomView}, coefficient::CoefficientView, domains::finite_field, - id::{Condition, Match, MatchSettings, PatternRestriction}, + id::{Condition, Match, MatchSettings, WildcardRestriction}, state::State, }; fn main() { @@ -16,11 +16,11 @@ fn main() { let z = State::get_symbol("z__"); let w = State::get_symbol("w__"); - let conditions = Condition::from((x, PatternRestriction::Length(0, Some(2)))) - & (y, PatternRestriction::Length(0, Some(4))) + let conditions = Condition::from((x, WildcardRestriction::Length(0, Some(2)))) + & (y, WildcardRestriction::Length(0, Some(4))) & ( y, - PatternRestriction::Cmp( + WildcardRestriction::Cmp( x, Box::new(|y, x| { let len_x = match x { @@ -37,7 +37,7 @@ fn main() { ) & ( z, - PatternRestriction::Filter(Box::new(|x: &Match| { + WildcardRestriction::Filter(Box::new(|x: &Match| { if let Match::Single(AtomView::Num(num)) = x { if let CoefficientView::Natural(x, y) = num.get_coeff_view() { y == 1 && x > 0 && finite_field::is_prime_u64(x as u64) @@ -49,7 +49,7 @@ fn main() { } })), ) - & (w, PatternRestriction::Length(0, None)); + & (w, WildcardRestriction::Length(0, None)); let settings = MatchSettings::default(); println!( diff --git a/src/api/python.rs b/src/api/python.rs index 6e62d655..8069e0e8 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -47,8 +47,9 @@ use crate::{ }, graph::Graph, id::{ - Condition, Match, MatchSettings, MatchStack, Pattern, PatternAtomTreeIterator, - PatternOrMap, PatternRestriction, ReplaceIterator, Replacement, WildcardAndRestriction, + Condition, ConditionResult, Match, MatchSettings, MatchStack, Pattern, + PatternAtomTreeIterator, PatternOrMap, PatternRestriction, ReplaceIterator, Replacement, + WildcardRestriction, }, numerical_integration::{ContinuousGrid, DiscreteGrid, Grid, MonteCarloRng, Sample}, parser::Token, @@ -1493,11 +1494,11 @@ impl Deref for PythonExpression { #[pyclass(name = "PatternRestriction", module = "symbolica")] #[derive(Clone)] pub struct PythonPatternRestriction { - pub condition: Condition, + pub condition: Condition, } -impl From> for PythonPatternRestriction { - fn from(condition: Condition) -> Self { +impl From> for PythonPatternRestriction { + fn from(condition: Condition) -> Self { PythonPatternRestriction { condition } } } @@ -1518,6 +1519,44 @@ impl PythonPatternRestriction { pub fn __invert__(&self) -> PythonPatternRestriction { (!self.condition.clone()).into() } + + /// Create a pattern restriction based on the current matched variables. + /// `match_fn` is a Python function that takes a dictionary of wildcards and their matched values + /// and should return an integer. If the integer is less than 0, the restriction is false. + /// If the integer is 0, the restriction is inconclusive. + /// If the integer is greater than 0, the restriction is true. + /// + /// If your pattern restriction cannot decide if it holds since not all the required variables + /// have been matched, it should return inclusive (0). + #[classmethod] + pub fn req_matches(_cls: &PyType, match_fn: PyObject) -> PyResult { + Ok(PythonPatternRestriction { + condition: PatternRestriction::MatchStack(Box::new(move |m| { + let matches: HashMap = m + .get_matches() + .iter() + .map(|(s, t)| (Atom::new_var(*s).into(), t.to_atom().into())) + .collect(); + + let r = Python::with_gil(|py| { + match_fn + .call(py, (matches,), None) + .expect("Bad callback function") + .extract::(py) + .expect("Pattern comparison does not return an integer") + }); + + if r < 0 { + false.into() + } else if r == 0 { + ConditionResult::Inconclusive + } else { + true.into() + } + })) + .into(), + }) + } } impl<'a> FromPyObject<'a> for ConvertibleToExpression { @@ -1711,7 +1750,7 @@ macro_rules! req_cmp { Ok(PythonPatternRestriction { condition: ( name, - PatternRestriction::Filter(Box::new(move |v: &Match| { + WildcardRestriction::Filter(Box::new(move |v: &Match| { let k = num.expr.as_view(); if let Match::Single(m) = v { @@ -1776,7 +1815,7 @@ macro_rules! req_wc_cmp { Ok(PythonPatternRestriction { condition: ( id, - PatternRestriction::Cmp( + WildcardRestriction::Cmp( other_id, Box::new(move |m1: &Match, m2: &Match| { if let Match::Single(a1) = m1 { @@ -2506,7 +2545,7 @@ impl PythonExpression { } Ok(PythonPatternRestriction { - condition: (name, PatternRestriction::Length(min_length, max_length)).into(), + condition: (name, WildcardRestriction::Length(min_length, max_length)).into(), }) } _ => Err(exceptions::PyTypeError::new_err( @@ -2540,7 +2579,7 @@ impl PythonExpression { Ok(PythonPatternRestriction { condition: ( name, - PatternRestriction::IsAtomType(match atom_type { + WildcardRestriction::IsAtomType(match atom_type { PythonAtomType::Num => AtomType::Num, PythonAtomType::Var => AtomType::Var, PythonAtomType::Add => AtomType::Add, @@ -2571,7 +2610,7 @@ impl PythonExpression { } Ok(PythonPatternRestriction { - condition: (name, PatternRestriction::IsLiteralWildcard(name)).into(), + condition: (name, WildcardRestriction::IsLiteralWildcard(name)).into(), }) } _ => Err(exceptions::PyTypeError::new_err( @@ -2741,7 +2780,7 @@ impl PythonExpression { Ok(PythonPatternRestriction { condition: ( id, - PatternRestriction::Filter(Box::new(move |m| { + WildcardRestriction::Filter(Box::new(move |m| { let data: PythonExpression = m.to_atom().into(); Python::with_gil(|py| { @@ -2901,7 +2940,7 @@ impl PythonExpression { Ok(PythonPatternRestriction { condition: ( id, - PatternRestriction::Cmp( + WildcardRestriction::Cmp( other_id, Box::new(move |m1, m2| { let data1: PythonExpression = m1.to_atom().into(); @@ -4284,7 +4323,7 @@ impl PythonExpression { pub struct PythonReplacement { pattern: Pattern, rhs: PatternOrMap, - cond: Condition, + cond: Condition, settings: MatchSettings, } @@ -4756,12 +4795,7 @@ impl PythonAtomIterator { } } -type OwnedMatch = ( - Pattern, - Atom, - Condition, - MatchSettings, -); +type OwnedMatch = (Pattern, Atom, Condition, MatchSettings); type MatchIterator<'a> = PatternAtomTreeIterator<'a, 'a>; self_cell!( @@ -4798,7 +4832,7 @@ type OwnedReplace = ( Pattern, Atom, PatternOrMap, - Condition, + Condition, MatchSettings, ); type ReplaceIteratorOne<'a> = ReplaceIterator<'a, 'a>; diff --git a/src/atom.rs b/src/atom.rs index aeb8934b..59f0c43d 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -616,6 +616,18 @@ impl Atom { self.as_view().is_one() } + /// Repeatedly apply an operation on the atom until the atom no longer changes. + pub fn repeat_map Atom>(&mut self, op: F) { + let mut res; + loop { + res = op(self.as_view()); + if res == *self { + break; + } + std::mem::swap(self, &mut res); + } + } + #[inline] pub fn to_num(&mut self, coeff: Coefficient) -> &mut Num { let buffer = std::mem::replace(self, Atom::Zero).into_raw(); diff --git a/src/id.rs b/src/id.rs index 6439461c..64d56b9e 100644 --- a/src/id.rs +++ b/src/id.rs @@ -56,7 +56,7 @@ impl std::fmt::Debug for PatternOrMap { pub struct Replacement<'a> { pat: &'a Pattern, rhs: &'a PatternOrMap, - conditions: Option<&'a Condition>, + conditions: Option<&'a Condition>, settings: Option<&'a MatchSettings>, } @@ -70,7 +70,7 @@ impl<'a> Replacement<'a> { } } - pub fn with_conditions(mut self, conditions: &'a Condition) -> Self { + pub fn with_conditions(mut self, conditions: &'a Condition) -> Self { self.conditions = Some(conditions); self } @@ -109,7 +109,7 @@ impl Atom { &self, pattern: &Pattern, rhs: &PatternOrMap, - conditions: Option<&Condition>, + conditions: Option<&Condition>, settings: Option<&MatchSettings>, ) -> Atom { self.as_view() @@ -121,7 +121,7 @@ impl Atom { &self, pattern: &Pattern, rhs: &PatternOrMap, - conditions: Option<&Condition>, + conditions: Option<&Condition>, settings: Option<&MatchSettings>, out: &mut Atom, ) -> bool { @@ -273,7 +273,7 @@ impl<'a> AtomView<'a> { &self, pattern: &Pattern, rhs: &PatternOrMap, - conditions: Option<&Condition>, + conditions: Option<&Condition>, settings: Option<&MatchSettings>, ) -> Atom { pattern.replace_all(*self, rhs, conditions, settings) @@ -284,7 +284,7 @@ impl<'a> AtomView<'a> { &self, pattern: &Pattern, rhs: &PatternOrMap, - conditions: Option<&Condition>, + conditions: Option<&Condition>, settings: Option<&MatchSettings>, out: &mut Atom, ) -> bool { @@ -1098,7 +1098,7 @@ impl Pattern { &'a self, target: AtomView<'a>, rhs: &'a PatternOrMap, - conditions: &'a Condition, + conditions: &'a Condition, settings: &'a MatchSettings, ) -> ReplaceIterator<'a, 'a> { ReplaceIterator::new(self, target, rhs, conditions, settings) @@ -1110,7 +1110,7 @@ impl Pattern { &self, target: AtomView<'_>, rhs: &PatternOrMap, - conditions: Option<&Condition>, + conditions: Option<&Condition>, settings: Option<&MatchSettings>, ) -> Atom { Workspace::get_local().with(|ws| { @@ -1126,7 +1126,7 @@ impl Pattern { &self, target: AtomView<'_>, rhs: &PatternOrMap, - conditions: Option<&Condition>, + conditions: Option<&Condition>, settings: Option<&MatchSettings>, out: &mut Atom, ) -> bool { @@ -1141,7 +1141,7 @@ impl Pattern { target: AtomView<'_>, rhs: &PatternOrMap, workspace: &Workspace, - conditions: Option<&Condition>, + conditions: Option<&Condition>, settings: Option<&MatchSettings>, out: &mut Atom, ) -> bool { @@ -1167,7 +1167,7 @@ impl Pattern { pub fn pattern_match<'a: 'b, 'b>( &'b self, target: AtomView<'a>, - conditions: &'b Condition, + conditions: &'b Condition, settings: &'b MatchSettings, ) -> PatternAtomTreeIterator<'a, 'b> { PatternAtomTreeIterator::new(self, target, conditions, settings) @@ -1188,20 +1188,24 @@ impl std::fmt::Debug for Pattern { } } -pub trait FilterFn: for<'a, 'b> Fn(&'a Match<'b>) -> bool + DynClone + Send + Sync {} +pub trait FilterFn: Fn(&Match) -> bool + DynClone + Send + Sync {} dyn_clone::clone_trait_object!(FilterFn); -impl Fn(&'a Match<'b>) -> bool> FilterFn for T {} +impl bool> FilterFn for T {} -pub trait CmpFn: for<'a, 'b> Fn(&Match<'_>, &Match<'_>) -> bool + DynClone + Send + Sync {} +pub trait CmpFn: Fn(&Match, &Match) -> bool + DynClone + Send + Sync {} dyn_clone::clone_trait_object!(CmpFn); -impl Fn(&Match<'_>, &Match<'_>) -> bool> CmpFn for T {} +impl bool> CmpFn for T {} + +pub trait MatchStackFn: Fn(&MatchStack) -> ConditionResult + DynClone + Send + Sync {} +dyn_clone::clone_trait_object!(MatchStackFn); +impl ConditionResult> MatchStackFn for T {} /// Restrictions for a wildcard. Note that a length restriction /// applies at any level and therefore /// `x_*f(x_) : length(x) == 2` /// does not match to `x*y*f(x*y)`, since the pattern `x_` has length /// 1 inside the function argument. -pub enum PatternRestriction { +pub enum WildcardRestriction { Length(usize, Option), // min-max range IsAtomType(AtomType), IsLiteralWildcard(Symbol), @@ -1210,7 +1214,38 @@ pub enum PatternRestriction { NotGreedy, } -pub type WildcardAndRestriction = (Symbol, PatternRestriction); +pub type WildcardAndRestriction = (Symbol, WildcardRestriction); + +pub enum PatternRestriction { + /// A restriction for a wildcard. + Wildcard(WildcardAndRestriction), + /// A function that checks if the restriction is met based on the currently matched wildcards. + /// If more information is needed to test the restriction, the function should return `Inconclusive`. + MatchStack(Box), +} + +impl Clone for PatternRestriction { + fn clone(&self) -> Self { + match self { + PatternRestriction::Wildcard(w) => PatternRestriction::Wildcard(w.clone()), + PatternRestriction::MatchStack(f) => { + PatternRestriction::MatchStack(dyn_clone::clone_box(f)) + } + } + } +} + +impl From for PatternRestriction { + fn from(value: WildcardAndRestriction) -> Self { + PatternRestriction::Wildcard(value) + } +} + +impl From for Condition { + fn from(value: WildcardAndRestriction) -> Self { + PatternRestriction::Wildcard(value).into() + } +} /// A logical expression. #[derive(Clone, Debug, Default)] @@ -1309,7 +1344,7 @@ impl From for ConditionResult { } } -impl Condition { +impl Condition { /// Check if the conditions on `var` are met fn check_possible(&self, var: Symbol, value: &Match, stack: &MatchStack) -> ConditionResult { match self { @@ -1322,10 +1357,17 @@ impl Condition { Condition::Not(n) => !n.check_possible(var, value, stack), Condition::True => ConditionResult::True, Condition::False => ConditionResult::False, - Condition::Yield((v, r)) => { + Condition::Yield(restriction) => { + let (v, r) = match restriction { + PatternRestriction::Wildcard((v, r)) => (v, r), + PatternRestriction::MatchStack(mf) => { + return mf(stack); + } + }; + if *v != var { match r { - PatternRestriction::Cmp(v, _) if *v == var => {} + WildcardRestriction::Cmp(v, _) if *v == var => {} _ => { return ConditionResult::Inconclusive; } @@ -1333,7 +1375,7 @@ impl Condition { } match r { - PatternRestriction::IsAtomType(t) => { + WildcardRestriction::IsAtomType(t) => { let is_type = match t { AtomType::Num => matches!(value, Match::Single(AtomView::Num(_))), AtomType::Var => matches!(value, Match::Single(AtomView::Var(_))), @@ -1355,16 +1397,16 @@ impl Condition { AtomType::Fun => matches!(value, Match::Single(AtomView::Fun(_))), }; - (is_type == matches!(r, PatternRestriction::IsAtomType(_))).into() + (is_type == matches!(r, WildcardRestriction::IsAtomType(_))).into() } - PatternRestriction::IsLiteralWildcard(wc) => { + WildcardRestriction::IsLiteralWildcard(wc) => { if let Match::Single(AtomView::Var(v)) = value { (wc == &v.get_symbol()).into() } else { false.into() } } - PatternRestriction::Length(min, max) => match &value { + WildcardRestriction::Length(min, max) => match &value { Match::Single(_) | Match::FunctionName(_) => { (*min <= 1 && max.map(|m| m >= 1).unwrap_or(true)).into() } @@ -1372,8 +1414,8 @@ impl Condition { && max.map(|m| m >= slice.len()).unwrap_or(true)) .into(), }, - PatternRestriction::Filter(f) => f(value).into(), - PatternRestriction::Cmp(v2, f) => { + WildcardRestriction::Filter(f) => f(value).into(), + WildcardRestriction::Cmp(v2, f) => { if *v == var { if let Some((_, value2)) = stack.stack.iter().find(|(k, _)| k == v2) { f(value, value2).into() @@ -1381,12 +1423,13 @@ impl Condition { ConditionResult::Inconclusive } } else if let Some((_, value2)) = stack.stack.iter().find(|(k, _)| k == v) { + // var == v2 at this point f(value2, value).into() } else { ConditionResult::Inconclusive } } - PatternRestriction::NotGreedy => true.into(), + WildcardRestriction::NotGreedy => true.into(), } } } @@ -1437,17 +1480,24 @@ impl Condition { (None, None) } Condition::True | Condition::False => (None, None), - Condition::Yield((v, r)) => { + Condition::Yield(restriction) => { + let (v, r) = match restriction { + PatternRestriction::Wildcard((v, r)) => (v, r), + PatternRestriction::MatchStack(_) => { + return (None, None); + } + }; + if *v != var { return (None, None); } match r { - PatternRestriction::Length(min, max) => (Some(*min), *max), - PatternRestriction::IsAtomType( + WildcardRestriction::Length(min, max) => (Some(*min), *max), + WildcardRestriction::IsAtomType( AtomType::Var | AtomType::Num | AtomType::Fun, ) - | PatternRestriction::IsLiteralWildcard(_) => (Some(1), Some(1)), + | WildcardRestriction::IsLiteralWildcard(_) => (Some(1), Some(1)), _ => (None, None), } } @@ -1455,7 +1505,7 @@ impl Condition { } } -impl Clone for PatternRestriction { +impl Clone for WildcardRestriction { fn clone(&self) -> Self { match self { Self::Length(min, max) => Self::Length(*min, *max), @@ -1468,7 +1518,7 @@ impl Clone for PatternRestriction { } } -impl std::fmt::Debug for PatternRestriction { +impl std::fmt::Debug for WildcardRestriction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Length(arg0, arg1) => f.debug_tuple("Length").field(arg0).field(arg1).finish(), @@ -1483,6 +1533,15 @@ impl std::fmt::Debug for PatternRestriction { } } +impl std::fmt::Debug for PatternRestriction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PatternRestriction::Wildcard(arg0) => f.debug_tuple("Wildcard").field(arg0).finish(), + PatternRestriction::MatchStack(_) => f.debug_tuple("Match").finish(), + } + } +} + /// A part of an expression that was matched to a wildcard. #[derive(Clone, PartialEq)] pub enum Match<'a> { @@ -1622,7 +1681,7 @@ pub struct MatchSettings { /// before inserting. pub struct MatchStack<'a, 'b> { stack: Vec<(Symbol, Match<'a>)>, - conditions: &'b Condition, + conditions: &'b Condition, settings: &'b MatchSettings, } @@ -1651,7 +1710,7 @@ impl<'a, 'b> std::fmt::Debug for MatchStack<'a, 'b> { impl<'a, 'b> MatchStack<'a, 'b> { /// Create a new match stack. pub fn new( - conditions: &'b Condition, + conditions: &'b Condition, settings: &'b MatchSettings, ) -> MatchStack<'a, 'b> { MatchStack { @@ -1677,12 +1736,17 @@ impl<'a, 'b> MatchStack<'a, 'b> { // test whether the current value passes all conditions // or returns an inconclusive result - if self.conditions.check_possible(key, &value, self) == ConditionResult::False { - return None; - } - self.stack.push((key, value)); - Some(self.stack.len() - 1) + if self + .conditions + .check_possible(key, &self.stack.last().unwrap().1, self) + == ConditionResult::False + { + self.stack.pop(); + None + } else { + Some(self.stack.len() - 1) + } } /// Get the mapped value for the wildcard `key`. @@ -2548,7 +2612,7 @@ impl<'a: 'b, 'b> PatternAtomTreeIterator<'a, 'b> { pub fn new( pattern: &'b Pattern, target: AtomView<'a>, - conditions: &'b Condition, + conditions: &'b Condition, settings: &'b MatchSettings, ) -> PatternAtomTreeIterator<'a, 'b> { PatternAtomTreeIterator { @@ -2615,7 +2679,7 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> { pattern: &'b Pattern, target: AtomView<'a>, rhs: &'b PatternOrMap, - conditions: &'a Condition, + conditions: &'a Condition, settings: &'a MatchSettings, ) -> ReplaceIterator<'a, 'b> { ReplaceIterator { @@ -2785,7 +2849,10 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> { mod test { use crate::{ atom::Atom, - id::{MatchSettings, PatternOrMap, Replacement}, + id::{ + ConditionResult, MatchSettings, PatternOrMap, PatternRestriction, Replacement, + WildcardRestriction, + }, state::State, }; @@ -2862,4 +2929,59 @@ mod test { let res = Atom::parse("v1(mu2)*v2(mu3)").unwrap(); assert_eq!(r, res); } + + #[test] + fn repeat_replace() { + let mut a = Atom::parse("f(10)").unwrap(); + let p1 = Pattern::parse("f(v1_)").unwrap(); + let rhs1 = Pattern::parse("f(v1_ - 1)").unwrap().into(); + + let rest = ( + State::get_symbol("v1_"), + WildcardRestriction::Filter(Box::new(|x| { + let n: Result = x.to_atom().try_into(); + if let Ok(y) = n { + y > 0i64 + } else { + false + } + })), + ) + .into(); + + a.repeat_map(|e| e.replace_all(&p1, &rhs1, Some(&rest), None)); + + let res = Atom::parse("f(0)").unwrap(); + assert_eq!(a, res); + } + + #[test] + fn match_stack_filter() { + let a = Atom::parse("f(1,2,3,4)").unwrap(); + let p1 = Pattern::parse("f(v1_,v2_,v3_,v4_)").unwrap(); + let rhs1 = Pattern::parse("f(v4_,v3_,v2_,v1_)").unwrap().into(); + + let rest = PatternRestriction::MatchStack(Box::new(|m| { + for x in m.get_matches().windows(2) { + if x[0].1.to_atom() >= x[1].1.to_atom() { + return false.into(); + } + } + + if m.get_matches().len() == 4 { + true.into() + } else { + ConditionResult::Inconclusive + } + })) + .into(); + + let r = a.replace_all(&p1, &rhs1, Some(&rest), None); + let res = Atom::parse("f(4,3,2,1)").unwrap(); + assert_eq!(r, res); + + let b = Atom::parse("f(1,2,4,3)").unwrap(); + let r = b.replace_all(&p1, &rhs1, Some(&rest), None); + assert_eq!(r, b); + } } diff --git a/src/streaming.rs b/src/streaming.rs index 69ed8742..0bb6cd36 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -589,7 +589,7 @@ mod test { use crate::{ atom::{Atom, AtomType}, - id::{Pattern, PatternRestriction}, + id::{Pattern, WildcardRestriction}, state::State, streaming::{TermStreamer, TermStreamerConfig}, }; @@ -648,7 +648,7 @@ mod test { Some( &( State::get_symbol("v1_"), - PatternRestriction::IsAtomType(AtomType::Var), + WildcardRestriction::IsAtomType(AtomType::Var), ) .into(), ), diff --git a/src/transformer.rs b/src/transformer.rs index 233e9dfa..5207fd25 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -5,7 +5,7 @@ use crate::{ coefficient::{Coefficient, CoefficientView}, combinatorics::{partitions, unique_permutations}, domains::rational::Rational, - id::{Condition, MatchSettings, Pattern, PatternOrMap, Replacement, WildcardAndRestriction}, + id::{Condition, MatchSettings, Pattern, PatternOrMap, PatternRestriction, Replacement}, printer::{AtomPrinter, PrintOptions}, state::{RecycledAtom, State, Workspace}, }; @@ -73,7 +73,7 @@ pub enum Transformer { ReplaceAll( Pattern, PatternOrMap, - Condition, + Condition, MatchSettings, ), /// Apply multiple find-and-replace on the lhs. @@ -81,7 +81,7 @@ pub enum Transformer { Vec<( Pattern, PatternOrMap, - Condition, + Condition, MatchSettings, )>, ), @@ -290,6 +290,17 @@ impl FunView<'_> { } else { mul.extend(a); } + } else if let AtomView::Pow(p) = a { + if let AtomView::Var(v) = p.get_base() { + let s = v.get_symbol(); + if symbols.map(|x| x.contains(&s)).unwrap_or(false) { + c.extend(a); + } else { + mul.extend(a); + } + } else { + mul.extend(a); + } } else { mul.extend(a); } @@ -777,7 +788,7 @@ impl Transformer { mod test { use crate::{ atom::{Atom, FunctionBuilder}, - id::{Condition, Match, MatchSettings, Pattern, PatternRestriction}, + id::{Condition, Match, MatchSettings, Pattern, WildcardRestriction}, printer::PrintOptions, state::{State, Workspace}, transformer::StatsOptions, @@ -903,7 +914,7 @@ mod test { Pattern::parse("x_-1").unwrap().into(), ( State::get_symbol("x_"), - PatternRestriction::Filter(Box::new(|x| { + WildcardRestriction::Filter(Box::new(|x| { x != &Match::Single(Atom::new_num(0).as_view()) })), ) @@ -924,13 +935,14 @@ mod test { #[test] fn linearize() { - let p = Atom::parse("f1(v1+v2,4*v3*v4+3)").unwrap(); + let p = Atom::parse("f1(v1+v2,4*v3*v4+3*v4/v3)").unwrap(); let out = Transformer::Linearize(Some(vec![State::get_symbol("v3")])) .execute(p.as_view()) .unwrap(); - let r = Atom::parse("f1(v1,3)+f1(v2,3)+4*v3*f1(v1,v4)+4*v3*f1(v2,v4)").unwrap(); + let r = Atom::parse("4*v3*f1(v1,v4)+4*v3*f1(v2,v4)+3*v3^-1*f1(v1,v4)+3*v3^-1*f1(v2,v4)") + .unwrap(); assert_eq!(out, r); } diff --git a/symbolica.pyi b/symbolica.pyi index 78407297..71d60ce6 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -1319,6 +1319,38 @@ class PatternRestriction: def __invert__(self) -> PatternRestriction: """Create a new pattern restriction that takes the logical 'not' of the current restriction.""" + @classmethod + def req_matches(_cls, match_fn: Callable[[dict[Expression, Expression]], int]) -> PatternRestriction: + """Create a pattern restriction based on the current matched variables. + `match_fn` is a Python function that takes a dictionary of wildcards and their matched values + and should return an integer. If the integer is less than 0, the restriction is false. + If the integer is 0, the restriction is inconclusive. + If the integer is greater than 0, the restriction is true. + + If your pattern restriction cannot decide if it holds since not all the required variables + have been matched, it should return inclusive (0). + + Examples + -------- + >>> from symbolica import * + >>> f, x_, y_, z_ = S('f', 'x_', 'y_', 'z_') + >>> + >>> def filter(m: dict[Expression, Expression]) -> int: + >>> if x_ in m and y_ in m: + >>> if m[x_] > m[y_]: + >>> return -1 # no match + >>> if z_ in m: + >>> if m[y_] > m[z_]: + >>> return -1 + >>> return 1 # match + >>> + >>> return 0 # inconclusive + >>> + >>> + >>> e = f(1, 2, 3).replace_all(f(x_, y_, z_), 1, + >>> PatternRestriction.req_matches(filter)) + """ + class CompareOp: """One of the following comparison operators: `<`,`>`,`<=`,`>=`,`==`,`!=`.""" diff --git a/tests/pattern_matching.rs b/tests/pattern_matching.rs index b5efaa3e..b419506e 100644 --- a/tests/pattern_matching.rs +++ b/tests/pattern_matching.rs @@ -1,6 +1,6 @@ use symbolica::{ atom::{Atom, AtomView}, - id::{Condition, Match, MatchSettings, Pattern, PatternRestriction}, + id::{Condition, Match, MatchSettings, Pattern, WildcardRestriction}, state::{RecycledAtom, State}, }; @@ -16,7 +16,7 @@ fn fibonacci() { // prepare the pattern restriction `x_ > 1` let restrictions = ( State::get_symbol("x_"), - PatternRestriction::Filter(Box::new(|v: &Match| match v { + WildcardRestriction::Filter(Box::new(|v: &Match| match v { Match::Single(AtomView::Num(n)) => !n.is_one() && !n.is_zero(), _ => false, })),