Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup logcdf tests #4734

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 49 additions & 34 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
logpt,
logpt_sum,
)
from pymc3.math import kronecker, logsumexp
from pymc3.math import kronecker
from pymc3.model import Deterministic, Model, Point
from pymc3.tests.helpers import select_by_precision
from pymc3.vartypes import continuous_types
Expand Down Expand Up @@ -750,6 +750,10 @@ def check_logcdf(
if not skip_paramdomain_inside_edge_test:
domains = paramdomains.copy()
domains["value"] = domain

model, param_vars = build_model(pymc3_dist, domain, paramdomains)
pymc3_logcdf = model.fastfn(logpt(model["value"], cdf=True))

if decimal is None:
decimal = select_by_precision(float64=6, float32=3)

Expand All @@ -758,17 +762,23 @@ def check_logcdf(
if skip_params_fn(params):
continue
scipy_cdf = scipy_logcdf(**params)

scipy_eval = scipy_logcdf(**params)
value = params.pop("value")
with Model() as m:
dist = pymc3_dist("y", **params)

# Update shared parameter variables in pymc3_logcdf function
for param_name, param_value in params.items():
param_vars[param_name].set_value(param_value)

pymc3_eval = pymc3_logcdf({"value": value})

params["value"] = value # for displaying in err_msg
with aesara.config.change_flags(on_opt_error="raise", mode=Mode("py")):
assert_almost_equal(
logcdf(dist, value).eval(),
scipy_cdf,
decimal=decimal,
err_msg=str(params),
)
assert_almost_equal(
pymc3_eval,
scipy_eval,
decimal=decimal,
err_msg=str(params),
)

valid_value = domain.vals[0]
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
Expand Down Expand Up @@ -848,24 +858,33 @@ def check_selfconsistency_discrete_logcdf(
"""
Check that logcdf of discrete distributions matches sum of logps up to value
"""
# This test only works for scalar random variables
assert distribution.rv_op.ndim_supp == 0

domains = paramdomains.copy()
domains["value"] = domain
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)

model, param_vars = build_model(distribution, domain, paramdomains)
dist_logcdf = model.fastfn(logpt(model["value"], cdf=True))
dist_logp = model.fastfn(logpt(model["value"]))

for pt in product(domains, n_samples=n_samples):
params = dict(pt)
if skip_params_fn(params):
continue
value = params.pop("value")
values = np.arange(domain.lower, value + 1)
dist = distribution.dist(**params)
# This only works for scalar random variables
assert dist.owner.op.ndim_supp == 0
values_dist = change_rv_size(dist, values.shape)

# Update shared parameter variables in logp/logcdf function
for param_name, param_value in params.items():
param_vars[param_name].set_value(param_value)

with aesara.config.change_flags(mode=Mode("py")):
assert_almost_equal(
logcdf(dist, value).eval(),
logsumexp(logpt(values_dist, values), keepdims=False).eval(),
dist_logcdf({"value": value}),
scipy.special.logsumexp([dist_logp({"value": value}) for value in values]),
decimal=decimal,
err_msg=str(pt),
)
Expand Down Expand Up @@ -1118,14 +1137,18 @@ def test_beta(self):
{"alpha": Rplus, "beta": Rplus},
lambda value, alpha, beta: sp.beta.logpdf(value, alpha, beta),
)
self.check_logp(Beta, Unit, {"mu": Unit, "sigma": Rplus}, beta_mu_sigma)
self.check_logp(
Beta,
Unit,
{"mu": Unit, "sigma": Rplus},
beta_mu_sigma,
)
self.check_logcdf(
Beta,
Unit,
{"alpha": Rplus, "beta": Rplus},
lambda value, alpha, beta: sp.beta.logcdf(value, alpha, beta),
n_samples=10,
decimal=select_by_precision(float64=5, float32=3),
decimal=select_by_precision(float64=5, float32=1),
)

def test_kumaraswamy(self):
Expand Down Expand Up @@ -1247,20 +1270,17 @@ def scipy_mu_alpha_logcdf(value, mu, alpha):
Nat,
{"mu": Rplus, "alpha": Rplus},
scipy_mu_alpha_logcdf,
n_samples=5,
)
self.check_logcdf(
NegativeBinomial,
Nat,
{"p": Unit, "n": Rplus},
lambda value, p, n: sp.nbinom.logcdf(value, n, p),
n_samples=5,
)
self.check_selfconsistency_discrete_logcdf(
NegativeBinomial,
Nat,
{"mu": Rplus, "alpha": Rplus},
n_samples=10,
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
Expand Down Expand Up @@ -1319,7 +1339,6 @@ def test_lognormal(self):
Rplus,
{"mu": R, "sigma": Rplusbig},
lambda value, mu, sigma: floatX(sp.lognorm.logpdf(value, sigma, 0, np.exp(mu))),
n_samples=5, # Just testing alternative parametrization
)
self.check_logcdf(
Lognormal,
Expand All @@ -1332,10 +1351,9 @@ def test_lognormal(self):
Rplus,
{"mu": R, "sigma": Rplusbig},
lambda value, mu, sigma: sp.lognorm.logcdf(value, sigma, 0, np.exp(mu)),
n_samples=5, # Just testing alternative parametrization
)

