Skip to content

Commit

Permalink
Repair parallelized population sampling (#3559)
Browse files Browse the repository at this point in the history
* test with and without parallelization across cores
demonstrates issue #3555

* replace parallelize kwarg by reliance on cores setting
closes #3555

* add the changes from pull 3559

* use more general suggestion in the log message

Co-Authored-By: Colin <[email protected]>
  • Loading branch information
2 people authored and junpenglao committed Jul 26, 2019
1 parent c0edddd commit 21eb865
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

### Maintenance
- Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508)
- Parallelization of population steppers (`DEMetropolis`) is now set via the `cores` argument. ([#3559](https://github.com/pymc-devs/pymc3/pull/3559))

## PyMC3 3.7 (May 29 2019)

Expand Down
6 changes: 3 additions & 3 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
if has_population_samplers:
_log.info('Population sampling ({} chains)'.format(chains))
_print_step_hierarchy(step)
trace = _sample_population(**sample_args)
trace = _sample_population(**sample_args, parallelize=cores > 1)
else:
_log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
_print_step_hierarchy(step)
Expand Down Expand Up @@ -690,7 +690,7 @@ def __init__(self, steppers, parallelize):
if parallelize:
try:
# configure a child process for each stepper
_log.info('Attempting to parallelize chains.')
_log.info('Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`.')
import multiprocessing
for c, stepper in enumerate(tqdm(steppers)):
slave_end, master_end = multiprocessing.Pipe()
Expand All @@ -715,7 +715,7 @@ def __init__(self, steppers, parallelize):
_log.debug('Error was: ', exec_info=True)
else:
_log.info('Chains are not parallelized. You can enable this by passing '
'pm.sample(parallelize=True).')
'`pm.sample(cores=n)`, where n > 1.')
return super().__init__()

def __enter__(self):
Expand Down
15 changes: 14 additions & 1 deletion pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,12 +915,25 @@ def test_checks_population_size(self):
trace = sample(draws=100, chains=4, step=step)
pass

def test_nonparallelized_chains_are_random(self):
with Model() as model:
x = Normal("x", 0, 1)
for stepper in TestPopulationSamplers.steppers:
step = stepper()
trace = sample(chains=4, cores=1, draws=20, tune=0, step=DEMetropolis())
samples = np.array(trace.get_values("x", combine=False))[:, 5]

assert len(set(samples)) == 4, "Parallelized {} " "chains are identical.".format(
stepper
)
pass

def test_parallelized_chains_are_random(self):
with Model() as model:
x = Normal("x", 0, 1)
for stepper in TestPopulationSamplers.steppers:
step = stepper()
trace = sample(chains=4, draws=20, tune=0, step=DEMetropolis())
trace = sample(chains=4, cores=4, draws=20, tune=0, step=DEMetropolis())
samples = np.array(trace.get_values("x", combine=False))[:, 5]

assert len(set(samples)) == 4, "Parallelized {} " "chains are identical.".format(
Expand Down

0 comments on commit 21eb865

Please sign in to comment.