Skip to content

Commit

Permalink
Add ComplexTransform. (#1964)
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann authored Jan 30, 2025
1 parent 5f3bdd1 commit a9c9fc6
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
8 changes: 8 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,14 @@ CholeskyTransform
:show-inheritance:
:member-order: bysource

ComplexTransform
^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.transforms.ComplexTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

ComposeTransform
^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.transforms.ComposeTransform
Expand Down
31 changes: 30 additions & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"AbsTransform",
"AffineTransform",
"CholeskyTransform",
"ComplexTransform",
"ComposeTransform",
"CorrCholeskyTransform",
"CorrMatrixCholeskyTransform",
Expand Down Expand Up @@ -1516,6 +1517,34 @@ def __eq__(self, other):
)


class ComplexTransform(ParameterFreeTransform):
"""
Transforms a pair of real numbers to a complex number.
"""

domain = constraints.real_vector
codomain = constraints.complex

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2."
return lax.complex(x[..., 0], x[..., 1])

def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
return jnp.stack([y.real, y.imag], axis=-1)

def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates=None
) -> jnp.ndarray:
return jnp.zeros_like(y)

def forward_shape(self, shape: tuple[int]) -> tuple[int]:
assert shape[-1] == 2, "Input must have a trailing dimension of size 2."
return shape[:-1]

def inverse_shape(self, shape: tuple[int]) -> tuple[int]:
return shape + (2,)


##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down Expand Up @@ -1649,7 +1678,7 @@ def _transform_to_positive_ordered_vector(constraint):

@biject_to.register(constraints.complex)
def _transform_to_complex(constraint):
return IdentityTransform()
return ComplexTransform()


@biject_to.register(constraints.real)
Expand Down
5 changes: 4 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AbsTransform,
AffineTransform,
CholeskyTransform,
ComplexTransform,
ComposeTransform,
CorrCholeskyTransform,
CorrMatrixCholeskyTransform,
Expand Down Expand Up @@ -105,6 +106,7 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
# unparametrized transforms
"abs": T(AbsTransform, (), dict()),
"cholesky": T(CholeskyTransform, (), dict()),
"complex": T(ComplexTransform, (), dict()),
"corr_chol": T(CorrCholeskyTransform, (), dict()),
"corr_matrix_chol": T(CorrMatrixCholeskyTransform, (), dict()),
"exp": T(ExpTransform, (), dict()),
Expand Down Expand Up @@ -270,6 +272,7 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims):
[
(AffineTransform(3, 2.5), ()),
(CholeskyTransform(), (10,)),
(ComplexTransform(), (2,)),
(ComposeTransform([SoftplusTransform(), SigmoidTransform()]), ()),
(CorrCholeskyTransform(), (15,)),
(CorrMatrixCholeskyTransform(), (15,)),
Expand Down Expand Up @@ -361,7 +364,7 @@ def test_batched_recursive_linear_transform():
"constraint, shape",
[
(constraints.circular, (3,)),
(constraints.complex, (3,)),
(constraints.complex, (3, 2)),
(constraints.corr_cholesky, (10, 10)),
(constraints.corr_matrix, (15,)),
(constraints.greater_than(3), ()),
Expand Down

0 comments on commit a9c9fc6

Please sign in to comment.