From 3323324f4a4cf017d83caac5c7f972cc6c08bf84 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Sun, 25 Aug 2024 11:45:52 +0200 Subject: [PATCH] Add cyclesymmetric symbols --- src/api/python.rs | 34 ++++++++++++++++++++++++++++------ src/atom.rs | 8 ++++++++ src/atom/representation.rs | 36 ++++++++++++++++++++++++++++-------- src/normalize.rs | 33 +++++++++++++++++++++++++++++++++ src/state.rs | 19 +++++++++++-------- symbolica.pyi | 12 +++++++++--- 6 files changed, 117 insertions(+), 25 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index 01d45ea1..68abb64c 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -1485,7 +1485,8 @@ macro_rules! req_wc_cmp { impl PythonExpression { /// Create a new symbol from a `name`. Symbols carry information about their attributes. /// The symbol can signal that it is symmetric if it is used as a function - /// using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`, and + /// using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`, + /// cyclesymmetric using `is_cyclesymmetric=True` and /// multilinear using `is_linear=True`. If no attributes /// are specified, the attributes are inherited from the symbol if it was already defined, /// otherwise all attributes are set to `false`. @@ -1523,15 +1524,24 @@ impl PythonExpression { name: &str, is_symmetric: Option, is_antisymmetric: Option, + is_cyclesymmetric: Option, is_linear: Option, ) -> PyResult { - if is_symmetric.is_none() && is_antisymmetric.is_none() && is_linear.is_none() { + if is_symmetric.is_none() + && is_antisymmetric.is_none() + && is_cyclesymmetric.is_none() + && is_linear.is_none() + { return Ok(Atom::new_var(State::get_symbol(name)).into()); } - if is_symmetric == Some(true) && is_antisymmetric == Some(true) { + let count = (is_symmetric == Some(true)) as u8 + + (is_antisymmetric == Some(true)) as u8 + + (is_cyclesymmetric == Some(true)) as u8; + + if count > 1 { Err(exceptions::PyValueError::new_err( - "Function cannot be both symmetric and antisymmetric", + "Function cannot be both symmetric, antisymmetric or cyclesymmetric", ))?; } @@ -1545,6 +1555,10 @@ impl PythonExpression { opts.push(FunctionAttribute::Antisymmetric); } + if let Some(true) = is_cyclesymmetric { + opts.push(FunctionAttribute::CycleSymmetric); + } + if let Some(true) = is_linear { opts.push(FunctionAttribute::Linear); } @@ -1563,20 +1577,28 @@ impl PythonExpression { /// >>> e = f(1,x) /// >>> print(e) /// f(1,x) - #[pyo3(signature = (*args,is_symmetric=None,is_antisymmetric=None,is_linear=None))] + #[pyo3(signature = (*args,is_symmetric=None,is_antisymmetric=None,is_cyclesymmetric=None,is_linear=None))] #[classmethod] pub fn symbols( cls: &PyType, args: &PyTuple, is_symmetric: Option, is_antisymmetric: Option, + is_cyclesymmetric: Option, is_linear: Option, ) -> PyResult> { let mut result = Vec::with_capacity(args.len()); for a in args { let name = a.extract::<&str>()?; - let s = Self::symbol(cls, name, is_symmetric, is_antisymmetric, is_linear)?; + let s = Self::symbol( + cls, + name, + is_symmetric, + is_antisymmetric, + is_cyclesymmetric, + is_linear, + )?; result.push(s); } diff --git a/src/atom.rs b/src/atom.rs index b4c10e65..5cc12ccf 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -26,6 +26,7 @@ pub struct Symbol { wildcard_level: u8, is_symmetric: bool, is_antisymmetric: bool, + is_cyclesymmetric: bool, is_linear: bool, } @@ -48,6 +49,7 @@ impl Symbol { wildcard_level, is_symmetric: false, is_antisymmetric: false, + is_cyclesymmetric: false, is_linear: false, } } @@ -59,6 +61,7 @@ impl Symbol { wildcard_level: u8, is_symmetric: bool, is_antisymmetric: bool, + is_cyclesymmetric: bool, is_linear: bool, ) -> Self { Symbol { @@ -66,6 +69,7 @@ impl Symbol { wildcard_level, is_symmetric, is_antisymmetric, + is_cyclesymmetric, is_linear, } } @@ -86,6 +90,10 @@ impl Symbol { self.is_antisymmetric } + pub fn is_cyclesymmetric(&self) -> bool { + self.is_cyclesymmetric + } + pub fn is_linear(&self) -> bool { self.is_linear } diff --git a/src/atom/representation.rs b/src/atom/representation.rs index 6778e65c..706ff545 100644 --- a/src/atom/representation.rs +++ b/src/atom/representation.rs @@ -30,6 +30,7 @@ const VAR_WILDCARD_LEVEL_3: u8 = 0b00011000; const FUN_SYMMETRIC_FLAG: u8 = 0b00100000; const FUN_LINEAR_FLAG: u8 = 0b01000000; const VAR_ANTISYMMETRIC_FLAG: u8 = 0b10000000; +const VAR_CYCLESYMMETRIC_FLAG: u8 = 0b10100000; // coded as symmetric | antisymmetric const FUN_ANTISYMMETRIC_FLAG: u64 = 1 << 32; // stored in the function id const MUL_HAS_COEFF_FLAG: u8 = 0b01000000; @@ -64,6 +65,9 @@ impl InlineVar { if symbol.is_antisymmetric { flags |= VAR_ANTISYMMETRIC_FLAG; } + if symbol.is_cyclesymmetric { + flags |= VAR_CYCLESYMMETRIC_FLAG; + } data[0] = flags; @@ -321,6 +325,9 @@ impl Var { if symbol.is_antisymmetric { flags |= VAR_ANTISYMMETRIC_FLAG; } + if symbol.is_cyclesymmetric { + flags |= VAR_CYCLESYMMETRIC_FLAG; + } self.data.put_u8(flags); @@ -391,7 +398,7 @@ impl Fun { _ => flags |= VAR_WILDCARD_LEVEL_3, } - if symbol.is_symmetric { + if symbol.is_symmetric || symbol.is_cyclesymmetric { flags |= FUN_SYMMETRIC_FLAG; } if symbol.is_linear { @@ -404,7 +411,7 @@ impl Fun { let buf_pos = self.data.len(); - let id = if symbol.is_antisymmetric { + let id = if symbol.is_antisymmetric || symbol.is_cyclesymmetric { symbol.id as u64 | FUN_ANTISYMMETRIC_FLAG } else { symbol.id as u64 @@ -902,11 +909,14 @@ impl<'a> VarView<'a> { #[inline(always)] pub fn get_symbol(&self) -> Symbol { + let is_cyclesymmetric = self.data[0] & VAR_CYCLESYMMETRIC_FLAG != 0; + Symbol::init_fn( self.data[1..].get_frac_u64().0 as u32, self.get_wildcard_level(), - self.data[0] & FUN_SYMMETRIC_FLAG != 0, - self.data[0] & VAR_ANTISYMMETRIC_FLAG != 0, + !is_cyclesymmetric && self.data[0] & FUN_SYMMETRIC_FLAG != 0, + !is_cyclesymmetric && self.data[0] & VAR_ANTISYMMETRIC_FLAG != 0, + is_cyclesymmetric, self.data[0] & FUN_LINEAR_FLAG != 0, ) } @@ -993,24 +1003,34 @@ impl<'a> FunView<'a> { pub fn get_symbol(&self) -> Symbol { let id = self.data[1 + 4..].get_frac_u64().0; + let is_cyclesymmetric = + self.data[0] & FUN_SYMMETRIC_FLAG != 0 && id & FUN_ANTISYMMETRIC_FLAG != 0; + Symbol::init_fn( id as u32, self.get_wildcard_level(), - self.is_symmetric(), - id & FUN_ANTISYMMETRIC_FLAG != 0, + !is_cyclesymmetric && self.data[0] & FUN_SYMMETRIC_FLAG != 0, + !is_cyclesymmetric && id & FUN_ANTISYMMETRIC_FLAG != 0, + is_cyclesymmetric, self.is_linear(), ) } #[inline(always)] pub fn is_symmetric(&self) -> bool { - self.data[0] & FUN_SYMMETRIC_FLAG != 0 + let id = self.data[1 + 4..].get_frac_u64().0; + self.data[0] & FUN_SYMMETRIC_FLAG != 0 && id & FUN_ANTISYMMETRIC_FLAG == 0 } #[inline(always)] pub fn is_antisymmetric(&self) -> bool { let id = self.data[1 + 4..].get_frac_u64().0; - id & FUN_ANTISYMMETRIC_FLAG != 0 + !self.is_symmetric() && id & FUN_ANTISYMMETRIC_FLAG != 0 + } + + #[inline(always)] + pub fn is_cyclesymmetric(&self) -> bool { + self.is_symmetric() && self.is_antisymmetric() } #[inline(always)] diff --git a/src/normalize.rs b/src/normalize.rs index a0852986..e1200ccc 100644 --- a/src/normalize.rs +++ b/src/normalize.rs @@ -1087,6 +1087,39 @@ impl<'a> AtomView<'a> { } out_f.set_normalized(true); + } else if id.is_cyclesymmetric() { + let mut args: SmallVec<[_; 20]> = SmallVec::new(); + for a in out_f.to_fun_view().iter() { + args.push(a); + } + + let mut best_shift = 0; + 'shift: for shift in 1..args.len() { + for i in 0..args.len() { + match args[(i + best_shift) % args.len()] + .cmp(&args[(i + shift) % args.len()]) + { + std::cmp::Ordering::Equal => {} + std::cmp::Ordering::Less => { + continue 'shift; + } + std::cmp::Ordering::Greater => break, + } + } + + best_shift = shift; + } + + let mut f = workspace.new_atom(); + let ff = f.to_fun(id); + for arg in args[best_shift..].iter().chain(&args[..best_shift]) { + ff.add_arg(*arg); + } + + drop(args); + + ff.set_normalized(true); + std::mem::swap(ff, out_f); } } AtomView::Pow(p) => { diff --git a/src/state.rs b/src/state.rs index 5d5b3425..189138d4 100644 --- a/src/state.rs +++ b/src/state.rs @@ -38,6 +38,7 @@ pub struct VariableListIndex(pub(crate) usize); pub enum FunctionAttribute { Symmetric, Antisymmetric, + Cyclesymmetric, Linear, } @@ -79,14 +80,14 @@ impl Default for State { } impl State { - pub const ARG: Symbol = Symbol::init_fn(0, 0, false, false, false); - pub const COEFF: Symbol = Symbol::init_fn(1, 0, false, false, false); - pub const EXP: Symbol = Symbol::init_fn(2, 0, false, false, false); - pub const LOG: Symbol = Symbol::init_fn(3, 0, false, false, false); - pub const SIN: Symbol = Symbol::init_fn(4, 0, false, false, false); - pub const COS: Symbol = Symbol::init_fn(5, 0, false, false, false); - pub const SQRT: Symbol = Symbol::init_fn(6, 0, false, false, false); - pub const DERIVATIVE: Symbol = Symbol::init_fn(7, 0, false, false, false); + pub const ARG: Symbol = Symbol::init_fn(0, 0, false, false, false, false); + pub const COEFF: Symbol = Symbol::init_fn(1, 0, false, false, false, false); + pub const EXP: Symbol = Symbol::init_fn(2, 0, false, false, false, false); + pub const LOG: Symbol = Symbol::init_fn(3, 0, false, false, false, false); + pub const SIN: Symbol = Symbol::init_fn(4, 0, false, false, false, false); + pub const COS: Symbol = Symbol::init_fn(5, 0, false, false, false, false); + pub const SQRT: Symbol = Symbol::init_fn(6, 0, false, false, false, false); + pub const DERIVATIVE: Symbol = Symbol::init_fn(7, 0, false, false, false, false); pub const E: Symbol = Symbol::init_var(8, 0); pub const I: Symbol = Symbol::init_var(9, 0); pub const PI: Symbol = Symbol::init_var(10, 0); @@ -265,6 +266,7 @@ impl State { r.get_wildcard_level(), attributes.contains(&FunctionAttribute::Symmetric), attributes.contains(&FunctionAttribute::Antisymmetric), + attributes.contains(&FunctionAttribute::Cyclesymmetric), attributes.contains(&FunctionAttribute::Linear), ); @@ -297,6 +299,7 @@ impl State { wildcard_level, attributes.contains(&FunctionAttribute::Symmetric), attributes.contains(&FunctionAttribute::Antisymmetric), + attributes.contains(&FunctionAttribute::Cyclesymmetric), attributes.contains(&FunctionAttribute::Linear), ); diff --git a/symbolica.pyi b/symbolica.pyi index 2022c7c2..6198289e 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -113,11 +113,17 @@ class Expression: """The built-in logarithm function.""" @classmethod - def symbol(_cls, name: str, is_symmetric: Optional[bool] = None, is_antisymmetric: Optional[bool] = None, is_linear: Optional[bool] = None) -> Expression: + def symbol(_cls, + name: str, + is_symmetric: Optional[bool] = None, + is_antisymmetric: Optional[bool] = None, + is_cyclesymmetric: Optional[bool] = None, + is_linear: Optional[bool] = None) -> Expression: """ Create a new symbol from a `name`. Symbols carry information about their attributes. The symbol can signal that it is symmetric if it is used as a function - using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`, and + using `is_symmetric=True`, antisymmetric using `is_antisymmetric=True`, + cyclesymmetric using `is_cyclesymmetric=True`, and multilinear using `is_linear=True`. If no attributes are specified, the attributes are inherited from the symbol if it was already defined, otherwise all attributes are set to `false`. @@ -154,7 +160,7 @@ class Expression: """ @classmethod - def symbols(_cls, *names: str, is_symmetric: Optional[bool] = None, is_antisymmetric: Optional[bool] = None, is_linear: Optional[bool] = None) -> Sequence[Expression]: + def symbols(_cls, *names: str, is_symmetric: Optional[bool] = None, is_antisymmetric: Optional[bool] = None, is_cyclesymmetric: Optional[bool] = None, is_linear: Optional[bool] = None) -> Sequence[Expression]: """ Create a Symbolica symbol for every name in `*names`. See `Expression.symbol` for more information.