Skip to content

Commit

Permalink
Add several transformer methods
Browse files Browse the repository at this point in the history
- Check for symbol name validity in Python API
  • Loading branch information
benruijl committed Sep 4, 2024
1 parent 42a5d2c commit aca4408
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 5 deletions.
227 changes: 223 additions & 4 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,208 @@ impl PythonTransformer {
Ok(out.into())
}

/// Set the coefficient ring to contain the variables in the `vars` list.
/// This will move all variables into a rational polynomial function.
///
/// Parameters
/// ----------
/// vars: List[Expression]
/// A list of variables
pub fn set_coefficient_ring(&self, vars: Vec<PythonExpression>) -> PyResult<PythonTransformer> {
let mut var_map = vec![];
for v in vars {
match v.expr.as_view() {
AtomView::Var(v) => var_map.push(v.get_symbol().into()),
e => {
Err(exceptions::PyValueError::new_err(format!(
"Expected variable instead of {}",
e
)))?;
}
}
}

let a = Arc::new(var_map);

return append_transformer!(
self,
Transformer::Map(Box::new(move |i, o| {
*o = i.set_coefficient_ring(&a);
Ok(())
}))
);
}

/// Create a transformer that collects terms involving the same power of `x`,
/// where `x` is a variable or function name.
/// Return the list of key-coefficient pairs and the remainder that matched no key.
///
/// Both the key (the quantity collected in) and its coefficient can be mapped using
/// `key_map` and `coeff_map` transformers respectively.
///
/// Examples
/// --------
/// >>> from symbolica import Expression
/// >>> x, y = Expression.symbol('x', 'y')
/// >>> e = 5*x + x * y + x**2 + 5
/// >>>
/// >>> print(e.transform().collect(x).execute())
///
/// yields `x^2+x*(y+5)+5`.
///
/// >>> from symbolica import Expression
/// >>> x, y, x_, var, coeff = Expression.symbol('x', 'y', 'x_', 'var', 'coeff')
/// >>> e = 5*x + x * y + x**2 + 5
/// >>> print(e.collect(x, key_map=Transformer().replace_all(x_, var(x_)),
/// coeff_map=Transformer().replace_all(x_, coeff(x_))))
///
/// yields `var(1)*coeff(5)+var(x)*coeff(y+5)+var(x^2)*coeff(1)`.
///
/// Parameters
/// ----------
/// x: Expression
/// The variable to collect terms in
/// key_map: Transformer
/// A transformer to be applied to the quantity collected in
/// coeff_map: Transformer
/// A transformer to be applied to the coefficient
pub fn collect(
&self,
x: ConvertibleToExpression,
key_map: Option<PythonTransformer>,
coeff_map: Option<PythonTransformer>,
) -> PyResult<PythonTransformer> {
let id = if let AtomView::Var(x) = x.to_expression().expr.as_view() {
x.get_symbol()
} else {
return Err(exceptions::PyValueError::new_err(
"Collect must be done wrt a variable or function name",
));
};

let key_map = if let Some(key_map) = key_map {
let Pattern::Transformer(p) = key_map.expr else {
return Err(exceptions::PyValueError::new_err(
"Key map must be a transformer",
));
};

if p.0.is_some() {
Err(exceptions::PyValueError::new_err(
"Key map must be an unbound transformer",
))?;
}

p.1.clone()
} else {
vec![]
};

let coeff_map = if let Some(coeff_map) = coeff_map {
let Pattern::Transformer(p) = coeff_map.expr else {
return Err(exceptions::PyValueError::new_err(
"Key map must be a transformer",
));
};

if p.0.is_some() {
Err(exceptions::PyValueError::new_err(
"Key map must be an unbound transformer",
))?;
}

p.1.clone()
} else {
vec![]
};

return append_transformer!(self, Transformer::Collect(id, key_map, coeff_map));
}

/// Create a transformer that collects terms involving the literal occurrence of `x`.
pub fn coefficient(&self, x: ConvertibleToExpression) -> PyResult<PythonTransformer> {
let a = x.to_expression().expr;
return append_transformer!(
self,
Transformer::Map(Box::new(move |i, o| {
*o = i.coefficient(a.as_view());
Ok(())
}))
);
}

/// Create a transformer that computes the partial fraction decomposition in `x`.
pub fn apart(&self, x: PythonExpression) -> PyResult<PythonTransformer> {
return append_transformer!(
self,
Transformer::Map(Box::new(move |i, o| {
let poly = i.to_rational_polynomial::<_, _, u32>(&Q, &Z, None);

let x = poly
.get_variables()
.iter()
.position(|v| match (v, x.expr.as_view()) {
(Variable::Symbol(y), AtomView::Var(vv)) => *y == vv.get_symbol(),
(Variable::Function(_, f) | Variable::Other(f), a) => f.as_view() == a,
_ => false,
})
.ok_or(TransformerError::ValueError(format!(
"Variable {} not found in polynomial",
x.expr
)))?;

let fs = poly.apart(x);

Workspace::get_local().with(|ws| {
let mut res = ws.new_atom();
let a = res.to_add();
for f in fs {
a.extend(f.to_expression().as_view());
}

res.as_view().normalize(ws, o);
});

Ok(())
}))
);
}

/// Create a transformer that writes the expression over a common denominator.
pub fn together(&self) -> PyResult<PythonTransformer> {
return append_transformer!(
self,
Transformer::Map(Box::new(|i, o| {
let poly = i.to_rational_polynomial::<_, _, u32>(&Q, &Z, None);
*o = poly.to_expression();
Ok(())
}))
);
}

/// Create a transformer that cancels common factors between numerators and denominators.
/// Any non-canceling parts of the expression will not be rewritten.
pub fn cancel(&self) -> PyResult<PythonTransformer> {
return append_transformer!(
self,
Transformer::Map(Box::new(|i, o| {
*o = i.cancel();
Ok(())
}))
);
}

/// Create a transformer that factors the expression over the rationals.
pub fn factor(&self) -> PyResult<PythonTransformer> {
return append_transformer!(
self,
Transformer::Map(Box::new(|i, o| {
*o = i.factor();
Ok(())
}))
);
}

/// Create a transformer that derives `self` w.r.t the variable `x`.
pub fn derivative(&self, x: ConvertibleToPattern) -> PyResult<PythonTransformer> {
let id = match &x.to_pattern()?.expr {
Expand Down Expand Up @@ -1622,6 +1824,23 @@ impl PythonExpression {
));
}

fn name_check(name: &str) -> PyResult<&str> {
let illegal_chars = [
'\0', '^', '+', '*', '-', '(', ')', '/', ',', '[', ']', ' ', '\t', '\n', '\r',
'\\', ';', ':', '&', '!', '%', '.',
];

if name.is_empty() {
Err(exceptions::PyValueError::new_err("Name cannot be empty"))
} else if name.chars().any(|x| illegal_chars.contains(&x)) {
Err(exceptions::PyValueError::new_err(
"Illegal character in name",
))
} else {
Ok(name)
}
}

