Skip to content

Commit

Permalink
is_simple()
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Oct 31, 2024
1 parent 8274ad4 commit 09201ba
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 37 deletions.
4 changes: 2 additions & 2 deletions cuequivariance/cuequivariance/representation/rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def continuous_generators(self) -> np.ndarray:

@property
def X(self) -> np.ndarray:
"""Generators of the representation"""
"""Generators of the representation, (lie_dim, dim, dim)"""
return self.continuous_generators()

def discrete_generators(self) -> np.ndarray:
Expand All @@ -130,7 +130,7 @@ def discrete_generators(self) -> np.ndarray:

@property
def H(self) -> np.ndarray:
"""Discrete generators of the representation"""
"""Discrete generators of the representation, (len(H), dim, dim)"""
return self.discrete_generators()

def trivial(self) -> Rep:
Expand Down
6 changes: 2 additions & 4 deletions cuequivariance_jax/cuequivariance_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
from .primitives.equivariant_tensor_product import equivariant_tensor_product

from .operations.activation import (
soft_odd,
normalspace,
normalize_function,
parity_function,
function_parity,
scalar_activation,
)
from .operations.spherical_harmonics import spherical_harmonics, normalize, norm
Expand All @@ -47,10 +46,9 @@
"tensor_product",
"symmetric_tensor_product",
"equivariant_tensor_product",
"soft_odd",
"normalspace",
"normalize_function",
"parity_function",
"function_parity",
"scalar_activation",
"spherical_harmonics",
"normalize",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LayerNorm(nn.Module):

@nn.compact
def __call__(self, input: cuex.IrrepsArray) -> cuex.IrrepsArray:
assert input.is_simple
assert input.is_simple()

def rms(v: jax.Array) -> jax.Array:
# v [..., ir, mul] or [..., mul, ir]
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance_jax/cuequivariance_jax/flax_linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(
if not isinstance(input, cuex.IrrepsArray):
raise ValueError(f"input must be of type IrrepsArray, got {type(input)}")

assert input.is_simple
assert input.is_simple()

irreps_out = cue.Irreps(self.irreps_out)
layout_out = cue.IrrepsLayout.as_layout(self.layout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def ndim(self) -> int:
def dtype(self) -> jax.numpy.dtype:
return self.array.dtype

@property
def is_simple(self) -> bool:
if len(self.dirreps) != 1:
return False
Expand Down Expand Up @@ -257,7 +256,7 @@ def simplify(self, axis: int = -1) -> IrrepsArray:
if self.layout == cue.mul_ir:
return IrrepsArray(dirreps, self.array, self.layout)

assert self.is_simple
assert self.is_simple()
segments = []
last_ir = None
for x, (mul, ir) in zip(self.segments(), self.irreps()):
Expand Down Expand Up @@ -288,7 +287,7 @@ def segments(self, axis: int = -1) -> list[jax.Array]:
]

def change_layout(self, layout: cue.IrrepsLayout | None = None) -> IrrepsArray:
assert self.is_simple
assert self.is_simple()

if layout is None:
layout = cue.get_layout_scope()
Expand All @@ -305,7 +304,7 @@ def change_layout(self, layout: cue.IrrepsLayout | None = None) -> IrrepsArray:
)

def move_axis_to_mul(self, axis: int) -> IrrepsArray:
assert self.is_simple
assert self.is_simple()
assert self.layout == cue.ir_mul
if axis < 0:
axis += self.ndim
Expand All @@ -316,7 +315,7 @@ def move_axis_to_mul(self, axis: int) -> IrrepsArray:
return IrrepsArray(mul * self.irreps(), array, cue.ir_mul)

def transform(self, v: jax.Array) -> IrrepsArray:
assert self.is_simple
assert self.is_simple()

def f(segment: jax.Array, mul: int, ir: cue.Irrep) -> jax.Array:
X = ir.X
Expand Down
26 changes: 6 additions & 20 deletions cuequivariance_jax/cuequivariance_jax/operations/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,6 @@
ActFn = Callable[[float], float] | Callable[[jax.Array], jax.Array]


def soft_odd(x: jax.Array) -> jax.Array:
"""Smooth odd function that can be used as activation function for odd scalars.
.. math::
x (1 - e^{-x^2})
Note:
Odd scalars (l=0 and p=-1) has to be activated by functions with well defined parity:
* even (:math:`f(-x)=f(x)`)
* odd (:math:`f(-x)=-f(x)`).
"""
return (1 - jnp.exp(-(x**2))) * x


def normalspace(n: int) -> jax.Array:
r"""Sequence of normally distributed numbers :math:`x_i` for :math:`i=1, \ldots, n` such that
Expand Down Expand Up @@ -95,7 +79,7 @@ def rho(x):
return rho


def parity_function(phi: ActFn) -> int:
def function_parity(phi: ActFn) -> int:
with jax.ensure_compile_time_eval():
x = jnp.linspace(0.0, 10.0, 256)

Expand All @@ -119,7 +103,7 @@ def scalar_activation(
"""
input = cuex.as_irreps_array(input)
assert isinstance(input, cuex.IrrepsArray)
assert input.is_simple
assert input.is_simple()

if isinstance(acts, dict):
acts = [acts.get(ir, None) for mul, ir in input.irreps()]
Expand All @@ -137,18 +121,20 @@ def scalar_activation(
x: jax.Array

if act is not None:
assert np.all(np.imag(ir.H) == 0), "TODO: support complex scalars"
assert ir.dim == 1, "Only scalars are supported"
assert np.allclose(ir.X, 0), "Only scalars are supported"
assert np.allclose(np.imag(ir.H), 0), "Only real scalars are supported"

if normalize_act:
act = normalize_function(act)

p_act = parity_function(act)
p_act = function_parity(act)

if np.allclose(ir.H, 1):
# if the input is even, we can apply any activation function
ir_out = ir
else:
assert np.allclose(ir.H * ir.H, 1), "H should be -1 or 1"
if p_act == 0:
raise ValueError(
"Activation: the parity is violated! The input scalar is odd but the activation is neither even nor odd."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def spherical_harmonics(
) -> cuex.IrrepsArray:
ls = list(ls)
assert isinstance(vector, cuex.IrrepsArray)
assert vector.is_simple
assert vector.is_simple()
irreps = vector.irreps()
assert len(irreps) == 1
mul, ir = irreps[0]
Expand All @@ -43,7 +43,7 @@ def spherical_harmonics(


def normalize(array: cuex.IrrepsArray) -> cuex.IrrepsArray:
assert array.is_simple
assert array.is_simple()

match array.layout:
case cue.ir_mul:
Expand Down Expand Up @@ -71,7 +71,7 @@ def f(x: jax.Array) -> jax.Array:

def norm(array: cuex.IrrepsArray, *, squared: bool = False) -> cuex.IrrepsArray:
"""Norm of IrrepsArray."""
assert array.is_simple
assert array.is_simple()

match array.layout:
case cue.ir_mul:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def equivariant_tensor_product(

for x, ope in zip(inputs, e.inputs):
if isinstance(x, cuex.IrrepsArray):
assert x.is_simple
assert x.is_simple()
assert x.irreps() == ope.irreps
assert x.layout == ope.layout
else:
Expand Down

0 comments on commit 09201ba

Please sign in to comment.