diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 516dda542c1..8372df9ea9f 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2334,23 +2334,19 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, *args, **kwar alpha = at.as_tensor_variable(floatX(alpha)) beta = at.as_tensor_variable(floatX(beta)) - # m = beta / (alpha - 1.0) - # try: - # mean = (alpha > 1) * m or np.inf - # except ValueError: # alpha is an array - # m[alpha <= 1] = np.inf - # mean = m - - # mode = beta / (alpha + 1.0) - # variance = at.switch( - # at.gt(alpha, 2), (beta ** 2) / ((alpha - 2) * (alpha - 1.0) ** 2), np.inf - # ) - assert_negative_support(alpha, "alpha", "InverseGamma") assert_negative_support(beta, "beta", "InverseGamma") return super().dist([alpha, beta], **kwargs) + def get_moment(rv, size, alpha, beta): + mean = beta / (alpha - 1.0) + mode = beta / (alpha + 1.0) + moment = at.switch(alpha > 1, mean, mode) + if not rv_size_is_none(size): + moment = at.full(size, moment) + return moment + @classmethod def _get_alpha_beta(cls, alpha, beta, mu, sigma): if alpha is not None: diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 777be4a8bce..e700985a31c 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -31,6 +31,7 @@ HalfNormal, HalfStudentT, HyperGeometric, + InverseGamma, Kumaraswamy, Laplace, Logistic, @@ -396,6 +397,21 @@ def test_gamma_moment(alpha, beta, size, expected): assert_moment_is_expected(model, expected) +@pytest.mark.parametrize( + "alpha, beta, size, expected", + [ + (5, 1, None, 1 / 4), + (0.5, 1, None, 1 / 1.5), + (5, 1, 5, np.full(5, 1 / (5 - 1))), + (np.arange(1, 6), 1, None, np.array([0.5, 1, 1 / 2, 1 / 3, 1 / 4])), + ], +) +def test_inverse_gamma_moment(alpha, beta, size, expected): + with Model() as model: + InverseGamma("x", alpha=alpha, beta=beta, size=size) + assert_moment_is_expected(model, expected) + + @pytest.mark.parametrize( "alpha, m, size, expected", [