From f73e933ace515c46da3dfe0f45e790874fdc8c2b Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Fri, 24 Sep 2021 00:27:48 +0200 Subject: [PATCH] Remove support for continuation of traces --- pymc3/sampling.py | 82 ++++++++++++++++-------------------- pymc3/tests/test_sampling.py | 21 ++++++++- 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index c595655e516..31931d167b1 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -252,9 +252,9 @@ def sample( draws=1000, step=None, init="auto", - n_init=200000, + n_init=200_000, start=None, - trace=None, + trace: Optional[Union[BaseTrace, List[str]]] = None, chain_idx=0, chains=None, cores=None, @@ -296,10 +296,9 @@ def sample( Defaults to ``trace.point(-1))`` if there is a trace provided and model.initial_point if not (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can overwrite the default. - trace : backend, list, or MultiTrace - This should be a backend instance, a list of variables to track, or a MultiTrace object - with past values. If a MultiTrace object is given, it must contain samples for the chain - number ``chain``. If None or a list of variables, the NDArray backend is used. + trace : backend or list + This should be a backend instance, or a list of variables to track. + If None or a list of variables, the NDArray backend is used. chain_idx : int Chain number used to store sample in backend. If ``chains`` is greater than one, chain numbers will start here. @@ -813,7 +812,7 @@ def _sample( start, draws: int, step=None, - trace=None, + trace: Optional[Union[BaseTrace, List[str]]] = None, tune=None, model: Optional[Model] = None, callback=None, @@ -839,10 +838,9 @@ def _sample( The number of samples to draw step : function Step function - trace : backend, list, or MultiTrace - This should be a backend instance, a list of variables to track, or a MultiTrace object - with past values. If a MultiTrace object is given, it must contain samples for the chain - number ``chain``. If None or a list of variables, the NDArray backend is used. + trace : backend or list + This should be a backend instance, or a list of variables to track. + If None or a list of variables, the NDArray backend is used. tune : int, optional Number of iterations to tune, if applicable (defaults to None) model : Model (optional if in ``with`` context) @@ -899,10 +897,9 @@ def iter_sample( start : dict Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if there is a trace provided and model.initial_point if not (defaults to empty dict) - trace : backend, list, or MultiTrace - This should be a backend instance, a list of variables to track, or a MultiTrace object - with past values. If a MultiTrace object is given, it must contain samples for the chain - number ``chain``. If None or a list of variables, the NDArray backend is used. + trace : backend or list + This should be a backend instance, or a list of variables to track. + If None or a list of variables, the NDArray backend is used. chain : int, optional Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers will start here. @@ -939,7 +936,7 @@ def _iter_sample( draws, step, start=None, - trace=None, + trace: Optional[Union[BaseTrace, List[str]]] = None, chain=0, tune=None, model=None, @@ -955,12 +952,10 @@ def _iter_sample( step : function Step function start : dict, optional - Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if - there is a trace provided and model.initial_point if not (defaults to empty dict) - trace : backend, list, MultiTrace, or None - This should be a backend instance, a list of variables to track, or a MultiTrace object - with past values. If a MultiTrace object is given, it must contain samples for the chain - number ``chain``. If None or a list of variables, the NDArray backend is used. + Starting point in parameter space (or partial point). Defaults to model.initial_point if not (defaults to empty dict) + trace : backend or list + This should be a backend instance, or a list of variables to track. + If None or a list of variables, the NDArray backend is used. chain : int, optional Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers will start here. @@ -986,12 +981,9 @@ def _iter_sample( if start is None: start = {} - strace = _choose_backend(trace, chain, model=model) + strace = _choose_backend(trace, model=model) - if len(strace) > 0: - model.update_start_vals(start, strace.point(-1)) - else: - model.update_start_vals(start, model.initial_point) + model.update_start_vals(start, model.initial_point) try: step = CompoundStep(step) @@ -1258,7 +1250,7 @@ def _prepare_iter_population( # 5. a PopulationStepper is configured for parallelized stepping # 1. prepare a BaseTrace for each chain - traces = [_choose_backend(None, chain, model=model) for chain in chains] + traces = [_choose_backend(None, model=model) for chain in chains] for c, strace in enumerate(traces): # initialize the trace size and variable transforms if len(strace) > 0: @@ -1361,30 +1353,29 @@ def _iter_population(draws, tune, popstep, steppers, traces, points): steppers[c].report._finalize(strace) -def _choose_backend(trace, chain, **kwds) -> Backend: +def _choose_backend(trace: Optional[Union[BaseTrace, List[str]]], **kwds) -> Backend: """Selects or creates a NDArray trace backend for a particular chain. Parameters ---------- - trace : BaseTrace, list, MultiTrace, or None - This should be a BaseTrace, list of variables to track, - or a MultiTrace object with past values. - If a MultiTrace object is given, it must contain samples for the chain number ``chain``. + trace : BaseTrace, list, or None + This should be a BaseTrace, or list of variables to track. If None or a list of variables, the NDArray backend is used. - chain : int - Number of the chain of interest. **kwds : keyword arguments to forward to the backend creation Returns ------- trace : BaseTrace - A trace object for the selected chain + The incoming, or a brand new trace object. """ + if isinstance(trace, BaseTrace) and len(trace) > 0: + raise ValueError("Continuation of traces is no longer supported.") + if isinstance(trace, MultiTrace): + raise ValueError("Starting from existing MultiTrace objects is no longer supported.") + if isinstance(trace, BaseTrace): return trace - if isinstance(trace, MultiTrace): - return trace._straces[chain] if trace is None: return NDArray(**kwds) @@ -1401,7 +1392,7 @@ def _mp_sample( random_seed: list, start: list, progressbar=True, - trace=None, + trace: Optional[Union[BaseTrace, List[str]]] = None, model=None, callback=None, discard_tuned_samples=True, @@ -1430,10 +1421,9 @@ def _mp_sample( Starting points for each chain. progressbar : bool Whether or not to display a progress bar in the command line. - trace : BaseTrace, list, MultiTrace or None - This should be a backend instance, a list of variables to track, or a MultiTrace object - with past values. If a MultiTrace object is given, it must contain samples for the chain - number ``chain``. If None or a list of variables, the NDArray backend is used. + trace : BaseTrace, list, or None + This should be a backend instance, or a list of variables to track + If None or a list of variables, the NDArray backend is used. model : Model (optional if in ``with`` context) callback : Callable A function which gets called for every sample from the trace of a chain. The function is @@ -1455,10 +1445,10 @@ def _mp_sample( traces = [] for idx in range(chain, chain + chains): if trace is not None: - strace = _choose_backend(copy(trace), idx, model=model) + strace = _choose_backend(copy(trace), model=model) else: - strace = _choose_backend(None, idx, model=model) - # for user supply start value, fill-in missing value if the supplied + strace = _choose_backend(None, model=model) + # for user supplied start value, fill-in missing value if the supplied # dict does not contain all parameters model.update_start_vals(start[idx - chain], model.initial_point) if step.generates_stats and strace.supports_sampler_stats: diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 04767ae9b68..f6aa77aaef4 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -32,6 +32,7 @@ import pymc3 as pm from pymc3.aesaraf import compile_rv_inplace +from pymc3.backends.base import MultiTrace from pymc3.backends.ndarray import NDArray from pymc3.exceptions import IncorrectArgumentsError, SamplingError from pymc3.tests.helpers import SeededTest @@ -438,14 +439,30 @@ def test_constant_named(self): class TestChooseBackend: def test_choose_backend_none(self): with mock.patch("pymc3.sampling.NDArray") as nd: - pm.sampling._choose_backend(None, "chain") + pm.sampling._choose_backend(None) assert nd.called def test_choose_backend_list_of_variables(self): with mock.patch("pymc3.sampling.NDArray") as nd: - pm.sampling._choose_backend(["var1", "var2"], "chain") + pm.sampling._choose_backend(["var1", "var2"]) nd.assert_called_with(vars=["var1", "var2"]) + def test_errors_and_warnings(self): + with pm.Model(): + A = pm.Normal("A") + B = pm.Uniform("B") + strace = pm.sampling.NDArray(vars=[A, B]) + strace.setup(10, 0) + + with pytest.raises(ValueError, match="from existing MultiTrace"): + pm.sampling._choose_backend(trace=MultiTrace([strace])) + + strace.record({"A": 2, "B_interval__": 0.1}) + assert len(strace) == 1 + with pytest.raises(ValueError, match="Continuation of traces"): + pm.sampling._choose_backend(trace=strace) + pass + class TestSamplePPC(SeededTest): def test_normal_scalar(self):