if is_symmetric.is_none()
&& is_antisymmetric.is_none()
&& is_cyclesymmetric.is_none()
Expand All @@ -1630,14 +1849,14 @@ impl PythonExpression {
if names.len() == 1 {
let name = names[0].extract::<&str>()?;

let id = State::get_symbol(name);
let id = State::get_symbol(name_check(name)?);
let r = PythonExpression::from(Atom::new_var(id));
return Ok(r.into_py(py));
} else {
let mut result = vec![];
for a in names {
let name = a.extract::<&str>()?;
let id = State::get_symbol(name);
let id = State::get_symbol(name_check(name)?);
let r = PythonExpression::from(Atom::new_var(id));
result.push(r);
}
Expand Down Expand Up @@ -1677,15 +1896,15 @@ impl PythonExpression {
if names.len() == 1 {
let name = names[0].extract::<&str>()?;

let id = State::get_symbol_with_attributes(name, &opts)
let id = State::get_symbol_with_attributes(name_check(name)?, &opts)
.map_err(|e| exceptions::PyTypeError::new_err(e.to_string()))?;
let r = PythonExpression::from(Atom::new_var(id));
Ok(r.into_py(py))
} else {
let mut result = vec![];
for a in names {
let name = a.extract::<&str>()?;
let id = State::get_symbol_with_attributes(name, &opts)
let id = State::get_symbol_with_attributes(name_check(name)?, &opts)
.map_err(|e| exceptions::PyTypeError::new_err(e.to_string()))?;
let r = PythonExpression::from(Atom::new_var(id));
result.push(r);
Expand Down
29 changes: 28 additions & 1 deletion src/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ pub enum Transformer {
Expand(Option<Symbol>),
/// Derive the rhs w.r.t a variable.
Derivative(Symbol),
/// Derive the rhs w.r.t a variable.
/// Perform a series expansion.
Series(Symbol, Atom, Rational, bool),
///Collect all terms in powers of a variable.
Collect(Symbol, Vec<Transformer>, Vec<Transformer>),
/// Apply find-and-replace on the lhs.
ReplaceAll(
Pattern,
Expand Down Expand Up @@ -117,6 +119,9 @@ impl std::fmt::Debug for Transformer {
match self {
Transformer::Expand(s) => f.debug_tuple("Expand").field(s).finish(),
Transformer::Derivative(x) => f.debug_tuple("Derivative").field(x).finish(),
Transformer::Collect(x, a, b) => {
f.debug_tuple("Collect").field(x).field(a).field(b).finish()
}
Transformer::ReplaceAll(pat, rhs, ..) => {
f.debug_tuple("ReplaceAll").field(pat).field(rhs).finish()
}
Expand Down Expand Up @@ -391,6 +396,28 @@ impl Transformer {
Transformer::Derivative(x) => {
cur_input.derivative_with_ws_into(*x, workspace, out);
}
Transformer::Collect(x, key_map, coeff_map) => cur_input.collect_into(
*x,
if key_map.is_empty() {
None
} else {
let key_map = key_map.clone();
Some(Box::new(move |i, o| {
Workspace::get_local()
.with(|ws| Self::execute_chain(i, &key_map, ws, o).unwrap())
}))
},
if coeff_map.is_empty() {
None
} else {
let coeff_map = coeff_map.clone();
Some(Box::new(move |i, o| {
Workspace::get_local()
.with(|ws| Self::execute_chain(i, &coeff_map, ws, o).unwrap())
}))
},
out,
),
Transformer::Series(x, expansion_point, depth, depth_is_absolute) => {
if let Ok(s) = cur_input.series(
*x,
Expand Down
73 changes: 73 additions & 0 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,79 @@ class Transformer:
def derivative(self, x: Transformer | Expression) -> Transformer:
"""Create a transformer that derives `self` w.r.t the variable `x`."""

def set_coefficient_ring(self, vars: Sequence[Expression]) -> Transformer:
"""
Create a transformer that sets the coefficient ring to contain the variables in the `vars` list.
This will move all variables into a rational polynomial function.
Parameters
----------
vars : Sequence[Expression]
A list of variables
"""

def collect(
self,
x: Expression,
key_map: Optional[Transformer] = None,
coeff_map: Optional[Transformer] = None,
) -> Transformer:
"""
Create a transformer that collect terms involving the same power of `x`,
where `x` is a variable or function name.
Return the list of key-coefficient pairs and the remainder that matched no key.
Both the key (the quantity collected in) and its coefficient can be mapped using
`key_map` and `coeff_map` transformers respectively.
Examples
--------
>>> from symbolica import Expression
>>> x, y = Expression.symbol('x', 'y')
>>> e = 5*x + x * y + x**2 + 5
>>>
>>> print(e.transform().collect(x).execute())
yields `x^2+x*(y+5)+5`.
>>> from symbolica import Expression
>>> x, y, x_, var, coeff = Expression.symbol('x', 'y', 'x_', 'var', 'coeff')
>>> e = 5*x + x * y + x**2 + 5
>>> print(e.collect(x, key_map=Transformer().replace_all(x_, var(x_)),
coeff_map=Transformer().replace_all(x_, coeff(x_))))
yields `var(1)*coeff(5)+var(x)*coeff(y+5)+var(x^2)*coeff(1)`.
Parameters
----------
x: Expression
The variable to collect terms in
key_map: Transformer
A transformer to be applied to the quantity collected in
coeff_map: Transformer
A transformer to be applied to the coefficient
"""

def coefficient(self, x: Expression) -> Transformer:
"""Create a transformer that collects terms involving the literal occurrence of `x`.
"""

def apart(self, x: Expression) -> Transformer:
"""Create a transformer that computes the partial fraction decomposition in `x`.
"""

def together(self) -> Transformer:
"""Create a transformer that writes the expression over a common denominator.
"""

def cancel(self) -> Transformer:
"""Create a transformer that cancels common factors between numerators and denominators.
Any non-canceling parts of the expression will not be rewritten.
"""

def factor(self) -> Transformer:
"""Create a transformer that factors the expression over the rationals."""

def series(
self,
x: Expression,
Expand Down

0 comments on commit aca4408

Please sign in to comment.