Skip to content

Commit

Permalink
Improve ergonomics of fun! macro
Browse files Browse the repository at this point in the history
- Normalize rational polynomial after unifying variables
- Improve lifetime bounds on PatternAtomTreeIterator
  • Loading branch information
benruijl committed Jul 25, 2024
1 parent 6623a94 commit f1111c6
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 38 deletions.
6 changes: 1 addition & 5 deletions examples/builder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use symbolica::{
atom::{Atom, FunctionBuilder},
fun,
state::State,
};
use symbolica::{atom::Atom, fun, state::State};

fn main() {
let x = Atom::parse("x").unwrap();
Expand Down
6 changes: 1 addition & 5 deletions examples/collect.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use symbolica::{
atom::{Atom, FunctionBuilder},
fun,
state::State,
};
use symbolica::{atom::Atom, fun, state::State};

fn main() {
let input = Atom::parse("x*(1+a)+x*5*y+f(5,x)+2+y^2+x^2 + x^3").unwrap();
Expand Down
17 changes: 10 additions & 7 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,18 @@ impl FunctionBuilder {

/// Create a new function by providing its name as the first argument,
/// followed by the list of arguments. This macro uses [`FunctionBuilder`].
///
/// For example:
/// ```
/// use symbolica::{atom::Atom, fun, state::State};
/// let f_id = State::get_symbol("f");
/// let f = fun!(f_id, Atom::new_num(3), &Atom::parse("x").unwrap());
/// ```
#[macro_export]
macro_rules! fun {
($name:ident, $($id:expr),*) => {
($name: expr, $($id: expr),*) => {
{
let mut f = FunctionBuilder::new($name);
let mut f = $crate::atom::FunctionBuilder::new($name);
$(
f = f.add_arg(&$id);
)+
Expand Down Expand Up @@ -1199,11 +1206,7 @@ impl<T: Into<Coefficient>> std::ops::Div<T> for Atom {

#[cfg(test)]
mod test {
use crate::{
atom::{Atom, FunctionBuilder},
fun,
state::State,
};
use crate::{atom::Atom, fun, state::State};

#[test]
fn debug() {
Expand Down
6 changes: 1 addition & 5 deletions src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,11 +631,7 @@ impl<'a> AtomView<'a> {

#[cfg(test)]
mod test {
use crate::{
atom::{Atom, FunctionBuilder},
fun,
state::State,
};
use crate::{atom::Atom, fun, state::State};

#[test]
fn coefficient_list() {
Expand Down
36 changes: 28 additions & 8 deletions src/domains/rational_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,34 @@ where
}
}

impl<R: Ring, E: Exponent> RationalPolynomial<R, E>
where
Self: FromNumeratorAndDenominator<R, R, E>,
{
pub fn unify_variables(&mut self, other: &mut Self) {
assert_eq!(self.numerator.variables, self.denominator.variables);
assert_eq!(other.numerator.variables, other.denominator.variables);

// this may require a new normalization of the denominator
self.numerator.unify_variables(&mut other.numerator);
self.denominator.unify_variables(&mut other.denominator);

*self = Self::from_num_den(
self.numerator.clone(),
self.denominator.clone(),
&self.numerator.field,
false,
);

*other = Self::from_num_den(
other.numerator.clone(),
other.denominator.clone(),
&other.numerator.field,
false,
);
}
}

impl<R: Ring, E: Exponent> RationalPolynomial<R, E> {
pub fn new(field: &R, var_map: Arc<Vec<Variable>>) -> RationalPolynomial<R, E> {
let num = MultivariatePolynomial::new(field, None, var_map);
Expand All @@ -110,14 +138,6 @@ impl<R: Ring, E: Exponent> RationalPolynomial<R, E> {
&self.numerator.variables
}

pub fn unify_variables(&mut self, other: &mut Self) {
assert_eq!(self.numerator.variables, self.denominator.variables);
assert_eq!(other.numerator.variables, other.denominator.variables);

self.numerator.unify_variables(&mut other.numerator);
self.denominator.unify_variables(&mut other.denominator);
}

pub fn is_zero(&self) -> bool {
self.numerator.is_zero()
}
Expand Down
24 changes: 17 additions & 7 deletions src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,12 +1080,12 @@ impl Pattern {
matched
}

pub fn pattern_match<'a>(
&'a self,
pub fn pattern_match<'a: 'b, 'b>(
&'b self,
target: AtomView<'a>,
conditions: &'a Condition<WildcardAndRestriction>,
settings: &'a MatchSettings,
) -> PatternAtomTreeIterator<'a, 'a> {
conditions: &'b Condition<WildcardAndRestriction>,
settings: &'b MatchSettings,
) -> PatternAtomTreeIterator<'a, 'b> {
PatternAtomTreeIterator::new(self, target, conditions, settings)
}
}
Expand Down Expand Up @@ -1601,6 +1601,16 @@ impl<'a, 'b> MatchStack<'a, 'b> {
None
}

/// Get a reference to all matches.
pub fn get_matches(&self) -> &[(Symbol, Match<'a>)] {
&self.stack
}

/// Get the underlying matches `Vec`.
pub fn into_matches(self) -> Vec<(Symbol, Match<'a>)> {
self.stack
}

/// Return the length of the stack.
#[inline]
pub fn len(&self) -> usize {
Expand Down Expand Up @@ -2444,8 +2454,8 @@ impl<'a: 'b, 'b> PatternAtomTreeIterator<'a, 'b> {
pub fn new(
pattern: &'b Pattern,
target: AtomView<'a>,
conditions: &'a Condition<WildcardAndRestriction>,
settings: &'a MatchSettings,
conditions: &'b Condition<WildcardAndRestriction>,
settings: &'b MatchSettings,
) -> PatternAtomTreeIterator<'a, 'b> {
PatternAtomTreeIterator {
pattern,
Expand Down
2 changes: 1 addition & 1 deletion src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ mod test {
let r = streamer.to_expression();

let res = Atom::parse("coeff((v3+2*v2*v3+v2*v3^2+v2^2)/(v2*v3))*v1").unwrap();
assert_eq!(r, res);
assert_eq!(r - res, Atom::Zero);
}

#[test]
Expand Down

0 comments on commit f1111c6

Please sign in to comment.