Skip to content

Commit

Permalink
Do not create temporary SymbolicDistribution just to retrieve number …
Browse files Browse the repository at this point in the history
…of RNGs needed

Reordered methods for consistency
  • Loading branch information
ricardoV94 committed May 4, 2022
1 parent a3bd083 commit d510e9b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 38 deletions.
32 changes: 16 additions & 16 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def dist(cls, dist, lower, upper, **kwargs):
check_dist_not_registered(dist)
return super().dist([dist, lower, upper], **kwargs)

@classmethod
def num_rngs(cls, *args, **kwargs):
return 1

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):

Expand All @@ -96,24 +104,12 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
rv_out.tag.upper = upper

if rngs is not None:
rv_out = cls.change_rngs(rv_out, rngs)
rv_out = cls._change_rngs(rv_out, rngs)

return rv_out

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def change_size(cls, rv, new_size, expand=False):
dist = rv.tag.dist
lower = rv.tag.lower
upper = rv.tag.upper
new_dist = change_rv_size(dist, new_size, expand=expand)
return cls.rv_op(new_dist, lower, upper)

@classmethod
def change_rngs(cls, rv, new_rngs):
def _change_rngs(cls, rv, new_rngs):
(new_rng,) = new_rngs
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
Expand All @@ -123,8 +119,12 @@ def change_rngs(cls, rv, new_rngs):
return cls.rv_op(new_dist, lower, upper)

@classmethod
def graph_rvs(cls, rv):
return (rv.tag.dist,)
def change_size(cls, rv, new_size, expand=False):
dist = rv.tag.dist
lower = rv.tag.lower
upper = rv.tag.upper
new_dist = change_rv_size(dist, new_size, expand=expand)
return cls.rv_op(new_dist, lower, upper)


@_moment.register(Clip)
Expand Down
19 changes: 10 additions & 9 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,18 +396,19 @@ def __new__(
to a canonical parametrization. It should call `super().dist()`, passing a
list with the default parameters as the first and only non keyword argument,
followed by other keyword arguments like size and rngs, and return the result
cls.num_rngs
Returns the number of rngs given the same arguments passed by the user when
calling the distribution
cls.ndim_supp
Returns the support of the symbolic distribution, given the default set of
parameters. This may not always be constant, for instance if the symbolic
distribution can be defined based on an arbitrary base distribution.
cls.rv_op
Returns a TensorVariable that represents the symbolic distribution
parametrized by a default set of parameters and a size and rngs arguments
cls.ndim_supp
Returns the support of the symbolic distribution, given the default
parameters. This may not always be constant, for instance if the symbolic
distribution can be defined based on an arbitrary base distribution.
cls.change_size
Returns an equivalent symbolic distribution with a different size. This is
analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s.
cls.graph_rvs
Returns base RVs in a symbolic distribution.
Parameters
----------
Expand Down Expand Up @@ -465,9 +466,9 @@ def __new__(
raise TypeError(f"Name needs to be a string but got: {name}")

if rngs is None:
# Create a temporary rv to obtain number of rngs needed
temp_graph = cls.dist(*args, rngs=None, **kwargs)
rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)]
# Instead of passing individual RNG variables we could pass a RandomStream
# and let the classes create as many RNGs as they need
rngs = [model.next_rng() for _ in range(cls.num_rngs(*args, **kwargs))]
elif not isinstance(rngs, (list, tuple)):
rngs = [rngs]

Expand Down
25 changes: 12 additions & 13 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,18 @@ def dist(cls, w, comp_dists, **kwargs):
w = at.as_tensor_variable(w)
return super().dist([w, *comp_dists], **kwargs)

@classmethod
def num_rngs(cls, w, comp_dists, **kwargs):
if not isinstance(comp_dists, (tuple, list)):
# comp_dists is a single component
comp_dists = [comp_dists]
return len(comp_dists) + 1

@classmethod
def ndim_supp(cls, weights, *components):
# We already checked that all components have the same support dimensionality
return components[0].owner.op.ndim_supp

@classmethod
def rv_op(cls, weights, *components, size=None, rngs=None):
# Update rngs if provided
Expand Down Expand Up @@ -329,11 +341,6 @@ def _resize_components(cls, size, *components):

return [change_rv_size(component, size) for component in components]

@classmethod
def ndim_supp(cls, weights, *components):
# We already checked that all components have the same support dimensionality
return components[0].owner.op.ndim_supp

@classmethod
def change_size(cls, rv, new_size, expand=False):
weights = rv.tag.weights
Expand All @@ -355,14 +362,6 @@ def change_size(cls, rv, new_size, expand=False):

return cls.rv_op(weights, *components, rngs=rngs, size=None)

@classmethod
def graph_rvs(cls, rv):
# We return rv, which is itself a pseudo RandomVariable, that contains a
# mix_indexes_ RV in its inner graph. We want super().dist() to generate
# (components + 1) rngs for us, and it will do so based on how many elements
# we return here
return (*rv.tag.components, rv)


@_get_measurable_outputs.register(MarginalMixtureRV)
def _get_measurable_outputs_MarginalMixtureRV(op, node):
Expand Down

0 comments on commit d510e9b

Please sign in to comment.