diff --git a/pymc/sampling.py b/pymc/sampling.py index 90b78fbf99a..9cb058a15a2 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -2531,7 +2531,9 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - approx_sample = approx.sample(draws=chains, return_inferencedata=False) + approx_sample = approx.sample( + draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + ) initial_points = [approx_sample[i] for i in range(chains)] std_apoint = approx.std.eval() cov = std_apoint**2 @@ -2549,7 +2551,9 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - approx_sample = approx.sample(draws=chains, return_inferencedata=False) + approx_sample = approx.sample( + draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) @@ -2564,7 +2568,9 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - approx_sample = approx.sample(draws=chains, return_inferencedata=False) + approx_sample = approx.sample( + draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 93eb573a723..cdc2d05192b 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -100,11 +100,6 @@ def test_random_seed(self, chains, seeds, cores, init): allequal = np.all(tr1["x"] == tr2["x"]) if seeds is None: assert not allequal - # TODO: ADVI init methods are not correctly seeded, as they rely on the state of - # the model RandomState/Generators which is updated in place when the function - # is compiled and evaluated. This elif branch must be removed once this is fixed - elif init == "advi": - assert not allequal else: assert allequal diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index d0c89b8cf58..eae2c6a8490 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -57,11 +57,20 @@ import pymc as pm -from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars +from pymc.aesaraf import ( + SeedSequenceSeed, + at_rng, + compile_pymc, + find_rng_nodes, + identity, + reseed_rngs, + rvs_to_value_vars, +) from pymc.backends import NDArray from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext +from pymc.sampling import RandomState, _get_seeds_per_chain from pymc.util import WithMemoization, locally_cachedmethod from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types @@ -1641,22 +1650,30 @@ def sample_dict_fn(self): sampled = [self.rslice(name) for name in names] sampled = self.set_size_and_deterministic(sampled, s, 0) sample_fn = compile_pymc([s], sampled) + rng_nodes = find_rng_nodes(sampled) - def inner(draws=100): + def inner(draws=100, *, random_seed: SeedSequenceSeed = None): + if random_seed is not None: + reseed_rngs(rng_nodes, random_seed) _samples = sample_fn(draws) + return {v_: s_ for v_, s_ in zip(names, _samples)} return inner - def sample(self, draws=500, return_inferencedata=True, **kwargs): + def sample( + self, draws=500, *, random_seed: RandomState = None, return_inferencedata=True, **kwargs + ): """Draw samples from variational posterior. Parameters ---------- - draws: `int` + draws : int Number of random samples. - return_inferencedata: `bool` - Return trace in Arviz format + random_seed : int, RandomState or Generator, optional + Seed for the random number generator. + return_inferencedata : bool + Return trace in Arviz format. Returns ------- @@ -1666,7 +1683,9 @@ def sample(self, draws=500, return_inferencedata=True, **kwargs): # TODO: add tests for include_transformed case kwargs["log_likelihood"] = False - samples = self.sample_dict_fn(draws) # type: dict + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + samples = self.sample_dict_fn(draws, random_seed=random_seed) # type: dict points = ({name: records[i] for name, records in samples.items()} for i in range(draws)) trace = NDArray(