Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored MLDA proposal to not use trace continuation #5095

Merged
merged 4 commits into from
Oct 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 53 additions & 65 deletions pymc/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import warnings

from typing import List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import aesara
import arviz as az
Expand All @@ -25,7 +25,8 @@

import pymc as pm

from pymc.blocking import DictToArrayBijection
from pymc.aesaraf import compile_rv_inplace
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.model import Model, Point
from pymc.step_methods.arraystep import ArrayStepShared, Competence, metrop_select
from pymc.step_methods.compound import CompoundStep
Expand Down Expand Up @@ -66,20 +67,20 @@ def __init__(self, *args, **kwargs):
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above

# extract some necessary variables
value_vars = kwargs.get("vars", None)
if value_vars is None:
value_vars = model.value_vars
vars = kwargs.get("vars", None)
if vars is None:
vars = model.value_vars
else:
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
value_vars = pm.inputvars(value_vars)
shared = pm.make_shared_replacements(initial_values, value_vars, model)
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)
shared = pm.make_shared_replacements(initial_values, vars, model)

# call parent class __init__
super().__init__(*args, **kwargs)

# modify the delta function and point to model if VR is used
if self.mlda_variance_reduction:
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
self.model = model

def reset_tuning(self):
Expand Down Expand Up @@ -136,20 +137,20 @@ def __init__(self, *args, **kwargs):
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above

# extract some necessary variables
value_vars = kwargs.get("vars", None)
if value_vars is None:
value_vars = model.value_vars
vars = kwargs.get("vars", None)
if vars is None:
vars = model.value_vars
else:
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
value_vars = pm.inputvars(value_vars)
shared = pm.make_shared_replacements(initial_values, value_vars, model)
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)
shared = pm.make_shared_replacements(initial_values, vars, model)

# call parent class __init__
super().__init__(*args, **kwargs)

# modify the delta function and point to model if VR is used
if self.mlda_variance_reduction:
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
self.model = model

