diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 63507b66..57882776 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -33,6 +33,7 @@ Distributions GeneralizedPoisson GenExtreme R2D2M2CP + Skellam histogram_approximation diff --git a/pymc_experimental/distributions/__init__.py b/pymc_experimental/distributions/__init__.py index 69592ca2..2a89e7d0 100644 --- a/pymc_experimental/distributions/__init__.py +++ b/pymc_experimental/distributions/__init__.py @@ -18,7 +18,7 @@ """ from pymc_experimental.distributions.continuous import Chi, GenExtreme -from pymc_experimental.distributions.discrete import GeneralizedPoisson +from pymc_experimental.distributions.discrete import GeneralizedPoisson, Skellam from pymc_experimental.distributions.histogram_utils import histogram_approximation from pymc_experimental.distributions.multivariate import R2D2M2CP from pymc_experimental.distributions.timeseries import DiscreteMarkovChain diff --git a/pymc_experimental/distributions/discrete.py b/pymc_experimental/distributions/discrete.py index 9d8cb649..fc9ddee0 100644 --- a/pymc_experimental/distributions/discrete.py +++ b/pymc_experimental/distributions/discrete.py @@ -171,3 +171,98 @@ def logp(value, mu, lam): (-mu / 4) <= lam, msg="0 < mu, max(-1, -mu/4)) <= lam <= 1", ) + + +class Skellam: + R""" + Skellam distribution. + + The Skellam distribution is the distribution of the difference of two + Poisson random variables. + + The pmf of this distribution is + + .. math:: + + f(x | \mu_1, \mu_2) = e^{{-(\mu _{1}\!+\!\mu _{2})}}\left({\frac {\mu _{1}}{\mu _{2}}}\right)^{{x/2}}\!\!I_{{x}}(2{\sqrt {\mu _{1}\mu _{2}}}) + + where :math:`I_{x}` is the modified Bessel function of the first kind of order :math:`x`. + + Read more about the Skellam distribution at https://en.wikipedia.org/wiki/Skellam_distribution + + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + import scipy.stats as st + import arviz as az + plt.style.use('arviz-darkgrid') + x = np.arange(-15, 15) + params = [ + (1, 1), + (5, 5), + (5, 1), + ] + for mu1, mu2 in params: + pmf = st.skellam.pmf(x, mu1, mu2) + plt.plot(x, pmf, "-o", label=r'$\mu_1$ = {}, $\mu_2$ = {}'.format(mu1, mu2)) + plt.xlabel('x', fontsize=12) + plt.ylabel('f(x)', fontsize=12) + plt.legend(loc=1) + plt.show() + + ======== ====================================== + Support :math:`x \in \mathbb{Z}` + Mean :math:`\mu_{1} - \mu_{2}` + Variance :math:`\mu_{1} + \mu_{2}` + ======== ====================================== + + Parameters + ---------- + mu1 : tensor_like of float + Mean parameter (mu1 >= 0). + mu2 : tensor_like of float + Mean parameter (mu2 >= 0). + """ + + @staticmethod + def skellam_dist(mu1, mu2, size): + return pm.Poisson.dist(mu=mu1, size=size) - pm.Poisson.dist(mu=mu2, size=size) + + @staticmethod + def skellam_logp(value, mu1, mu2): + res = ( + -mu1 + - mu2 + + 0.5 * value * (pt.log(mu1) - pt.log(mu2)) + + pt.log(pt.iv(value, 2 * pt.sqrt(mu1 * mu2))) + ) + return check_parameters( + res, + mu1 >= 0, + mu2 >= 0, + msg="mu1 >= 0, mu2 >= 0", + ) + + def __new__(cls, name, mu1, mu2, **kwargs): + return pm.CustomDist( + name, + mu1, + mu2, + dist=cls.skellam_dist, + logp=cls.skellam_logp, + class_name="Skellam", + **kwargs, + ) + + @classmethod + def dist(cls, mu1, mu2, **kwargs): + return pm.CustomDist.dist( + mu1, + mu2, + dist=cls.skellam_dist, + logp=cls.skellam_logp, + class_name="Skellam", + **kwargs, + ) diff --git a/pymc_experimental/tests/distributions/test_discrete.py b/pymc_experimental/tests/distributions/test_discrete.py index eb0a1d77..c94e5f36 100644 --- a/pymc_experimental/tests/distributions/test_discrete.py +++ b/pymc_experimental/tests/distributions/test_discrete.py @@ -21,13 +21,15 @@ from pymc.testing import ( BaseTestDistributionRandom, Domain, + I, Rplus, assert_moment_is_expected, + check_logp, discrete_random_tester, ) from pytensor import config -from pymc_experimental.distributions import GeneralizedPoisson +from pymc_experimental.distributions import GeneralizedPoisson, Skellam class TestGeneralizedPoisson: @@ -118,3 +120,13 @@ def test_moment(self, mu, lam, size, expected): with pm.Model() as model: GeneralizedPoisson("x", mu=mu, lam=lam, size=size) assert_moment_is_expected(model, expected) + + +class TestSkellam: + def test_logp(self): + check_logp( + Skellam, + I, + {"mu1": Rplus, "mu2": Rplus}, + lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2), + )