Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove rng seeder #5787

Merged
merged 6 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ Also check out the [milestones](https://github.com/pymc-devs/pymc/milestones) fo

All of the above apply to:

Signature and default parameters changed for several distributions:
⚠ Random seeding behavior changed!
- Sampling results will differ from those of V3 when passing the same random_state as before. They will be consitent across subsequent V4 releases unless mentioned otherwise.
- Sampling functions no longer respect user-specified global seeding! Always pass `random_seed` to ensure reproducible behavior.
- Signature and default parameters changed for several distributions:
- `pm.StudentT` now requires either `sigma` or `lam` as kwarg (see [#5628](https://github.com/pymc-devs/pymc/pull/5628))
- `pm.StudentT` now requires `nu` to be specified (no longer defaults to 1) (see [#5628](https://github.com/pymc-devs/pymc/pull/5628))
- `pm.AsymmetricLaplace` positional arguments re-ordered (see [#5628](https://github.com/pymc-devs/pymc/pull/5628))
Expand Down
9 changes: 6 additions & 3 deletions benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,16 @@ def time_glm_hierarchical_init(self, init):
"""How long does it take to run the initialization."""
with glm_hierarchical_model():
pm.init_nuts(
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
init=init,
chains=self.chains,
progressbar=False,
random_seed=np.arange(self.chains),
)

def track_glm_hierarchical_ess(self, init):
with glm_hierarchical_model():
start, step = pm.init_nuts(
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
init=init, chains=self.chains, progressbar=False, random_seed=np.arange(self.chains)
)
t0 = time.time()
idata = pm.sample(
Expand All @@ -201,7 +204,7 @@ def track_marginal_mixture_model_ess(self, init):
model, start = mixture_model()
with model:
_, step = pm.init_nuts(
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
init=init, chains=self.chains, progressbar=False, random_seed=np.arange(self.chains)
)
start = [{k: v for k, v in start.items()} for _ in range(self.chains)]
t0 = time.time()
Expand Down
63 changes: 61 additions & 2 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
Expand Down Expand Up @@ -893,11 +894,64 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
)


def find_rng_nodes(variables: Iterable[TensorVariable]):
"""Return RNG variables in a graph"""
return [
node
for node in graph_inputs(variables)
if isinstance(
node,
(
at.random.var.RandomStateSharedVariable,
at.random.var.RandomGeneratorSharedVariable,
),
)
]


SeedSequenceSeed = Optional[Union[int, Sequence[int], np.ndarray, np.random.SeedSequence]]


def reseed_rngs(
rngs: Sequence[SharedVariable],
seed: SeedSequenceSeed,
) -> None:
"""Create a new set of RandomState/Generator for each rng based on a seed"""
bit_generators = [
np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
lucianopaz marked this conversation as resolved.
Show resolved Hide resolved
]
for rng, bit_generator in zip(rngs, bit_generators):
if isinstance(rng, at.random.var.RandomStateSharedVariable):
new_rng = np.random.RandomState(bit_generator)
lucianopaz marked this conversation as resolved.
Show resolved Hide resolved
else:
new_rng = np.random.Generator(bit_generator)
rng.set_value(new_rng, borrow=True)


def compile_pymc(
inputs, outputs, mode=None, **kwargs
inputs,
outputs,
random_seed: SeedSequenceSeed = None,
mode=None,
**kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
"""Use ``aesara.function`` with specialized pymc rewrites always enabled.

This function also ensures shared RandomState/Generator used by RandomVariables
in the graph are updated across calls, to ensure independent draws.

Parameters
----------
inputs: list of TensorVariables, optional
Inputs of the compiled Aesara function
outputs: list of TensorVariables, optional
Outputs of the compiled Aesara function
random_seed: int, array-like of int or SeedSequence, optional
Seed used to override any RandomState/Generator shared variables in the graph.
If not specified, the value of original shared variables will still be overwritten.
mode: optional
Aesara mode used to compile the function

Included rewrites
-----------------
random_make_inplace
Expand All @@ -917,7 +971,6 @@ def compile_pymc(
"""
# Create an update mapping of RandomVariable's RNG so that it is automatically
# updated after every function call
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
rng_updates = {}
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
for random_var in (
Expand All @@ -931,11 +984,17 @@ def compile_pymc(
rng = random_var.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng_updates[rng] = random_var.owner.outputs[0]
else:
rng_updates[rng] = rng.default_update
else:
update_fn = getattr(random_var.owner.op, "update", None)
if update_fn is not None:
rng_updates.update(update_fn(random_var.owner))

# We always reseed random variables as this provides RNGs with no chances of collision
if rng_updates:
reseed_rngs(rng_updates.keys(), random_seed)

# If called inside a model context, see if check_bounds flag is set to False
try:
from pymc.model import modelcontext
Expand Down
19 changes: 1 addition & 18 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,12 @@ def dist(cls, dist, lower, upper, **kwargs):
check_dist_not_registered(dist)
return super().dist([dist, lower, upper], **kwargs)

@classmethod
def num_rngs(cls, *args, **kwargs):
return 1

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
def rv_op(cls, dist, lower=None, upper=None, size=None):

lower = at.constant(-np.inf) if lower is None else at.as_tensor_variable(lower)
upper = at.constant(np.inf) if upper is None else at.as_tensor_variable(upper)
Expand All @@ -112,21 +108,8 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
rv_out.tag.lower = lower
rv_out.tag.upper = upper

if rngs is not None:
rv_out = cls._change_rngs(rv_out, rngs)

return rv_out

@classmethod
def _change_rngs(cls, rv, new_rngs):
lucianopaz marked this conversation as resolved.
Show resolved Hide resolved
(new_rng,) = new_rngs
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
upper = rv.tag.upper
olg_rng, size, dtype, *dist_params = dist_node.inputs
new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

@classmethod
def change_size(cls, rv, new_size, expand=False):
dist = rv.tag.dist
Expand Down
22 changes: 3 additions & 19 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABCMeta
from functools import singledispatch
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union, cast
from typing import Callable, Optional, Sequence, Tuple, Union, cast

import aesara
import numpy as np
Expand Down Expand Up @@ -258,13 +258,10 @@ def __new__(
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

if rng is None:
rng = model.next_rng()

# Create the RV and process dims and observed to determine
# a shape by which the created RV may need to be resized.
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
)

if resize_shape:
Expand Down Expand Up @@ -383,9 +380,6 @@ class SymbolicDistribution:
to a canonical parametrization. It should call `super().dist()`, passing a
list with the default parameters as the first and only non keyword argument,
followed by other keyword arguments like size and rngs, and return the result
cls.num_rngs
Returns the number of rngs given the same arguments passed by the user when
calling the distribution
cls.ndim_supp
Returns the support of the symbolic distribution, given the default set of
parameters. This may not always be constant, for instance if the symbolic
Expand All @@ -402,7 +396,6 @@ def __new__(
cls,
name: str,
*args,
rngs: Optional[Iterable] = None,
dims: Optional[Dims] = None,
initval=None,
observed=None,
Expand All @@ -419,8 +412,6 @@ def __new__(
A distribution class that inherits from SymbolicDistribution.
name : str
Name for the new model variable.
rngs : optional
Random number generator to use for the RandomVariable(s) in the graph.
dims : tuple, optional
A tuple of dimension names known to the model.
initval : optional
Expand Down Expand Up @@ -468,17 +459,10 @@ def __new__(
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

if rngs is None:
# Instead of passing individual RNG variables we could pass a RandomStream
# and let the classes create as many RNGs as they need
rngs = [model.next_rng() for _ in range(cls.num_rngs(*args, **kwargs))]
elif not isinstance(rngs, (list, tuple)):
rngs = [rngs]

# Create the RV and process dims and observed to determine
# a shape by which the created RV may need to be resized.
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
cls=cls, dims=dims, model=model, observed=observed, args=args, rngs=rngs, **kwargs
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
)

if resize_shape:
Expand Down
34 changes: 4 additions & 30 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,27 +205,15 @@ def dist(cls, w, comp_dists, **kwargs):
w = at.as_tensor_variable(w)
return super().dist([w, *comp_dists], **kwargs)

@classmethod
def num_rngs(cls, w, comp_dists, **kwargs):
if not isinstance(comp_dists, (tuple, list)):
# comp_dists is a single component
comp_dists = [comp_dists]
return len(comp_dists) + 1

@classmethod
def ndim_supp(cls, weights, *components):
# We already checked that all components have the same support dimensionality
return components[0].owner.op.ndim_supp

@classmethod
def rv_op(cls, weights, *components, size=None, rngs=None):
# Update rngs if provided
if rngs is not None:
components = cls._reseed_components(rngs, *components)
*_, mix_indexes_rng = rngs
else:
# Create new rng for the mix_indexes internal RV
mix_indexes_rng = aesara.shared(np.random.default_rng())
def rv_op(cls, weights, *components, size=None):
# Create new rng for the mix_indexes internal RV
mix_indexes_rng = aesara.shared(np.random.default_rng())

single_component = len(components) == 1
ndim_supp = components[0].owner.op.ndim_supp
Expand Down Expand Up @@ -317,19 +305,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None):

return mix_out

@classmethod
def _reseed_components(cls, rngs, *components):
*components_rngs, mix_indexes_rng = rngs
assert len(components) == len(components_rngs)
new_components = []
for component, component_rng in zip(components, components_rngs):
component_node = component.owner
old_rng, *inputs = component_node.inputs
new_components.append(
component_node.op.make_node(component_rng, *inputs).default_output()
)
return new_components

@classmethod
def _resize_components(cls, size, *components):
if len(components) == 1:
Expand All @@ -345,7 +320,6 @@ def _resize_components(cls, size, *components):
def change_size(cls, rv, new_size, expand=False):
weights = rv.tag.weights
components = rv.tag.components
rngs = [component.owner.inputs[0] for component in components] + [rv.tag.choices_rng]

if expand:
component = rv.tag.components[0]
Expand All @@ -360,7 +334,7 @@ def change_size(cls, rv, new_size, expand=False):

components = cls._resize_components(new_size, *components)

return cls.rv_op(weights, *components, rngs=rngs, size=None)
return cls.rv_op(weights, *components, size=None)


@_get_measurable_outputs.register(MarginalMixtureRV)
Expand Down
24 changes: 3 additions & 21 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,28 +494,12 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant:

return ar_order

@classmethod
def num_rngs(cls, *args, **kwargs):
return 2

@classmethod
def ndim_supp(cls, *args):
return 1

@classmethod
def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None, rngs=None):

if rngs is None:
rngs = [
aesara.shared(np.random.default_rng(seed))
for seed in np.random.SeedSequence().spawn(2)
]
(init_dist_rng, noise_rng) = rngs
# Re-seed init_dist
if init_dist.owner.inputs[0] is not init_dist_rng:
_, *inputs = init_dist.owner.inputs
init_dist = init_dist.owner.op.make_node(init_dist_rng, *inputs).default_output()

def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None):
# Init dist should have shape (*size, ar_order)
if size is not None:
batch_size = size
Expand Down Expand Up @@ -543,6 +527,8 @@ def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None
rhos_bcast_shape_ = (*rhos_bcast_shape_[:-1], rhos_bcast_shape_[-1] + 1)
rhos_bcast_ = at.broadcast_to(rhos_, rhos_bcast_shape_)

noise_rng = aesara.shared(np.random.default_rng())

def step(*args):
*prev_xs, reversed_rhos, sigma, rng = args
if constant_term:
Expand Down Expand Up @@ -581,16 +567,12 @@ def change_size(cls, rv, new_size, expand=False):
old_size = rv.shape[:-1]
new_size = at.concatenate([new_size, old_size])

init_dist_rng = rv.owner.inputs[2].owner.inputs[0]
noise_rng = rv.owner.inputs[-1]

op = rv.owner.op
return cls.rv_op(
*rv.owner.inputs,
ar_order=op.ar_order,
constant_term=op.constant_term,
size=new_size,
rngs=(init_dist_rng, noise_rng),
)


Expand Down
Loading