def reset_tuning(self):
Expand Down Expand Up @@ -363,7 +364,7 @@ class MLDA(ArrayStepShared):
def __init__(
self,
coarse_models: List[Model],
value_vars: Optional[list] = None,
vars: Optional[list] = None,
base_sampler="DEMetropolisZ",
base_S: Optional = None,
base_proposal_dist: Optional[Type[Proposal]] = None,
Expand All @@ -386,10 +387,6 @@ def __init__(
# this variable is used to identify MLDA objects which are
# not in the finest level (i.e. child MLDA objects)
self.is_child = kwargs.get("is_child", False)
if not self.is_child:
warnings.warn(
"The MLDA implementation in PyMC is still immature. You should be particularly critical of its results."
)

if not isinstance(coarse_models, list):
raise ValueError("MLDA step method cannot use coarse_models if it is not a list")
Expand Down Expand Up @@ -546,20 +543,20 @@ def __init__(
self.mode = mode

# Process model variables
if value_vars is None:
value_vars = model.value_vars
if vars is None:
vars = model.value_vars
else:
value_vars = [model.rvs_to_values.get(var, var) for var in value_vars]
value_vars = pm.inputvars(value_vars)
self.vars = value_vars
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = pm.inputvars(vars)
self.vars = vars
self.var_names = [var.name for var in self.vars]

self.accepted = 0

# Construct Aesara function for current-level model likelihood
# (for use in acceptance)
shared = pm.make_shared_replacements(initial_values, value_vars, model)
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, value_vars, shared)
shared = pm.make_shared_replacements(initial_values, vars, model)
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)

# Construct Aesara function for below-level model likelihood
# (for use in acceptance)
Expand All @@ -571,7 +568,7 @@ def __init__(
initial_values, model_below.logpt, vars_below, shared_below
)

super().__init__(value_vars, shared)
super().__init__(vars, shared)

# initialise complete step method hierarchy
if self.num_levels == 2:
Expand Down Expand Up @@ -643,7 +640,7 @@ def __init__(

# MLDA sampler in some intermediate level, targeting self.model_below
self.step_method_below = pm.MLDA(
value_vars=vars_below,
vars=vars_below,
base_S=self.base_S,
base_sampler=self.base_sampler,
base_proposal_dist=self.base_proposal_dist,
Expand Down Expand Up @@ -715,7 +712,7 @@ def __init__(
if self.store_Q_fine and not self.is_child:
self.stats_dtypes[0][f"Q_{self.num_levels - 1}"] = object

def astep(self, q0):
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
"""One MLDA step, given current sample q0"""
# Check if the tuning flag has been changed and if yes,
# change the proposal's tuning flag and reset self.accepted
Expand All @@ -730,10 +727,6 @@ def astep(self, q0):
method.tune = self.tune
self.accepted = 0

# Convert current sample from numpy array ->
# dict before feeding to proposal
q0_dict = DictToArrayBijection.rmap(q0)

# Set subchain_selection (which sample from the coarse chain
# is passed as a proposal to the fine chain). If variance
# reduction is used, a random sample is selected as proposal.
Expand All @@ -747,14 +740,13 @@ def astep(self, q0):

# Call the recursive DA proposal to get proposed sample
# and convert dict -> numpy array
pre_q = self.proposal_dist(q0_dict)
q = DictToArrayBijection.map(pre_q)
q = self.proposal_dist(q0)

# Evaluate MLDA acceptance log-ratio
# If proposed sample from lower levels is the same as current one,
# do not calculate likelihood, just set accept to 0.0
if (q.data == q0.data).all():
accept = np.float(0.0)
accept = np.float64(0.0)
skipped_logp = True
else:
accept = self.delta_logp(q.data, q0.data) + self.delta_logp_below(q0.data, q.data)
Expand Down Expand Up @@ -811,22 +803,22 @@ def astep(self, q0):
if isinstance(self.step_method_below, MLDA):
self.base_tuning_stats = self.step_method_below.base_tuning_stats
elif isinstance(self.step_method_below, MetropolisMLDA):
self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling[0]})
self.base_tuning_stats.append({"base_scaling": self.step_method_below.scaling})
elif isinstance(self.step_method_below, DEMetropolisZMLDA):
self.base_tuning_stats.append(
{
"base_scaling": self.step_method_below.scaling[0],
"base_scaling": self.step_method_below.scaling,
"base_lambda": self.step_method_below.lamb,
}
)
elif isinstance(self.step_method_below, CompoundStep):
# Below method is CompoundStep
for method in self.step_method_below.methods:
if isinstance(method, MetropolisMLDA):
self.base_tuning_stats.append({"base_scaling": method.scaling[0]})
self.base_tuning_stats.append({"base_scaling": method.scaling})
elif isinstance(method, DEMetropolisZMLDA):
self.base_tuning_stats.append(
{"base_scaling": method.scaling[0], "base_lambda": method.lamb}
{"base_scaling": method.scaling, "base_lambda": method.lamb}
)

return q_new, [stats] + self.base_tuning_stats
Expand Down Expand Up @@ -970,7 +962,7 @@ def delta_logp_inverse(point, logp, vars, shared):

logp1 = pm.CallableTensor(logp0)(inarray1)

f = aesara.function([inarray1, inarray0], -logp0 + logp1)
f = compile_rv_inplace([inarray1, inarray0], -logp0 + logp1)
f.trust_input = True
return f

Expand Down Expand Up @@ -1015,9 +1007,6 @@ def subsample(
trace=None,
tune=0,
model=None,
random_seed=None,
callback=None,
**kwargs,
):
"""
A stripped down version of sample(), which is called only
Expand All @@ -1032,19 +1021,10 @@ def subsample(
model = pm.modelcontext(model)
chain = 0
random_seed = np.random.randint(2 ** 30)

if start is not None:
pm.sampling._check_start_shape(model, start)
else:
start = {}
callback = None

draws += tune

step = pm.sampling.assign_step_methods(model, step, step_kwargs=kwargs)

if isinstance(step, list):
step = CompoundStep(step)

sampling = pm.sampling._iter_sample(
draws, step, start, trace, chain, tune, model, random_seed, callback
)
Expand Down Expand Up @@ -1086,9 +1066,8 @@ def __init__(
self.subsampling_rate = subsampling_rate
self.subchain_selection = None
self.tuning_end_trigger = True
self.trace = None

def __call__(self, q0_dict: dict) -> dict:
def __call__(self, q0: RaveledVars) -> RaveledVars:
"""Returns proposed sample given the current sample
in dictionary form (q0_dict)."""

Expand All @@ -1097,6 +1076,10 @@ def __call__(self, q0_dict: dict) -> dict:
_log = logging.getLogger("pymc")
_log.setLevel(logging.ERROR)

# Convert current sample from RaveledVars ->
# dict before feeding to subsample.
q0_dict = DictToArrayBijection.rmap(q0)

with self.model_below:
# Check if the tuning flag has been set to False
# in which case tuning is stopped. The flag is set
Expand All @@ -1106,11 +1089,10 @@ def __call__(self, q0_dict: dict) -> dict:

if self.tune:
# Subsample in tuning mode
self.trace = subsample(
trace = subsample(
draws=0,
step=self.step_method_below,
start=q0_dict,
trace=self.trace,
tune=self.subsampling_rate,
)
else:
Expand All @@ -1122,11 +1104,11 @@ def __call__(self, q0_dict: dict) -> dict:
self.step_method_below.tuning_end_trigger = True
self.tuning_end_trigger = False

self.trace = subsample(
trace = subsample(
draws=self.subsampling_rate,
step=self.step_method_below,
start=q0_dict,
trace=self.trace,
tune=0,
)

# set logging back to normal
Expand All @@ -1135,7 +1117,13 @@ def __call__(self, q0_dict: dict) -> dict:
# return sample with index self.subchain_selection from the generated
# sequence of length self.subsampling_rate. The index is set within
# MLDA's astep() function
new_point = self.trace.point(-self.subsampling_rate + self.subchain_selection)
new_point = Point(new_point, model=self.model_below, filter_model_vars=True)
q_dict = trace.point(self.subchain_selection)

# Make sure output dict is ordered the same way as the input dict.
q_dict = Point(
{key: q_dict[key] for key in q0_dict.keys()},
model=self.model_below,
filter_model_vars=True,
)

return new_point
return DictToArrayBijection.map(q_dict)
Loading