Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove support for continuation of traces #5019

Merged
merged 2 commits into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 36 additions & 46 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def sample(
draws=1000,
step=None,
init="auto",
n_init=200000,
n_init=200_000,
start=None,
trace=None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
chain_idx=0,
chains=None,
cores=None,
Expand Down Expand Up @@ -296,10 +296,9 @@ def sample(
Defaults to ``trace.point(-1))`` if there is a trace provided and model.initial_point if not
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
overwrite the default.
trace : backend, list, or MultiTrace
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain_idx : int
Chain number used to store sample in backend. If ``chains`` is greater than one, chain
numbers will start here.
Expand Down Expand Up @@ -813,7 +812,7 @@ def _sample(
start,
draws: int,
step=None,
trace=None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
tune=None,
model: Optional[Model] = None,
callback=None,
Expand All @@ -839,10 +838,9 @@ def _sample(
The number of samples to draw
step : function
Step function
trace : backend, list, or MultiTrace
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
tune : int, optional
Number of iterations to tune, if applicable (defaults to None)
model : Model (optional if in ``with`` context)
Expand Down Expand Up @@ -899,10 +897,9 @@ def iter_sample(
start : dict
Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
there is a trace provided and model.initial_point if not (defaults to empty dict)
trace : backend, list, or MultiTrace
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain : int, optional
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
will start here.
Expand Down Expand Up @@ -939,7 +936,7 @@ def _iter_sample(
draws,
step,
start=None,
trace=None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
chain=0,
tune=None,
model=None,
Expand All @@ -955,12 +952,10 @@ def _iter_sample(
step : function
Step function
start : dict, optional
Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
there is a trace provided and model.initial_point if not (defaults to empty dict)
trace : backend, list, MultiTrace, or None
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
Starting point in parameter space (or partial point). Defaults to model.initial_point if not (defaults to empty dict)
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain : int, optional
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
will start here.
Expand All @@ -986,12 +981,9 @@ def _iter_sample(
if start is None:
start = {}

strace = _choose_backend(trace, chain, model=model)
strace = _choose_backend(trace, model=model)

if len(strace) > 0:
model.update_start_vals(start, strace.point(-1))
else:
model.update_start_vals(start, model.initial_point)
model.update_start_vals(start, model.initial_point)

try:
step = CompoundStep(step)
Expand Down Expand Up @@ -1258,7 +1250,7 @@ def _prepare_iter_population(
# 5. a PopulationStepper is configured for parallelized stepping

# 1. prepare a BaseTrace for each chain
traces = [_choose_backend(None, chain, model=model) for chain in chains]
traces = [_choose_backend(None, model=model) for chain in chains]
for c, strace in enumerate(traces):
# initialize the trace size and variable transforms
if len(strace) > 0:
Expand Down Expand Up @@ -1361,30 +1353,29 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
steppers[c].report._finalize(strace)


def _choose_backend(trace, chain, **kwds) -> Backend:
def _choose_backend(trace: Optional[Union[BaseTrace, List[str]]], **kwds) -> Backend:
"""Selects or creates a NDArray trace backend for a particular chain.

Parameters
----------
trace : BaseTrace, list, MultiTrace, or None
This should be a BaseTrace, list of variables to track,
or a MultiTrace object with past values.
If a MultiTrace object is given, it must contain samples for the chain number ``chain``.
trace : BaseTrace, list, or None
This should be a BaseTrace, or list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain : int
Number of the chain of interest.
**kwds :
keyword arguments to forward to the backend creation

Returns
-------
trace : BaseTrace
A trace object for the selected chain
The incoming, or a brand new trace object.
"""
if isinstance(trace, BaseTrace) and len(trace) > 0:
raise ValueError("Continuation of traces is no longer supported.")
if isinstance(trace, MultiTrace):
raise ValueError("Starting from existing MultiTrace objects is no longer supported.")

if isinstance(trace, BaseTrace):
return trace
if isinstance(trace, MultiTrace):
return trace._straces[chain]
if trace is None:
return NDArray(**kwds)

Expand All @@ -1401,7 +1392,7 @@ def _mp_sample(
random_seed: list,
start: list,
progressbar=True,
trace=None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
model=None,
callback=None,
discard_tuned_samples=True,
Expand Down Expand Up @@ -1430,10 +1421,9 @@ def _mp_sample(
Starting points for each chain.
progressbar : bool
Whether or not to display a progress bar in the command line.
trace : BaseTrace, list, MultiTrace or None
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
trace : BaseTrace, list, or None
This should be a backend instance, or a list of variables to track
If None or a list of variables, the NDArray backend is used.
model : Model (optional if in ``with`` context)
callback : Callable
A function which gets called for every sample from the trace of a chain. The function is
Expand All @@ -1455,10 +1445,10 @@ def _mp_sample(
traces = []
for idx in range(chain, chain + chains):
if trace is not None:
strace = _choose_backend(copy(trace), idx, model=model)
strace = _choose_backend(copy(trace), model=model)
else:
strace = _choose_backend(None, idx, model=model)
# for user supply start value, fill-in missing value if the supplied
strace = _choose_backend(None, model=model)
# for user supplied start value, fill-in missing value if the supplied
# dict does not contain all parameters
model.update_start_vals(start[idx - chain], model.initial_point)
if step.generates_stats and strace.supports_sampler_stats:
Expand Down
3 changes: 3 additions & 0 deletions pymc3/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ def build_models(self):

return model, coarse_models

@pytest.mark.skip(
reason="MLDA needs to be refactored to no longer depend on trace continuation. See #5021."
)
def test_run(self):
model, coarse_models = self.build_models()

Expand Down
21 changes: 19 additions & 2 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import pymc3 as pm

from pymc3.aesaraf import compile_rv_inplace
from pymc3.backends.base import MultiTrace
from pymc3.backends.ndarray import NDArray
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
from pymc3.tests.helpers import SeededTest
Expand Down Expand Up @@ -438,14 +439,30 @@ def test_constant_named(self):
class TestChooseBackend:
def test_choose_backend_none(self):
with mock.patch("pymc3.sampling.NDArray") as nd:
pm.sampling._choose_backend(None, "chain")
pm.sampling._choose_backend(None)
assert nd.called

def test_choose_backend_list_of_variables(self):
with mock.patch("pymc3.sampling.NDArray") as nd:
pm.sampling._choose_backend(["var1", "var2"], "chain")
pm.sampling._choose_backend(["var1", "var2"])
nd.assert_called_with(vars=["var1", "var2"])

def test_errors_and_warnings(self):
with pm.Model():
A = pm.Normal("A")
B = pm.Uniform("B")
strace = pm.sampling.NDArray(vars=[A, B])
strace.setup(10, 0)

with pytest.raises(ValueError, match="from existing MultiTrace"):
pm.sampling._choose_backend(trace=MultiTrace([strace]))

strace.record({"A": 2, "B_interval__": 0.1})
assert len(strace) == 1
with pytest.raises(ValueError, match="Continuation of traces"):
pm.sampling._choose_backend(trace=strace)
pass


class TestSamplePPC(SeededTest):
def test_normal_scalar(self):
Expand Down
14 changes: 9 additions & 5 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,12 @@ def test_step_continuous(self):
HamiltonianMC(scaling=C, is_cov=True, blocked=False),
]
),
MLDA(
coarse_models=[model_coarse],
base_S=C,
base_proposal_dist=MultivariateNormalProposal,
),
# NOTE: The MLDA uses the trace continuation which was removed.
# MLDA(
# coarse_models=[model_coarse],
# base_S=C,
# base_proposal_dist=MultivariateNormalProposal,
# ),
)
for step in steps:
idata = sample(
Expand Down Expand Up @@ -1038,6 +1039,9 @@ def test_sampler_stats(self):
assert (trace.model_logp == model_logp_).all()


@pytest.mark.skip(
reason="MLDA needs to be refactored to no longer depend on trace continuation. See #5021."
)
class TestMLDA:
steppers = [MLDA]

Expand Down