Skip to content

Commit

Permalink
Fixup mypy errors in sampling.py (#4327)
Browse files Browse the repository at this point in the history
* 🏷️ type sampling

* 🔥 remove deprecated vars from sample_prior_predictive
  • Loading branch information
MarcoGorelli authored Dec 13, 2020
1 parent 70fdcf9 commit b386d94
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 49 deletions.
72 changes: 31 additions & 41 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@

"""Functions for MCMC sampling."""

import collections.abc as abc
import logging
import pickle
import sys
import time
import warnings

from collections import defaultdict
from collections.abc import Iterable
from copy import copy
from typing import Any, Dict
from typing import Iterable as TIterable
from typing import List, Optional, Union, cast
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast

import arviz
import numpy as np
Expand Down Expand Up @@ -57,8 +55,8 @@
HamiltonianMC,
Metropolis,
Slice,
arraystep,
)
from pymc3.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc3.step_methods.hmc import quadpotential
from pymc3.util import (
chains_and_samples,
Expand Down Expand Up @@ -93,15 +91,19 @@
CategoricalGibbsMetropolis,
PGBART,
)
Step = Union[BlockedStep, CompoundStep]

ArrayLike = Union[np.ndarray, List[float]]
PointType = Dict[str, np.ndarray]
PointList = List[PointType]
Backend = Union[BaseTrace, MultiTrace, NDArray]

_log = logging.getLogger("pymc3")


def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
def instantiate_steppers(
_model, steps: List[Step], selected_steps, step_kwargs=None
) -> Union[Step, List[Step]]:
"""Instantiate steppers assigned to the model variables.
This function is intended to be called automatically from ``sample()``, but
Expand Down Expand Up @@ -142,7 +144,7 @@ def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
raise ValueError("Unused step method arguments: %s" % unused_args)

if len(steps) == 1:
steps = steps[0]
return steps[0]

return steps

Expand Down Expand Up @@ -216,7 +218,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
return instantiate_steppers(model, steps, selected_steps, step_kwargs)


def _print_step_hierarchy(s, level=0):
def _print_step_hierarchy(s: Step, level=0) -> None:
if isinstance(s, CompoundStep):
_log.info(">" * level + "CompoundStep")
for i in s.methods:
Expand Down Expand Up @@ -447,7 +449,7 @@ def sample(
if random_seed is not None:
np.random.seed(random_seed)
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
if not isinstance(random_seed, Iterable):
if not isinstance(random_seed, abc.Iterable):
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")

if not discard_tuned_samples and not return_inferencedata:
Expand Down Expand Up @@ -542,7 +544,7 @@ def sample(

has_population_samplers = np.any(
[
isinstance(m, arraystep.PopulationArrayStepShared)
isinstance(m, PopulationArrayStepShared)
for m in (step.methods if isinstance(step, CompoundStep) else [step])
]
)
Expand Down Expand Up @@ -706,7 +708,7 @@ def _sample_many(
trace: MultiTrace
Contains samples of all chains
"""
traces = []
traces: List[Backend] = []
for i in range(chains):
trace = _sample(
draws=draws,
Expand Down Expand Up @@ -1140,7 +1142,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
# has to be updated, therefore we identify the substeppers first.
population_steppers = []
for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]:
if isinstance(sm, arraystep.PopulationArrayStepShared):
if isinstance(sm, PopulationArrayStepShared):
population_steppers.append(sm)
while True:
incoming = secondary_end.recv()
Expand Down Expand Up @@ -1259,7 +1261,7 @@ def _prepare_iter_population(
population = [Point(start[c], model=model) for c in range(nchains)]

# 3. Set up the steppers
steppers = [None] * nchains
steppers: List[Step] = []
for c in range(nchains):
# need indepenent samplers for each chain
# it is important to copy the actual steppers (but not the delta_logp)
Expand All @@ -1269,9 +1271,9 @@ def _prepare_iter_population(
chainstep = copy(step)
# link population samplers to the shared population state
for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]:
if isinstance(sm, arraystep.PopulationArrayStepShared):
if isinstance(sm, PopulationArrayStepShared):
sm.link_population(population, c)
steppers[c] = chainstep
steppers.append(chainstep)

# 4. configure tracking of sampler stats
for c in range(nchains):
Expand Down Expand Up @@ -1349,7 +1351,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
steppers[c].report._finalize(strace)


def _choose_backend(trace, chain, **kwds):
def _choose_backend(trace, chain, **kwds) -> Backend:
"""Selects or creates a NDArray trace backend for a particular chain.
Parameters
Expand Down Expand Up @@ -1562,8 +1564,8 @@ class _DefaultTrace:
`insert()` method
"""

trace_dict = {} # type: Dict[str, np.ndarray]
_len = None # type: int
trace_dict: Dict[str, np.ndarray] = {}
_len: Optional[int] = None

def __init__(self, samples: int):
self._len = samples
Expand Down Expand Up @@ -1600,7 +1602,7 @@ def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
model: Optional[Model] = None,
vars: Optional[TIterable[Tensor]] = None,
vars: Optional[Iterable[Tensor]] = None,
var_names: Optional[List[str]] = None,
size: Optional[int] = None,
keep_size: Optional[bool] = False,
Expand Down Expand Up @@ -1885,8 +1887,7 @@ def sample_posterior_predictive_w(
def sample_prior_predictive(
samples=500,
model: Optional[Model] = None,
vars: Optional[TIterable[str]] = None,
var_names: Optional[TIterable[str]] = None,
var_names: Optional[Iterable[str]] = None,
random_seed=None,
) -> Dict[str, np.ndarray]:
"""Generate samples from the prior predictive distribution.
Expand All @@ -1896,9 +1897,6 @@ def sample_prior_predictive(
samples : int
Number of samples from the prior predictive to generate. Defaults to 500.
model : Model (optional if in ``with`` context)
vars : Iterable[str]
A list of names of variables for which to compute the posterior predictive
samples. *DEPRECATED* - Use ``var_names`` argument instead.
var_names : Iterable[str]
A list of names of variables for which to compute the posterior predictive
samples. Defaults to both observed and unobserved RVs.
Expand All @@ -1913,22 +1911,14 @@ def sample_prior_predictive(
"""
model = modelcontext(model)

if vars is None and var_names is None:
if var_names is None:
prior_pred_vars = model.observed_RVs
prior_vars = (
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
)
vars_ = [var.name for var in prior_vars + prior_pred_vars]
vars = set(vars_)
elif vars is None:
vars = var_names
vars_ = vars
elif vars is not None:
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
vars_ = vars
vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars}
else:
raise ValueError("Cannot supply both vars and var_names arguments.")
vars = cast(TIterable[str], vars) # tell mypy that vars cannot be None here.
vars_ = set(var_names)

if random_seed is not None:
np.random.seed(random_seed)
Expand All @@ -1940,8 +1930,8 @@ def sample_prior_predictive(
if data is None:
raise AssertionError("No variables sampled: attempting to sample %s" % names)

prior = {} # type: Dict[str, np.ndarray]
for var_name in vars:
prior: Dict[str, np.ndarray] = {}
for var_name in vars_:
if var_name in data:
prior[var_name] = data[var_name]
elif is_transformed_name(var_name):
Expand Down Expand Up @@ -2093,15 +2083,15 @@ def init_nuts(
var = np.ones_like(mean)
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
elif init == "advi+adapt_diag_grad":
approx = pm.fit(
approx: pm.MeanField = pm.fit(
random_seed=random_seed,
n=n_init,
method="advi",
model=model,
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
) # type: pm.MeanField
)
start = approx.sample(draws=chains)
start = list(start)
stds = approx.bij.rmap(approx.std.eval())
Expand All @@ -2119,7 +2109,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
) # type: pm.MeanField
)
start = approx.sample(draws=chains)
start = list(start)
stds = approx.bij.rmap(approx.std.eval())
Expand All @@ -2137,7 +2127,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
) # type: pm.MeanField
)
start = approx.sample(draws=chains)
start = list(start)
stds = approx.bij.rmap(approx.std.eval())
Expand Down
5 changes: 4 additions & 1 deletion pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

