Skip to content

Commit

Permalink
Add QMC marginalization
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Seyboldt <[email protected]>
Co-authored-by: larryshamalama <[email protected]>
Co-authored-by: Rob Zinkov <[email protected]>
Co-authored-by: theorashid <[email protected]>
  • Loading branch information
5 people committed Jul 5, 2024
1 parent 87d4aea commit d8971c4
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 16 deletions.
156 changes: 140 additions & 16 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import numpy as np
import pymc
import pytensor.tensor as pt
import scipy
from arviz import InferenceData, dict_to_dataset
from pymc import SymbolicRandomVariable
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
from pymc.distributions import MvNormal, SymbolicRandomVariable
from pymc.distributions.continuous import Continuous
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.transforms import Chain
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.basic import conditional_logp, icdf, logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import compile_pymc, constant_fold
from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold
from pymc.util import RandomState, _get_seeds_per_chain, treedict
from pytensor import Mode, scan
from pytensor.compile import SharedVariable
Expand Down Expand Up @@ -159,17 +161,17 @@ def _marginalize(self, user_warnings=False):
f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
)

old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
if isinstance(rv_to_marginalize.owner.op, Continuous):
subgraph_builder_fn = replace_continuous_marginal_subgraph
else:
subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
old_rvs, new_rvs = subgraph_builder_fn(
fg,
rv_to_marginalize,
self.basic_RVs + rvs_left_to_marginalize,
user_warnings=user_warnings,
)

if user_warnings and len(new_rvs) > 2:
warnings.warn(
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}",
UserWarning,
)

rvs_left_to_marginalize.remove(rv_to_marginalize)

for old_rv, new_rv in zip(old_rvs, new_rvs):
Expand Down Expand Up @@ -267,7 +269,11 @@ def marginalize(
)

rv_op = rv_to_marginalize.owner.op
if isinstance(rv_op, DiscreteMarkovChain):

if isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
pass

elif isinstance(rv_op, DiscreteMarkovChain):
if rv_op.n_lags > 1:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
Expand All @@ -276,7 +282,11 @@ def marginalize(
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
)
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):

elif isinstance(rv_op, Continuous):
pass

else:
raise NotImplementedError(
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
)
Expand Down Expand Up @@ -449,7 +459,7 @@ def transform_input(inputs):
rv_loglike_fn = None
joint_logps_norm = log_softmax(joint_logps, axis=-1)
if return_samples:
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
sample_rv_outs = Categorical.dist(logit_p=joint_logps)
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
sample_rv_outs += rv_domain[0]

Expand Down Expand Up @@ -549,6 +559,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
"""Base class for Discrete Marginal Markov Chain RVs"""


class QMCMarginalNormalRV(MarginalRV):
"""Basec class for QMC Marginalized RVs"""

__props__ = ("qmc_order",)

def __init__(self, *args, qmc_order: int, **kwargs):
self.qmc_order = qmc_order
super().__init__(*args, **kwargs)


def static_shape_ancestors(vars):
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
return [
Expand Down Expand Up @@ -646,7 +666,9 @@ def collect_shared_vars(outputs, blockers):
]


def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
def replace_finite_discrete_marginal_subgraph(
fgraph, rv_to_marginalize, all_rvs, user_warnings: bool = False
):
# TODO: This should eventually be integrated in a more general routine that can
# identify other types of supported marginalization, of which finite discrete
# RVs is just one
Expand All @@ -655,6 +677,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
if not dependent_rvs:
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")

if user_warnings and len(dependent_rvs) > 1:
warnings.warn(
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
f"Their joint logp terms will be assigned to the first RV: {dependent_rvs[0]}",
UserWarning,
)

ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
if len(ndim_supp) != 1:
raise NotImplementedError(
Expand Down Expand Up @@ -707,6 +736,39 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
return rvs_to_marginalize, marginalized_rvs


def replace_continuous_marginal_subgraph(
fgraph, rv_to_marginalize, all_rvs, user_warnings: bool = False
):
dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
if not dependent_rvs:
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")

marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
dependent_rvs_input_rvs = [
rv
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
if rv is not rv_to_marginalize
]

input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs]
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]

outputs = rvs_to_marginalize
# We are strict about shared variables in SymbolicRandomVariables
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)

# TODO: Assert no non-marginalized variables depend on the rng output of the marginalized variables!!!
marginalized_rvs = QMCMarginalNormalRV(
inputs=inputs,
outputs=[*outputs, *collect_default_updates(inputs=inputs, outputs=outputs).values()],
ndim_supp=max([rv.owner.op.ndim_supp for rv in dependent_rvs]),
qmc_order=13,
)(*inputs)[: len(outputs)]

fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
return rvs_to_marginalize, marginalized_rvs


def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
op = rv.owner.op
dist_params = rv.owner.op.dist_params(rv.owner)
Expand Down Expand Up @@ -870,3 +932,65 @@ def step_alpha(logp_emission, log_alpha, log_P):
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
dummy_logps = (pt.constant(0),) * (len(values) - 1)
return joint_logp, *dummy_logps


