From 1af20c1911d94ab4b3bd287a5a93e2378632dac8 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Wed, 29 Jan 2025 17:38:53 -0500 Subject: [PATCH] Add `ComplexTransform`. --- docs/source/distributions.rst | 8 ++++++++ numpyro/distributions/transforms.py | 31 ++++++++++++++++++++++++++++- test/test_transforms.py | 5 ++++- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index e02da4d00..dc134d90e 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -910,6 +910,14 @@ CholeskyTransform :show-inheritance: :member-order: bysource +ComplexTransform +^^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.distributions.transforms.ComplexTransform + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + ComposeTransform ^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.transforms.ComposeTransform diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index ad60589bb..0cbb0aa4b 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -30,6 +30,7 @@ "AbsTransform", "AffineTransform", "CholeskyTransform", + "ComplexTransform", "ComposeTransform", "CorrCholeskyTransform", "CorrMatrixCholeskyTransform", @@ -1516,6 +1517,34 @@ def __eq__(self, other): ) +class ComplexTransform(ParameterFreeTransform): + """ + Transforms a pair of real numbers to a complex number. + """ + + domain = constraints.real_vector + codomain = constraints.complex + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2." + return lax.complex(x[..., 0], x[..., 1]) + + def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: + return jnp.stack([y.real, y.imag], axis=-1) + + def log_abs_det_jacobian( + self, x: jnp.ndarray, y: jnp.ndarray, intermediates=None + ) -> jnp.ndarray: + return jnp.zeros_like(y) + + def forward_shape(self, shape: tuple[int]) -> tuple[int]: + assert shape[-1] == 2, "Input must have a trailing dimension of size 2." + return shape[:-1] + + def inverse_shape(self, shape: tuple[int]) -> tuple[int]: + return shape + (2,) + + ########################################################## # CONSTRAINT_REGISTRY ########################################################## @@ -1649,7 +1678,7 @@ def _transform_to_positive_ordered_vector(constraint): @biject_to.register(constraints.complex) def _transform_to_complex(constraint): - return IdentityTransform() + return ComplexTransform() @biject_to.register(constraints.real) diff --git a/test/test_transforms.py b/test/test_transforms.py index beff83b8c..4a1dc3a42 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -21,6 +21,7 @@ AbsTransform, AffineTransform, CholeskyTransform, + ComplexTransform, ComposeTransform, CorrCholeskyTransform, CorrMatrixCholeskyTransform, @@ -105,6 +106,7 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): # unparametrized transforms "abs": T(AbsTransform, (), dict()), "cholesky": T(CholeskyTransform, (), dict()), + "complex": T(ComplexTransform, (), dict()), "corr_chol": T(CorrCholeskyTransform, (), dict()), "corr_matrix_chol": T(CorrMatrixCholeskyTransform, (), dict()), "exp": T(ExpTransform, (), dict()), @@ -270,6 +272,7 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims): [ (AffineTransform(3, 2.5), ()), (CholeskyTransform(), (10,)), + (ComplexTransform(), (2,)), (ComposeTransform([SoftplusTransform(), SigmoidTransform()]), ()), (CorrCholeskyTransform(), (15,)), (CorrMatrixCholeskyTransform(), (15,)), @@ -361,7 +364,7 @@ def test_batched_recursive_linear_transform(): "constraint, shape", [ (constraints.circular, (3,)), - (constraints.complex, (3,)), + (constraints.complex, (3, 2)), (constraints.corr_cholesky, (10, 10)), (constraints.corr_matrix, (15,)), (constraints.greater_than(3), ()),