Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: as_op not pickled, making parallel SMC crash #7078

Open
jucor opened this issue Dec 27, 2023 · 11 comments
Open

BUG: as_op not pickled, making parallel SMC crash #7078

jucor opened this issue Dec 27, 2023 · 11 comments
Labels

Comments

@jucor
Copy link
Contributor

jucor commented Dec 27, 2023

Describe the issue:

As it stands, SMC sampler cannot be parallelized with custom ops.

When using SMC sampler with more than one core (i.e. parallel sampling) and an as_op custom op, the op is not pickled properly in the "manual" pickling at

# "manually" (de)serialize params before/after multiprocessing
, thus causing the run to fail.

Reproduceable code example:

import pymc as pm
import pytensor.tensor as pt

from pytensor.compile.ops import as_op

@as_op(itypes=[pt.dvector], otypes=[pt.dvector])
def twice(x):
    return 2*x

with pm.Model() as model:
    x = pm.Normal('x', mu=[0, 0], sigma=1)
    y = twice(x)
    z = pm.Normal(name='z', mu=y, observed=[1, 1])

    # Using cores=1 would work, but cores=2 throws an error
    trace = pm.sample_smc(10,cores=2)

Error message:

<details>
{
	"name": "AttributeError",
	"message": "module '__main__' has no attribute 'twice'",
	"stack": "---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
\"\"\"
Traceback (most recent call last):
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 419, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 320, in _sample_smc_int
    (draws, kernel, start, model) = map(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/ops.py\", line 221, in load_back
    obj = getattr(module, name)
          ^^^^^^^^^^^^^^^^^^^^^
AttributeError: module '__main__' has no attribute 'twice'
\"\"\"

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
Cell In[14], line 2
      1 with model:
----> 2     trace = pm.sample_smc(10,cores=2)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:213, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
    210 t1 = time.time()
    212 if cores > 1:
--> 213     results = run_chains_parallel(
    214         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
    215     )
    216 else:
    217     results = run_chains_sequential(
    218         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
    219     )

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:388, in run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores)
    386 params = tuple(cloudpickle.dumps(p) for p in params)
    387 kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
--> 388 results = _starmap_with_kwargs(
    389     pool,
    390     to_run,
    391     [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
    392     repeat(kernel_kwargs),
    393 )
    394 results = tuple(cloudpickle.loads(r) for r in results)
    395 pool.close()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:415, in _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter)
    411 def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
    412     # Helper function to allow kwargs with Pool.starmap
    413     # Copied from https://stackoverflow.com/a/53173433/13311693
    414     args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
--> 415     return pool.starmap(_apply_args_and_kwargs, args_for_starmap)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:375, in Pool.starmap(self, func, iterable, chunksize)
    369 def starmap(self, func, iterable, chunksize=None):
    370     '''
    371     Like `map()` method but the elements of the `iterable` are expected to
    372     be iterables as well and will be unpacked as arguments. Hence
    373     `func` and (a, b) becomes func(a, b).
    374     '''
--> 375     return self._map_async(func, iterable, starmapstar, chunksize).get()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:774, in ApplyResult.get(self, timeout)
    772     return self._value
    773 else:
--> 774     raise self._value

AttributeError: module '__main__' has no attribute 'twice'"
}
</details>

PyMC version information:

pymc: 5.10.3
pytensor: 2.18.4
python: 3.11.7

Installed in a fresh conda environment with
conda create -c conda-forge -n pymc_env "pymc>=5"

Context for the issue:

As it stands, SMC sampler cannot run the official PyMC example from https://www.pymc.io/projects/examples/en/latest/ode_models/ODE_Lotka_Volterra_multiple_ways.html
Any simple ODE where sunode is overkill will crash similarly, as it requires a custom op, that is not pickled.

The workaround of using a single core makes the method much slower than needed.

Is there a way to serialize the custom operation please?

@jucor jucor added the bug label Dec 27, 2023
Copy link

welcome bot commented Dec 27, 2023

Welcome Banner
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@jucor
Copy link
Contributor Author

jucor commented Dec 28, 2023

