Skip to content

Commit

Permalink
Allow user-defined normalization function for new symbols
Browse files Browse the repository at this point in the history
- Fix edge labeling in mermaid output
- Remove `symbols` from `pyi` file
  • Loading branch information
benruijl committed Sep 13, 2024
1 parent 080a6b0 commit 9a424cc
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 34 deletions.
80 changes: 73 additions & 7 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,14 @@ fn get_license_key(email: String) -> PyResult<()> {
}

/// Shorthand notation for :func:`Expression.symbol`.
#[pyfunction(name = "S", signature = (*names,is_symmetric=None,is_antisymmetric=None,is_cyclesymmetric=None,is_linear=None))]
#[pyfunction(name = "S", signature = (*names,is_symmetric=None,is_antisymmetric=None,is_cyclesymmetric=None,is_linear=None,custom_normalization=None))]
fn symbol_shorthand(
names: &PyTuple,
is_symmetric: Option<bool>,
is_antisymmetric: Option<bool>,
is_cyclesymmetric: Option<bool>,
is_linear: Option<bool>,
custom_normalization: Option<PythonTransformer>,
py: Python<'_>,
) -> PyResult<PyObject> {
PythonExpression::symbol(
Expand All @@ -190,6 +191,7 @@ fn symbol_shorthand(
is_antisymmetric,
is_cyclesymmetric,
is_linear,
custom_normalization,
)
}

Expand Down Expand Up @@ -1848,7 +1850,8 @@ impl PythonExpression {
/// cyclesymmetric using `is_cyclesymmetric=True` and
/// multilinear using `is_linear=True`. If no attributes
/// are specified, the attributes are inherited from the symbol if it was already defined,
/// otherwise all attributes are set to `false`.
/// otherwise all attributes are set to `false`. A transformer that is executed
/// after normalization can be defined with `custom_normalization`.
///
/// Once attributes are defined on a symbol, they cannot be redefined later.
///
Expand Down Expand Up @@ -1877,7 +1880,12 @@ impl PythonExpression {
/// >>> dot = Expression.symbol('dot', is_symmetric=True, is_linear=True)
/// >>> e = dot(p2+2*p3,p1+3*p2-p3)
/// dot(p1,p2)+2*dot(p1,p3)+3*dot(p2,p2)-dot(p2,p3)+6*dot(p2,p3)-2*dot(p3,p3)
#[pyo3(signature = (*names,is_symmetric=None,is_antisymmetric=None,is_cyclesymmetric=None,is_linear=None))]
///
///
/// Define a custom normalization function:
/// >>> e = S('real_log', custom_normalization=Transformer().replace_all(E("x_(exp(x1_))"), E("x1_")))
/// >>> E("real_log(exp(x)) + real_log(5)")
#[pyo3(signature = (*names,is_symmetric=None,is_antisymmetric=None,is_cyclesymmetric=None,is_linear=None,custom_normalization=None))]
#[classmethod]
pub fn symbol(
_cls: &PyType,
Expand All @@ -1887,6 +1895,7 @@ impl PythonExpression {
is_antisymmetric: Option<bool>,
is_cyclesymmetric: Option<bool>,
is_linear: Option<bool>,
custom_normalization: Option<PythonTransformer>,
) -> PyResult<PyObject> {
if names.is_empty() {
return Err(exceptions::PyValueError::new_err(
Expand Down Expand Up @@ -1915,6 +1924,7 @@ impl PythonExpression {
&& is_antisymmetric.is_none()
&& is_cyclesymmetric.is_none()
&& is_linear.is_none()
&& custom_normalization.is_none()
{
if names.len() == 1 {
let name = names[0].extract::<&str>()?;
Expand Down Expand Up @@ -1965,17 +1975,73 @@ impl PythonExpression {

if names.len() == 1 {
let name = names[0].extract::<&str>()?;
let name = name_check(name)?;

let id = if let Some(f) = custom_normalization {
if let Pattern::Transformer(t) = f.expr {
if !t.0.is_none() {
Err(exceptions::PyValueError::new_err(
"Transformer must be unbound",
))?;
}

State::get_symbol_with_attributes_and_function(
name,
&opts,
Box::new(move |input, out| {
Workspace::get_local()
.with(|ws| {
Transformer::execute_chain(input, &t.1, ws, out).map_err(|e| e)
})
.unwrap();
true
}),
)
} else {
return Err(exceptions::PyValueError::new_err("Transformer expected"));
}
} else {
State::get_symbol_with_attributes(name, &opts)
}
.map_err(|e| exceptions::PyTypeError::new_err(e.to_string()))?;

let id = State::get_symbol_with_attributes(name_check(name)?, &opts)
.map_err(|e| exceptions::PyTypeError::new_err(e.to_string()))?;
let r = PythonExpression::from(Atom::new_var(id));
Ok(r.into_py(py))
} else {
let mut result = vec![];
for a in names {
let name = a.extract::<&str>()?;
let id = State::get_symbol_with_attributes(name_check(name)?, &opts)
.map_err(|e| exceptions::PyTypeError::new_err(e.to_string()))?;
let name = name_check(name)?;

let id = if let Some(f) = &custom_normalization {
if let Pattern::Transformer(t) = &f.expr {
if !t.0.is_none() {
Err(exceptions::PyValueError::new_err(
"Transformer must be unbound",
))?;
}

let t = t.1.clone();
State::get_symbol_with_attributes_and_function(
name,
&opts,
Box::new(move |input, out| {
Workspace::get_local()
.with(|ws| {
Transformer::execute_chain(input, &t, ws, out)
.map_err(|e| e)
})
.unwrap();
true
}),
)
} else {
return Err(exceptions::PyValueError::new_err("Transformer expected"));
}
} else {
State::get_symbol_with_attributes(name, &opts)
}
.map_err(|e| exceptions::PyTypeError::new_err(e.to_string()))?;
let r = PythonExpression::from(Atom::new_var(id));
result.push(r);
}
Expand Down
8 changes: 4 additions & 4 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ impl<N: Display, E: Display> Graph<N, E> {
for x in &self.edges {
if x.directed {
out.push_str(&format!(
" {} --> {}[{}];\n",
x.vertices.0, x.vertices.1, x.data
" {} -->|{}| {};\n",
x.vertices.0, x.data, x.vertices.1,
));
} else {
out.push_str(&format!(
" {} --- {}[{}];\n",
x.vertices.0, x.vertices.1, x.data
" {} ---|{}| {};\n",
x.vertices.0, x.data, x.vertices.1,
));
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,14 @@ impl<'a> AtomView<'a> {
ff.set_normalized(true);
std::mem::swap(ff, out_f);
}

if let Some(f) = State::get_normalization_function(id) {
let mut fs = workspace.new_atom();
if f(out.as_view(), &mut fs) {
std::mem::swap(out, fs.deref_mut());
}
debug_assert!(!out.as_view().needs_normalization());
}
}
AtomView::Pow(p) => {
let (base, exp) = p.get_base_exp();
Expand Down
143 changes: 137 additions & 6 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use byteorder::LittleEndian;
use once_cell::sync::Lazy;
use smartstring::alias::String;

use crate::atom::AtomView;
use crate::domains::finite_field::Zp64;
use crate::poly::Variable;
use crate::{
Expand Down Expand Up @@ -54,8 +55,18 @@ impl StateMap {
}
}

/// A function that is called after normalization of the arguments.
/// If the input, the first argument, is normalized, the function should return `false`.
/// Otherwise, the function must return `true` and set the second argument to the normalized value.
pub type NormalizationFunction = Box<dyn Fn(AtomView, &mut Atom) -> bool + Send + Sync>;

struct SymbolData {
name: String,
function: Option<NormalizationFunction>,
}

static STATE: Lazy<RwLock<State>> = Lazy::new(|| RwLock::new(State::new()));
static ID_TO_STR: AppendOnlyVec<(Symbol, String)> = AppendOnlyVec::new();
static ID_TO_STR: AppendOnlyVec<(Symbol, SymbolData)> = AppendOnlyVec::new();
static FINITE_FIELDS: AppendOnlyVec<Zp64> = AppendOnlyVec::new();
static VARIABLE_LISTS: AppendOnlyVec<Arc<Vec<Variable>>> = AppendOnlyVec::new();
static SYMBOL_OFFSET: AtomicUsize = AtomicUsize::new(0);
Expand Down Expand Up @@ -197,7 +208,7 @@ impl State {
ID_TO_STR
.iter()
.skip(SYMBOL_OFFSET.load(Ordering::Relaxed))
.map(|s| (s.0, s.1.as_str()))
.map(|s| (s.0, s.1.name.as_str()))
}

/// Returns `true` iff this identifier is defined by Symbolica.
Expand Down Expand Up @@ -234,7 +245,13 @@ impl State {
// as the state itself is behind a mutex
let id = ID_TO_STR.len() - offset;
let new_symbol = Symbol::init_var(id as u32, wildcard_level);
let id_ret = ID_TO_STR.push((new_symbol, name.into())) - offset;
let id_ret = ID_TO_STR.push((
new_symbol,
SymbolData {
name: name.into(),
function: None,
},
)) - offset;
assert_eq!(id, id_ret);

v.insert(new_symbol);
Expand Down Expand Up @@ -279,7 +296,7 @@ impl State {
if r == new_id {
Ok(r)
} else {
Err(format!("Function {} redefined with new attributes", name).into())
Err(format!("Symbol {} redefined with new attributes", name).into())
}
}
Entry::Vacant(v) => {
Expand Down Expand Up @@ -309,7 +326,13 @@ impl State {
attributes.contains(&FunctionAttribute::Linear),
);

let id_ret = ID_TO_STR.push((new_symbol, name.into())) - offset;
let id_ret = ID_TO_STR.push((
new_symbol,
SymbolData {
name: name.into(),
function: None,
},
)) - offset;
assert_eq!(id, id_ret);

v.insert(new_symbol);
Expand All @@ -319,9 +342,86 @@ impl State {
}
}

/// Register a new symbol with the given attributes and a specific function
/// that is called after normalization of the arguments. This function cannot
/// be exported, and therefore before importing a state, symbols with special
/// normalization functions must be registered explicitly.
///
/// If the symbol already exists, an error is returned.
pub fn get_symbol_with_attributes_and_function<S: AsRef<str>>(
name: S,
attributes: &[FunctionAttribute],
f: NormalizationFunction,
) -> Result<Symbol, String> {
STATE
.write()
.unwrap()
.get_symbol_with_attributes_and_function_impl(name.as_ref(), attributes, f)
}

pub(crate) fn get_symbol_with_attributes_and_function_impl(
&mut self,
name: &str,
attributes: &[FunctionAttribute],
f: NormalizationFunction,
) -> Result<Symbol, String> {
if self.str_to_id.contains_key(name) {
Err(format!("Symbol {} already defined", name).into())
} else {
let offset = SYMBOL_OFFSET.load(Ordering::Relaxed);
if ID_TO_STR.len() - offset == u32::MAX as usize - 1 {
panic!("Too many variables defined");
}

// there is no synchronization issue since only one thread can insert at a time
// as the state itself is behind a mutex
let id = ID_TO_STR.len() - offset;

let mut wildcard_level = 0;
for x in name.chars().rev() {
if x != '_' {
break;
}
wildcard_level += 1;
}

let new_symbol = Symbol::init_fn(
id as u32,
wildcard_level,
attributes.contains(&FunctionAttribute::Symmetric),
attributes.contains(&FunctionAttribute::Antisymmetric),
attributes.contains(&FunctionAttribute::Cyclesymmetric),
attributes.contains(&FunctionAttribute::Linear),
);

let id_ret = ID_TO_STR.push((
new_symbol,
SymbolData {
name: name.into(),
function: Some(f),
},
)) - offset;
assert_eq!(id, id_ret);

self.str_to_id.insert(name.into(), new_symbol);

Ok(new_symbol)
}
}

/// Get the name for a given symbol.
pub fn get_name(id: Symbol) -> &'static str {
&ID_TO_STR[id.get_id() as usize + SYMBOL_OFFSET.load(Ordering::Relaxed)].1
&ID_TO_STR[id.get_id() as usize + SYMBOL_OFFSET.load(Ordering::Relaxed)]
.1
.name
}

/// Get the user-specified normalization function for the symbol.
pub fn get_normalization_function(id: Symbol) -> Option<&'static NormalizationFunction> {
ID_TO_STR[id.get_id() as usize + SYMBOL_OFFSET.load(Ordering::Relaxed)]
.1
.function
.as_ref()
}

pub fn get_finite_field(fi: FiniteFieldIndex) -> &'static Zp64 {
Expand Down Expand Up @@ -734,6 +834,8 @@ impl Drop for RecycledAtom {
mod tests {
use std::io::Cursor;

use crate::atom::{Atom, AtomView};

use super::State;

#[test]
Expand All @@ -744,4 +846,33 @@ mod tests {
let i = State::import(Cursor::new(&export), None).unwrap();
assert!(i.is_empty());
}

#[test]
fn custom_normalization() {
let _real_log = State::get_symbol_with_attributes_and_function(
"custom_normalization_real_log",
&[],
Box::new(|input, out| {
if let AtomView::Fun(f) = input {
if f.get_nargs() == 1 {
let arg = f.iter().next().unwrap();
if let AtomView::Fun(f2) = arg {
if f2.get_symbol() == State::EXP {
if f2.get_nargs() == 1 {
out.set_from_view(&f2.iter().next().unwrap());
return true;
}
}
}
}
}

false
}),
)
.unwrap();

let e = Atom::parse("custom_normalization_real_log(exp(x))").unwrap();
assert_eq!(e, Atom::parse("x").unwrap());
}
}
Loading

0 comments on commit 9a424cc

Please sign in to comment.