Skip to content

Commit

Permalink
Remove support for continuation of traces
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and twiecki committed Sep 24, 2021
1 parent 53e572c commit f73e933
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 48 deletions.
82 changes: 36 additions & 46 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def sample(
draws=1000,
step=None,
init="auto",
n_init=200000,
n_init=200_000,
start=None,
trace=None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
chain_idx=0,
chains=None,
cores=None,
Expand Down Expand Up @@ -296,10 +296,9 @@ def sample(
Defaults to ``trace.point(-1))`` if there is a trace provided and model.initial_point if not
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
overwrite the default.
trace : backend, list, or MultiTrace
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain_idx : int
Chain number used to store sample in backend. If ``chains`` is greater than one, chain
numbers will start here.
Expand Down Expand Up @@ -813,7 +812,7 @@ def _sample(
start,
draws: int,
step=None,
trace=None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
tune=None,
model: Optional[Model] = None,
callback=None,
Expand All @@ -839,10 +838,9 @@ def _sample(
The number of samples to draw
step : function
Step function
trace : backend, list, or MultiTrace
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
tune : int, optional
Number of iterations to tune, if applicable (defaults to None)
model : Model (optional if in ``with`` context)
Expand Down Expand Up @@ -899,10 +897,9 @@ def iter_sample(
start : dict
Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
there is a trace provided and model.initial_point if not (defaults to empty dict)
trace : backend, list, or MultiTrace
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain : int, optional
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
will start here.
Expand Down Expand Up @@ -939,7 +936,7 @@ def _iter_sample(
draws,
step,
start=None,
trace=None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
chain=0,
tune=None,
model=None,
Expand All @@ -955,12 +952,10 @@ def _iter_sample(
step : function
Step function
start : dict, optional
Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if
there is a trace provided and model.initial_point if not (defaults to empty dict)
trace : backend, list, MultiTrace, or None
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
Starting point in parameter space (or partial point). Defaults to model.initial_point if not (defaults to empty dict)
trace : backend or list
This should be a backend instance, or a list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain : int, optional
Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers
will start here.
Expand All @@ -986,12 +981,9 @@ def _iter_sample(
if start is None:
start = {}

strace = _choose_backend(trace, chain, model=model)
strace = _choose_backend(trace, model=model)

if len(strace) > 0:
model.update_start_vals(start, strace.point(-1))
else:
model.update_start_vals(start, model.initial_point)
model.update_start_vals(start, model.initial_point)

try:
step = CompoundStep(step)
Expand Down Expand Up @@ -1258,7 +1250,7 @@ def _prepare_iter_population(
# 5. a PopulationStepper is configured for parallelized stepping

# 1. prepare a BaseTrace for each chain
traces = [_choose_backend(None, chain, model=model) for chain in chains]
traces = [_choose_backend(None, model=model) for chain in chains]
for c, strace in enumerate(traces):
# initialize the trace size and variable transforms
if len(strace) > 0:
Expand Down Expand Up @@ -1361,30 +1353,29 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
steppers[c].report._finalize(strace)


def _choose_backend(trace, chain, **kwds) -> Backend:
def _choose_backend(trace: Optional[Union[BaseTrace, List[str]]], **kwds) -> Backend:
"""Selects or creates a NDArray trace backend for a particular chain.
Parameters
----------
trace : BaseTrace, list, MultiTrace, or None
This should be a BaseTrace, list of variables to track,
or a MultiTrace object with past values.
If a MultiTrace object is given, it must contain samples for the chain number ``chain``.
trace : BaseTrace, list, or None
This should be a BaseTrace, or list of variables to track.
If None or a list of variables, the NDArray backend is used.
chain : int
Number of the chain of interest.
**kwds :
keyword arguments to forward to the backend creation
Returns
-------
trace : BaseTrace
A trace object for the selected chain
The incoming, or a brand new trace object.
"""
if isinstance(trace, BaseTrace) and len(trace) > 0:
raise ValueError("Continuation of traces is no longer supported.")
if isinstance(trace, MultiTrace):
raise ValueError("Starting from existing MultiTrace objects is no longer supported.")

if isinstance(trace, BaseTrace):
return trace
if isinstance(trace, MultiTrace):
return trace._straces[chain]
if trace is None:
return NDArray(**kwds)

Expand All @@ -1401,7 +1392,7 @@ def _mp_sample(
random_seed: list,
start: list,
progressbar=True,
trace=None,
trace: Optional[Union[BaseTrace, List[str]]] = None,
model=None,
callback=None,
discard_tuned_samples=True,
Expand Down Expand Up @@ -1430,10 +1421,9 @@ def _mp_sample(
Starting points for each chain.
progressbar : bool
Whether or not to display a progress bar in the command line.
trace : BaseTrace, list, MultiTrace or None
This should be a backend instance, a list of variables to track, or a MultiTrace object
with past values. If a MultiTrace object is given, it must contain samples for the chain
number ``chain``. If None or a list of variables, the NDArray backend is used.
trace : BaseTrace, list, or None
This should be a backend instance, or a list of variables to track
If None or a list of variables, the NDArray backend is used.
model : Model (optional if in ``with`` context)
callback : Callable
A function which gets called for every sample from the trace of a chain. The function is
Expand All @@ -1455,10 +1445,10 @@ def _mp_sample(
traces = []
for idx in range(chain, chain + chains):
if trace is not None:
strace = _choose_backend(copy(trace), idx, model=model)
strace = _choose_backend(copy(trace), model=model)
else:
strace = _choose_backend(None, idx, model=model)
# for user supply start value, fill-in missing value if the supplied
strace = _choose_backend(None, model=model)
# for user supplied start value, fill-in missing value if the supplied
# dict does not contain all parameters
model.update_start_vals(start[idx - chain], model.initial_point)
if step.generates_stats and strace.supports_sampler_stats:
Expand Down
21 changes: 19 additions & 2 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import pymc3 as pm

from pymc3.aesaraf import compile_rv_inplace
from pymc3.backends.base import MultiTrace
from pymc3.backends.ndarray import NDArray
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
from pymc3.tests.helpers import SeededTest
Expand Down Expand Up @@ -438,14 +439,30 @@ def test_constant_named(self):
class TestChooseBackend:
def test_choose_backend_none(self):
with mock.patch("pymc3.sampling.NDArray") as nd:
pm.sampling._choose_backend(None, "chain")
pm.sampling._choose_backend(None)
assert nd.called

def test_choose_backend_list_of_variables(self):
with mock.patch("pymc3.sampling.NDArray") as nd:
pm.sampling._choose_backend(["var1", "var2"], "chain")
pm.sampling._choose_backend(["var1", "var2"])
nd.assert_called_with(vars=["var1", "var2"])

def test_errors_and_warnings(self):
with pm.Model():
A = pm.Normal("A")
B = pm.Uniform("B")
strace = pm.sampling.NDArray(vars=[A, B])
strace.setup(10, 0)

with pytest.raises(ValueError, match="from existing MultiTrace"):
pm.sampling._choose_backend(trace=MultiTrace([strace]))

strace.record({"A": 2, "B_interval__": 0.1})
assert len(strace) == 1
with pytest.raises(ValueError, match="Continuation of traces"):
pm.sampling._choose_backend(trace=strace)
pass


class TestSamplePPC(SeededTest):
def test_normal_scalar(self):
Expand Down

0 comments on commit f73e933

Please sign in to comment.