Skip to content

Commit

Permalink
Add option to expand in variable
Browse files Browse the repository at this point in the history
- Expansion in variable is always performed now when collecting
- Improve Python API documentation generation
  • Loading branch information
benruijl committed Apr 20, 2024
1 parent 7f16108 commit d74b329
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 69 deletions.
123 changes: 92 additions & 31 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ fn symbolica(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PythonAtomTree>()?;
m.add_class::<PythonInstructionEvaluator>()?;
m.add_class::<PythonRandomNumberGenerator>()?;
m.add_class::<PythonPatternRestriction>()?;

m.add_function(wrap_pyfunction!(get_version, m)?)?;
m.add_function(wrap_pyfunction!(is_licensed, m)?)?;
Expand Down Expand Up @@ -148,7 +149,7 @@ fn get_offline_license_key() -> PyResult<String> {

/// Specifies the type of the atom.
#[derive(Clone, Copy)]
#[pyclass(name = "AtomType")]
#[pyclass(name = "AtomType", module = "symbolica")]
pub enum PythonAtomType {
Num,
Var,
Expand All @@ -173,7 +174,7 @@ pub enum PythonAtomType {
/// - the base and exponent for type `Pow`
/// - the function arguments for type `Fn`
#[derive(Clone)]
#[pyclass(name = "AtomTree")]
#[pyclass(name = "AtomTree", module = "symbolica")]
pub struct PythonAtomTree {
/// The type of this atom.
#[pyo3(get)]
Expand Down Expand Up @@ -249,7 +250,7 @@ impl ConvertibleToPattern {
}

/// Operations that transform an expression.
#[pyclass(name = "Transformer")]
#[pyclass(name = "Transformer", module = "symbolica")]
#[derive(Clone)]
pub struct PythonPattern {
pub expr: Arc<Pattern>,
Expand Down Expand Up @@ -294,8 +295,20 @@ impl PythonPattern {
/// >>> f = Expression.fun('f')
/// >>> e = f((x+1)**2).replace_all(f(x_), x_.transform().expand())
/// >>> print(e)
pub fn expand(&self) -> PyResult<PythonPattern> {
return append_transformer!(self, Transformer::Expand);
pub fn expand(&self, var: Option<ConvertibleToExpression>) -> PyResult<PythonPattern> {
if let Some(var) = var {
let id = if let AtomView::Var(x) = var.to_expression().expr.as_view() {
x.get_symbol()
} else {
return Err(exceptions::PyValueError::new_err(
"Expansion must be done wrt a variable or function name",
));
};

return append_transformer!(self, Transformer::Expand(Some(id)));
} else {
return append_transformer!(self, Transformer::Expand(None));
}
}

/// Create a transformer that computes the product of a list of arguments.
Expand Down Expand Up @@ -1029,14 +1042,23 @@ impl PythonPattern {
/// >>> x = Expression.var('x')
/// >>> e = x**2 + 2 - x + 1 / x**4
/// >>> print(e)
#[pyclass(name = "Expression")]
///
/// Attributes
/// ----------
/// E: Expression
/// Euler's number `e`.
/// PI: Expression
/// The mathematical constant `π`.
/// I: Expression
/// The mathematical constant `i`, where `i^2 = -1`.
#[pyclass(name = "Expression", module = "symbolica")]
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct PythonExpression {
pub expr: Arc<Atom>,
}

/// A restriction on wildcards.
#[pyclass(name = "PatternRestriction")]
#[pyclass(name = "PatternRestriction", module = "symbolica")]
#[derive(Clone)]
pub struct PythonPatternRestriction {
pub condition: Arc<Condition<WildcardAndRestriction>>,
Expand Down Expand Up @@ -1419,8 +1441,7 @@ impl PythonExpression {
///
/// Parameters
/// ----------
/// input: str
/// An input string. UTF-8 character are allowed.
/// input: str An input string. UTF-8 character are allowed.
///
/// Examples
/// --------
Expand All @@ -1434,8 +1455,8 @@ impl PythonExpression {
/// If the input is not a valid Symbolica expression.
///
#[classmethod]
pub fn parse(_cls: &PyType, arg: &str) -> PyResult<PythonExpression> {
let e = Atom::parse(arg).map_err(exceptions::PyValueError::new_err)?;
pub fn parse(_cls: &PyType, input: &str) -> PyResult<PythonExpression> {
let e = Atom::parse(input).map_err(exceptions::PyValueError::new_err)?;

Ok(PythonExpression { expr: Arc::new(e) })
}
Expand Down Expand Up @@ -2238,8 +2259,7 @@ impl PythonExpression {
/// Parameters
/// ----------
/// vars : List[Expression]
/// A list of variables
/// vars: List[Expression] A list of variables
pub fn set_coefficient_ring(&self, vars: Vec<PythonExpression>) -> PyResult<PythonExpression> {
let mut var_map = vec![];
for v in vars {
Expand All @@ -2259,10 +2279,23 @@ impl PythonExpression {
Ok(PythonExpression { expr: Arc::new(b) })
}

/// Expand the expression.
pub fn expand(&self) -> PyResult<PythonExpression> {
/// Expand the expression. Optionally, expand in `var` only.
pub fn expand(&self, var: Option<ConvertibleToExpression>) -> PyResult<PythonExpression> {
if let Some(var) = var {
let id = if let AtomView::Var(x) = var.to_expression().expr.as_view() {
x.get_symbol()
} else {
return Err(exceptions::PyValueError::new_err(
"Expansion must be done wrt a variable or function name",
));
};

let b = self.expr.as_view().expand_in(id);
Ok(PythonExpression { expr: Arc::new(b) })
} else {
let b = self.expr.as_view().expand();
Ok(PythonExpression { expr: Arc::new(b) })
}
}

/// Collect terms involving the same power of `x`, where `x` is a variable or function name.
Expand Down Expand Up @@ -2458,7 +2491,7 @@ impl PythonExpression {

/// Taylor expand in `x` around `expansion_point` to depth `depth`.
///
/// Example
/// Examples
/// -------
/// >>> from symbolica import Expression
/// >>> x, y = Expression.vars('x', 'y')
Expand Down Expand Up @@ -2758,6 +2791,21 @@ impl PythonExpression {
/// >>> e = f(3,x)
/// >>> r = e.replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), (w1_ >= 1) & w2_.is_var())
/// >>> print(r)
///
/// Parameters
/// ----------
/// pattern: Transformer | Expression | int
/// The pattern to match.
/// rhs: Transformer | Expression | int
/// The right-hand side to replace the matched subexpression with.
/// cond: Optional[PatternRestriction]
/// Conditions on the pattern.
/// level_range: (int, int), optional
/// Specifies the `[min,max]` level at which the pattern is allowed to match. The first level is 0 and the level is increased when going into a function or one level deeper in the expression tree, depending on `level_is_tree_depth`.
/// level_is_tree_depth: bool, optional
/// If set to `True`, the level is increased when going one level deeper in the expression tree.
/// repeat: bool, optional
/// If set to `True`, the entire operation will be repeated until there are no more matches.
pub fn replace_all(
&self,
pattern: ConvertibleToPattern,
Expand Down Expand Up @@ -2988,7 +3036,20 @@ impl PythonExpression {
/// f = Expression.fun("f")
/// e = f(1,2,3)
/// ```
#[pyclass(name = "Function")]
///
/// Attributes
/// ----------
/// COEFF: Function
/// The built-in function that converts a rational polynomial to a coefficient.
/// COS: Function
/// The built-in cosine function.
/// SIN: Function
/// The built-in sine function.
/// EXP: Function
/// The built-in exponential function.
/// LOG: Function
/// The built-in logarithm function.
#[pyclass(name = "Function", module = "symbolica")]
#[derive(Clone)]
pub struct PythonFunction {
id: Symbol,
Expand Down Expand Up @@ -3153,7 +3214,7 @@ impl PythonFunction {
}

self_cell!(
#[pyclass]
#[pyclass(module = "symbolica")]
pub struct PythonAtomIterator {
owner: Arc<Atom>,
#[covariant]
Expand Down Expand Up @@ -3198,7 +3259,7 @@ type MatchIterator<'a> = PatternAtomTreeIterator<'a, 'a>;

self_cell!(
/// An iterator over matches.
#[pyclass]
#[pyclass(module = "symbolica")]
pub struct PythonMatchIterator {
owner: OwnedMatch,
#[not_covariant]
Expand Down Expand Up @@ -3250,7 +3311,7 @@ type ReplaceIteratorOne<'a> = ReplaceIterator<'a, 'a>;

self_cell!(
/// An iterator over all single replacements.
#[pyclass]
#[pyclass(module = "symbolica")]
pub struct PythonReplaceIterator {
owner: OwnedReplace,
#[not_covariant]
Expand Down Expand Up @@ -3281,7 +3342,7 @@ impl PythonReplaceIterator {
}
}

#[pyclass(name = "Polynomial")]
#[pyclass(name = "Polynomial", module = "symbolica")]
#[derive(Clone)]
pub struct PythonPolynomial {
pub poly: Arc<MultivariatePolynomial<RationalField, u16>>,
Expand Down Expand Up @@ -3497,7 +3558,7 @@ impl PythonPolynomial {
}
}

#[pyclass(name = "Evaluator")]
#[pyclass(name = "Evaluator", module = "symbolica")]
#[derive(Clone)]
pub struct PythonInstructionEvaluator {
pub instr: Arc<InstructionEvaluator<f64>>,
Expand All @@ -3516,7 +3577,7 @@ impl PythonInstructionEvaluator {
}
}

#[pyclass(name = "IntegerPolynomial")]
#[pyclass(name = "IntegerPolynomial", module = "symbolica")]
#[derive(Clone)]
pub struct PythonIntegerPolynomial {
pub poly: Arc<MultivariatePolynomial<IntegerRing, u8>>,
Expand Down Expand Up @@ -3575,7 +3636,7 @@ impl PythonIntegerPolynomial {
}

/// A Symbolica polynomial over finite fields.
#[pyclass(name = "FiniteFieldPolynomial")]
#[pyclass(name = "FiniteFieldPolynomial", module = "symbolica")]
#[derive(Clone)]
pub struct PythonFiniteFieldPolynomial {
pub poly: Arc<MultivariatePolynomial<Zp, u16>>,
Expand Down Expand Up @@ -4116,7 +4177,7 @@ generate_methods!(PythonIntegerPolynomial, u8);
generate_methods!(PythonFiniteFieldPolynomial, u16);

/// A Symbolica rational polynomial.
#[pyclass(name = "RationalPolynomial")]
#[pyclass(name = "RationalPolynomial", module = "symbolica")]
#[derive(Clone)]
pub struct PythonRationalPolynomial {
pub poly: Arc<RationalPolynomial<IntegerRing, u16>>,
Expand Down Expand Up @@ -4160,7 +4221,7 @@ impl PythonRationalPolynomial {
}

/// A Symbolica rational polynomial with variable powers limited to 255.
#[pyclass(name = "RationalPolynomialSmallExponent")]
#[pyclass(name = "RationalPolynomialSmallExponent", module = "symbolica")]
#[derive(Clone)]
pub struct PythonRationalPolynomialSmallExponent {
pub poly: Arc<RationalPolynomial<IntegerRing, u8>>,
Expand Down Expand Up @@ -4225,7 +4286,7 @@ generate_rat_parse!(PythonRationalPolynomial);
generate_rat_parse!(PythonRationalPolynomialSmallExponent);

/// A Symbolica rational polynomial over finite fields.
#[pyclass(name = "FiniteFieldRationalPolynomial")]
#[pyclass(name = "FiniteFieldRationalPolynomial", module = "symbolica")]
#[derive(Clone)]
pub struct PythonFiniteFieldRationalPolynomial {
pub poly: Arc<RationalPolynomial<Zp, u16>>,
Expand Down Expand Up @@ -4512,7 +4573,7 @@ pub enum ScalarOrMatrix {
}

/// A Symbolica matrix with rational polynomial coefficients.
#[pyclass(name = "Matrix")]
#[pyclass(name = "Matrix", module = "symbolica")]
#[derive(Clone)]
pub struct PythonMatrix {
pub matrix: Arc<Matrix<RationalPolynomialField<IntegerRing, u16>>>,
Expand Down Expand Up @@ -4976,7 +5037,7 @@ impl PythonMatrix {

/// A sample from the Symbolica integrator. It could consist of discrete layers,
/// accessible with `d` (empty when there are not discrete layers), and the final continous layer `c` if it is present.
#[pyclass(name = "Sample")]
#[pyclass(name = "Sample", module = "symbolica")]
#[derive(Clone)]
pub struct PythonSample {
#[pyo3(get)]
Expand Down Expand Up @@ -5052,7 +5113,7 @@ impl PythonSample {
///
/// Each thread or instance generating samples should use the same `seed` but a different `stream_id`,
/// which is an instance counter starting at 0.
#[pyclass(name = "RandomNumberGenerator")]
#[pyclass(name = "RandomNumberGenerator", module = "symbolica")]
struct PythonRandomNumberGenerator {
state: MonteCarloRng,
}
Expand All @@ -5069,7 +5130,7 @@ impl PythonRandomNumberGenerator {
}
}

#[pyclass(name = "NumericalIntegrator")]
#[pyclass(name = "NumericalIntegrator", module = "symbolica")]
#[derive(Clone)]
struct PythonNumericalIntegrator {
grid: Grid<f64>,
Expand Down
11 changes: 11 additions & 0 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ impl std::fmt::Debug for Symbol {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AtomType {
Num,
Var,
Add,
Mul,
Pow,
Fun,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SliceType {
Add,
Expand Down Expand Up @@ -356,6 +366,7 @@ pub enum Atom {
Pow(Pow),
Mul(Mul),
Add(Add),
#[doc(hidden)]
Empty, // for internal use
}

Expand Down
Loading

0 comments on commit d74b329

Please sign in to comment.