diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 67ae726a2..06b51c929 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -380,6 +380,13 @@ Weibull :show-inheritance: :member-order: bysource +ZeroSumNormal +^^^^^^^^^^^^^ +.. autoclass:: numpyro.distributions.continuous.ZeroSumNormal + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource Discrete Distributions ---------------------- @@ -820,6 +827,9 @@ unit_interval ^^^^^^^^^^^^^ .. autodata:: numpyro.distributions.constraints.unit_interval +zero_sum +^^^^^^^^ +.. autodata:: numpyro.distributions.constraints.zero_sum Transforms ---------- @@ -1014,6 +1024,15 @@ StickBreakingTransform :show-inheritance: :member-order: bysource +ZeroSumTransform +^^^^^^^^^^^^^^^^ + +.. autoclass:: numpyro.distributions.transforms.ZeroSumTransform + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + Flows ----- diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 49554097e..d05376573 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -47,6 +47,7 @@ StudentT, Uniform, Weibull, + ZeroSumNormal, ) from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta from numpyro.distributions.directional import ( @@ -196,4 +197,5 @@ "ZeroInflatedDistribution", "ZeroInflatedPoisson", "ZeroInflatedNegativeBinomial2", + "ZeroSumNormal", ] diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 7b8b59340..21dac6669 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -55,6 +55,7 @@ "softplus_lower_cholesky", "softplus_positive", "unit_interval", + "zero_sum", "Constraint", ] @@ -697,6 +698,29 @@ def feasible_like(self, prototype): return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5)) +class _ZeroSum(Constraint): + def __init__(self, event_dim=1): + self.event_dim = event_dim + super().__init__() + + def __call__(self, x): + jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 + zerosum_true = True + for dim in range(-self.event_dim, 0): + zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol) + return zerosum_true + + def __eq__(self, other): + return type(self) is type(other) and self.event_dim == other.event_dim + + def feasible_like(self, prototype): + return jax.numpy.zeros_like(prototype) + + def tree_flatten(self): + return (self.event_dim,), (("event_dim",), dict()) + + # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 @@ -731,3 +755,4 @@ def feasible_like(self, prototype): sphere = _Sphere() unit_interval = _UnitInterval() open_interval = _OpenInterval +zero_sum = _ZeroSum diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 861eddf1b..0e895c827 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -58,6 +58,7 @@ ExpTransform, PowerTransform, SigmoidTransform, + ZeroSumTransform, ) from numpyro.distributions.util import ( add_diag, @@ -2438,3 +2439,97 @@ def cdf(self, value): def icdf(self, value): return self._ald.icdf(value) + + +class ZeroSumNormal(TransformedDistribution): + r""" + Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or + more axes are constrained to sum to zero (the last axis by default). + + .. math:: + \begin{align*} + ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ + \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ + n = \text{number of zero-sum axes} + \end{align*} + + :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is + enforced. + :param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero. + + **Example:** + + .. doctest:: + + >>> from numpy.testing import assert_allclose + >>> from jax import random + >>> import jax.numpy as jnp + >>> import numpyro + >>> import numpyro.distributions as dist + >>> from numpyro.infer import MCMC, NUTS + + >>> N = 1000 + >>> n_categories = 20 + >>> rng_key = random.PRNGKey(0) + >>> key1, key2, key3 = random.split(rng_key, 3) + >>> category_ind = random.choice(key1, jnp.arange(n_categories), shape=(N,)) + >>> beta = random.normal(key2, shape=(n_categories,)) + >>> beta -= beta.mean(-1) + >>> y = 5 + beta[category_ind] + random.normal(key3, shape=(N,)) + + >>> def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories + ... N = len(category_ind) + ... alpha = numpyro.sample("alpha", dist.Normal(0, 2.5)) + ... beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(n_categories,))) + ... sigma = numpyro.sample("sigma", dist.Exponential(1)) + ... with numpyro.plate("observations", N): + ... mu = alpha + beta[category_ind] + ... obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) + ... return obs + + >>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9) + >>> mcmc = MCMC( + ... sampler=nuts_kernel, + ... num_samples=1_000, num_warmup=1_000, num_chains=4 + ... ) + >>> mcmc.run(random.PRNGKey(0), category_ind=category_ind, y=y) + >>> posterior_samples = mcmc.get_samples() + >>> # Confirm everything along last axis sums to zero + >>> assert_allclose(posterior_samples['beta'].sum(-1), 0, atol=1e-3) + + **References** + [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 + [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html + [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ + """ + + arg_constraints = {"scale": constraints.positive} + reparametrized_params = ["scale"] + + def __init__(self, scale, event_shape, *, validate_args=None): + event_ndim = len(event_shape) + transformed_shape = tuple(size - 1 for size in event_shape) + self.scale = scale + super().__init__( + Normal(0, scale).expand(transformed_shape).to_event(event_ndim), + ZeroSumTransform(event_ndim), + validate_args=validate_args, + ) + + @constraints.dependent_property(is_discrete=False) + def support(self): + return constraints.zero_sum(len(self.event_shape)) + + @property + def mean(self): + return jnp.zeros(self.batch_shape + self.event_shape) + + @property + def variance(self): + event_ndim = len(self.event_shape) + zero_sum_axes = tuple(range(-event_ndim, 0)) + theoretical_var = jnp.square(self.scale) + for axis in zero_sum_axes: + theoretical_var *= 1 - 1 / self.event_shape[axis] + + return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 3b29575f0..a057d86d2 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -6,6 +6,7 @@ import weakref import numpy as np +from numpy.core.numeric import normalize_axis_tuple from jax import lax, vmap from jax.flatten_util import ravel_pytree @@ -50,6 +51,7 @@ "StickBreakingTransform", "Transform", "UnpackTransform", + "ZeroSumTransform", ] @@ -1380,6 +1382,92 @@ def __eq__(self, other): return jnp.array_equal(self.transition_matrix, other.transition_matrix) +class ZeroSumTransform(Transform): + """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] + + :param transform_ndims: Number of trailing dimensions to transform. + + **References** + [1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266 + [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html + [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ + """ + + def __init__(self, transform_ndims: int = 1) -> None: + self.transform_ndims = transform_ndims + + @property + def domain(self) -> constraints.Constraint: + return constraints.independent(constraints.real, self.transform_ndims) + + @property + def codomain(self) -> constraints.Constraint: + return constraints.zero_sum(self.transform_ndims) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + zero_sum_axes = tuple(range(-self.transform_ndims, 0)) + for axis in zero_sum_axes: + x = self.extend_axis(x, axis=axis) + return x + + def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: + zero_sum_axes = tuple(range(-self.transform_ndims, 0)) + for axis in zero_sum_axes: + y = self.extend_axis_rev(y, axis=axis) + return y + + def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: + normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] + + n = array.shape[normalized_axis] + last = jnp.take(array, jnp.array([-1]), axis=normalized_axis) + + sum_vals = -last * jnp.sqrt(n) + norm = sum_vals / (jnp.sqrt(n) + n) + slice_before = (slice(None, None),) * normalized_axis + return array[(*slice_before, slice(None, -1))] + norm + + def extend_axis(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: + n = array.shape[axis] + 1 + + sum_vals = array.sum(axis, keepdims=True) + norm = sum_vals / (jnp.sqrt(n) + n) + fill_val = norm - sum_vals / jnp.sqrt(n) + + out = jnp.concatenate([array, fill_val], axis=axis) + return out - norm + + def log_abs_det_jacobian( + self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None + ) -> jnp.ndarray: + shape = jnp.broadcast_shapes( + x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] + ) + return jnp.zeros_like(x, shape=shape) + + def forward_shape(self, shape: tuple) -> tuple: + return shape[: -self.transform_ndims] + tuple( + s + 1 for s in shape[-self.transform_ndims :] + ) + + def inverse_shape(self, shape: tuple) -> tuple: + return shape[: -self.transform_ndims] + tuple( + s - 1 for s in shape[-self.transform_ndims :] + ) + + def tree_flatten(self): + aux_data = { + "transform_ndims": self.transform_ndims, + } + return (), ((), aux_data) + + def __eq__(self, other): + return ( + isinstance(other, ZeroSumTransform) + and self.transform_ndims == other.transform_ndims + ) + + ########################################################## # CONSTRAINT_REGISTRY ########################################################## @@ -1530,3 +1618,8 @@ def _transform_to_softplus_lower_cholesky(constraint): @biject_to.register(constraints.simplex) def _transform_to_simplex(constraint): return StickBreakingTransform() + + +@biject_to.register(constraints.zero_sum) +def _transform_to_zero_sum(constraint): + return ZeroSumTransform(constraint.event_dim) diff --git a/test/test_constraints.py b/test/test_constraints.py index 735969fa6..fb34bf6b8 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -62,6 +62,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): dict(), ), "open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()), + "zero_sum": T(constraints.zero_sum, (), dict(event_dim=1)), } # TODO: BijectorConstraint diff --git a/test/test_distributions.py b/test/test_distributions.py index 34c37379a..6ebb990f2 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -774,6 +774,9 @@ def get_sp_dist(jax_dist): T(dist.Weibull, 0.2, 1.1), T(dist.Weibull, 2.8, np.array([2.0, 2.0])), T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])), + T(dist.ZeroSumNormal, 1.0, (5,)), + T(dist.ZeroSumNormal, np.array([2.0]), (5,)), + T(dist.ZeroSumNormal, 1.0, (4, 5)), T( _GaussianMixture, np.ones(3) / 3.0, @@ -1018,6 +1021,12 @@ def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): sign = random.bernoulli(key1) bounds = [0, (-1) ** sign * 0.5] return random.uniform(key, size, float, *sorted(bounds)) + elif isinstance(constraint, constraints.zero_sum): + x = random.normal(key, size) + zero_sum_axes = tuple(i for i in range(-constraint.event_dim, 0)) + for axis in zero_sum_axes: + x -= x.mean(axis) + return x else: raise NotImplementedError("{} not implemented.".format(constraint)) @@ -1085,6 +1094,9 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): sign = random.bernoulli(key1) bounds = [(-1) ** sign * 1.1, (-1) ** sign * 2] return random.uniform(key, size, float, *sorted(bounds)) + elif isinstance(constraint, constraints.zero_sum): + x = random.normal(key, size) + return x else: raise NotImplementedError("{} not implemented.".format(constraint)) @@ -1297,6 +1309,7 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params): "LKJ", "LKJCholesky", "_SparseCAR", + "ZeroSumNormal", ): pytest.xfail(reason="non-jittable params") @@ -1442,6 +1455,8 @@ def test_gof(jax_dist, sp_dist, params): d = jax_dist(*params) if d.event_dim > 1: pytest.skip("EulerMaruyama skip test when event shape is non-trivial.") + if jax_dist is dist.ZeroSumNormal: + pytest.skip("skip gof test for ZeroSumNormal") num_samples = 10000 if "BetaProportion" in jax_dist.__name__: @@ -1672,6 +1687,9 @@ def fn(*args): if jax_dist is _SparseCAR and i == 3: # skip taking grad w.r.t. adj_matrix continue + if jax_dist is dist.ZeroSumNormal and i != 0: + # skip taking grad w.r.t. event_shape + continue if isinstance( params[i], dist.Distribution ): # skip taking grad w.r.t. base_dist @@ -1858,7 +1876,7 @@ def get_min_shape(ix, batch_shape): if isinstance(d_jax, dist.Gompertz): pytest.skip("Gompertz distribution does not have `variance` implemented.") if jnp.all(jnp.isfinite(d_jax.variance)): - assert_allclose( + assert jnp.allclose( jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.05, atol=1e-2 ) @@ -1899,6 +1917,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue + if jax_dist is dist.ZeroSumNormal and dist_args[i] == "event_shape": + continue if ( jax_dist is dist.SineBivariateVonMises and dist_args[i] == "weighted_correlation" diff --git a/test/test_transforms.py b/test/test_transforms.py index 1a706bbc6..261818429 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -40,6 +40,7 @@ SoftplusTransform, StickBreakingTransform, UnpackTransform, + ZeroSumTransform, biject_to, ) @@ -134,6 +135,7 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): "reshape": T( ReshapeTransform, (), {"forward_shape": (3, 4), "inverse_shape": (4, 3)} ), + "zero_sum": T(ZeroSumTransform, (), dict(transform_ndims=1)), } @@ -296,6 +298,7 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims): (SoftplusLowerCholeskyTransform(), (10,)), (SoftplusTransform(), ()), (StickBreakingTransform(), (11,)), + (ZeroSumTransform(1), (5,)), ], ) def test_bijective_transforms(transform, shape):