from enum import IntEnum, unique
from typing import Dict, List

import numpy as np

from numpy.random import uniform

from pymc3.blocking import ArrayOrdering, DictToArrayBijection
from pymc3.model import modelcontext
from pymc3.model import PyMC3Variable, modelcontext
from pymc3.step_methods.compound import CompoundStep
from pymc3.theanof import inputvars
from pymc3.util import get_var_name
Expand All @@ -46,6 +47,8 @@ class Competence(IntEnum):
class BlockedStep:

generates_stats = False
stats_dtypes: List[Dict[str, np.dtype]] = []
vars: List[PyMC3Variable] = []

def __new__(cls, *args, **kwargs):
blocked = kwargs.get("blocked")
Expand Down
4 changes: 0 additions & 4 deletions pymc3/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ class MLDA(ArrayStepShared):
default_blocked = True
generates_stats = True

# stat data types are different, depending on the base sampler.
# these are assigned in the init method.
stats_dtypes = None

def __init__(
self,
coarse_models: List[Model],
Expand Down
5 changes: 2 additions & 3 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,9 +903,8 @@ def test_respects_shape(self):
with pm.Model():
mu = pm.Gamma("mu", 3, 1, shape=1)
goals = pm.Poisson("goals", mu, shape=shape)
with pytest.warns(DeprecationWarning):
trace1 = pm.sample_prior_predictive(10, vars=["mu", "goals"])
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"])
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
if shape == 2: # want to test shape as an int
shape = (2,)
assert trace1["goals"].shape == (10,) + shape
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ convention = numpy
[isort]
lines_between_types = 1
profile = black

[mypy]
ignore_missing_imports = True

0 comments on commit b386d94

Please sign in to comment.