@_logprob.register(QMCMarginalNormalRV)
def qmc_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rvs_node = op.make_node(*inputs)
# The MarginalizedRV contains the following outputs:
# 1. The variable we marginalized
# 2. The dependent variables
# 3. The updates for the marginalized and dependent variables
marginalized_rv, *inner_rvs_and_updates = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)
inner_rvs = inner_rvs_and_updates[: (len(inner_rvs_and_updates) - 1) // 2]

marginalized_rv_node = marginalized_rv.owner
marginalized_rv_op = marginalized_rv_node.op

# GET QMC draws from the marginalized RV
# TODO: Make this an Op
rng = marginalized_rv_op.rng_param(marginalized_rv_node)
shape = constant_fold(tuple(marginalized_rv.shape))
size = np.prod(shape).astype(int)
n_draws = 2**op.qmc_order

# TODO: Wrap Sobol in an Op so we can control the RNG and change whenever
qmc_engine = scipy.stats.qmc.Sobol(d=size, seed=rng.get_value(borrow=False))
uniform_draws = qmc_engine.random(n_draws).reshape((n_draws, *shape))

if isinstance(marginalized_rv_op, MvNormal):
# Adapted from https://github.com/scipy/scipy/blob/87c46641a8b3b5b47b81de44c07b840468f7ebe7/scipy/stats/_qmc.py#L2211-L2298
mean, cov = marginalized_rv_op.dist_params(marginalized_rv_node)
corr_matrix = pt.linalg.cholesky(cov).mT
base_draws = pt.as_tensor(scipy.stats.norm.ppf(0.5 + (1 - 1e-10) * (uniform_draws - 0.5)))
qmc_draws = base_draws @ corr_matrix + mean
else:
qmc_draws = icdf(marginalized_rv, uniform_draws)

qmc_draws.name = f"QMC_{marginalized_rv_op.name}_draws"

# Obtain the logp of the dependent variables
# We need to include the marginalized RV for correctness, we remove it later.
inner_rv_values = dict(zip(inner_rvs, values))
marginalized_vv = marginalized_rv.clone()
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
# Pop the logp term corresponding to the marginalized RV
# (it already got accounted for in the bias of the QMC draws)
logps_dict.pop(marginalized_vv)

# Vectorize across QMC draws and take the mean on log scale
core_marginalized_logps = list(logps_dict.values())
batched_marginalized_logps = vectorize_graph(
core_marginalized_logps, replace={marginalized_vv: qmc_draws}
)

# Take the mean in log scale
return tuple(
pt.logsumexp(batched_marginalized_logp, axis=0) - pt.log(n_draws)
for batched_marginalized_logp in batched_marginalized_logps
)
63 changes: 63 additions & 0 deletions pymc_experimental/tests/model/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import pymc as pm
import pytensor.tensor as pt
import pytest
import scipy
from arviz import InferenceData, dict_to_dataset
from pymc.distributions import transforms
from pymc.logprob.abstract import _logprob
from pymc.model.fgraph import fgraph_from_model
from pymc.pytensorf import inputvars
from pymc.util import UNSET
from pytensor.graph import FunctionGraph
from scipy.special import log_softmax, logsumexp
from scipy.stats import halfnorm, norm

Expand All @@ -21,6 +23,7 @@
MarginalModel,
is_conditional_dependent,
marginalize,
replace_continuous_marginal_subgraph,
)
from pymc_experimental.tests.utils import equal_computations_up_to_root

Expand Down Expand Up @@ -803,3 +806,63 @@ def create_model(model_class):
marginal_m.compile_logp()(ip),
reference_m.compile_logp()(ip),
)


@pytest.mark.parametrize("univariate", (True, False), ids=["univariate", "multivariate"])
@pytest.mark.parametrize(
"multiple_dependent", (False, True), ids=["single-dependent", "multiple-dependent"]
)
def test_marginalize_normal_qmc(univariate, multiple_dependent):
with MarginalModel() as m:
SD = pm.HalfNormal("SD", default_transform=None)
if univariate:
X = pm.Normal("X", sigma=SD, shape=(3,))
else:
X = pm.MvNormal("X", mu=[0, 0, 0], cov=np.eye(3) * SD**2)

if multiple_dependent:
Y = [
pm.Normal("Y[0]", mu=(2 * X[0] + 1), sigma=1, observed=1),
pm.Normal("Y[1:]", mu=(2 * X[1:] + 1), sigma=1, observed=[2, 3]),
]
else:
Y = [pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])]

m.marginalize([X]) # ideally method="qmc"

logp_eval = np.hstack(m.compile_logp(vars=Y, sum=False)({"SD": 2.0}))

np.testing.assert_allclose(
logp_eval,
scipy.stats.norm.logpdf([1, 2, 3], 1, np.sqrt(17)),
rtol=1e-5,
)


def test_marginalize_non_trivial_mvnormal_qmc():
with MarginalModel() as m:
SD = pm.HalfNormal("SD", default_transform=None)
X = pm.MvNormal("X", cov=[[1.0, 0.5], [0.5, 1.0]] * SD**2)
Y = pm.MvNormal("Y", mu=2 * X + 1, cov=np.eye(2), observed=[1, 2])

m.marginalize([X])

[logp_eval] = m.compile_logp(vars=Y, sum=False)({"SD": 1})

np.testing.assert_allclose(
logp_eval,
scipy.stats.multivariate_normal.logpdf([1, 2], [1, 1], [[5, 2], [2, 5]]),
rtol=1e-5,
)


def test_marginalize_sample():
with pm.Model() as m:
SD = pm.HalfNormal("SD")
X = pm.Normal.dist(sigma=SD, name="X")
Y = pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])

fg = FunctionGraph(outputs=[SD, Y, X], clone=False)
old_rvs, new_rvs = replace_continuous_marginal_subgraph(fg, X, [Y, SD, X])
res1, res2 = pm.draw(new_rvs, draws=2)
assert not np.allclose(res1, res2)

0 comments on commit d8971c4

Please sign in to comment.