Skip to content

Commit

Permalink
Added point parameter to rand call
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz authored and rpgoldman committed Aug 16, 2019
1 parent 8f74ea9 commit 16a1d76
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,15 @@ def __init__(

def random(self, point=None, size=None, **kwargs):
if self.rand is not None:
not_broadcast_kwargs = dict(point=point)
not_broadcast_kwargs.update(**kwargs)
if self.wrap_random_with_dist_shape:
size = to_tuple(size)
with _DrawValuesContextBlocker():
test_draw = generate_samples(
self.rand,
size=None,
not_broadcast_kwargs=kwargs,
not_broadcast_kwargs=not_broadcast_kwargs,
)
test_shape = test_draw.shape
if self.shape[:len(size)] == size:
Expand All @@ -406,10 +408,10 @@ def random(self, point=None, size=None, **kwargs):
self.rand,
broadcast_shape=broadcast_shape,
size=size,
not_broadcast_kwargs=kwargs,
not_broadcast_kwargs=not_broadcast_kwargs,
)
else:
samples = self.rand(size=size, **kwargs)
samples = self.rand(point=point, size=size, **kwargs)
if self.check_shape_in_random:
expected_shape = (
self.shape
Expand Down

0 comments on commit 16a1d76

Please sign in to comment.