Skip to content

Commit

Permalink
Convert RVs to value vars in step methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 22, 2021
1 parent d926746 commit 1b2afa0
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 13 deletions.
6 changes: 3 additions & 3 deletions pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ArrayStep(BlockedStep):
Parameters
----------
vars: list
List of variables for sampler.
List of value variables for sampler.
fs: list of logp Aesara functions
allvars: Boolean (default False)
blocked: Boolean (default True)
Expand Down Expand Up @@ -190,7 +190,7 @@ def __init__(self, vars, shared, blocked=True):
"""
Parameters
----------
vars: list of sampling variables
vars: list of sampling value variables
shared: dict of Aesara variable -> shared variable
blocked: Boolean (default True)
"""
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(self, vars, shared, blocked=True):
"""
Parameters
----------
vars: list of sampling variables
vars: list of sampling value variables
shared: dict of Aesara variable -> shared variable
blocked: Boolean (default True)
"""
Expand Down
5 changes: 4 additions & 1 deletion pymc3/step_methods/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import numpy.random as nr

from pymc3.aesaraf import inputvars
from pymc3.model import modelcontext
from pymc3.step_methods.arraystep import ArrayStep, Competence

Expand Down Expand Up @@ -61,7 +62,7 @@ class EllipticalSlice(ArrayStep):
Parameters
----------
vars: list
List of variables for sampler.
List of value variables for sampler.
prior_cov: array, optional
Covariance matrix of the multivariate Gaussian prior.
prior_chol: array, optional
Expand All @@ -88,6 +89,8 @@ def __init__(self, vars=None, prior_cov=None, prior_chol=None, model=None, **kwa

if vars is None:
vars = self.model.cont_vars
else:
vars = [self.model.rvs_to_values.get(var, var) for var in vars]
vars = inputvars(vars)

super().__init__(vars, [self.model.fastlogp], **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions pymc3/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def __init__(

if vars is None:
vars = self._model.cont_vars
else:
vars = [self._model.rvs_to_values.get(var, var) for var in vars]

super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **aesara_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pymc3/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
Parameters
----------
vars: list, default=None
List of Aesara variables. If None, all continuous RVs from the
List of value variables. If None, all continuous RVs from the
model are included.
path_length: float, default=2
Total length to travel
Expand Down
2 changes: 1 addition & 1 deletion pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs)
Parameters
----------
vars: list, default=None
List of Aesara variables. If None, all continuous RVs from the
List of value variables. If None, all continuous RVs from the
model are included.
Emax: float, default 1000
Maximum energy change allowed during leapfrog steps. Larger
Expand Down
16 changes: 13 additions & 3 deletions pymc3/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
Parameters
----------
vars: list
List of variables for sampler
List of value variables for sampler
S: standard deviation or covariance matrix
Some measure of variance to parameterize proposal distribution
proposal_dist: function
Expand All @@ -153,6 +153,8 @@ def __init__(

if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)

if S is None:
Expand Down Expand Up @@ -288,7 +290,7 @@ class BinaryMetropolis(ArrayStep):
Parameters
----------
vars: list
List of variables for sampler
List of value variables for sampler
scaling: scalar or array
Initial scale factor for proposal. Defaults to 1.
tune: bool
Expand Down Expand Up @@ -321,6 +323,8 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
self.steps_until_tune = tune_interval
self.accepted = 0

vars = [model.rvs_to_values.get(var, var) for var in vars]

if not all([v.dtype in pm.discrete_types for v in vars]):
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")

Expand Down Expand Up @@ -388,7 +392,7 @@ class BinaryGibbsMetropolis(ArrayStep):
Parameters
----------
vars: list
List of variables for sampler
List of value variables for sampler
order: list or 'random'
List of integers indicating the Gibbs update order
e.g., [0, 2, 1, ...]. Default is random
Expand All @@ -410,6 +414,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
self.transit_p = transit_p

initial_point = model.initial_point
vars = [model.rvs_to_values.get(var, var) for var in vars]
self.dim = sum(initial_point[v.name].size for v in vars)

if order == "random":
Expand Down Expand Up @@ -490,6 +495,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):

model = pm.modelcontext(model)

vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)