@aloctavodia Given all your work on SMC and its parallelization (in particular #3981), would you have any idea what's going on, please, and how to add those ops to what's being pickled, please? Thanks a lot for any idea :)

@ricardoV94
Copy link
Member

May want to try and define the Op in a python script instead of at runtime

@jucor
Copy link
Contributor Author

jucor commented Dec 28, 2023 via email

@jucor
Copy link
Contributor Author

jucor commented Dec 28, 2023

@ricardoV94 That workaround works!! 🎉 Awesome, that'll be perfect until a longer-time fix works :)

Now the smc sampler is hitting another issue lower down, which could also be related to pickling but seems linked to the progress bar, complaining about HTML not existing. Any idea if an extra import somewhere could help?

{
	"name": "NameError",
	"message": "name 'HTML' is not defined",
	"stack": "---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
\"\"\"
Traceback (most recent call last):
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 419, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 342, in _sample_smc_int
    progressbar.update_bar(getattr(progressbar, \"offset\", 0) + 0)
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/fastprogress/fastprogress.py\", line 81, in update_bar
    self.on_update(val, f'{pct}[{val}/{tot} {elapsed_t}{self.lt}{remaining_t}{end}]')
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/fastprogress/fastprogress.py\", line 133, in on_update
    if self.display: self.out.update(HTML(self.progress))
                                     ^^^^
NameError: name 'HTML' is not defined
\"\"\"

The above exception was the direct cause of the following exception:

NameError                                 Traceback (most recent call last)
Cell In[421], line 4
      2 draws = 2000
      3 with model:
----> 4     trace_SMC_like = pm.sample_smc(draws,cores=5)
      5 trace = trace_SMC_like
      6 az.summary(trace)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:213, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
    210 t1 = time.time()
    212 if cores > 1:
--> 213     results = run_chains_parallel(
    214         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
    215     )
    216 else:
    217     results = run_chains_sequential(
    218         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
    219     )

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:388, in run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores)
    386 params = tuple(cloudpickle.dumps(p) for p in params)
    387 kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
--> 388 results = _starmap_with_kwargs(
    389     pool,
    390     to_run,
    391     [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
    392     repeat(kernel_kwargs),
    393 )
    394 results = tuple(cloudpickle.loads(r) for r in results)
    395 pool.close()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:415, in _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter)
    411 def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
    412     # Helper function to allow kwargs with Pool.starmap
    413     # Copied from https://stackoverflow.com/a/53173433/13311693
    414     args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
--> 415     return pool.starmap(_apply_args_and_kwargs, args_for_starmap)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:375, in Pool.starmap(self, func, iterable, chunksize)
    369 def starmap(self, func, iterable, chunksize=None):
    370     '''
    371     Like `map()` method but the elements of the `iterable` are expected to
    372     be iterables as well and will be unpacked as arguments. Hence
    373     `func` and (a, b) becomes func(a, b).
    374     '''
--> 375     return self._map_async(func, iterable, starmapstar, chunksize).get()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:774, in ApplyResult.get(self, timeout)
    772     return self._value
    773 else:
--> 774     raise self._value

NameError: name 'HTML' is not defined"
}

@jucor
Copy link
Contributor Author

jucor commented Dec 28, 2023

Dang, the latter seems to be related to AnswerDotAI/fastprogress#32 and #5855 and #5980 , none of which seems to have had an actual resolution :/

@jucor
Copy link
Contributor Author

jucor commented Dec 28, 2023

A really ugly workaround is to call pm.sample_smc(..., progressbar=False) , which does not try to render the progressbar in the notebook and thus skips the error. But that means the user is flying completely blind while the sampler is running, which is not ideal.

@jucor
Copy link
Contributor Author

jucor commented Dec 28, 2023

I confirm that the problem with fastprogress only occurs with cores > 1, so it's definitely tied to the parallelism.

fastprogress works also fine standalone in a notebook.

Inspecting its code in https://github.com/fastai/fastprogress/blob/master/fastprogress/fastprogress.py#L104 confirms that it checks its import of HTML to make sure the widget works. So the way we serialize/unserialize, or parallelize, must screw it up somehow.

@jucor
Copy link
Contributor Author

jucor commented Dec 28, 2023

The way I understand it, both problems come down to the fact that the "manual" serialization is missing some symbols: the local op in one case, or the HTML object imported by fastprogress when it is itself imported in the other.

I'm not strong enough about closures and namespaces in Python to pinpoint exactly how to spot these missing symbols, capture them, and reserialize them, but I would bet good money it should be done at this point in the SMC sampler code:

# "manually" (de)serialize params before/after multiprocessing

I'd be happy to work with anyone here to enrich this serialization!

@jucor
Copy link
Contributor Author

jucor commented Dec 29, 2023

OK, progress bar error figured out: it's not quite due to pickling, it's due to how fastprogress autodetects that it is run in a notebook and conditionally imports HTML or not.
Upon instantiation, the fastprogress module is loaded from a function called from a notebook, thus imports the object HTML. However, when the job runs and tries to update the progress bar, it is in a separate process is not a notebook, thus the module has imported HTML.
I have tried forcing the import, but the link to the notebook is broken and the pretty HTML bar is not updated.
However, I have a fix for that issue: using the non-HTML progress bars, in the multicore setup. I'll make a PR for that.

So that'll be one of two problems sorted :)

jucor added a commit to jucor/pymc that referenced this issue Dec 29, 2023
@jucor
Copy link
Contributor Author

jucor commented Dec 29, 2023

PR opened to fix the progress bar bug :)
That doesn't help with the requirement to put the as_op to a separate file, but at least it's one thing cleaner :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants