From e73ea498dcca4a59010c335ccf4a3a22f1ef30e8 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Mon, 2 Dec 2024 12:45:38 +0100 Subject: [PATCH] Boolean queries on Python expressions now yield Conditions - Add conversion from conditions to pattern restrictions - Add contains condition --- src/api/python.rs | 395 ++++++++++++++++++++++++++++++++++++++-------- src/atom.rs | 2 +- src/id.rs | 93 +++++++++++ symbolica.pyi | 76 ++++++--- 4 files changed, 481 insertions(+), 85 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index b99396b..93107d5 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -52,7 +52,7 @@ use crate::{ }, graph::Graph, id::{ - Condition, ConditionResult, Match, MatchSettings, MatchStack, Pattern, + Condition, ConditionResult, Evaluate, Match, MatchSettings, MatchStack, Pattern, PatternAtomTreeIterator, PatternOrMap, PatternRestriction, ReplaceIterator, Replacement, WildcardRestriction, }, @@ -1306,7 +1306,7 @@ impl PythonTransformer { &self, lhs: ConvertibleToPattern, rhs: ConvertibleToPatternOrMap, - cond: Option, + cond: Option, non_greedy_wildcards: Option>, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, @@ -1352,7 +1352,7 @@ impl PythonTransformer { Transformer::ReplaceAll( lhs.to_pattern()?.expr, rhs.to_pattern_or_map()?, - cond.map(|r| r.condition.clone()).unwrap_or_default(), + cond.map(|r| r.0).unwrap_or_default(), settings, ) ); @@ -1711,6 +1711,242 @@ impl PythonPatternRestriction { } } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Relation { + Eq(Atom, Atom), + Ne(Atom, Atom), + Gt(Atom, Atom), + Ge(Atom, Atom), + Lt(Atom, Atom), + Le(Atom, Atom), + Contains(Atom, Atom), + IsType(Atom, AtomType), +} + +impl std::fmt::Display for Relation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Relation::Eq(a, b) => write!(f, "{} == {}", a, b), + Relation::Ne(a, b) => write!(f, "{} != {}", a, b), + Relation::Gt(a, b) => write!(f, "{} > {}", a, b), + Relation::Ge(a, b) => write!(f, "{} >= {}", a, b), + Relation::Lt(a, b) => write!(f, "{} < {}", a, b), + Relation::Le(a, b) => write!(f, "{} <= {}", a, b), + Relation::Contains(a, b) => write!(f, "{} contains {}", a, b), + Relation::IsType(a, b) => write!(f, "{} is type {:?}", a, b), + } + } +} + +impl Evaluate for Relation { + type State<'a> = (); + + fn evaluate(&self, _state: &()) -> ConditionResult { + match self { + Relation::Eq(a, b) => (a == b).into(), + Relation::Ne(a, b) => (a != b).into(), + Relation::Gt(a, b) => (a > b).into(), + Relation::Ge(a, b) => (a >= b).into(), + Relation::Lt(a, b) => (a < b).into(), + Relation::Le(a, b) => (a <= b).into(), + Relation::Contains(a, b) => (a.contains(b)).into(), + Relation::IsType(a, b) => match a { + Atom::Var(_) => (*b == AtomType::Var).into(), + Atom::Fun(_) => (*b == AtomType::Fun).into(), + Atom::Num(_) => (*b == AtomType::Num).into(), + Atom::Add(_) => (*b == AtomType::Add).into(), + Atom::Mul(_) => (*b == AtomType::Mul).into(), + Atom::Pow(_) => (*b == AtomType::Pow).into(), + Atom::Zero => (*b == AtomType::Num).into(), + }, + } + } +} + +/// A restriction on wildcards. +#[pyclass(name = "Condition", module = "symbolica")] +#[derive(Clone)] +pub struct PythonCondition { + pub condition: Condition, +} + +impl From> for PythonCondition { + fn from(condition: Condition) -> Self { + PythonCondition { condition } + } +} + +#[pymethods] +impl PythonCondition { + pub fn __repr__(&self) -> String { + format!("{:?}", self.condition) + } + + pub fn __str__(&self) -> String { + format!("{}", self.condition) + } + + pub fn eval(&self) -> bool { + self.condition.evaluate(&()) == ConditionResult::True + } + + pub fn __bool__(&self) -> bool { + self.eval() + } + + /// Create a new pattern restriction that is the logical 'and' operation between two restrictions (i.e., both should hold). + pub fn __and__(&self, other: Self) -> PythonCondition { + (self.condition.clone() & other.condition.clone()).into() + } + + /// Create a new pattern restriction that is the logical 'or' operation between two restrictions (i.e., one of the two should hold). + pub fn __or__(&self, other: Self) -> PythonCondition { + (self.condition.clone() | other.condition.clone()).into() + } + + /// Create a new pattern restriction that takes the logical 'not' of the current restriction. + pub fn __invert__(&self) -> PythonCondition { + (!self.condition.clone()).into() + } + + /// Convert the condition to a pattern restriction. + pub fn to_req(&self) -> PyResult { + self.condition + .clone() + .try_into() + .map(|e| PythonPatternRestriction { condition: e }) + .map_err(|e| exceptions::PyValueError::new_err(e)) + } +} + +macro_rules! req_cmp_rel { + ($self:ident,$num:ident,$cmp_any_atom:ident,$c:ident) => {{ + if !$cmp_any_atom && !matches!($num.as_view(), AtomView::Num(_)) { + return Err("Can only compare to number"); + }; + + match $self.as_view() { + AtomView::Var(v) => { + let name = v.get_symbol(); + if v.get_wildcard_level() == 0 { + return Err("Only wildcards can be restricted."); + } + + Ok(PatternRestriction::Wildcard(( + name, + WildcardRestriction::Filter(Box::new(move |v: &Match| { + let k = $num.as_view(); + + if let Match::Single(m) = v { + if !$cmp_any_atom { + if let AtomView::Num(_) = m { + return m.cmp(&k).$c(); + } + } else { + return m.cmp(&k).$c(); + } + } + + false + })), + ))) + } + _ => Err("Only wildcards can be restricted."), + } + }}; +} + +impl TryFrom for PatternRestriction { + type Error = &'static str; + + fn try_from(value: Relation) -> Result { + match value { + Relation::Eq(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_eq); + } + Relation::Ne(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_ne); + } + Relation::Gt(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_gt); + } + Relation::Ge(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_ge); + } + Relation::Lt(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_lt); + } + Relation::Le(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_le); + } + Relation::Contains(atom, atom1) => { + if let Atom::Var(v) = atom { + let name = v.get_symbol(); + if name.get_wildcard_level() == 0 { + return Err("Only wildcards can be restricted."); + } + + Ok(PatternRestriction::Wildcard(( + name, + WildcardRestriction::Filter(Box::new(move |m| match m { + Match::Single(v) => v.contains(atom1.as_view()), + Match::Multiple(_, v) => v.iter().any(|x| x.contains(atom1.as_view())), + Match::FunctionName(_) => false, + })), + ))) + } else { + Err("LHS must be wildcard") + } + } + Relation::IsType(atom, atom_type) => { + if let Atom::Var(v) = atom { + Ok(PatternRestriction::Wildcard(( + v.get_symbol(), + WildcardRestriction::IsAtomType(atom_type), + ))) + } else { + Err("LHS must be wildcard") + } + } + } + } +} + +impl TryFrom> for Condition { + type Error = &'static str; + + fn try_from(value: Condition) -> Result { + Ok(match value { + Condition::True => Condition::True, + Condition::False => Condition::False, + Condition::Yield(r) => Condition::Yield(r.try_into()?), + Condition::And(a) => Condition::And(Box::new((a.0.try_into()?, a.1.try_into()?))), + Condition::Or(a) => Condition::Or(Box::new((a.0.try_into()?, a.1.try_into()?))), + Condition::Not(a) => Condition::Not(Box::new((*a).try_into()?)), + }) + } +} + +pub struct ConvertibleToPatternRestriction(Condition); + +impl<'a> FromPyObject<'a> for ConvertibleToPatternRestriction { + fn extract_bound(ob: &Bound<'a, pyo3::PyAny>) -> PyResult { + if let Ok(a) = ob.extract::() { + Ok(ConvertibleToPatternRestriction(a.condition)) + } else if let Ok(a) = ob.extract::() { + Ok(ConvertibleToPatternRestriction( + a.condition + .try_into() + .map_err(|e| exceptions::PyValueError::new_err(e))?, + )) + } else { + Err(exceptions::PyTypeError::new_err( + "Cannot convert to pattern restriction", + )) + } + } +} + impl<'a> FromPyObject<'a> for ConvertibleToExpression { fn extract_bound(ob: &Bound<'a, pyo3::PyAny>) -> PyResult { if let Ok(a) = ob.extract::() { @@ -1723,13 +1959,13 @@ impl<'a> FromPyObject<'a> for ConvertibleToExpression { Ok(ConvertibleToExpression(Atom::new_num(i).into())) } else if let Ok(_) = ob.extract::() { // disallow direct string conversion - Err(exceptions::PyValueError::new_err( + Err(exceptions::PyTypeError::new_err( "Cannot convert to expression", )) } else if let Ok(f) = ob.extract::() { Ok(ConvertibleToExpression(Atom::new_num(f.0).into())) } else { - Err(exceptions::PyValueError::new_err( + Err(exceptions::PyTypeError::new_err( "Cannot convert to expression", )) } @@ -1741,13 +1977,13 @@ impl<'a> FromPyObject<'a> for Symbol { if let Ok(a) = ob.extract::() { match a.expr.as_view() { AtomView::Var(v) => Ok(v.get_symbol()), - e => Err(exceptions::PyValueError::new_err(format!( + e => Err(exceptions::PyTypeError::new_err(format!( "Expected variable instead of {}", e ))), } } else { - Err(exceptions::PyValueError::new_err("Not a valid variable")) + Err(exceptions::PyTypeError::new_err("Not a valid variable")) } } } @@ -2088,6 +2324,10 @@ impl PythonExpression { Err(exceptions::PyValueError::new_err( "Illegal character in name", )) + } else if name.chars().next().unwrap().is_numeric() { + Err(exceptions::PyValueError::new_err( + "Name cannot start with a number", + )) } else { Ok(name) } @@ -2813,8 +3053,13 @@ impl PythonExpression { /// >>> e.contains(x) # True /// >>> e.contains(x*y*z) # True /// >>> e.contains(x*y) # False - pub fn contains(&self, s: ConvertibleToExpression) -> bool { - self.expr.contains(s.to_expression().expr.as_view()) + pub fn contains(&self, s: ConvertibleToExpression) -> PythonCondition { + PythonCondition { + condition: Condition::Yield(Relation::Contains( + self.expr.clone(), + s.to_expression().expr, + )), + } } /// Get all symbols in the current expression, optionally including function symbols. @@ -2935,6 +3180,35 @@ impl PythonExpression { } } + /// Create a pattern restriction that filters for expressions that contain `a`. + pub fn req_contains(&self, a: PythonExpression) -> PyResult { + match self.expr.as_view() { + AtomView::Var(v) => { + let name = v.get_symbol(); + if v.get_wildcard_level() == 0 { + return Err(exceptions::PyTypeError::new_err( + "Only wildcards can be restricted.", + )); + } + + Ok(PythonPatternRestriction { + condition: ( + name, + WildcardRestriction::Filter(Box::new(move |m| match m { + Match::Single(v) => v.contains(a.expr.as_view()), + Match::Multiple(_, v) => v.iter().any(|x| x.contains(a.expr.as_view())), + Match::FunctionName(_) => false, + })), + ) + .into(), + }) + } + _ => Err(exceptions::PyTypeError::new_err( + "Only wildcards can be restricted.", + )), + } + } + /// Create a pattern restriction that treats the wildcard as a literal variable, /// so that it only matches to itself. pub fn req_lit(&self) -> PyResult { @@ -2957,41 +3231,45 @@ impl PythonExpression { } } - /// Compare two expressions. - fn __richcmp__(&self, other: ConvertibleToExpression, op: CompareOp) -> PyResult { - match op { - CompareOp::Eq => Ok(self.expr == other.to_expression().expr), - CompareOp::Ne => Ok(self.expr != other.to_expression().expr), - _ => { - let other = other.to_expression(); - if let n1 @ AtomView::Num(_) = self.expr.as_view() { - if let n2 @ AtomView::Num(_) = other.expr.as_view() { - return Ok(match op { - CompareOp::Eq => n1 == n2, - CompareOp::Ge => n1 >= n2, - CompareOp::Gt => n1 > n2, - CompareOp::Le => n1 <= n2, - CompareOp::Lt => n1 < n2, - CompareOp::Ne => n1 != n2, - }); + /// Test if the expression is of a certain type. + pub fn is_type(&self, atom_type: PythonAtomType) -> PythonCondition { + PythonCondition { + condition: Condition::Yield(Relation::IsType( + self.expr.clone(), + match atom_type { + PythonAtomType::Num => AtomType::Num, + PythonAtomType::Var => AtomType::Var, + PythonAtomType::Add => AtomType::Add, + PythonAtomType::Mul => AtomType::Mul, + PythonAtomType::Pow => AtomType::Pow, + PythonAtomType::Fn => AtomType::Fun, + }, + )), } } - Err(exceptions::PyTypeError::new_err(format!( - "Inequalities between expression that are not numbers are not allowed in {} {} {}", - self.__str__()?, + /// Compare two expressions. If one of the expressions is not a number, an + /// internal ordering will be used. + fn __richcmp__(&self, other: ConvertibleToExpression, op: CompareOp) -> PythonCondition { match op { - CompareOp::Eq => "==", - CompareOp::Ge => ">=", - CompareOp::Gt => ">", - CompareOp::Le => "<=", - CompareOp::Lt => "<", - CompareOp::Ne => "!=", - }, - other.__str__()?, - ) - )) - } + CompareOp::Eq => PythonCondition { + condition: Relation::Eq(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Ne => PythonCondition { + condition: Relation::Ne(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Ge => PythonCondition { + condition: Relation::Ge(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Gt => PythonCondition { + condition: Relation::Gt(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Le => PythonCondition { + condition: Relation::Le(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Lt => PythonCondition { + condition: Relation::Lt(self.expr.clone(), other.to_expression().expr).into(), + }, } } @@ -3976,14 +4254,12 @@ impl PythonExpression { pub fn pattern_match( &self, lhs: ConvertibleToPattern, - cond: Option, + cond: Option, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, allow_new_wildcards_on_rhs: Option, ) -> PyResult { - let conditions = cond - .map(|r| r.condition.clone()) - .unwrap_or(Condition::default()); + let conditions = cond.map(|r| r.0).unwrap_or(Condition::default()); let settings = MatchSettings { level_range: level_range.unwrap_or((0, None)), level_is_tree_depth: level_is_tree_depth.unwrap_or(false), @@ -4016,15 +4292,13 @@ impl PythonExpression { pub fn matches( &self, lhs: ConvertibleToPattern, - cond: Option, + cond: Option, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, allow_new_wildcards_on_rhs: Option, ) -> PyResult { let pat = lhs.to_pattern()?.expr; - let conditions = cond - .map(|r| r.condition.clone()) - .unwrap_or(Condition::default()); + let conditions = cond.map(|r| r.0).unwrap_or(Condition::default()); let settings = MatchSettings { level_range: level_range.unwrap_or((0, None)), level_is_tree_depth: level_is_tree_depth.unwrap_or(false), @@ -4067,14 +4341,12 @@ impl PythonExpression { &self, lhs: ConvertibleToPattern, rhs: ConvertibleToPatternOrMap, - cond: Option, + cond: Option, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, allow_new_wildcards_on_rhs: Option, ) -> PyResult { - let conditions = cond - .map(|r| r.condition.clone()) - .unwrap_or(Condition::default()); + let conditions = cond.map(|r| r.0.clone()).unwrap_or(Condition::default()); let settings = MatchSettings { level_range: level_range.unwrap_or((0, None)), level_is_tree_depth: level_is_tree_depth.unwrap_or(false), @@ -4138,7 +4410,7 @@ impl PythonExpression { &self, pattern: ConvertibleToPattern, rhs: ConvertibleToPatternOrMap, - cond: Option, + cond: Option, non_greedy_wildcards: Option>, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, @@ -4185,15 +4457,11 @@ impl PythonExpression { let mut expr_ref = self.expr.as_view(); + let cond = cond.map(|r| r.0); + let mut out = RecycledAtom::new(); let mut out2 = RecycledAtom::new(); - while pattern.replace_all_into( - expr_ref, - rhs, - cond.as_ref().map(|r| &r.condition), - Some(&settings), - &mut out, - ) { + while pattern.replace_all_into(expr_ref, rhs, cond.as_ref(), Some(&settings), &mut out) { if !repeat.unwrap_or(false) { break; } @@ -4881,7 +5149,7 @@ impl PythonReplacement { pub fn new( pattern: ConvertibleToPattern, rhs: ConvertibleToPatternOrMap, - cond: Option, + cond: Option, non_greedy_wildcards: Option>, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, @@ -4925,10 +5193,7 @@ impl PythonReplacement { settings.rhs_cache_size = rhs_cache_size; } - let cond = cond - .as_ref() - .map(|r| r.condition.clone()) - .unwrap_or(Condition::default()); + let cond = cond.map(|r| r.0).unwrap_or(Condition::default()); Ok(Self { pattern, diff --git a/src/atom.rs b/src/atom.rs index 8e89250..ddd4515 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -100,7 +100,7 @@ impl Symbol { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum AtomType { Num, Var, diff --git a/src/id.rs b/src/id.rs index 9716c8f..41fddf1 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1345,6 +1345,41 @@ pub enum Condition { False, } +impl std::fmt::Display for Condition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Condition::And(a) => write!(f, "({}) & ({})", a.0, a.1), + Condition::Or(o) => write!(f, "{} | {}", o.0, o.1), + Condition::Not(n) => write!(f, "!({})", n), + Condition::True => write!(f, "True"), + Condition::False => write!(f, "False"), + Condition::Yield(t) => write!(f, "{}", t), + } + } +} + +pub trait Evaluate { + type State<'a>; + + /// Evaluate a condition. + fn evaluate<'a>(&self, state: &Self::State<'a>) -> ConditionResult; +} + +impl Evaluate for Condition { + type State<'a> = T::State<'a>; + + fn evaluate(&self, state: &T::State<'_>) -> ConditionResult { + match self { + Condition::And(a) => a.0.evaluate(state) & a.1.evaluate(state), + Condition::Or(o) => o.0.evaluate(state) | o.1.evaluate(state), + Condition::Not(n) => !n.evaluate(state), + Condition::True => ConditionResult::True, + Condition::False => ConditionResult::False, + Condition::Yield(t) => t.evaluate(state), + } + } +} + impl From for Condition { fn from(value: T) -> Self { Condition::Yield(value) @@ -1430,6 +1465,64 @@ impl From for ConditionResult { } } +impl Evaluate for Condition { + type State<'a> = MatchStack<'a, 'a>; + + fn evaluate(&self, state: &MatchStack) -> ConditionResult { + match self { + Condition::And(a) => a.0.evaluate(state) & a.1.evaluate(state), + Condition::Or(o) => o.0.evaluate(state) | o.1.evaluate(state), + Condition::Not(n) => !n.evaluate(state), + Condition::True => ConditionResult::True, + Condition::False => ConditionResult::False, + Condition::Yield(t) => match t { + PatternRestriction::Wildcard((v, r)) => { + if let Some((_, value)) = state.stack.iter().find(|(k, _)| k == v) { + match r { + WildcardRestriction::IsAtomType(t) => match value { + Match::Single(AtomView::Num(_)) => *t == AtomType::Num, + Match::Single(AtomView::Var(_)) => *t == AtomType::Var, + Match::Single(AtomView::Add(_)) => *t == AtomType::Add, + Match::Single(AtomView::Mul(_)) => *t == AtomType::Mul, + Match::Single(AtomView::Pow(_)) => *t == AtomType::Pow, + Match::Single(AtomView::Fun(_)) => *t == AtomType::Fun, + _ => false, + }, + WildcardRestriction::IsLiteralWildcard(wc) => match value { + Match::Single(AtomView::Var(v)) => wc == &v.get_symbol(), + _ => false, + }, + WildcardRestriction::Length(min, max) => match value { + Match::Single(_) | Match::FunctionName(_) => { + *min <= 1 && max.map(|m| m >= 1).unwrap_or(true) + } + Match::Multiple(_, slice) => { + *min <= slice.len() + && max.map(|m| m >= slice.len()).unwrap_or(true) + } + }, + WildcardRestriction::Filter(f) => f(value), + WildcardRestriction::Cmp(v2, f) => { + if let Some((_, value2)) = state.stack.iter().find(|(k, _)| k == v2) + { + f(value, value2) + } else { + return ConditionResult::Inconclusive; + } + } + WildcardRestriction::NotGreedy => true, + } + .into() + } else { + ConditionResult::Inconclusive + } + } + PatternRestriction::MatchStack(mf) => mf(state), + }, + } + } +} + impl Condition { /// Check if the conditions on `var` are met fn check_possible(&self, var: Symbol, value: &Match, stack: &MatchStack) -> ConditionResult { diff --git a/symbolica.pyi b/symbolica.pyi index adb7611..43ce8da 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -517,7 +517,7 @@ class Expression: transformations can be applied. """ - def contains(self, a: Expression | int | float | Decimal) -> bool: + def contains(self, a: Expression | int | float | Decimal) -> Condition: """Returns true iff `self` contains `a` literally. Examples @@ -569,6 +569,16 @@ class Expression: Yields `f(x)*f(1)`. """ + def is_type(self, atom_type: AtomType) -> Condition: + """ + Test if the expression is of a certain type. + """ + + def req_contains(self, a: Expression) -> PatternRestriction: + """ + Create a pattern restriction that filters for expressions that contain `a`. + """ + def req_lit(self) -> PatternRestriction: """ Create a pattern restriction that treats the wildcard as a literal variable, @@ -746,34 +756,34 @@ class Expression: >>> e = e.replace_all(f(x_,y_), 1, x_.req_cmp_ge(y_)) """ - def __eq__(self, other: Expression | int | float | Decimal) -> bool: + def __eq__(self, other: Expression | int | float | Decimal) -> Condition: """ Compare two expressions. """ - def __neq__(self, other: Expression | int | float | Decimal) -> bool: + def __neq__(self, other: Expression | int | float | Decimal) -> Condition: """ Compare two expressions. """ - def __lt__(self, other: Expression | int | float | Decimal) -> bool: + def __lt__(self, other: Expression | int | float | Decimal) -> Condition: """ - Compare two expressions. Both expressions must be a number. + Compare two expressions. If any of the two expressions is not a rational number, an interal ordering is used. """ - def __le__(self, other: Expression | int | float | Decimal) -> bool: + def __le__(self, other: Expression | int | float | Decimal) -> Condition: """ - Compare two expressions. Both expressions must be a number. + Compare two expressions. If any of the two expressions is not a rational number, an interal ordering is used. """ - def __gt__(self, other: Expression | int | float | Decimal) -> bool: + def __gt__(self, other: Expression | int | float | Decimal) -> Condition: """ - Compare two expressions. Both expressions must be a number. + Compare two expressions. If any of the two expressions is not a rational number, an interal ordering is used. """ - def __ge__(self, other: Expression | int | float | Decimal) -> bool: + def __ge__(self, other: Expression | int | float | Decimal) -> Condition: """ - Compare two expressions. Both expressions must be a number. + Compare two expressions. If any of the two expressions is not a rational number, an interal ordering is used. """ def __iter__(self) -> Iterator[Expression]: @@ -1056,7 +1066,7 @@ class Expression: def match( self, lhs: Transformer | Expression | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, allow_new_wildcards_on_rhs: Optional[bool] = False, @@ -1083,7 +1093,7 @@ class Expression: def matches( self, lhs: Transformer | Expression | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, allow_new_wildcards_on_rhs: Optional[bool] = False, @@ -1104,7 +1114,7 @@ class Expression: self, lhs: Transformer | Expression | int | float | Decimal, rhs: Transformer | Expression | Callable[[dict[Expression, Expression]], Expression] | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, allow_new_wildcards_on_rhs: Optional[bool] = False, @@ -1150,7 +1160,7 @@ class Expression: self, pattern: Transformer | Expression | int | float | Decimal, rhs: Transformer | Expression | Callable[[dict[Expression, Expression]], Expression] | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, non_greedy_wildcards: Optional[Sequence[Expression]] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, @@ -1167,7 +1177,7 @@ class Expression: >>> x, w1_, w2_ = Expression.symbol('x','w1_','w2_') >>> f = Expression.symbol('f') >>> e = f(3,x) - >>> r = e.replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), (w1_ >= 1) & w2_.is_var()) + >>> r = e.replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), w1_ >= 1) >>> print(r) Parameters @@ -1464,7 +1474,7 @@ class Replacement: cls, pattern: Transformer | Expression | int | float | Decimal, rhs: Transformer | Expression | Callable[[dict[Expression, Expression]], Expression] | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, non_greedy_wildcards: Optional[Sequence[Expression]] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, @@ -1518,6 +1528,34 @@ class PatternRestriction: """ +class Condition: + """Relations that evaluate to booleans""" + + def eval(self) -> bool: + """Evaluate the condition.""" + + def __repr__(self) -> str: + """Return a string representation of the condition.""" + + def __str__(self) -> str: + """Return a string representation of the condition.""" + + def __bool__(self) -> bool: + """Return the boolean value of the condition.""" + + def __and__(self, other: Condition) -> Condition: + """Create a condition that is the logical and operation between two conditions (i.e., both should hold).""" + + def __or__(self, other: Condition) -> Condition: + """Create a condition that is the logical 'or' operation between two conditions (i.e., at least one of the two should hold).""" + + def __invert__(self) -> Condition: + """Create a condition that takes the logical 'not' of the current condition.""" + + def to_req(self) -> PatternRestriction: + """Convert the condition to a pattern restriction.""" + + class CompareOp: """One of the following comparison operators: `<`,`>`,`<=`,`>=`,`==`,`!=`.""" @@ -1955,7 +1993,7 @@ class Transformer: self, pat: Transformer | Expression | int | float | Decimal, rhs: Transformer | Expression | Callable[[dict[Expression, Expression]], Expression] | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, non_greedy_wildcards: Optional[Sequence[Expression]] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, @@ -1970,7 +2008,7 @@ class Transformer: >>> x, w1_, w2_ = Expression.symbol('x','w1_','w2_') >>> f = Expression.symbol('f') >>> e = f(3,x) - >>> r = e.transform().replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), (w1_ >= 1) & w2_.is_var()) + >>> r = e.transform().replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), w1_ >= 1) >>> print(r) Parameters