initial_point = model.initial_point
Expand Down Expand Up @@ -697,6 +703,8 @@ def __init__(

if vars is None:
vars = model.cont_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)

if S is None:
Expand Down Expand Up @@ -846,6 +854,8 @@ def __init__(

if vars is None:
vars = model.cont_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)

if S is None:
Expand Down
8 changes: 7 additions & 1 deletion pymc3/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(self, *args, **kwargs):
value_vars = kwargs.get("vars", None)
if value_vars is None:
value_vars = model.value_vars
else:
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
value_vars = pm.inputvars(value_vars)
shared = pm.make_shared_replacements(initial_values, value_vars, model)

Expand Down Expand Up @@ -142,6 +144,8 @@ def __init__(self, *args, **kwargs):
value_vars = kwargs.get("vars", None)
if value_vars is None:
value_vars = model.value_vars
else:
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
value_vars = pm.inputvars(value_vars)
shared = pm.make_shared_replacements(initial_values, value_vars, model)

Expand Down Expand Up @@ -218,7 +222,7 @@ class MLDA(ArrayStepShared):
Note this list excludes the model passed to the model
argument above, which is the finest available.
vars : list
List of variables for sampler
List of value variables for sampler
base_sampler : string
Sampler used in the base (coarsest) chain. Can be 'Metropolis' or
'DEMetropolisZ'. Defaults to 'DEMetropolisZ'.
Expand Down Expand Up @@ -549,6 +553,8 @@ def __init__(
# Process model variables
if value_vars is None:
value_vars = model.value_vars
else:
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
value_vars = pm.inputvars(value_vars)
self.vars = value_vars
self.var_names = [var.name for var in self.vars]
Expand Down
2 changes: 1 addition & 1 deletion pymc3/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class PGBART(ArrayStepShared):
Parameters
----------
vars: list
List of variables for sampler
List of value variables for sampler
num_particles : int
Number of particles for the conditional SMC sampler. Defaults to 10
max_stages : int
Expand Down
4 changes: 3 additions & 1 deletion pymc3/step_methods/sgmcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class BaseStochasticGradient(ArrayStepShared):
Parameters
----------
vars: list
List of variables for sampler
List of value variables for sampler
batch_size`: int
Batch Size for each step
total_size: int
Expand Down Expand Up @@ -132,6 +132,8 @@ def __init__(

if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]

vars = inputvars(vars)

Expand Down
4 changes: 3 additions & 1 deletion pymc3/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Slice(ArrayStep):
Parameters
----------
vars: list
List of variables for sampler.
List of value variables for sampler.
w: float
Initial width of slice (Defaults to 1).
tune: bool
Expand All @@ -57,6 +57,8 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, *

if vars is None:
vars = self.model.cont_vars
else:
vars = [self.model.rvs_to_values.get(var, var) for var in vars]
vars = inputvars(vars)

super().__init__(vars, [self.model.fastlogp], **kwargs)
Expand Down
43 changes: 43 additions & 0 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,3 +1767,46 @@ def perform(self, node, inputs, outputs):
)
assert Q_1_0.mean(axis=1) == 0.0
assert Q_2_1.mean(axis=1) == 0.0


def test_rvs_assignment():
cont_steps = (
# MLDA, # TODO
NUTS,
DEMetropolis,
DEMetropolisZ,
HamiltonianMC,
Metropolis,
(EllipticalSlice, {"prior_cov": np.eye(1)}),
Slice,
)

disc_steps = (
BinaryGibbsMetropolis,
CategoricalGibbsMetropolis,
)

with Model() as m:
c1 = HalfNormal("c1")
c2 = HalfNormal("c2")
d1 = Bernoulli("d1", p=0.5)
d2 = Bernoulli("d2", p=0.5)

for step in cont_steps:
with m:
if isinstance(step, tuple):
step, step_kwargs = step
else:
step_kwargs = {}

assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(
step([c1, c2], **step_kwargs).vars
)

for step in disc_steps:
with m:
assert [m.rvs_to_values[d1]] == step([d1]).vars
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(
step([d1, d2], **step_kwargs).vars
)

0 comments on commit 1b2afa0

Please sign in to comment.