From cee01252d8e21d28dc6fd64f3e26c11c24ce8075 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 19 May 2022 19:54:09 +0200 Subject: [PATCH] Remove model.rng_seeder and allow compile_pymc to reseed RNG variables in compiled function --- pymc/aesaraf.py | 24 ++++++- pymc/distributions/censored.py | 19 +----- pymc/distributions/distribution.py | 22 +------ pymc/distributions/mixture.py | 34 ++-------- pymc/distributions/timeseries.py | 24 +------ pymc/model.py | 45 ++------------ pymc/sampling.py | 83 +++++++++++-------------- pymc/sampling_jax.py | 17 +++-- pymc/tests/models.py | 2 +- pymc/tests/test_aesaraf.py | 32 +++++++++- pymc/tests/test_distributions.py | 21 ++++--- pymc/tests/test_distributions_random.py | 8 +-- pymc/tests/test_mixture.py | 20 +++--- pymc/tests/test_model.py | 2 +- pymc/tests/test_sampling.py | 50 +++++++++------ pymc/tests/test_sampling_jax.py | 4 +- pymc/tests/test_step.py | 6 +- pymc/tuning/starting.py | 2 - pymc/variational/opvi.py | 2 +- 19 files changed, 175 insertions(+), 242 deletions(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 6cf11ed7411..1f2994f5ae0 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -922,10 +922,25 @@ def reseed_rngs(rngs: Iterable[SharedVariable], seed: Optional[int]) -> None: def compile_pymc( - inputs, outputs, mode=None, **kwargs + inputs, + outputs, + random_seed=None, + mode=None, + **kwargs, ) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]: """Use ``aesara.function`` with specialized pymc rewrites always enabled. + Parameters + ---------- + inputs: list of TensorVariables, optional + Inputs of the compiled Aesara function + outputs: list of TensorVariables, options + Outputs of the compiled Aesara function + random_seed: int, optional + Seed used override any RandomState/Generator variables in the graph + mode: optional + Aesara mode used to compile the function + Included rewrites ----------------- random_make_inplace @@ -945,7 +960,6 @@ def compile_pymc( """ # Create an update mapping of RandomVariable's RNG so that it is automatically # updated after every function call - # TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph) rng_updates = {} output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs] for random_var in ( @@ -959,11 +973,17 @@ def compile_pymc( rng = random_var.owner.inputs[0] if not hasattr(rng, "default_update"): rng_updates[rng] = random_var.owner.outputs[0] + else: + rng_updates[rng] = rng.default_update else: update_fn = getattr(random_var.owner.op, "update", None) if update_fn is not None: rng_updates.update(update_fn(random_var.owner)) + # We always reseed random variables as this provides RNGs with no chances of collision + if rng_updates: + reseed_rngs(rng_updates.keys(), random_seed) + # If called inside a model context, see if check_bounds flag is set to False try: from pymc.model import modelcontext diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index a59965c8576..d7bd343b5e4 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -76,16 +76,12 @@ def dist(cls, dist, lower, upper, **kwargs): check_dist_not_registered(dist) return super().dist([dist, lower, upper], **kwargs) - @classmethod - def num_rngs(cls, *args, **kwargs): - return 1 - @classmethod def ndim_supp(cls, *dist_params): return 0 @classmethod - def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): + def rv_op(cls, dist, lower=None, upper=None, size=None): lower = at.constant(-np.inf) if lower is None else at.as_tensor_variable(lower) upper = at.constant(np.inf) if upper is None else at.as_tensor_variable(upper) @@ -103,21 +99,8 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): rv_out.tag.lower = lower rv_out.tag.upper = upper - if rngs is not None: - rv_out = cls._change_rngs(rv_out, rngs) - return rv_out - @classmethod - def _change_rngs(cls, rv, new_rngs): - (new_rng,) = new_rngs - dist_node = rv.tag.dist.owner - lower = rv.tag.lower - upper = rv.tag.upper - olg_rng, size, dtype, *dist_params = dist_node.inputs - new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output() - return cls.rv_op(new_dist, lower, upper) - @classmethod def change_size(cls, rv, new_size, expand=False): dist = rv.tag.dist diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 138b3cd2532..98b661ca231 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -19,7 +19,7 @@ from abc import ABCMeta from functools import singledispatch -from typing import Callable, Iterable, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Optional, Sequence, Tuple, Union, cast import aesara import numpy as np @@ -258,13 +258,10 @@ def __new__( if not isinstance(name, string_types): raise TypeError(f"Name needs to be a string but got: {name}") - if rng is None: - rng = model.next_rng() - # Create the RV and process dims and observed to determine # a shape by which the created RV may need to be resized. rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( - cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs + cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs ) if resize_shape: @@ -383,9 +380,6 @@ class SymbolicDistribution: to a canonical parametrization. It should call `super().dist()`, passing a list with the default parameters as the first and only non keyword argument, followed by other keyword arguments like size and rngs, and return the result - cls.num_rngs - Returns the number of rngs given the same arguments passed by the user when - calling the distribution cls.ndim_supp Returns the support of the symbolic distribution, given the default set of parameters. This may not always be constant, for instance if the symbolic @@ -402,7 +396,6 @@ def __new__( cls, name: str, *args, - rngs: Optional[Iterable] = None, dims: Optional[Dims] = None, initval=None, observed=None, @@ -419,8 +412,6 @@ def __new__( A distribution class that inherits from SymbolicDistribution. name : str Name for the new model variable. - rngs : optional - Random number generator to use for the RandomVariable(s) in the graph. dims : tuple, optional A tuple of dimension names known to the model. initval : optional @@ -468,17 +459,10 @@ def __new__( if not isinstance(name, string_types): raise TypeError(f"Name needs to be a string but got: {name}") - if rngs is None: - # Instead of passing individual RNG variables we could pass a RandomStream - # and let the classes create as many RNGs as they need - rngs = [model.next_rng() for _ in range(cls.num_rngs(*args, **kwargs))] - elif not isinstance(rngs, (list, tuple)): - rngs = [rngs] - # Create the RV and process dims and observed to determine # a shape by which the created RV may need to be resized. rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( - cls=cls, dims=dims, model=model, observed=observed, args=args, rngs=rngs, **kwargs + cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs ) if resize_shape: diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index a1c6129ebea..952016d0b7e 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -205,27 +205,15 @@ def dist(cls, w, comp_dists, **kwargs): w = at.as_tensor_variable(w) return super().dist([w, *comp_dists], **kwargs) - @classmethod - def num_rngs(cls, w, comp_dists, **kwargs): - if not isinstance(comp_dists, (tuple, list)): - # comp_dists is a single component - comp_dists = [comp_dists] - return len(comp_dists) + 1 - @classmethod def ndim_supp(cls, weights, *components): # We already checked that all components have the same support dimensionality return components[0].owner.op.ndim_supp @classmethod - def rv_op(cls, weights, *components, size=None, rngs=None): - # Update rngs if provided - if rngs is not None: - components = cls._reseed_components(rngs, *components) - *_, mix_indexes_rng = rngs - else: - # Create new rng for the mix_indexes internal RV - mix_indexes_rng = aesara.shared(np.random.default_rng()) + def rv_op(cls, weights, *components, size=None): + # Create new rng for the mix_indexes internal RV + mix_indexes_rng = aesara.shared(np.random.default_rng()) single_component = len(components) == 1 ndim_supp = components[0].owner.op.ndim_supp @@ -317,19 +305,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None): return mix_out - @classmethod - def _reseed_components(cls, rngs, *components): - *components_rngs, mix_indexes_rng = rngs - assert len(components) == len(components_rngs) - new_components = [] - for component, component_rng in zip(components, components_rngs): - component_node = component.owner - old_rng, *inputs = component_node.inputs - new_components.append( - component_node.op.make_node(component_rng, *inputs).default_output() - ) - return new_components - @classmethod def _resize_components(cls, size, *components): if len(components) == 1: @@ -345,7 +320,6 @@ def _resize_components(cls, size, *components): def change_size(cls, rv, new_size, expand=False): weights = rv.tag.weights components = rv.tag.components - rngs = [component.owner.inputs[0] for component in components] + [rv.tag.choices_rng] if expand: component = rv.tag.components[0] @@ -360,7 +334,7 @@ def change_size(cls, rv, new_size, expand=False): components = cls._resize_components(new_size, *components) - return cls.rv_op(weights, *components, rngs=rngs, size=None) + return cls.rv_op(weights, *components, size=None) @_get_measurable_outputs.register(MarginalMixtureRV) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 5e1bb1e1a24..9c0b547a67f 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -494,28 +494,12 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: return ar_order - @classmethod - def num_rngs(cls, *args, **kwargs): - return 2 - @classmethod def ndim_supp(cls, *args): return 1 @classmethod - def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None, rngs=None): - - if rngs is None: - rngs = [ - aesara.shared(np.random.default_rng(seed)) - for seed in np.random.SeedSequence().spawn(2) - ] - (init_dist_rng, noise_rng) = rngs - # Re-seed init_dist - if init_dist.owner.inputs[0] is not init_dist_rng: - _, *inputs = init_dist.owner.inputs - init_dist = init_dist.owner.op.make_node(init_dist_rng, *inputs).default_output() - + def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None): # Init dist should have shape (*size, ar_order) if size is not None: batch_size = size @@ -543,6 +527,8 @@ def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None rhos_bcast_shape_ = (*rhos_bcast_shape_[:-1], rhos_bcast_shape_[-1] + 1) rhos_bcast_ = at.broadcast_to(rhos_, rhos_bcast_shape_) + noise_rng = aesara.shared(np.random.default_rng()) + def step(*args): *prev_xs, reversed_rhos, sigma, rng = args if constant_term: @@ -581,16 +567,12 @@ def change_size(cls, rv, new_size, expand=False): old_size = rv.shape[:-1] new_size = at.concatenate([new_size, old_size]) - init_dist_rng = rv.owner.inputs[2].owner.inputs[0] - noise_rng = rv.owner.inputs[-1] - op = rv.owner.op return cls.rv_op( *rv.owner.inputs, ar_order=op.ar_order, constant_term=op.constant_term, size=new_size, - rngs=(init_dist_rng, noise_rng), ) diff --git a/pymc/model.py b/pymc/model.py index a3c20bb81c4..468fd3aa163 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -44,7 +44,6 @@ from aesara.graph.basic import Constant, Variable, graph_inputs from aesara.graph.fg import FunctionGraph from aesara.tensor.random.opt import local_subtensor_rv_lift -from aesara.tensor.random.var import RandomStateSharedVariable from aesara.tensor.sharedvar import ScalarSharedVariable from aesara.tensor.var import TensorConstant, TensorVariable @@ -445,13 +444,6 @@ class Model(WithMemoization, metaclass=ContextMeta): parameters can only take on valid values you can set this to False for increased speed. This should not be used if your model contains discrete variables. - rng_seeder: int or numpy.random.RandomState - The ``numpy.random.RandomState`` used to seed the - ``RandomStateSharedVariable`` sequence used by a model - ``RandomVariable``s, or an int used to seed a new - ``numpy.random.RandomState``. If ``None``, a - ``RandomStateSharedVariable`` will be generated and used. Incremental - access to the state sequence is provided by ``Model.next_rng``. Examples -------- @@ -549,20 +541,10 @@ def __init__( name="", coords=None, check_bounds=True, - rng_seeder: Optional[Union[int, np.random.RandomState]] = None, ): self.name = self._validate_name(name) self.check_bounds = check_bounds - if rng_seeder is None: - self.rng_seeder = np.random.RandomState() - elif isinstance(rng_seeder, int): - self.rng_seeder = np.random.RandomState(rng_seeder) - else: - self.rng_seeder = rng_seeder - - # The sequence of model-generated RNGs - self.rng_seq: List[SharedVariable] = [] self._initial_values: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]] = {} if self.parent is not None: @@ -1016,8 +998,6 @@ def initial_point(self, seed=None) -> Dict[str, np.ndarray]: ip : dict Maps names of transformed variables to numeric initial values in the transformed space. """ - if seed is None: - seed = self.rng_seeder.randint(2**30, dtype=np.int64) fn = make_initial_point_fn(model=self, return_transformed=True) return Point(fn(seed), model=self) @@ -1038,20 +1018,6 @@ def set_initval(self, rv_var, initval): self.initial_values[rv_var] = initval - def next_rng(self) -> RandomStateSharedVariable: - """Generate a new ``RandomStateSharedVariable``. - - The new ``RandomStateSharedVariable`` is also added to - ``Model.rng_seq``. - """ - new_seed = self.rng_seeder.randint(2**30, dtype=np.int64) - next_rng = aesara.shared(np.random.RandomState(new_seed), borrow=True) - next_rng.tag.is_rng = True - - self.rng_seq.append(next_rng) - - return next_rng - def shape_from_dims(self, dims): shape = [] if len(set(dims)) != len(dims): @@ -1381,14 +1347,11 @@ def make_obs_var( clone=False, ) (observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) - # Make a clone of the RV, but change the rng so that observed and missing - # are not treated as equivalent nodes by aesara. This would happen if the - # size of the masked and unmasked array happened to coincide + # Make a clone of the RV, but let it create a new rng so that observed and + # missing are not treated as equivalent nodes by aesara. This would happen + # if the size of the masked and unmasked array happened to coincide _, size, _, *inps = observed_rv_var.owner.inputs - rng = self.model.next_rng() - observed_rv_var = observed_rv_var.owner.op( - *inps, size=size, rng=rng, name=f"{name}_observed" - ) + observed_rv_var = observed_rv_var.owner.op(*inps, size=size, name=f"{name}_observed") observed_rv_var.tag.observations = nonmissing_data self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) diff --git a/pymc/sampling.py b/pymc/sampling.py index cb439ad0008..6a59cac2803 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -14,7 +14,6 @@ """Functions for MCMC sampling.""" -import collections.abc as abc import logging import pickle import sys @@ -253,6 +252,32 @@ def all_continuous(vars): return True +def _get_seeds_per_chain( + random_seed: Optional[Union[int, Sequence[int], np.random.BitGenerator, np.random.RandomState]], + chains: int, +) -> Sequence[int]: + def _get_unique_sees_per_chain(rng): + seeds = [] + while len(set(seeds)) != chains: + seeds = rng.randint(2**30, dtype=np.int64, size=chains) + return seeds + + if random_seed is None: + return _get_unique_sees_per_chain(np.random.default_rng()) + if isinstance(random_seed, (np.random.Generator, np.random.RandomState)): + return _get_unique_sees_per_chain(random_seed) + + if not isinstance(random_seed, (list, tuple, np.ndarray)): + raise ValueError(f"The `seeds` must be array-like. Got {type(random_seed)} instead.") + + if len(random_seed) != chains: + raise ValueError( + f"Number of seeds ({len(random_seed)}) does not match the number of chains ({chains})." + ) + + return random_seed + + def sample( draws: int = 1000, step=None, @@ -434,18 +459,10 @@ def sample( if chains is None: chains = max(2, cores) + if random_seed == -1: random_seed = None - if chains == 1 and isinstance(random_seed, int): - random_seed = [random_seed] - - if random_seed is None or isinstance(random_seed, int): - if random_seed is not None: - np.random.seed(random_seed) - random_seed = [np.random.randint(2**30) for _ in range(chains)] - - if not isinstance(random_seed, abc.Iterable): - raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int") + random_seed = _get_seeds_per_chain(random_seed, chains) if not discard_tuned_samples and not return_inferencedata: warnings.warn( @@ -1838,14 +1855,6 @@ def sample_posterior_predictive( else: vars_ = model.observed_RVs + model.auto_deterministics - if random_seed is not None: - warnings.warn( - "In this version, RNG seeding is managed by the Model objects. " - "See the `rng_seeder` argument in Model's constructor.", - FutureWarning, - stacklevel=2, - ) - indices = np.arange(samples) if progressbar: @@ -1873,6 +1882,7 @@ def sample_posterior_predictive( vars_in_trace=vars_in_trace, basic_rvs=model.basic_RVs, givens_dict=None, + random_seed=random_seed, **compile_kwargs, ) ) @@ -1992,14 +2002,6 @@ def sample_posterior_predictive_w( if models is None: models = [modelcontext(models)] * len(traces) - if random_seed: - warnings.warn( - "In this version, RNG seeding is managed by the Model objects. " - "See the `rng_seeder` argument in Model's constructor.", - FutureWarning, - stacklevel=2, - ) - for model in models: if model.potentials: warnings.warn( @@ -2159,14 +2161,6 @@ def sample_prior_predictive( else: vars_ = set(var_names) - if random_seed is not None: - warnings.warn( - "In this version, RNG seeding is managed by the Model objects. " - "See the `rng_seeder` argument in Model's constructor.", - FutureWarning, - stacklevel=2, - ) - names = get_default_varnames(vars_, include_transformed=False) vars_to_sample = [model[name] for name in names] @@ -2198,6 +2192,7 @@ def sample_prior_predictive( vars_in_trace=[], basic_rvs=model.basic_RVs, givens_dict=None, + random_seed=random_seed, **compile_kwargs, ) @@ -2223,6 +2218,7 @@ def sample_prior_predictive( def draw( vars: Union[Variable, Sequence[Variable]], draws: int = 1, + random_seed=None, **kwargs, ) -> Union[np.ndarray, List[np.ndarray]]: """Draw samples for one variable or a list of variables @@ -2233,6 +2229,7 @@ def draw( A variable or a list of variables for which to draw samples. draws : int, default 1 Number of samples needed to draw. + random_seed: int, Optional **kwargs : dict, optional Keyword arguments for :func:`pymc.aesara.compile_pymc`. @@ -2265,8 +2262,7 @@ def draw( assert draws[1].shape == (num_draws, 10) assert draws[2].shape == (num_draws, 5) """ - - draw_fn = compile_pymc(inputs=[], outputs=vars, **kwargs) + draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs) if draws == 1: return draw_fn() @@ -2285,7 +2281,7 @@ def draw( def _init_jitter( model: Model, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], - seeds: Union[List[Any], Tuple[Any, ...], np.ndarray], + seeds: Sequence[int], jitter: bool, jitter_max_retries: int, ) -> List[PointType]: @@ -2342,7 +2338,7 @@ def init_nuts( chains: int = 1, n_init: int = 500_000, model=None, - seeds: Iterable[Any] = None, + seeds: Sequence[Any] = None, progressbar=True, jitter_max_retries: int = 10, tune: Optional[int] = None, @@ -2423,14 +2419,7 @@ def init_nuts( if init == "auto": init = "jitter+adapt_diag" - if seeds is None: - seeds = model.rng_seeder.randint(2**30, dtype=np.int64, size=chains) - if not isinstance(seeds, (list, tuple, np.ndarray)): - raise ValueError(f"The `seeds` must be array-like. Got {type(seeds)} instead.") - if len(seeds) != chains: - raise ValueError( - f"Number of seeds ({len(seeds)}) does not match the number of chains ({chains})." - ) + seeds = _get_seeds_per_chain(seeds, chains) _log.info(f"Initializing NUTS using {init}...") diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index e3d2e1034d9..1da980212c0 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -7,7 +7,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union from pymc.initial_point import StartDict -from pymc.sampling import _init_jitter +from pymc.sampling import _get_seeds_per_chain, _init_jitter xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() @@ -160,14 +160,10 @@ def _get_batched_jittered_initial_points( Each item has shape `(chains, *var.shape)` """ - random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains) - - assert len(random_seed) == chains - initial_points = _init_jitter( model, initvals, - seeds=random_seed, + seeds=_get_seeds_per_chain(random_seed, chains), jitter=jitter, jitter_max_retries=jitter_max_retries, ) @@ -220,7 +216,7 @@ def sample_blackjax_nuts( tune=1000, chains=4, target_accept=0.8, - random_seed=10, + random_seed=None, initvals=None, model=None, var_names=None, @@ -245,7 +241,7 @@ def sample_blackjax_nuts( target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int, default 10 + random_seed : int, optional Random seed used by the sampling steps. model : Model, optional Model to sample from. The model needs to have free random variables. When inside a ``with`` model @@ -292,6 +288,9 @@ def sample_blackjax_nuts( else: dims = {} + if random_seed is None: + random_seed = np.random.randint(2**30, dtype=np.int64) + tic1 = datetime.now() print("Compiling...", file=sys.stdout) @@ -471,7 +470,7 @@ def sample_numpyro_nuts( dims = {} if random_seed is None: - random_seed = model.rng_seeder.randint(2**30, dtype=np.int64) + random_seed = np.random.randint(2**30, dtype=np.int64) tic1 = datetime.now() print("Compiling...", file=sys.stdout) diff --git a/pymc/tests/models.py b/pymc/tests/models.py index 0c1f176a754..8b4d6b319ea 100644 --- a/pymc/tests/models.py +++ b/pymc/tests/models.py @@ -189,7 +189,7 @@ def simple_normal(bounded_prior=False): sigma = 1.0 a, b = (9, 12) # bounds for uniform RV, need non-symmetric to reproduce issue - with pm.Model(rng_seeder=2482) as model: + with pm.Model() as model: if bounded_prior: mu_i = pm.Uniform("mu_i", a, b) else: diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index e5b288a6d3f..291ae649eef 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock import aesara import aesara.tensor as at @@ -487,7 +488,9 @@ def test_compile_pymc_updates_inputs(self): # Each RV adds a shared output for its rng assert len(fn_fgraph.outputs) == 1 + rvs_in_graph - def test_compile_pymc_custom_update_op(self): + # Disable `reseed_rngs` so that we can test with simpler update rule + @mock.patch("pymc.aesaraf.reseed_rngs") + def test_compile_pymc_custom_update_op(self, _): """Test that custom MeasurableVariable Op updates are used by compile_pymc""" class UnmeasurableOp(OpFromGraph): @@ -507,3 +510,30 @@ def update(self, node): fn = compile_pymc(inputs=[], outputs=dummy_x) assert fn() == 2.0 assert fn() == 3.0 + + def test_random_seed(self): + seedx = aesara.shared(np.random.default_rng(1)) + seedy = aesara.shared(np.random.default_rng(1)) + x = at.random.normal(rng=seedx) + y = at.random.normal(rng=seedy) + + # Shared variables are the same, so outputs will be identical + f0 = aesara.function([], [x, y]) + x0_eval, y0_eval = f0() + assert x0_eval == y0_eval + + # The variables will be reseeded with new seeds by default + f1 = compile_pymc([], [x, y]) + x1_eval, y1_eval = f1() + assert x1_eval != y1_eval + + # Check that seeding works + f2 = compile_pymc([], [x, y], random_seed=1) + x2_eval, y2_eval = f2() + assert x2_eval != x1_eval + assert y2_eval != y1_eval + + f3 = compile_pymc([], [x, y], random_seed=1) + x3_eval, y3_eval = f3() + assert x3_eval == x2_eval + assert y3_eval == y2_eval diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index f58b6d16b7c..37df4a42045 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -3222,21 +3222,25 @@ def random(rng, size): def test_distinct_rvs(): """Make sure `RandomVariable`s generated using a `Model`'s default RNG state all have distinct states.""" - with pm.Model(rng_seeder=np.random.RandomState(2023532)) as model: + with pm.Model() as model: X_rv = pm.Normal("x") Y_rv = pm.Normal("y") - pp_samples = pm.sample_prior_predictive(samples=2, return_inferencedata=False) + pp_samples = pm.sample_prior_predictive( + samples=2, return_inferencedata=False, random_seed=np.random.RandomState(2023532) + ) assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0] assert len(model.rng_seq) == 2 - with pm.Model(rng_seeder=np.random.RandomState(2023532)): + with pm.Model(): X_rv = pm.Normal("x") Y_rv = pm.Normal("y") - pp_samples_2 = pm.sample_prior_predictive(samples=2, return_inferencedata=False) + pp_samples_2 = pm.sample_prior_predictive( + samples=2, return_inferencedata=False, random_seed=np.random.RandomState(2023532) + ) assert np.array_equal(pp_samples["y"], pp_samples_2["y"]) @@ -3312,7 +3316,8 @@ def test_censored_workflow(self, censored): data[data <= low] = low data[data >= high] = high - with pm.Model(rng_seeder=17092021) as m: + rng = 17092021 + with pm.Model() as m: mu = pm.Normal( "mu", mu=((high - low) / 2) + low, @@ -3328,9 +3333,9 @@ def test_censored_workflow(self, censored): observed=data, ) - prior_pred = pm.sample_prior_predictive() - posterior = pm.sample(tune=500, draws=500) - posterior_pred = pm.sample_posterior_predictive(posterior) + prior_pred = pm.sample_prior_predictive(random_seed=rng) + posterior = pm.sample(tune=500, draws=500, random_seed=rng) + posterior_pred = pm.sample_posterior_predictive(posterior, random_seed=rng) expected = True if censored else False assert (9 < prior_pred.prior_predictive.mean() < 10) == expected diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 9fd7c2ce812..5da279f85b7 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1682,14 +1682,14 @@ def check_draws(self): def ref_rand(mu, rowcov, colcov): return st.matrix_normal.rvs(mean=mu, rowcov=rowcov, colcov=colcov) - with pm.Model(rng_seeder=1): + with pm.Model(): matrixnormal = pm.MatrixNormal( "matnormal", mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3), ) - check = pm.sample_prior_predictive(n_fails, return_inferencedata=False) + check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=1) ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3)) @@ -2328,10 +2328,10 @@ def test_car_rng_fn(sparse): if sparse: W = aesara.sparse.csr_from_dense(W) - with pm.Model(rng_seeder=1): + with pm.Model(): car = pm.CAR("car", mu, W, alpha, tau, size=size) mn = pm.MvNormal("mn", mu, cov, size=size) - check = pm.sample_prior_predictive(n_fails, return_inferencedata=False) + check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=1) p, f = delta, n_fails while p <= delta and f > 0: diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py index 4562911f958..6b890a3bd6d 100644 --- a/pymc/tests/test_mixture.py +++ b/pymc/tests/test_mixture.py @@ -232,12 +232,8 @@ def test_list_univariate_components_deterministic_weights(self, weights, compone @pytest.mark.parametrize("size", [None, (4,), (5, 4)]) def test_single_multivariate_component_deterministic_weights(self, weights, component, size): # This test needs seeding to avoid repetitions - rngs = [ - aesara.shared(np.random.default_rng(seed)) - for seed in self.get_random_state().randint(2**30, size=2) - ] - mix = Mixture.dist(weights, component, size=size, rngs=rngs) - mix_eval = mix.eval() + mix = Mixture.dist(weights, component, size=size) + mix_eval = draw(mix, random_seed=self.get_random_state()) # Test shape # component shape is either (4, 2, 3), (2, 3) @@ -853,7 +849,7 @@ def test_scalar_components(self): # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] mus = at.constant(np.full((nd, npop), np.arange(npop))) - with Model(rng_seeder=self.get_random_state()) as model: + with Model() as model: m = NormalMixture( "m", w=np.ones(npop) / npop, @@ -867,8 +863,8 @@ def test_scalar_components(self): latent_m = Normal("latent_m", mu=mu, sigma=1e-5, shape=nd) size = 100 - m_val = draw(m, draws=size) - latent_m_val = draw(latent_m, draws=size) + m_val = draw(m, draws=size, rng=self.get_random_state()) + latent_m_val = draw(latent_m, draws=size, rng=self.get_random_state()) assert m_val.shape == latent_m_val.shape # Test that each element in axis = -1 can come from independent @@ -888,7 +884,7 @@ def test_vector_components(self): # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] mus = at.constant(np.full((nd, npop), np.arange(npop))) - with Model(rng_seeder=self.get_random_state()) as model: + with Model() as model: m = Mixture( "m", w=np.ones(npop) / npop, @@ -900,8 +896,8 @@ def test_vector_components(self): latent_m = Normal("latent_m", mu=mus[..., z], sigma=1e-5, shape=nd) size = 100 - m_val = draw(m, draws=size) - latent_m_val = draw(latent_m, draws=size) + m_val = draw(m, draws=size, random_seed=self.get_random_state()) + latent_m_val = draw(latent_m, draws=size, random_seed=self.get_random_state()) assert m_val.shape == latent_m_val.shape # Test that each element in axis = -1 comes from the same mixture # component diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 50b06e24b0b..a84f79c5f98 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -655,7 +655,7 @@ def test_set_initval(): # generating initial values rng = np.random.RandomState(392) - with pm.Model(rng_seeder=rng) as model: + with pm.Model() as model: eta = pm.Uniform("eta", 1.0, 2.0, size=(1, 1)) mu = pm.Normal("mu", sigma=eta, initval=[[100]]) alpha = pm.HalfNormal("alpha", initval=100) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 634ce109867..26e6da81247 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -679,18 +679,18 @@ def test_model_not_drawable_prior(self): assert samples["foo"].shape == (40, 200) def test_model_shared_variable(self): - rng = np.random.RandomState(9832) - x = rng.randn(100) y = x > 0 x_shared = aesara.shared(x) y_shared = aesara.shared(y) - with pm.Model(rng_seeder=rng) as model: + with pm.Model() as model: coeff = pm.Normal("x", mu=0, sigma=1) logistic = pm.Deterministic("p", pm.math.sigmoid(coeff * x_shared)) obs = pm.Bernoulli("obs", p=logistic, observed=y_shared) - trace = pm.sample(100, return_inferencedata=False, compute_convergence_checks=False) + trace = pm.sample( + 100, return_inferencedata=False, compute_convergence_checks=False, random_seed=rng + ) x_shared.set_value([-1, 0, 1.0]) y_shared.set_value([0, 0, 0]) @@ -711,7 +711,7 @@ def test_deterministic_of_observed(self): meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(10)) meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(10)) nchains = 2 - with pm.Model(rng_seeder=rng) as model: + with pm.Model() as model: mu_in_1 = pm.Normal("mu_in_1", 0, 2) sigma_in_1 = pm.HalfNormal("sd_in_1", 1) mu_in_2 = pm.Normal("mu_in_2", 0, 2) @@ -730,6 +730,7 @@ def test_deterministic_of_observed(self): step=pm.Metropolis(), return_inferencedata=False, compute_convergence_checks=False, + random_seed=rng, ) rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-4 @@ -750,7 +751,7 @@ def test_deterministic_of_observed_modified_interface(self): meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(100)) meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(100)) - with pm.Model(rng_seeder=rng) as model: + with pm.Model() as model: mu_in_1 = pm.Normal("mu_in_1", 0, 1, initval=0) sigma_in_1 = pm.HalfNormal("sd_in_1", 1, initval=1) mu_in_2 = pm.Normal("mu_in_2", 0, 1, initval=0) @@ -768,6 +769,7 @@ def test_deterministic_of_observed_modified_interface(self): step=pm.Metropolis(), return_inferencedata=False, compute_convergence_checks=False, + random_seed=rng, ) varnames = [v for v in trace.varnames if v != "out"] ppc_trace = [ @@ -1085,11 +1087,11 @@ def test_multivariate2(self): assert sim_ppc["obs"].shape == (20,) + mn_data.shape def test_layers(self): - with pm.Model(rng_seeder=232093) as model: + with pm.Model() as model: a = pm.Uniform("a", lower=0, upper=1, size=10) b = pm.Binomial("b", n=1, p=a, size=10) - b_sampler = compile_pymc([], b, mode="FAST_RUN") + b_sampler = compile_pymc([], b, mode="FAST_RUN", random_seed=232093) avg = np.stack([b_sampler() for i in range(10000)]).mean(0) npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2) @@ -1190,13 +1192,14 @@ def ub_interval_forward(x, ub): # Interval transform assuming lower bound is zero return np.log(x - 0) - np.log(ub - x) - with pm.Model(rng_seeder=123) as model: + with pm.Model() as model: ub = pm.HalfNormal("ub", 10) x = pm.Uniform("x", 0, ub) prior = pm.sample_prior_predictive( var_names=["ub", "ub_log__", "x", "x_interval__"], samples=10, + random_seed=123, ) # Check values are correct @@ -1207,13 +1210,14 @@ def ub_interval_forward(x, ub): ) # Check that it works when the original RVs are not mentioned in var_names - with pm.Model(rng_seeder=123) as model_transformed_only: + with pm.Model() as model_transformed_only: ub = pm.HalfNormal("ub", 10) x = pm.Uniform("x", 0, ub) prior_transformed_only = pm.sample_prior_predictive( var_names=["ub_log__", "x_interval__"], samples=10, + random_seed=123, ) assert ( "ub" not in prior_transformed_only.prior.data_vars @@ -1230,19 +1234,23 @@ def test_issue_4490(self): # Test that samples do not depend on var_name order or, more fundamentally, # that they do not depend on the set order used inside `sample_prior_predictive` seed = 4490 - with pm.Model(rng_seeder=seed) as m1: + with pm.Model() as m1: a = pm.Normal("a") b = pm.Normal("b") c = pm.Normal("c") d = pm.Normal("d") - prior1 = pm.sample_prior_predictive(samples=1, var_names=["a", "b", "c", "d"]) + prior1 = pm.sample_prior_predictive( + samples=1, var_names=["a", "b", "c", "d"], random_seed=seed + ) - with pm.Model(rng_seeder=seed) as m2: + with pm.Model() as m2: a = pm.Normal("a") b = pm.Normal("b") c = pm.Normal("c") d = pm.Normal("d") - prior2 = pm.sample_prior_predictive(samples=1, var_names=["b", "a", "d", "c"]) + prior2 = pm.sample_prior_predictive( + samples=1, var_names=["b", "a", "d", "c"], random_seed=seed + ) assert prior1.prior["a"] == prior2.prior["a"] assert prior1.prior["b"] == prior2.prior["b"] @@ -1384,19 +1392,21 @@ def test_draw_aesara_function_kwargs(self): def test_step_args(): - with pm.Model(rng_seeder=1410) as model: + with pm.Model() as model: a = pm.Normal("a") - idata0 = pm.sample(target_accept=0.5) - idata1 = pm.sample(nuts={"target_accept": 0.5}) + idata0 = pm.sample(target_accept=0.5, random_seed=1410) + idata1 = pm.sample(nuts={"target_accept": 0.5}, random_seed=1410 * 2) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) - with pm.Model(rng_seeder=1418) as model: + with pm.Model() as model: a = pm.Normal("a") b = pm.Poisson("b", 1) - idata0 = pm.sample(target_accept=0.5) - idata1 = pm.sample(nuts={"target_accept": 0.5}, metropolis={"scaling": 0}) + idata0 = pm.sample(target_accept=0.5, random_seed=1418) + idata1 = pm.sample( + nuts={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2 + ) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 488db8c5e4a..795a5ac5b5b 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -209,11 +209,11 @@ def test_seeding(chains, random_seed): random_seed=random_seed, ) - with pm.Model(rng_seeder=456) as m: + with pm.Model() as m: pm.Normal("x", mu=0, sigma=1) result1 = sample_numpyro_nuts(**sample_kwargs) - with pm.Model(rng_seeder=456) as m: + with pm.Model() as m: pm.Normal("x", mu=0, sigma=1) result2 = sample_numpyro_nuts(**sample_kwargs) result3 = sample_numpyro_nuts(**sample_kwargs) diff --git a/pymc/tests/test_step.py b/pymc/tests/test_step.py index 0a9b923017b..7bb12d9b634 100644 --- a/pymc/tests/test_step.py +++ b/pymc/tests/test_step.py @@ -1183,7 +1183,7 @@ def perform(self, node, inputs, outputs): rng = np.random.RandomState(seed) - with Model(rng_seeder=rng) as coarse_model_0: + with Model() as coarse_model_0: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) else: @@ -1202,7 +1202,7 @@ def perform(self, node, inputs, outputs): rng = np.random.RandomState(seed) - with Model(rng_seeder=rng) as coarse_model_1: + with Model() as coarse_model_1: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) else: @@ -1221,7 +1221,7 @@ def perform(self, node, inputs, outputs): rng = np.random.RandomState(seed) - with Model(rng_seeder=rng) as model: + with Model() as model: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) else: diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index f2f94951ddd..ccca3aa2f7c 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -107,8 +107,6 @@ def find_MAP( return_transformed=True, overrides=start, ) - if seed is None: - seed = model.rng_seeder.randint(2**30, dtype=np.int64) start = ipfn(seed) model.check_start_vals(start) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0b6d6dddc65..d2839d3dfe3 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -867,7 +867,7 @@ def _prepare_start(self, start=None): jitter_rvs={}, return_transformed=True, ) - start = ipfn(self.model.rng_seeder.randint(2**30, dtype=np.int64)) + start = ipfn(np.random.randint(2**30, dtype=np.int64)) group_vars = {self.model.rvs_to_values[v].name for v in self.group} start = {k: v for k, v in start.items() if k in group_vars} if self.batched: