Skip to content

Commit

Permalink
SegmentedOperand compatible with dataclasses.replace
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Mar 3, 2025
1 parent 2cde3a0 commit cb23e6f
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions cuequivariance/cuequivariance/segmented_operand.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
class SegmentedOperand:
"""A tensor product operand. It is a list of segments and subscripts."""

_segments: list[tuple[int, ...]]
subscripts: stp.Subscripts
segments: tuple[tuple[int, ...]]
_dims: dict[str, set[int]]

def __init__(
Expand All @@ -42,7 +42,7 @@ def __init__(

if segments is None:
segments = []
object.__setattr__(self, "_segments", segments)
object.__setattr__(self, "segments", tuple(segments))

if _dims is None:
_dims = dict()
Expand Down Expand Up @@ -88,8 +88,20 @@ def insert_segment(
f"segment has {len(segment)} dimensions, expected {len(self.subscripts)} for subscripts {self.subscripts}."
)

if index < 0:
index = len(self.segments) + index

if index < 0 or index > len(self.segments):
raise ValueError(
f"index {index} is out of bounds for segments {self.segments}."
)

segment = tuple(int(d) for d in segment)
self._segments.insert(index, segment)
object.__setattr__(
self,
"segments",
self.segments[:index] + (segment,) + self.segments[index:],
)

for m, d in zip(self.subscripts, segment):
self._dims.setdefault(m, set()).add(d)
Expand All @@ -100,12 +112,14 @@ def add_segment(self, segment: Union[tuple[int, ...], dict[str, int]]) -> int:
return len(self.segments) - 1

def __hash__(self) -> int:
return hash((tuple(self.segments), self.subscripts))
return hash((self.segments, self.subscripts))

def __eq__(self, other: SegmentedOperand) -> bool:
assert isinstance(other, SegmentedOperand)
return self.subscripts == other.subscripts and self.segments == other.segments

def __lt__(self, other: SegmentedOperand) -> bool:
assert isinstance(other, SegmentedOperand)
return (self.subscripts, self.segments) < (other.subscripts, other.segments)

def __repr__(self) -> str:
Expand All @@ -121,11 +135,6 @@ def __len__(self) -> int:
def __iter__(self):
return iter(self.segments)

@property
def segments(self) -> tuple[tuple[int, ...], ...]:
"""The segments of the operand."""
return tuple(self._segments)

@property
def num_segments(self) -> int:
"""The number of segments in the operand."""
Expand Down Expand Up @@ -161,12 +170,6 @@ def get_dims(self, m: str) -> set[int]:
"""Return the dimensions for a given channel."""
return self._dims.get(m, set()).copy()

def get_segment_shape(
self, dims: dict[str, int], *, default: int = -1
) -> tuple[int, ...]:
"""Return the shape of a potential segment."""
return tuple(dims.get(ch, default) for ch in self.subscripts)

def transpose_modes(
self, subscripts: Union[str, Sequence[str], Sequence[int]]
) -> SegmentedOperand:
Expand Down

0 comments on commit cb23e6f

Please sign in to comment.