diff --git a/Makefile b/Makefile index 642351434..f9cce9207 100644 --- a/Makefile +++ b/Makefile @@ -10,8 +10,8 @@ license: FORCE python scripts/update_headers.py format: license FORCE - ruff check --fix . ruff format . + ruff check --fix . install: FORCE pip install -e '.[dev,doc,test,examples]' diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 988389ef3..db3b43084 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2426,7 +2426,10 @@ def entropy(self): class Uniform(Distribution): - arg_constraints = {"low": constraints.dependent, "high": constraints.dependent} + arg_constraints = { + "low": constraints.dependent(is_discrete=False, event_dim=0), + "high": constraints.dependent(is_discrete=False, event_dim=0), + } reparametrized_params = ["low", "high"] pytree_data_fields = ("low", "high", "_support") @@ -2727,7 +2730,7 @@ class Wishart(TransformedDistribution): """ arg_constraints = { - "concentration": constraints.dependent(is_discrete=False), + "concentration": constraints.dependent(is_discrete=False, event_dim=0), "scale_matrix": constraints.positive_definite, "rate_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, @@ -2820,7 +2823,7 @@ class WishartCholesky(Distribution): """ arg_constraints = { - "concentration": constraints.dependent(is_discrete=False), + "concentration": constraints.dependent(is_discrete=False, event_dim=0), "scale_matrix": constraints.positive_definite, "rate_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 4f0fea56a..de0eaf0ab 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -421,7 +421,10 @@ def Categorical(probs=None, logits=None, *, validate_args=None): class DiscreteUniform(Distribution): - arg_constraints = {"low": constraints.dependent, "high": constraints.dependent} + arg_constraints = { + "low": constraints.dependent(is_discrete=True, event_dim=0), + "high": constraints.dependent(is_discrete=True, event_dim=0), + } has_enumerate_support = True pytree_data_fields = ("low", "high", "_support") diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index d9b86a465..5cfa130c1 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -179,8 +179,8 @@ def var(self): class TwoSidedTruncatedDistribution(Distribution): arg_constraints = { - "low": constraints.dependent, - "high": constraints.dependent, + "low": constraints.dependent(is_discrete=False, event_dim=0), + "high": constraints.dependent(is_discrete=False, event_dim=0), } reparametrized_params = ["low", "high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) diff --git a/test/test_distributions.py b/test/test_distributions.py index 003c20b9c..fe214f2f0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1327,6 +1327,17 @@ def test_has_rsample(jax_dist, sp_dist, params): transf_dist.rsample(random.PRNGKey(0)) +@pytest.mark.parametrize( + "jax_dist_cls, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL +) +def test_args_attributes(jax_dist_cls, sp_dist, params): + jax_dist = jax_dist_cls(*params) + for constraint in jax_dist.arg_constraints.values(): + if jax_dist_cls != dist.Delta: + constraint.event_dim + constraint.is_discrete + + @pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)]) def test_unit(batch_shape): log_factor = random.normal(random.PRNGKey(0), batch_shape)