Skip to content

Commit

Permalink
Add index group to distinguish contracted indices
Browse files Browse the repository at this point in the history
- Make bit operations less error prone
  • Loading branch information
benruijl committed Aug 26, 2024
1 parent 01bef5b commit 35d0d3b
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 37 deletions.
24 changes: 22 additions & 2 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3871,6 +3871,10 @@ impl PythonExpression {
/// The tensors must be written as functions, with its indices are the arguments.
/// The repeated indices should be provided in `contracted_indices`.
///
/// If the contracted indices are distinguishable (for example in their dimension),
/// you can provide an optional group marker for each index using `index_group`.
/// This makes sure that an index will not be renamed to an index from a different group.
///
/// Examples
/// --------
/// g = Expression.symbol('g', is_symmetric=True)
Expand All @@ -3881,7 +3885,11 @@ impl PythonExpression {
/// >>>
/// >>> print(e.canonize_tensors([mu1, mu2, mu3, mu4]))
/// yields `g(mu1,mu2)*fc(mu1,mu3,mu2,k1,mu3,k1)`.
fn canonize_tensors(&self, contracted_indices: Vec<ConvertibleToExpression>) -> PyResult<Self> {
fn canonize_tensors(
&self,
contracted_indices: Vec<ConvertibleToExpression>,
index_group: Option<Vec<ConvertibleToExpression>>,
) -> PyResult<Self> {
let contracted_indices = contracted_indices
.into_iter()
.map(|x| x.to_expression().expr)
Expand All @@ -3891,9 +3899,21 @@ impl PythonExpression {
.map(|x| x.as_view())
.collect::<Vec<_>>();

let index_group = index_group.map(|x| {
x.into_iter()
.map(|x| x.to_expression().expr)
.collect::<Vec<_>>()
});
let index_group = index_group
.as_ref()
.map(|x| x.iter().map(|x| x.as_view()).collect::<Vec<_>>());

let r = self
.expr
.canonize_tensors(&contracted_indices)
.canonize_tensors(
&contracted_indices,
index_group.as_ref().map(|x| x.as_slice()),
)
.map_err(|e| {
exceptions::PyValueError::new_err(format!("Could not canonize tensors: {}", e))
})?;
Expand Down
26 changes: 13 additions & 13 deletions src/atom/representation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,15 +909,15 @@ impl<'a> VarView<'a> {

#[inline(always)]
pub fn get_symbol(&self) -> Symbol {
let is_cyclesymmetric = self.data[0] & VAR_CYCLESYMMETRIC_FLAG != 0;
let is_cyclesymmetric = self.data[0] & VAR_CYCLESYMMETRIC_FLAG == VAR_CYCLESYMMETRIC_FLAG;

Symbol::init_fn(
self.data[1..].get_frac_u64().0 as u32,
self.get_wildcard_level(),
!is_cyclesymmetric && self.data[0] & FUN_SYMMETRIC_FLAG != 0,
!is_cyclesymmetric && self.data[0] & VAR_ANTISYMMETRIC_FLAG != 0,
!is_cyclesymmetric && self.data[0] & FUN_SYMMETRIC_FLAG == FUN_SYMMETRIC_FLAG,
!is_cyclesymmetric && self.data[0] & VAR_ANTISYMMETRIC_FLAG == VAR_ANTISYMMETRIC_FLAG,
is_cyclesymmetric,
self.data[0] & FUN_LINEAR_FLAG != 0,
self.data[0] & FUN_LINEAR_FLAG == FUN_LINEAR_FLAG,
)
}

Expand Down Expand Up @@ -1003,14 +1003,14 @@ impl<'a> FunView<'a> {
pub fn get_symbol(&self) -> Symbol {
let id = self.data[1 + 4..].get_frac_u64().0;

let is_cyclesymmetric =
self.data[0] & FUN_SYMMETRIC_FLAG != 0 && id & FUN_ANTISYMMETRIC_FLAG != 0;
let is_cyclesymmetric = self.data[0] & FUN_SYMMETRIC_FLAG == FUN_SYMMETRIC_FLAG
&& id & FUN_ANTISYMMETRIC_FLAG == FUN_ANTISYMMETRIC_FLAG;

Symbol::init_fn(
id as u32,
self.get_wildcard_level(),
!is_cyclesymmetric && self.data[0] & FUN_SYMMETRIC_FLAG != 0,
!is_cyclesymmetric && id & FUN_ANTISYMMETRIC_FLAG != 0,
!is_cyclesymmetric && self.data[0] & FUN_SYMMETRIC_FLAG == FUN_SYMMETRIC_FLAG,
!is_cyclesymmetric && id & FUN_ANTISYMMETRIC_FLAG == FUN_ANTISYMMETRIC_FLAG,
is_cyclesymmetric,
self.is_linear(),
)
Expand All @@ -1028,12 +1028,12 @@ impl<'a> FunView<'a> {

#[inline(always)]
pub fn is_antisymmetric(&self) -> bool {
if self.data[0] & FUN_SYMMETRIC_FLAG != 0 {
if self.data[0] & FUN_SYMMETRIC_FLAG == FUN_SYMMETRIC_FLAG {
return false;
}

let id = self.data[1 + 4..].get_frac_u64().0;
id & FUN_ANTISYMMETRIC_FLAG != 0
id & FUN_ANTISYMMETRIC_FLAG == FUN_ANTISYMMETRIC_FLAG
}

#[inline(always)]
Expand All @@ -1043,12 +1043,12 @@ impl<'a> FunView<'a> {
}

let id = self.data[1 + 4..].get_frac_u64().0;
id & FUN_ANTISYMMETRIC_FLAG != 0
id & FUN_ANTISYMMETRIC_FLAG == FUN_ANTISYMMETRIC_FLAG
}

#[inline(always)]
pub fn is_linear(&self) -> bool {
self.data[0] & FUN_LINEAR_FLAG != 0
self.data[0] & FUN_LINEAR_FLAG == FUN_LINEAR_FLAG
}

#[inline(always)]
Expand Down Expand Up @@ -1341,7 +1341,7 @@ impl<'a> MulView<'a> {

#[inline]
pub fn has_coefficient(&self) -> bool {
(self.data[0] & MUL_HAS_COEFF_FLAG) != 0
self.data[0] & MUL_HAS_COEFF_FLAG == MUL_HAS_COEFF_FLAG
}

pub fn get_byte_size(&self) -> usize {
Expand Down
Loading

0 comments on commit 35d0d3b

Please sign in to comment.