Skip to content

Commit

Permalink
Add canonization of products of tensors
Browse files Browse the repository at this point in the history
- Flag directed edges during refinement
- Sort undirected edge node indices
- Fix is_cyclesymmetric
- Support cyclesymmetric functions in export
  • Loading branch information
benruijl committed Aug 26, 2024
1 parent 2eb20aa commit 01bef5b
Show file tree
Hide file tree
Showing 8 changed files with 461 additions and 8 deletions.
1 change: 1 addition & 0 deletions examples/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ fn main() {
println!(
"Result for x = 6.: {}",
a.evaluate::<f64, _>(|x| x.into(), &const_map, &fn_map, &mut cache)
.unwrap()
);
}
34 changes: 34 additions & 0 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3866,6 +3866,40 @@ impl PythonExpression {

Ok(PythonExpressionEvaluator { eval: eval_f64 })
}

/// Canonize (products of) tensors in the expression by relabeling repeated indices.
/// The tensors must be written as functions, with its indices are the arguments.
/// The repeated indices should be provided in `contracted_indices`.
///
/// Examples
/// --------
/// g = Expression.symbol('g', is_symmetric=True)
/// >>> fc = Expression.symbol('fc', is_cyclesymmetric=True)
/// >>> mu1, mu2, mu3, mu4, k1 = Expression.symbols('mu1', 'mu2', 'mu3', 'mu4', 'k1')
/// >>>
/// >>> e = g(mu2, mu3)*fc(mu4, mu2, k1, mu4, k1, mu3)
/// >>>
/// >>> print(e.canonize_tensors([mu1, mu2, mu3, mu4]))
/// yields `g(mu1,mu2)*fc(mu1,mu3,mu2,k1,mu3,k1)`.
fn canonize_tensors(&self, contracted_indices: Vec<ConvertibleToExpression>) -> PyResult<Self> {
let contracted_indices = contracted_indices
.into_iter()
.map(|x| x.to_expression().expr)
.collect::<Vec<_>>();
let contracted_indices = contracted_indices
.iter()
.map(|x| x.as_view())
.collect::<Vec<_>>();

let r = self
.expr
.canonize_tensors(&contracted_indices)
.map_err(|e| {
exceptions::PyValueError::new_err(format!("Could not canonize tensors: {}", e))
})?;

Ok(r.into())
}
}

/// A raplacement, which is a pattern and a right-hand side, with optional conditions and settings.
Expand Down
28 changes: 27 additions & 1 deletion src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,21 @@ impl<'a> From<AddView<'a>> for AtomView<'a> {
}

/// A copy-on-write structure for `Atom` and `AtomView`.
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum AtomOrView<'a> {
Atom(Atom),
View(AtomView<'a>),
}

impl<'a> std::fmt::Display for AtomOrView<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AtomOrView::Atom(a) => a.fmt(f),
AtomOrView::View(a) => a.fmt(f),
}
}
}

impl<'a> PartialEq for AtomOrView<'a> {
#[inline]
fn eq(&self, other: &Self) -> bool {
Expand All @@ -220,6 +229,23 @@ impl<'a> PartialEq for AtomOrView<'a> {

impl Eq for AtomOrView<'_> {}

impl<'a> PartialOrd for AtomOrView<'a> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl<'a> Ord for AtomOrView<'a> {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(AtomOrView::Atom(a1), AtomOrView::Atom(a2)) => a1.as_view().cmp(&a2.as_view()),
(AtomOrView::Atom(a1), AtomOrView::View(a2)) => a1.as_view().cmp(a2),
(AtomOrView::View(a1), AtomOrView::Atom(a2)) => a1.cmp(&a2.as_view()),
(AtomOrView::View(a1), AtomOrView::View(a2)) => a1.cmp(a2),
}
}
}

impl Hash for AtomOrView<'_> {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
Expand Down
19 changes: 16 additions & 3 deletions src/atom/representation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1018,19 +1018,32 @@ impl<'a> FunView<'a> {

#[inline(always)]
pub fn is_symmetric(&self) -> bool {
if self.data[0] & FUN_SYMMETRIC_FLAG == 0 {
return false;
}

let id = self.data[1 + 4..].get_frac_u64().0;
self.data[0] & FUN_SYMMETRIC_FLAG != 0 && id & FUN_ANTISYMMETRIC_FLAG == 0
id & FUN_ANTISYMMETRIC_FLAG == 0
}

#[inline(always)]
pub fn is_antisymmetric(&self) -> bool {
if self.data[0] & FUN_SYMMETRIC_FLAG != 0 {
return false;
}

let id = self.data[1 + 4..].get_frac_u64().0;
!self.is_symmetric() && id & FUN_ANTISYMMETRIC_FLAG != 0
id & FUN_ANTISYMMETRIC_FLAG != 0
}

#[inline(always)]
pub fn is_cyclesymmetric(&self) -> bool {
self.is_symmetric() && self.is_antisymmetric()
if self.data[0] & FUN_SYMMETRIC_FLAG == 0 {
return false;
}

let id = self.data[1 + 4..].get_frac_u64().0;
id & FUN_ANTISYMMETRIC_FLAG != 0
}

#[inline(always)]
Expand Down
12 changes: 8 additions & 4 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct Edge<EdgeData = Empty> {
}

/// Empty data type.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Empty;

impl Display for Empty {
Expand Down Expand Up @@ -83,7 +83,11 @@ impl<N, E> Graph<N, E> {
pub fn add_edge(&mut self, source: usize, target: usize, directed: bool, data: E) {
let index = self.edges.len();
self.edges.push(Edge {
vertices: (source, target),
vertices: if !directed && source > target {
(target, source)
} else {
(source, target)
},
directed,
data,
});
Expand Down Expand Up @@ -517,9 +521,9 @@ impl<I: NodeIndex> SearchTreeNode<I> {
};
if j.contains(&k) {
if e.directed {
edge_data.push((&e.data, is_source));
edge_data.push((&e.data, e.directed, is_source));
} else {
edge_data.push((&e.data, true));
edge_data.push((&e.data, false, false));
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ impl State {
&[FunctionAttribute::Symmetric],
);
}
for i in 0..5 {
let _ = self.get_symbol_with_attributes_impl(
&format!("fc{}", i),
&[FunctionAttribute::Cyclesymmetric],
);
}
for i in 0..5 {
let _ = self.get_symbol_with_attributes_impl(
&format!("fa{}", i),
Expand Down Expand Up @@ -374,6 +380,7 @@ impl State {
dest.write_u8(s.get_wildcard_level())?;
dest.write_u8(s.is_symmetric() as u8)?;
dest.write_u8(s.is_antisymmetric() as u8)?;
dest.write_u8(s.is_cyclesymmetric() as u8)?;
dest.write_u8(s.is_linear() as u8)?;
}

Expand Down Expand Up @@ -446,6 +453,7 @@ impl State {
let wildcard_level = source.read_u8()?;
let is_symmetric = source.read_u8()? != 0;
let is_antisymmetric = source.read_u8()? != 0;
let is_cyclesymmetric = source.read_u8()? != 0;
let is_linear = source.read_u8()? != 0;

attributes.clear();
Expand All @@ -455,6 +463,9 @@ impl State {
if is_symmetric {
attributes.push(FunctionAttribute::Symmetric);
}
if is_cyclesymmetric {
attributes.push(FunctionAttribute::Cyclesymmetric);
}
if is_linear {
attributes.push(FunctionAttribute::Linear);
}
Expand Down
Loading

0 comments on commit 01bef5b

Please sign in to comment.