diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 47cbcac381d..10066c18908 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -14,6 +14,7 @@ """Functions for MCMC sampling.""" +import collections.abc as abc import logging import pickle import sys @@ -21,11 +22,8 @@ import warnings from collections import defaultdict -from collections.abc import Iterable from copy import copy -from typing import Any, Dict -from typing import Iterable as TIterable -from typing import List, Optional, Union, cast +from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast import arviz import numpy as np @@ -57,8 +55,8 @@ HamiltonianMC, Metropolis, Slice, - arraystep, ) +from pymc3.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc3.step_methods.hmc import quadpotential from pymc3.util import ( chains_and_samples, @@ -93,15 +91,19 @@ CategoricalGibbsMetropolis, PGBART, ) +Step = Union[BlockedStep, CompoundStep] ArrayLike = Union[np.ndarray, List[float]] PointType = Dict[str, np.ndarray] PointList = List[PointType] +Backend = Union[BaseTrace, MultiTrace, NDArray] _log = logging.getLogger("pymc3") -def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None): +def instantiate_steppers( + _model, steps: List[Step], selected_steps, step_kwargs=None +) -> Union[Step, List[Step]]: """Instantiate steppers assigned to the model variables. This function is intended to be called automatically from ``sample()``, but @@ -142,7 +144,7 @@ def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None): raise ValueError("Unused step method arguments: %s" % unused_args) if len(steps) == 1: - steps = steps[0] + return steps[0] return steps @@ -216,7 +218,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None return instantiate_steppers(model, steps, selected_steps, step_kwargs) -def _print_step_hierarchy(s, level=0): +def _print_step_hierarchy(s: Step, level=0) -> None: if isinstance(s, CompoundStep): _log.info(">" * level + "CompoundStep") for i in s.methods: @@ -447,7 +449,7 @@ def sample( if random_seed is not None: np.random.seed(random_seed) random_seed = [np.random.randint(2 ** 30) for _ in range(chains)] - if not isinstance(random_seed, Iterable): + if not isinstance(random_seed, abc.Iterable): raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int") if not discard_tuned_samples and not return_inferencedata: @@ -542,7 +544,7 @@ def sample( has_population_samplers = np.any( [ - isinstance(m, arraystep.PopulationArrayStepShared) + isinstance(m, PopulationArrayStepShared) for m in (step.methods if isinstance(step, CompoundStep) else [step]) ] ) @@ -706,7 +708,7 @@ def _sample_many( trace: MultiTrace Contains samples of all chains """ - traces = [] + traces: List[Backend] = [] for i in range(chains): trace = _sample( draws=draws, @@ -1140,7 +1142,7 @@ def _run_secondary(c, stepper_dumps, secondary_end): # has to be updated, therefore we identify the substeppers first. population_steppers = [] for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]: - if isinstance(sm, arraystep.PopulationArrayStepShared): + if isinstance(sm, PopulationArrayStepShared): population_steppers.append(sm) while True: incoming = secondary_end.recv() @@ -1259,7 +1261,7 @@ def _prepare_iter_population( population = [Point(start[c], model=model) for c in range(nchains)] # 3. Set up the steppers - steppers = [None] * nchains + steppers: List[Step] = [] for c in range(nchains): # need indepenent samplers for each chain # it is important to copy the actual steppers (but not the delta_logp) @@ -1269,9 +1271,9 @@ def _prepare_iter_population( chainstep = copy(step) # link population samplers to the shared population state for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]: - if isinstance(sm, arraystep.PopulationArrayStepShared): + if isinstance(sm, PopulationArrayStepShared): sm.link_population(population, c) - steppers[c] = chainstep + steppers.append(chainstep) # 4. configure tracking of sampler stats for c in range(nchains): @@ -1349,7 +1351,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points): steppers[c].report._finalize(strace) -def _choose_backend(trace, chain, **kwds): +def _choose_backend(trace, chain, **kwds) -> Backend: """Selects or creates a NDArray trace backend for a particular chain. Parameters @@ -1562,8 +1564,8 @@ class _DefaultTrace: `insert()` method """ - trace_dict = {} # type: Dict[str, np.ndarray] - _len = None # type: int + trace_dict: Dict[str, np.ndarray] = {} + _len: Optional[int] = None def __init__(self, samples: int): self._len = samples @@ -1600,7 +1602,7 @@ def sample_posterior_predictive( trace, samples: Optional[int] = None, model: Optional[Model] = None, - vars: Optional[TIterable[Tensor]] = None, + vars: Optional[Iterable[Tensor]] = None, var_names: Optional[List[str]] = None, size: Optional[int] = None, keep_size: Optional[bool] = False, @@ -1885,8 +1887,7 @@ def sample_posterior_predictive_w( def sample_prior_predictive( samples=500, model: Optional[Model] = None, - vars: Optional[TIterable[str]] = None, - var_names: Optional[TIterable[str]] = None, + var_names: Optional[Iterable[str]] = None, random_seed=None, ) -> Dict[str, np.ndarray]: """Generate samples from the prior predictive distribution. @@ -1896,9 +1897,6 @@ def sample_prior_predictive( samples : int Number of samples from the prior predictive to generate. Defaults to 500. model : Model (optional if in ``with`` context) - vars : Iterable[str] - A list of names of variables for which to compute the posterior predictive - samples. *DEPRECATED* - Use ``var_names`` argument instead. var_names : Iterable[str] A list of names of variables for which to compute the posterior predictive samples. Defaults to both observed and unobserved RVs. @@ -1913,22 +1911,14 @@ def sample_prior_predictive( """ model = modelcontext(model) - if vars is None and var_names is None: + if var_names is None: prior_pred_vars = model.observed_RVs prior_vars = ( get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials ) - vars_ = [var.name for var in prior_vars + prior_pred_vars] - vars = set(vars_) - elif vars is None: - vars = var_names - vars_ = vars - elif vars is not None: - warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning) - vars_ = vars + vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars} else: - raise ValueError("Cannot supply both vars and var_names arguments.") - vars = cast(TIterable[str], vars) # tell mypy that vars cannot be None here. + vars_ = set(var_names) if random_seed is not None: np.random.seed(random_seed) @@ -1940,8 +1930,8 @@ def sample_prior_predictive( if data is None: raise AssertionError("No variables sampled: attempting to sample %s" % names) - prior = {} # type: Dict[str, np.ndarray] - for var_name in vars: + prior: Dict[str, np.ndarray] = {} + for var_name in vars_: if var_name in data: prior[var_name] = data[var_name] elif is_transformed_name(var_name): @@ -2093,7 +2083,7 @@ def init_nuts( var = np.ones_like(mean) potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10) elif init == "advi+adapt_diag_grad": - approx = pm.fit( + approx: pm.MeanField = pm.fit( random_seed=random_seed, n=n_init, method="advi", @@ -2101,7 +2091,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, - ) # type: pm.MeanField + ) start = approx.sample(draws=chains) start = list(start) stds = approx.bij.rmap(approx.std.eval()) @@ -2119,7 +2109,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, - ) # type: pm.MeanField + ) start = approx.sample(draws=chains) start = list(start) stds = approx.bij.rmap(approx.std.eval()) @@ -2137,7 +2127,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, - ) # type: pm.MeanField + ) start = approx.sample(draws=chains) start = list(start) stds = approx.bij.rmap(approx.std.eval()) diff --git a/pymc3/step_methods/arraystep.py b/pymc3/step_methods/arraystep.py index 91eaa7d5435..c3e1cf6f8bb 100644 --- a/pymc3/step_methods/arraystep.py +++ b/pymc3/step_methods/arraystep.py @@ -13,13 +13,14 @@ # limitations under the License. from enum import IntEnum, unique +from typing import Dict, List import numpy as np from numpy.random import uniform from pymc3.blocking import ArrayOrdering, DictToArrayBijection -from pymc3.model import modelcontext +from pymc3.model import PyMC3Variable, modelcontext from pymc3.step_methods.compound import CompoundStep from pymc3.theanof import inputvars from pymc3.util import get_var_name @@ -46,6 +47,8 @@ class Competence(IntEnum): class BlockedStep: generates_stats = False + stats_dtypes: List[Dict[str, np.dtype]] = [] + vars: List[PyMC3Variable] = [] def __new__(cls, *args, **kwargs): blocked = kwargs.get("blocked") diff --git a/pymc3/step_methods/mlda.py b/pymc3/step_methods/mlda.py index d5810eec77b..f99af282524 100644 --- a/pymc3/step_methods/mlda.py +++ b/pymc3/step_methods/mlda.py @@ -356,10 +356,6 @@ class MLDA(ArrayStepShared): default_blocked = True generates_stats = True - # stat data types are different, depending on the base sampler. - # these are assigned in the init method. - stats_dtypes = None - def __init__( self, coarse_models: List[Model], diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index cb5f9806ace..8b932545c32 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -903,9 +903,8 @@ def test_respects_shape(self): with pm.Model(): mu = pm.Gamma("mu", 3, 1, shape=1) goals = pm.Poisson("goals", mu, shape=shape) - with pytest.warns(DeprecationWarning): - trace1 = pm.sample_prior_predictive(10, vars=["mu", "goals"]) - trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"]) + trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"]) + trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"]) if shape == 2: # want to test shape as an int shape = (2,) assert trace1["goals"].shape == (10,) + shape diff --git a/setup.cfg b/setup.cfg index aafb77f3b34..f380b0c4731 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,3 +11,6 @@ convention = numpy [isort] lines_between_types = 1 profile = black + +[mypy] +ignore_missing_imports = True