Skip to content

Commit

Permalink
Fix for #3310 and some more (#3319)
Browse files Browse the repository at this point in the history
* Fixed #3310. Added broadcast_distribution_samples, which helps broadcasting multiple rvs calls with different size and distribution parameter shapes. Added shape guards to other continuous distributions.

* Fixed broken continuous distributions. Did not notice that _random got a parameter shape aware size input thanks to generate_samples.

* Fixed lint error

* Addressed comments
  • Loading branch information
lucianopaz authored and twiecki committed Dec 22, 2018
1 parent 01f2444 commit e67c476
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 5 deletions.
3 changes: 3 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

### Maintenance

- Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310).
- The `Wald`, `Kumaraswamy`, `LogNormal`, `Pareto`, `Cauchy`, `HalfCauchy`, `Weibull` and `ExGaussian` distributions `random` method used a hidden `_random` function that was written with scalars in mind. This could potentially lead to artificial correlations between random draws. Added shape guards and broadcasting of the distribution samples to prevent this (Similar to issue #3310).

### Deprecations

## PyMC3 3.6 (Dec 21 2018)
Expand Down
12 changes: 11 additions & 1 deletion pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
alltrue_elemwise, betaln, bound, gammaln, i0e, incomplete_beta, logpow,
normal_lccdf, normal_lcdf, SplineWrapper, std_cdf, zvalue,
)
from .distribution import Continuous, draw_values, generate_samples
from .distribution import (Continuous, draw_values, generate_samples,
broadcast_distribution_samples)

__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'TruncatedNormal', 'Beta',
'Kumaraswamy', 'Exponential', 'Laplace', 'StudentT', 'Cauchy',
Expand Down Expand Up @@ -957,6 +958,8 @@ def random(self, point=None, size=None):
"""
mu, lam, alpha = draw_values([self.mu, self.lam, self.alpha],
point=point, size=size)
mu, lam, alpha = broadcast_distribution_samples([mu, lam, alpha],
size=size)
return generate_samples(self._random,
mu, lam, alpha,
dist_shape=self.shape,
Expand Down Expand Up @@ -1285,6 +1288,7 @@ def random(self, point=None, size=None):
"""
a, b = draw_values([self.a, self.b],
point=point, size=size)
a, b = broadcast_distribution_samples([a, b], size=size)
return generate_samples(self._random, a, b,
dist_shape=self.shape,
size=size)
Expand Down Expand Up @@ -1658,6 +1662,7 @@ def random(self, point=None, size=None):
array
"""
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
mu, tau = broadcast_distribution_samples([mu, tau], size=size)
return generate_samples(self._random, mu, tau,
dist_shape=self.shape,
size=size)
Expand Down Expand Up @@ -1945,6 +1950,7 @@ def random(self, point=None, size=None):
"""
alpha, m = draw_values([self.alpha, self.m],
point=point, size=size)
alpha, m = broadcast_distribution_samples([alpha, m], size=size)
return generate_samples(self._random, alpha, m,
dist_shape=self.shape,
size=size)
Expand Down Expand Up @@ -2069,6 +2075,7 @@ def random(self, point=None, size=None):
"""
alpha, beta = draw_values([self.alpha, self.beta],
point=point, size=size)
alpha, beta = broadcast_distribution_samples([alpha, beta], size=size)
return generate_samples(self._random, alpha, beta,
dist_shape=self.shape,
size=size)
Expand Down Expand Up @@ -2629,6 +2636,7 @@ def random(self, point=None, size=None):
"""
alpha, beta = draw_values([self.alpha, self.beta],
point=point, size=size)
alpha, beta = broadcast_distribution_samples([alpha, beta], size=size)

