diff --git a/examples/evaluate.rs b/examples/evaluate.rs index 082e6203..280c3b3f 100644 --- a/examples/evaluate.rs +++ b/examples/evaluate.rs @@ -37,5 +37,6 @@ fn main() { println!( "Result for x = 6.: {}", a.evaluate::(|x| x.into(), &const_map, &fn_map, &mut cache) + .unwrap() ); } diff --git a/src/api/python.rs b/src/api/python.rs index 79eefcc0..9770f87c 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -3866,6 +3866,40 @@ impl PythonExpression { Ok(PythonExpressionEvaluator { eval: eval_f64 }) } + + /// Canonize (products of) tensors in the expression by relabeling repeated indices. + /// The tensors must be written as functions, with its indices are the arguments. + /// The repeated indices should be provided in `contracted_indices`. + /// + /// Examples + /// -------- + /// g = Expression.symbol('g', is_symmetric=True) + /// >>> fc = Expression.symbol('fc', is_cyclesymmetric=True) + /// >>> mu1, mu2, mu3, mu4, k1 = Expression.symbols('mu1', 'mu2', 'mu3', 'mu4', 'k1') + /// >>> + /// >>> e = g(mu2, mu3)*fc(mu4, mu2, k1, mu4, k1, mu3) + /// >>> + /// >>> print(e.canonize_tensors([mu1, mu2, mu3, mu4])) + /// yields `g(mu1,mu2)*fc(mu1,mu3,mu2,k1,mu3,k1)`. + fn canonize_tensors(&self, contracted_indices: Vec) -> PyResult { + let contracted_indices = contracted_indices + .into_iter() + .map(|x| x.to_expression().expr) + .collect::>(); + let contracted_indices = contracted_indices + .iter() + .map(|x| x.as_view()) + .collect::>(); + + let r = self + .expr + .canonize_tensors(&contracted_indices) + .map_err(|e| { + exceptions::PyValueError::new_err(format!("Could not canonize tensors: {}", e)) + })?; + + Ok(r.into()) + } } /// A raplacement, which is a pattern and a right-hand side, with optional conditions and settings. diff --git a/src/atom.rs b/src/atom.rs index 5cc12ccf..310fe0e0 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -201,12 +201,21 @@ impl<'a> From> for AtomView<'a> { } /// A copy-on-write structure for `Atom` and `AtomView`. -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum AtomOrView<'a> { Atom(Atom), View(AtomView<'a>), } +impl<'a> std::fmt::Display for AtomOrView<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AtomOrView::Atom(a) => a.fmt(f), + AtomOrView::View(a) => a.fmt(f), + } + } +} + impl<'a> PartialEq for AtomOrView<'a> { #[inline] fn eq(&self, other: &Self) -> bool { @@ -220,6 +229,23 @@ impl<'a> PartialEq for AtomOrView<'a> { impl Eq for AtomOrView<'_> {} +impl<'a> PartialOrd for AtomOrView<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl<'a> Ord for AtomOrView<'a> { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (AtomOrView::Atom(a1), AtomOrView::Atom(a2)) => a1.as_view().cmp(&a2.as_view()), + (AtomOrView::Atom(a1), AtomOrView::View(a2)) => a1.as_view().cmp(a2), + (AtomOrView::View(a1), AtomOrView::Atom(a2)) => a1.cmp(&a2.as_view()), + (AtomOrView::View(a1), AtomOrView::View(a2)) => a1.cmp(a2), + } + } +} + impl Hash for AtomOrView<'_> { #[inline] fn hash(&self, state: &mut H) { diff --git a/src/atom/representation.rs b/src/atom/representation.rs index 706ff545..dbcce4dc 100644 --- a/src/atom/representation.rs +++ b/src/atom/representation.rs @@ -1018,19 +1018,32 @@ impl<'a> FunView<'a> { #[inline(always)] pub fn is_symmetric(&self) -> bool { + if self.data[0] & FUN_SYMMETRIC_FLAG == 0 { + return false; + } + let id = self.data[1 + 4..].get_frac_u64().0; - self.data[0] & FUN_SYMMETRIC_FLAG != 0 && id & FUN_ANTISYMMETRIC_FLAG == 0 + id & FUN_ANTISYMMETRIC_FLAG == 0 } #[inline(always)] pub fn is_antisymmetric(&self) -> bool { + if self.data[0] & FUN_SYMMETRIC_FLAG != 0 { + return false; + } + let id = self.data[1 + 4..].get_frac_u64().0; - !self.is_symmetric() && id & FUN_ANTISYMMETRIC_FLAG != 0 + id & FUN_ANTISYMMETRIC_FLAG != 0 } #[inline(always)] pub fn is_cyclesymmetric(&self) -> bool { - self.is_symmetric() && self.is_antisymmetric() + if self.data[0] & FUN_SYMMETRIC_FLAG == 0 { + return false; + } + + let id = self.data[1 + 4..].get_frac_u64().0; + id & FUN_ANTISYMMETRIC_FLAG != 0 } #[inline(always)] diff --git a/src/graph.rs b/src/graph.rs index dae44ccf..c85f16fd 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -26,7 +26,7 @@ pub struct Edge { } /// Empty data type. -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Empty; impl Display for Empty { @@ -83,7 +83,11 @@ impl Graph { pub fn add_edge(&mut self, source: usize, target: usize, directed: bool, data: E) { let index = self.edges.len(); self.edges.push(Edge { - vertices: (source, target), + vertices: if !directed && source > target { + (target, source) + } else { + (source, target) + }, directed, data, }); @@ -517,9 +521,9 @@ impl SearchTreeNode { }; if j.contains(&k) { if e.directed { - edge_data.push((&e.data, is_source)); + edge_data.push((&e.data, e.directed, is_source)); } else { - edge_data.push((&e.data, true)); + edge_data.push((&e.data, false, false)); } } } diff --git a/src/state.rs b/src/state.rs index 189138d4..09ec7d73 100644 --- a/src/state.rs +++ b/src/state.rs @@ -138,6 +138,12 @@ impl State { &[FunctionAttribute::Symmetric], ); } + for i in 0..5 { + let _ = self.get_symbol_with_attributes_impl( + &format!("fc{}", i), + &[FunctionAttribute::Cyclesymmetric], + ); + } for i in 0..5 { let _ = self.get_symbol_with_attributes_impl( &format!("fa{}", i), @@ -374,6 +380,7 @@ impl State { dest.write_u8(s.get_wildcard_level())?; dest.write_u8(s.is_symmetric() as u8)?; dest.write_u8(s.is_antisymmetric() as u8)?; + dest.write_u8(s.is_cyclesymmetric() as u8)?; dest.write_u8(s.is_linear() as u8)?; } @@ -446,6 +453,7 @@ impl State { let wildcard_level = source.read_u8()?; let is_symmetric = source.read_u8()? != 0; let is_antisymmetric = source.read_u8()? != 0; + let is_cyclesymmetric = source.read_u8()? != 0; let is_linear = source.read_u8()? != 0; attributes.clear(); @@ -455,6 +463,9 @@ impl State { if is_symmetric { attributes.push(FunctionAttribute::Symmetric); } + if is_cyclesymmetric { + attributes.push(FunctionAttribute::Cyclesymmetric); + } if is_linear { attributes.push(FunctionAttribute::Linear); } diff --git a/src/tensors.rs b/src/tensors.rs index 3bb3506e..6abfc3d2 100644 --- a/src/tensors.rs +++ b/src/tensors.rs @@ -1 +1,349 @@ +use crate::{ + atom::{Atom, AtomOrView, AtomView, Mul}, + graph::Graph, + state::{RecycledAtom, Workspace}, +}; + pub mod matrix; + +impl Atom { + /// Canonize (products of) tensors in the expression by relabeling repeated indices. + /// The tensors must be written as functions, with its indices are the arguments. + /// The repeated indices should be provided in `contracted_indices`. + /// + /// Example + /// ------- + /// ``` + /// # use symbolica::{atom::Atom, state::{FunctionAttribute, State}}; + /// # + /// # fn main() { + /// let _ = State::get_symbol_with_attributes("fs", &[FunctionAttribute::Symmetric]).unwrap(); + /// let _ = State::get_symbol_with_attributes("fc", &[FunctionAttribute::Cyclesymmetric]).unwrap(); + /// let a = Atom::parse("fs(mu2,mu3)*fc(mu4,mu2,k1,mu4,k1,mu3)").unwrap(); + /// + /// let mu1 = Atom::parse("mu1").unwrap(); + /// let mu2 = Atom::parse("mu2").unwrap(); + /// let mu3 = Atom::parse("mu3").unwrap(); + /// let mu4 = Atom::parse("mu4").unwrap(); + /// + /// let r = a.canonize_tensors(&[mu1.as_view(), mu2.as_view(), mu3.as_view(), mu4.as_view()]).unwrap(); + /// println!("{}", r); + /// # } + /// ``` + /// yields `fs(mu1,mu2)*fc(mu1,k1,mu3,k1,mu2,mu3)`. + pub fn canonize_tensors(&self, contracted_indices: &[AtomView]) -> Result { + self.as_view().canonize_tensors(contracted_indices) + } +} + +impl<'a> AtomView<'a> { + /// Canonize (products of) tensors in the expression by relabeling repeated indices. + /// The tensors must be written as functions, with its indices are the arguments. + /// The repeated indices should be provided in `contracted_indices`. + pub fn canonize_tensors(&self, contracted_indices: &[AtomView]) -> Result { + Workspace::get_local().with(|ws| { + if let AtomView::Add(a) = self { + let mut aa = ws.new_atom(); + let add = aa.to_add(); + + for a in a.iter() { + add.extend(a.canonize_tensor_product(contracted_indices, ws)?.as_view()); + } + + let mut out = Atom::new(); + aa.as_view().normalize(ws, &mut out); + Ok(out) + } else { + Ok(self + .canonize_tensor_product(contracted_indices, ws)? + .into_inner()) + } + }) + } + + /// Canonize a tensor product by relabeling repeated indices. + fn canonize_tensor_product( + &self, + contracted_indices: &[AtomView], + ws: &Workspace, + ) -> Result { + let mut g = Graph::new(); + let mut connections = vec![None; contracted_indices.len()]; + + let mut t = ws.new_atom(); + let mul = t.to_mul(); + + self.tensor_to_graph_impl(contracted_indices, &mut connections, &mut g, mul)?; + + for (i, f) in contracted_indices.iter().zip(&connections) { + if f.is_some() { + return Err(format!("Index {} is not contracted", i)); + } + } + + let (_, gc) = g.canonize(); + + let mut funcs = vec![]; + for n in gc.nodes() { + funcs.push((true, n.data.clone().as_mut().clone())); + } + + // connect dummy indices + let mut index_count = 0; + for e in gc.edges() { + if e.directed { + continue; + } + + if let Atom::Fun(f) = &mut funcs[e.vertices.0].1 { + f.add_arg(contracted_indices[index_count]); + } else { + unreachable!("Only functions should be left"); + } + + if let Atom::Fun(f) = &mut funcs[e.vertices.1].1 { + f.add_arg(contracted_indices[index_count]); + } else { + unreachable!("Only functions should be left"); + } + + index_count += 1; + } + + // now join all regular and cyclesymmetric functions + // the start of the cyclesymmetric function is determined by its + // first encountered node + for fi in 0..funcs.len() { + if !funcs[fi].0 { + continue; + } + + if let AtomView::Fun(ff) = funcs[fi].1.as_view() { + if ff.get_symbol().is_symmetric() { + funcs[fi].0 = false; + mul.extend(funcs[fi].1.as_view()); + continue; + } else if !ff.get_symbol().is_cyclesymmetric() { + // check if the current index is the start of a regular function + if gc.node(fi).edges.iter().any(|ei| { + let e = gc.edge(*ei); + e.directed && e.vertices.0 == fi && e.data != 1 + }) { + continue; + } + } + } + + let mut ff = funcs[fi].1.clone(); + let mut cur_pos = fi; + 'next: loop { + funcs[cur_pos].0 = false; + + for ei in &gc.node(cur_pos).edges { + let e = gc.edge(*ei); + + if e.directed && e.vertices.0 == cur_pos { + debug_assert!(e.vertices.0 != e.vertices.1); + + if e.vertices.1 == fi { + // cycle completed + break 'next; + } + + debug_assert!(funcs[e.vertices.1].0); + + if let Atom::Fun(ff) = &mut ff { + if let AtomView::Fun(f) = funcs[e.vertices.1].1.as_view() { + for a in f.iter() { + ff.add_arg(a); + } + } + } + + cur_pos = e.vertices.1; + + continue 'next; + } + } + + break; + } + + mul.extend(ff.as_view()); + } + + debug_assert!(funcs.iter().all(|f| !f.0)); + + let mut out = ws.new_atom(); + t.as_view().normalize(ws, &mut out); + + Ok(out) + } + + fn tensor_to_graph_impl( + &self, + contracted_indices: &[AtomView], + connections: &mut [Option], + g: &mut Graph, usize>, + remainder: &mut Mul, + ) -> Result<(), String> { + match self { + AtomView::Num(_) | AtomView::Var(_) => { + remainder.extend(*self); + Ok(()) + } + AtomView::Pow(p) => { + let (b, e) = p.get_base_exp(); + + if !contracted_indices.iter().any(|a| b.contains(*a)) { + remainder.extend(*self); + Ok(()) + } else { + if let Ok(n) = e.try_into() { + if n > 0 { + for _ in 0..n { + b.tensor_to_graph_impl( + contracted_indices, + connections, + g, + remainder, + )?; + } + Ok(()) + } else { + Err("Only tensors raised to positive powers are supported".to_owned()) + } + } else { + Err("Only tensors raised to positive powers are supported".to_owned()) + } + } + } + AtomView::Fun(f) => { + if !f.iter().any(|a| contracted_indices.contains(&a)) { + remainder.extend(*self); + return Ok(()); + } + + let nargs = f.get_nargs(); + if f.is_symmetric() || nargs == 1 { + // collect all non-dummy arguments + let mut ff = Atom::new(); + let fff = ff.to_fun(f.get_symbol()); + + for a in f.iter() { + if !contracted_indices.contains(&a) { + fff.add_arg(a); + } + } + fff.set_normalized(true); + + let n = g.add_node(ff.into()); + for a in f.iter() { + if let Some(p) = contracted_indices.iter().position(|x| x == &a) { + if let Some(n2) = connections[p] { + g.add_edge(n, n2, false, Default::default()); + connections[p] = None; + } else { + connections[p] = Some(n); + } + } + } + + Ok(()) + } else if f.is_antisymmetric() { + Err("Antisymmetric functions are not supported yet".to_owned()) + } else { + let is_cyclesymmetric = f.is_cyclesymmetric(); + + // add a node for every slot + let start = g.nodes().len(); + for (i, a) in f.iter().enumerate() { + let mut ff = Atom::new(); + let fff = ff.to_fun(f.get_symbol()); + + if let Some(p) = contracted_indices.iter().position(|x| x == &a) { + ff.set_normalized(true); + g.add_node(ff.into()); + + if let Some(n2) = connections[p] { + g.add_edge(start + i, n2, false, 0); + connections[p] = None; + } else { + connections[p] = Some(start + i); + } + } else { + fff.add_arg(a); + fff.set_normalized(true); + g.add_node(ff.into()); + } + + if i != 0 { + g.add_edge( + start + i - 1, + start + i, + true, + if is_cyclesymmetric { 0 } else { i }, + ); + } + } + + if is_cyclesymmetric { + g.add_edge(start + nargs - 1, start, true, 0); + } + + Ok(()) + } + } + AtomView::Mul(m) => { + for a in m.iter() { + a.tensor_to_graph_impl(contracted_indices, connections, g, remainder)?; + } + Ok(()) + } + AtomView::Add(_) => { + if !contracted_indices.iter().any(|a| self.contains(*a)) { + remainder.extend(*self); + Ok(()) + } else { + Err( + "Nested additions containing contracted indices is not supported" + .to_owned(), + ) + } + } + } + } +} + +#[cfg(test)] +mod test { + use crate::{ + atom::{representation::InlineVar, Atom}, + state::State, + }; + + #[test] + fn canonize_tensors() { + // fs1 is symmetric and fc1 is cyclesymmetric + let a1 = Atom::parse( + "fs1(k2,mu1,mu2)*fs1(mu1,mu3)*fc1(mu4,mu2,k1,mu4,k1,mu3)*(1+x)*f(k)*fs1(mu5,mu6)^2*f(mu7,mu9,k3,mu9,mu7)*h(mu8)*i(mu8)+fc1(mu4,mu5,mu6)*fc1(mu5,mu4,mu6)", + ) + .unwrap(); + + let mus: Vec<_> = (0..10) + .map(|i| InlineVar::new(State::get_symbol(format!("mu{}", i + 1)))) + .collect(); + let mu_ref = mus.iter().map(|x| x.as_view()).collect::>(); + + let r1 = a1.canonize_tensors(&mu_ref).unwrap(); + + let a2 = Atom::parse( + "fs1(k2,mu2,mu9)*fs1(mu2,mu5)*fc1(k1,mu8,k1,mu5,mu8,mu9)*(1+x)*f(k)*fs1(mu3,mu6)^2*f(mu7,mu1,k3,mu1,mu7)*h(mu4)*i(mu4)+fc1(mu1,mu4,mu6)*fc1(mu4,mu1,mu6)", + ) + .unwrap(); + + let r2 = a2.canonize_tensors(&mu_ref).unwrap(); + + assert_eq!(r1, r2); + } +} diff --git a/symbolica.pyi b/symbolica.pyi index 6198289e..3d949a71 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -1154,6 +1154,22 @@ class Expression: will recycle the `x^2` """ + def canonize_tensors(self, contracted_indices: Sequence[Expression | int]) -> Expression: + """Canonize (products of) tensors in the expression by relabeling repeated indices. + The tensors must be written as functions, with its indices are the arguments. + The repeated indices should be provided in `contracted_indices`. + + Examples + -------- + >>> g = Expression.symbol('g', is_symmetric=True) + >>> fc = Expression.symbol('fc', is_cyclesymmetric=True) + >>> mu1, mu2, mu3, mu4, k1 = Expression.symbols('mu1', 'mu2', 'mu3', 'mu4', 'k1') + >>> e = g(mu2, mu3)*fc(mu4, mu2, k1, mu4, k1, mu3) + >>> print(e.canonize_tensors([mu1, mu2, mu3, mu4])) + + yields `g(mu1,mu2)*fc(mu1,mu3,mu2,k1,mu3,k1)`. + """ + class Replacement: """A replacement of a pattern by a right-hand side."""