Skip to content

Commit

Permalink
Fix non-deterministic NUTS initialization when using ADVI
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 24, 2022
1 parent cee2b30 commit 15431d9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 15 deletions.
12 changes: 9 additions & 3 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2531,7 +2531,9 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
)
initial_points = [approx_sample[i] for i in range(chains)]
std_apoint = approx.std.eval()
cov = std_apoint**2
Expand All @@ -2549,7 +2551,9 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
Expand All @@ -2564,7 +2568,9 @@ def init_nuts(
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
)
initial_points = [approx_sample[i] for i in range(chains)]
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
Expand Down
5 changes: 0 additions & 5 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ def test_random_seed(self, chains, seeds, cores, init):
allequal = np.all(tr1["x"] == tr2["x"])
if seeds is None:
assert not allequal
# TODO: ADVI init methods are not correctly seeded, as they rely on the state of
# the model RandomState/Generators which is updated in place when the function
# is compiled and evaluated. This elif branch must be removed once this is fixed
elif init == "advi":
assert not allequal
else:
assert allequal

Expand Down
33 changes: 26 additions & 7 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,20 @@

import pymc as pm

from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars
from pymc.aesaraf import (
SeedSequenceSeed,
at_rng,
compile_pymc,
find_rng_nodes,
identity,
reseed_rngs,
rvs_to_value_vars,
)
from pymc.backends import NDArray
from pymc.blocking import DictToArrayBijection
from pymc.initial_point import make_initial_point_fn
from pymc.model import modelcontext
from pymc.sampling import RandomState, _get_seeds_per_chain
from pymc.util import WithMemoization, locally_cachedmethod
from pymc.variational.updates import adagrad_window
from pymc.vartypes import discrete_types
Expand Down Expand Up @@ -1641,22 +1650,30 @@ def sample_dict_fn(self):
sampled = [self.rslice(name) for name in names]
sampled = self.set_size_and_deterministic(sampled, s, 0)
sample_fn = compile_pymc([s], sampled)
rng_nodes = find_rng_nodes(sampled)

def inner(draws=100):
def inner(draws=100, *, random_seed: SeedSequenceSeed = None):
if random_seed is not None:
reseed_rngs(rng_nodes, random_seed)
_samples = sample_fn(draws)

return {v_: s_ for v_, s_ in zip(names, _samples)}

return inner

def sample(self, draws=500, return_inferencedata=True, **kwargs):
def sample(
self, draws=500, *, random_seed: RandomState = None, return_inferencedata=True, **kwargs
):
"""Draw samples from variational posterior.
Parameters
----------
draws: `int`
draws : int
Number of random samples.
return_inferencedata: `bool`
Return trace in Arviz format
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.
return_inferencedata : bool
Return trace in Arviz format.
Returns
-------
Expand All @@ -1666,7 +1683,9 @@ def sample(self, draws=500, return_inferencedata=True, **kwargs):
# TODO: add tests for include_transformed case
kwargs["log_likelihood"] = False

samples = self.sample_dict_fn(draws) # type: dict
if random_seed is not None:
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
samples = self.sample_dict_fn(draws, random_seed=random_seed) # type: dict
points = ({name: records[i] for name, records in samples.items()} for i in range(draws))

trace = NDArray(
Expand Down

0 comments on commit 15431d9

Please sign in to comment.