diff --git a/pymc3/step_methods/arraystep.py b/pymc3/step_methods/arraystep.py index 4a367424023..4673b19234e 100644 --- a/pymc3/step_methods/arraystep.py +++ b/pymc3/step_methods/arraystep.py @@ -134,7 +134,7 @@ class ArrayStep(BlockedStep): Parameters ---------- vars: list - List of variables for sampler. + List of value variables for sampler. fs: list of logp Aesara functions allvars: Boolean (default False) blocked: Boolean (default True) @@ -190,7 +190,7 @@ def __init__(self, vars, shared, blocked=True): """ Parameters ---------- - vars: list of sampling variables + vars: list of sampling value variables shared: dict of Aesara variable -> shared variable blocked: Boolean (default True) """ @@ -235,7 +235,7 @@ def __init__(self, vars, shared, blocked=True): """ Parameters ---------- - vars: list of sampling variables + vars: list of sampling value variables shared: dict of Aesara variable -> shared variable blocked: Boolean (default True) """ diff --git a/pymc3/step_methods/elliptical_slice.py b/pymc3/step_methods/elliptical_slice.py index ea88d716598..dfc24631cfc 100644 --- a/pymc3/step_methods/elliptical_slice.py +++ b/pymc3/step_methods/elliptical_slice.py @@ -16,6 +16,7 @@ import numpy as np import numpy.random as nr +from pymc3.aesaraf import inputvars from pymc3.model import modelcontext from pymc3.step_methods.arraystep import ArrayStep, Competence @@ -61,7 +62,7 @@ class EllipticalSlice(ArrayStep): Parameters ---------- vars: list - List of variables for sampler. + List of value variables for sampler. prior_cov: array, optional Covariance matrix of the multivariate Gaussian prior. prior_chol: array, optional @@ -88,6 +89,8 @@ def __init__(self, vars=None, prior_cov=None, prior_chol=None, model=None, **kwa if vars is None: vars = self.model.cont_vars + else: + vars = [self.model.rvs_to_values.get(var, var) for var in vars] vars = inputvars(vars) super().__init__(vars, [self.model.fastlogp], **kwargs) diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index aaaaa9f4b2c..074d7e3721f 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -89,6 +89,8 @@ def __init__( if vars is None: vars = self._model.cont_vars + else: + vars = [self._model.rvs_to_values.get(var, var) for var in vars] super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **aesara_kwargs) diff --git a/pymc3/step_methods/hmc/hmc.py b/pymc3/step_methods/hmc/hmc.py index 1cc8ef335a2..4f3c42c7cab 100644 --- a/pymc3/step_methods/hmc/hmc.py +++ b/pymc3/step_methods/hmc/hmc.py @@ -60,7 +60,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs): Parameters ---------- vars: list, default=None - List of Aesara variables. If None, all continuous RVs from the + List of value variables. If None, all continuous RVs from the model are included. path_length: float, default=2 Total length to travel diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index 267c20659fa..210608e2b0f 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -115,7 +115,7 @@ def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs) Parameters ---------- vars: list, default=None - List of Aesara variables. If None, all continuous RVs from the + List of value variables. If None, all continuous RVs from the model are included. Emax: float, default 1000 Maximum energy change allowed during leapfrog steps. Larger diff --git a/pymc3/step_methods/metropolis.py b/pymc3/step_methods/metropolis.py index 24b88f7ee89..0d5b9c4b81b 100644 --- a/pymc3/step_methods/metropolis.py +++ b/pymc3/step_methods/metropolis.py @@ -130,7 +130,7 @@ def __init__( Parameters ---------- vars: list - List of variables for sampler + List of value variables for sampler S: standard deviation or covariance matrix Some measure of variance to parameterize proposal distribution proposal_dist: function @@ -153,6 +153,8 @@ def __init__( if vars is None: vars = model.value_vars + else: + vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) if S is None: @@ -288,7 +290,7 @@ class BinaryMetropolis(ArrayStep): Parameters ---------- vars: list - List of variables for sampler + List of value variables for sampler scaling: scalar or array Initial scale factor for proposal. Defaults to 1. tune: bool @@ -321,6 +323,8 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None): self.steps_until_tune = tune_interval self.accepted = 0 + vars = [model.rvs_to_values.get(var, var) for var in vars] + if not all([v.dtype in pm.discrete_types for v in vars]): raise ValueError("All variables must be Bernoulli for BinaryMetropolis") @@ -388,7 +392,7 @@ class BinaryGibbsMetropolis(ArrayStep): Parameters ---------- vars: list - List of variables for sampler + List of value variables for sampler order: list or 'random' List of integers indicating the Gibbs update order e.g., [0, 2, 1, ...]. Default is random @@ -410,6 +414,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None): self.transit_p = transit_p initial_point = model.initial_point + vars = [model.rvs_to_values.get(var, var) for var in vars] self.dim = sum(initial_point[v.name].size for v in vars) if order == "random": @@ -490,6 +495,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None): model = pm.modelcontext(model) + vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) initial_point = model.initial_point @@ -697,6 +703,8 @@ def __init__( if vars is None: vars = model.cont_vars + else: + vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) if S is None: @@ -846,6 +854,8 @@ def __init__( if vars is None: vars = model.cont_vars + else: + vars = [model.rvs_to_values.get(var, var) for var in vars] vars = pm.inputvars(vars) if S is None: diff --git a/pymc3/step_methods/mlda.py b/pymc3/step_methods/mlda.py index 0b9dcdc8f28..c2cf6493370 100644 --- a/pymc3/step_methods/mlda.py +++ b/pymc3/step_methods/mlda.py @@ -74,6 +74,8 @@ def __init__(self, *args, **kwargs): value_vars = kwargs.get("vars", None) if value_vars is None: value_vars = model.value_vars + else: + value_vars = [model.rvs_to_values.get(var, var) for var in value_vars] value_vars = pm.inputvars(value_vars) shared = pm.make_shared_replacements(initial_values, value_vars, model) @@ -142,6 +144,8 @@ def __init__(self, *args, **kwargs): value_vars = kwargs.get("vars", None) if value_vars is None: value_vars = model.value_vars + else: + value_vars = [model.rvs_to_values.get(var, var) for var in value_vars] value_vars = pm.inputvars(value_vars) shared = pm.make_shared_replacements(initial_values, value_vars, model) @@ -218,7 +222,7 @@ class MLDA(ArrayStepShared): Note this list excludes the model passed to the model argument above, which is the finest available. vars : list - List of variables for sampler + List of value variables for sampler base_sampler : string Sampler used in the base (coarsest) chain. Can be 'Metropolis' or 'DEMetropolisZ'. Defaults to 'DEMetropolisZ'. @@ -549,6 +553,8 @@ def __init__( # Process model variables if value_vars is None: value_vars = model.value_vars + else: + value_vars = [model.rvs_to_values.get(var, var) for var in value_vars] value_vars = pm.inputvars(value_vars) self.vars = value_vars self.var_names = [var.name for var in self.vars] diff --git a/pymc3/step_methods/pgbart.py b/pymc3/step_methods/pgbart.py index 351f1ae8a26..6c556be95b0 100644 --- a/pymc3/step_methods/pgbart.py +++ b/pymc3/step_methods/pgbart.py @@ -39,7 +39,7 @@ class PGBART(ArrayStepShared): Parameters ---------- vars: list - List of variables for sampler + List of value variables for sampler num_particles : int Number of particles for the conditional SMC sampler. Defaults to 10 max_stages : int diff --git a/pymc3/step_methods/sgmcmc.py b/pymc3/step_methods/sgmcmc.py index 800c2da540c..19308eb6693 100644 --- a/pymc3/step_methods/sgmcmc.py +++ b/pymc3/step_methods/sgmcmc.py @@ -87,7 +87,7 @@ class BaseStochasticGradient(ArrayStepShared): Parameters ---------- vars: list - List of variables for sampler + List of value variables for sampler batch_size`: int Batch Size for each step total_size: int @@ -132,6 +132,8 @@ def __init__( if vars is None: vars = model.value_vars + else: + vars = [model.rvs_to_values.get(var, var) for var in vars] vars = inputvars(vars) diff --git a/pymc3/step_methods/slicer.py b/pymc3/step_methods/slicer.py index 5651d6e78ac..6074c9cc421 100644 --- a/pymc3/step_methods/slicer.py +++ b/pymc3/step_methods/slicer.py @@ -35,7 +35,7 @@ class Slice(ArrayStep): Parameters ---------- vars: list - List of variables for sampler. + List of value variables for sampler. w: float Initial width of slice (Defaults to 1). tune: bool @@ -57,6 +57,8 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, * if vars is None: vars = self.model.cont_vars + else: + vars = [self.model.rvs_to_values.get(var, var) for var in vars] vars = inputvars(vars) super().__init__(vars, [self.model.fastlogp], **kwargs) diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index c3893b939fd..38567eb23aa 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -1767,3 +1767,46 @@ def perform(self, node, inputs, outputs): ) assert Q_1_0.mean(axis=1) == 0.0 assert Q_2_1.mean(axis=1) == 0.0 + + +def test_rvs_assignment(): + cont_steps = ( + # MLDA, # TODO + NUTS, + DEMetropolis, + DEMetropolisZ, + HamiltonianMC, + Metropolis, + (EllipticalSlice, {"prior_cov": np.eye(1)}), + Slice, + ) + + disc_steps = ( + BinaryGibbsMetropolis, + CategoricalGibbsMetropolis, + ) + + with Model() as m: + c1 = HalfNormal("c1") + c2 = HalfNormal("c2") + d1 = Bernoulli("d1", p=0.5) + d2 = Bernoulli("d2", p=0.5) + + for step in cont_steps: + with m: + if isinstance(step, tuple): + step, step_kwargs = step + else: + step_kwargs = {} + + assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars + assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set( + step([c1, c2], **step_kwargs).vars + ) + + for step in disc_steps: + with m: + assert [m.rvs_to_values[d1]] == step([d1]).vars + assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set( + step([d1, d2], **step_kwargs).vars + )