def test_t(self):
def test_studentt_logp(self):
self.check_logp(
StudentT,
R,
Expand All @@ -1347,21 +1365,24 @@ def test_t(self):
R,
{"nu": Rplus, "mu": R, "sigma": Rplus},
lambda value, nu, mu, sigma: sp.t.logpdf(value, nu, mu, sigma),
n_samples=5, # Just testing alternative parametrization
)
self.check_logcdf(
StudentT,
R,
{"nu": Rplus, "mu": R, "lam": Rplus},
lambda value, nu, mu, lam: sp.t.logcdf(value, nu, mu, lam ** -0.5),
n_samples=10, # relies on slow incomplete beta
)

@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Fails on float32 due to numerical issues",
)
def test_studentt_logcdf(self):
self.check_logcdf(
StudentT,
R,
{"nu": Rplus, "mu": R, "sigma": Rplus},
lambda value, nu, mu, sigma: sp.t.logcdf(value, nu, mu, sigma),
n_samples=5, # Just testing alternative parametrization
)

def test_cauchy(self):
Expand Down Expand Up @@ -1538,13 +1559,11 @@ def test_binomial(self):
Nat,
{"n": NatSmall, "p": Unit},
lambda value, n, p: sp.binom.logcdf(value, n, p),
n_samples=10,
)
self.check_selfconsistency_discrete_logcdf(
Binomial,
Nat,
{"n": NatSmall, "p": Unit},
n_samples=10,
)

@pytest.mark.xfail(reason="checkd tests has not been refactored")
Expand Down Expand Up @@ -1747,14 +1766,12 @@ def logcdf_fn(value, psi, mu, alpha):
Nat,
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
logcdf_fn,
n_samples=10,
)

self.check_selfconsistency_discrete_logcdf(
ZeroInflatedNegativeBinomial,
Nat,
{"psi": Unit, "mu": Rplusbig, "alpha": Rplusbig},
n_samples=10,
)

@pytest.mark.xfail(reason="Test not refactored yet")
Expand Down Expand Up @@ -1787,14 +1804,12 @@ def logcdf_fn(value, psi, n, p):
Nat,
{"psi": Unit, "n": NatSmall, "p": Unit},
logcdf_fn,
n_samples=10,
)

self.check_selfconsistency_discrete_logcdf(
ZeroInflatedBinomial,
Nat,
{"n": NatSmall, "p": Unit, "psi": Unit},
n_samples=10,
)

@pytest.mark.parametrize("n", [1, 2, 3])
Expand Down