Skip to content

Commit

Permalink
Update expected behavior of sample in relation to global seeding.
Browse files Browse the repository at this point in the history
Together, `test_sample_does_not_set_seed` and `test_parallel_sample_does_not_reuse_seed` covered two unspoken behaviors of `sample`:
1. When no seed is specified, PyMC shall not set global seed state of numpy in the main process.
2. When no seed is specified, sampling will depend on numpy global seeding state for reproducible behavior.

Point 1 is due to PyMC legacy dependency on global seeding for step samplers. It tries to minimize "damage" by only setting global seeds when it absolutely needs to, in order to ensure deterministic sampling. Ideally calls to `numpy.seed` would never be made.

Point 2 goes against NumPy current best practices of using None when defining new Generators / SeedSequences (https://numpy.org/doc/stable/reference/random/bit_generators/generated/numpy.random.SeedSequence.html#numpy.random.SeedSequence)

The refactored tests cover point 1 more directly, and assert the opposite of point 2.
  • Loading branch information
ricardoV94 committed May 23, 2022
1 parent 6956d9b commit 301912e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 29 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Also check out the [milestones](https://github.com/pymc-devs/pymc/milestones) fo

All of the above apply to:

⚠ Sampling functions no longer respect user-specified global seeding! Always pass `random_seed` to ensure reproducible behavior (see [#5787](https://github.com/pymc-devs/pymc/pull/5787)),
Signature and default parameters changed for several distributions:
- `pm.StudentT` now requires either `sigma` or `lam` as kwarg (see [#5628](https://github.com/pymc-devs/pymc/pull/5628))
- `pm.StudentT` now requires `nu` to be specified (no longer defaults to 1) (see [#5628](https://github.com/pymc-devs/pymc/pull/5628))
Expand Down
66 changes: 37 additions & 29 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import unittest.mock as mock

from contextlib import ExitStack as does_not_raise
from itertools import combinations
from typing import Tuple

import aesara
Expand Down Expand Up @@ -108,35 +107,44 @@ def test_random_seed(self, chains, seeds, cores, init):
else:
assert allequal

def test_sample_does_not_set_seed(self):
# This tests that when random_seed is None, the global seed is not affected
random_numbers = []
for _ in range(2):
@mock.patch("numpy.random.seed")
def test_default_sample_does_not_set_global_seed(self, mocked_seed):
# Test that when random_seed is None, `np.random.seed` is not called in the main
# process. Ideally it would never be called, but PyMC step samplers still rely
# on global seeding for reproducible behavior.
kwargs = dict(tune=2, draws=2, random_seed=None)
with self.model:
pm.sample(chains=1, **kwargs)
pm.sample(chains=2, cores=1, **kwargs)
pm.sample(chains=2, cores=2, **kwargs)
mocked_seed.assert_not_called()

def test_sample_does_not_rely_on_external_global_seeding(self):
# Tests that sampling does not depend on exertenal global seeding
kwargs = dict(
tune=2,
draws=20,
random_seed=None,
return_inferencedata=False,
)
with self.model:
np.random.seed(1)
idata11 = pm.sample(chains=1, **kwargs)
np.random.seed(1)
idata12 = pm.sample(chains=2, cores=1, **kwargs)
np.random.seed(1)
with self.model:
pm.sample(1, tune=0, chains=1, random_seed=None)
random_numbers.append(np.random.random())
assert random_numbers[0] == random_numbers[1]

def test_parallel_sample_does_not_reuse_seed(self):
cores = 4
random_numbers = []
draws = []
for _ in range(2):
np.random.seed(1) # seeds in other processes don't effect main process
with self.model:
idata = pm.sample(100, tune=0, cores=cores)
# numpy thread mentioned race condition. might as well check none are equal
for first, second in combinations(range(cores), 2):
first_chain = idata.posterior["x"].sel(chain=first).values
second_chain = idata.posterior["x"].sel(chain=second).values
assert not np.allclose(first_chain, second_chain)
draws.append(idata.posterior["x"].values)
random_numbers.append(np.random.random())

# Make sure future random processes aren't effected by this
assert random_numbers[0] == random_numbers[1]
assert (draws[0] == draws[1]).all()
idata13 = pm.sample(chains=2, cores=2, **kwargs)

np.random.seed(1)
idata21 = pm.sample(chains=1, **kwargs)
np.random.seed(1)
idata22 = pm.sample(chains=2, cores=1, **kwargs)
np.random.seed(1)
idata23 = pm.sample(chains=2, cores=2, **kwargs)

assert np.all(idata11["x"] != idata21["x"])
assert np.all(idata12["x"] != idata22["x"])
assert np.all(idata13["x"] != idata23["x"])

def test_sample(self):
test_cores = [1]
Expand Down

0 comments on commit 301912e

Please sign in to comment.