def _random(a, b, size=None):
return b * (-np.log(np.random.uniform(size=size)))**(1 / a)
Expand Down Expand Up @@ -2913,6 +2921,8 @@ def random(self, point=None, size=None):
"""
mu, sigma, nu = draw_values([self.mu, self.sigma, self.nu],
point=point, size=size)
mu, sigma, nu = broadcast_distribution_samples([mu, sigma, nu],
size=size)

def _random(mu, sigma, nu, size=None):
return (np.random.normal(mu, sigma, size=size)
Expand Down
13 changes: 9 additions & 4 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from pymc3.util import get_variable_name
from .dist_math import bound, factln, binomln, betaln, logpow, random_choice
from .distribution import Discrete, draw_values, generate_samples
from .distribution import (Discrete, draw_values, generate_samples,
broadcast_distribution_samples)
from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp


Expand Down Expand Up @@ -345,6 +346,7 @@ def _ppf(self, p):

def _random(self, q, beta, size=None):
p = np.random.uniform(size=size)
p, q, beta = broadcast_distribution_samples([p, q, beta], size=size)

return np.ceil(np.power(np.log(1 - p) / np.log(q), 1. / beta)) - 1

Expand Down Expand Up @@ -847,7 +849,8 @@ def random(self, point=None, size=None):
g = generate_samples(stats.poisson.rvs, theta,
dist_shape=self.shape,
size=size)
return g * (np.random.random(np.squeeze(g.shape)) < psi)
g, psi = broadcast_distribution_samples([g, psi], size=size)
return g * (np.random.random(g.shape) < psi)

def logp(self, value):
psi = self.psi
Expand Down Expand Up @@ -939,7 +942,8 @@ def random(self, point=None, size=None):
g = generate_samples(stats.binom.rvs, n, p,
dist_shape=self.shape,
size=size)
return g * (np.random.random(np.squeeze(g.shape)) < psi)
g, psi = broadcast_distribution_samples([g, psi], size=size)
return g * (np.random.random(g.shape) < psi)

def logp(self, value):
psi = self.psi
Expand Down Expand Up @@ -1057,7 +1061,8 @@ def random(self, point=None, size=None):
dist_shape=self.shape,
size=size)
g[g == 0] = np.finfo(float).eps # Just in case
return stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
g, psi = broadcast_distribution_samples([g, psi], size=size)
return stats.poisson.rvs(g) * (np.random.random(g.shape) < psi)

def logp(self, value):
alpha = self.alpha
Expand Down
27 changes: 27 additions & 0 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,30 @@ def generate_samples(generator, *args, **kwargs):
if one_d and samples.shape[-1] == 1:
samples = samples.reshape(samples.shape[:-1])
return np.asarray(samples)


def broadcast_distribution_samples(samples, size=None):
if size is None:
return np.broadcast_arrays(*samples)
_size = to_tuple(size)
try:
broadcasted_samples = np.broadcast_arrays(*samples)
except ValueError:
# Raw samples shapes
p_shapes = [p.shape for p in samples]
# samples shapes without the size prepend
sp_shapes = [s[len(_size):] if _size == s[:len(_size)] else s
for s in p_shapes]
broadcast_shape = np.broadcast(*[np.empty(s) for s in sp_shapes]).shape
broadcasted_samples = []
for param, p_shape, sp_shape in zip(samples, p_shapes, sp_shapes):
if _size == p_shape[:len(_size)]:
slicer_head = [slice(None)] * len(_size)
else:
slicer_head = [np.newaxis] * len(_size)
slicer_tail = ([np.newaxis] * (len(broadcast_shape) -
len(sp_shape)) +
[slice(None)] * len(sp_shape))
broadcasted_samples.append(param[tuple(slicer_head + slicer_tail)])
broadcasted_samples = np.broadcast_arrays(*broadcasted_samples)
return broadcasted_samples
10 changes: 10 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,13 @@ def test_shape_edgecase(self):
x = pm.Normal('x', mu=mu, sd=sd, shape=5)
prior = pm.sample_prior_predictive(10)
assert prior['mu'].shape == (10, 5)

def test_zeroinflatedpoisson(self):
with pm.Model():
theta = pm.Beta('theta', alpha=1, beta=1)
psi = pm.HalfNormal('psi', sd=1)
pm.ZeroInflatedPoisson('suppliers', psi=psi, theta=theta, shape=20)
gen_data = pm.sample_prior_predictive(samples=5000)
assert gen_data['theta'].shape == (5000,)
assert gen_data['psi'].shape == (5000,)
assert gen_data['suppliers'].shape == (5000, 20)

0 comments on commit e67c476

Please sign in to comment.