Skip to content

Commit

Permalink
Extend support for automatic imputation
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 27, 2023
1 parent 7b08fc1 commit 48458a1
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 96 deletions.
5 changes: 5 additions & 0 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,11 @@ def Data(
# `convert_observed_data` takes care of parameter `value` and
# transforms it to something digestible for PyTensor.
arr = convert_observed_data(value)
if isinstance(arr, np.ma.MaskedArray):
raise NotImplementedError(
"Masked arrays or arrays with `nan` entries are not supported. "
"Pass them directly to `observed` if you want to trigger auto-imputation"
)

if mutable is None:
warnings.warn(
Expand Down
146 changes: 145 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@

from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import node_rewriter
from pytensor.graph import FunctionGraph, node_rewriter
from pytensor.graph.basic import Node, Variable
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import MetaType
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.var import TensorVariable
from typing_extensions import TypeAlias
Expand All @@ -49,6 +50,7 @@
)
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import BlockModelAccess
from pymc.printing import str_for_dist
Expand Down Expand Up @@ -1148,3 +1150,145 @@ def logcdf(value, c):
-np.inf,
0,
)


class PartialObservedRV(SymbolicRandomVariable):
"""RandomVariable with partially observed subspace, as indicated by a boolean mask.
See `create_partial_observed_rv` for more details.
"""


def create_partial_observed_rv(
rv: TensorVariable,
mask: Union[np.ndarray, TensorVariable],
) -> Tuple[
Tuple[TensorVariable, TensorVariable], Tuple[TensorVariable, TensorVariable], TensorVariable
]:
"""Separate observed and unobserved components of a RandomVariable.
This function may return two independent RandomVariables or, if not possible,
two variables from a common `PartialObservedRV` node
Parameters
----------
rv : TensorVariable
mask : tensor_like
Constant or variable boolean mask. True entries correspond to components of the variable that are not observed.
Returns
-------
observed_rv and mask : Tuple of TensorVariable
The observed component of the RV and respective indexing mask
unobserved_rv and mask: Tuple of TensorVariable
The unobserved component of the RV and respective indexing mask
joined_rv:
The symbolic join of the observed and unobserved components.
"""
if not mask.dtype == "bool":
raise ValueError(
f"mask must be an array or tensor of boolean dtype, got dtype: {mask.dtype}"
)

if mask.ndim > rv.ndim:
raise ValueError(f"mast can't have more dims than rv, got ndim: {mask.ndim}")

antimask = ~mask

can_rewrite = False
# Only pure RVs can be rewritten
if isinstance(rv.owner.op, RandomVariable):
ndim_supp = rv.owner.op.ndim_supp

# All univariate RVs can be rewritten
if ndim_supp == 0:
can_rewrite = True

# Multivariate RVs can be rewritten if masking does not split within support dimensions
else:
batch_dims = rv.type.ndim - ndim_supp
constant_mask = getattr(as_tensor_variable(mask), "data", None)

# Indexing does not overlap with core dimensions
if mask.ndim <= batch_dims:
can_rewrite = True

# Try to handle special case where mask is constant across support dimensions,
# TODO: This could be done by the rewrite itself
elif constant_mask is not None:
# We check if a constant_mask that only keeps the first entry of each support dim
# is equivalent to the original one after re-expanding.
trimmed_mask = constant_mask[(...,) + (0,) * ndim_supp]
expanded_mask = np.broadcast_to(
np.expand_dims(trimmed_mask, axis=tuple(range(-ndim_supp, 0))),
shape=constant_mask.shape,
)
if np.array_equal(constant_mask, expanded_mask):
mask = trimmed_mask
antimask = ~trimmed_mask
can_rewrite = True

if can_rewrite:
# Rewrite doesn't work with boolean masks. Should be fixed after https://github.com/pymc-devs/pytensor/pull/329
mask, antimask = mask.nonzero(), antimask.nonzero()

masked_rv = rv[mask]
fgraph = FunctionGraph(outputs=[masked_rv], clone=False)
[unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)

antimasked_rv = rv[antimask]
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False)
[observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)

# Make a clone of the observedRV, with a distinct rng so that observed and
# unobserved are never treated as equivalent (and mergeable) nodes by pytensor.
_, size, _, *inps = observed_rv.owner.inputs
observed_rv = observed_rv.owner.op(*inps, size=size)

# For all other cases use the more general PartialObservedRV
else:
# The symbolic graph simply splits the observed and unobserved components,
# so they can be given separate values.
dist_, mask_ = rv.type(), as_tensor_variable(mask).type()
observed_rv_, unobserved_rv_ = dist_[~mask_], dist_[mask_]

observed_rv, unobserved_rv = PartialObservedRV(
inputs=[dist_, mask_],
outputs=[observed_rv_, unobserved_rv_],
ndim_supp=rv.owner.op.ndim_supp,
)(rv, mask)

joined_rv = pt.empty(rv.shape, dtype=rv.type.dtype)
joined_rv = pt.set_subtensor(joined_rv[mask], unobserved_rv)
joined_rv = pt.set_subtensor(joined_rv[antimask], observed_rv)

return (observed_rv, antimask), (unobserved_rv, mask), joined_rv


@_logprob.register(PartialObservedRV)
def partial_observed_rv_logprob(op, values, dist, mask, **kwargs):
# For the logp, simply join the values
[obs_value, unobs_value] = values
antimask = ~mask
joined_value = pt.empty_like(dist)
joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
joined_logp = logp(dist, joined_value)

# If we have a univariate RV we can split apart the logp terms
if op.ndim_supp == 0:
return joined_logp[antimask], joined_logp[mask]
# Otherwise, we can't (always/ easily) split apart logp terms.
# We return the full logp for the observed value, and a 0-nd array for the unobserved value
else:
return joined_logp.ravel(), pt.zeros((0,), dtype=joined_logp.type.dtype)


@_moment.register(PartialObservedRV)
def partial_observed_rv_moment(op, partial_obs_rv, rv, mask):
# Unobserved output
if partial_obs_rv.owner.outputs.index(partial_obs_rv) == 1:
return moment(rv)[mask]
# Observed output
else:
return moment(rv)[~mask]
74 changes: 16 additions & 58 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@
from pytensor.compile import DeepCopyOp, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.sharedvar import ScalarSharedVariable
from pytensor.tensor.var import TensorConstant, TensorVariable
Expand Down Expand Up @@ -1409,67 +1407,27 @@ def make_obs_var(
if total_size is not None:
raise ValueError("total_size is not compatible with imputed variables")

if not isinstance(rv_var.owner.op, RandomVariable):
raise NotImplementedError(
"Automatic inputation is only supported for univariate RandomVariables."
f" {rv_var} of type {type(rv_var.owner.op)} is not supported."
)

if rv_var.owner.op.ndim_supp > 0:
raise NotImplementedError(
f"Automatic inputation is only supported for univariate "
f"RandomVariables, but {rv_var} is multivariate"
)
from pymc.distributions.distribution import create_partial_observed_rv

# We can get a random variable comprised of only the unobserved
# entries by lifting the indices through the `RandomVariable` `Op`.
(
(observed_rv, observed_mask),
(unobserved_rv, _),
joined_rv,
) = create_partial_observed_rv(rv_var, mask)
observed_data = pt.as_tensor(data.data[observed_mask])

masked_rv_var = rv_var[mask.nonzero()]

fgraph = FunctionGraph(
[i for i in graph_inputs((masked_rv_var,)) if not isinstance(i, Constant)],
[masked_rv_var],
clone=False,
)
# Register ObservedRV corresponding to observed component
observed_rv.name = f"{name}_observed"
self.create_value_var(observed_rv, transform=None, value_var=observed_data)
self.add_named_variable(observed_rv)
self.observed_RVs.append(observed_rv)

(missing_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
# Register FreeRV corresponding to unobserved components
self.register_rv(unobserved_rv, f"{name}_missing", transform=transform)

self.register_rv(missing_rv_var, f"{name}_missing", transform=transform)

# Now, we lift the non-missing observed values and produce a new
# `rv_var` that contains only those.
#
# The end result is two disjoint distributions: one for the missing
# values, and another for the non-missing values.

antimask_idx = (~mask).nonzero()
nonmissing_data = pt.as_tensor_variable(data[antimask_idx].data)
unmasked_rv_var = rv_var[antimask_idx]
unmasked_rv_var = unmasked_rv_var.owner.clone().default_output()

fgraph = FunctionGraph(
[i for i in graph_inputs((unmasked_rv_var,)) if not isinstance(i, Constant)],
[unmasked_rv_var],
clone=False,
)
(observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
# Make a clone of the RV, but let it create a new rng so that observed and
# missing are not treated as equivalent nodes by pytensor. This would happen
# if the size of the masked and unmasked array happened to coincide
_, size, _, *inps = observed_rv_var.owner.inputs
observed_rv_var = observed_rv_var.owner.op(*inps, size=size, name=f"{name}_observed")
observed_rv_var.tag.observations = nonmissing_data

self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
self.add_named_variable(observed_rv_var)
self.observed_RVs.append(observed_rv_var)

# Create deterministic that combines observed and missing
# Register Deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
rv_var = pt.empty(data.shape, dtype=observed_rv_var.type.dtype)
rv_var = pt.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
rv_var = pt.set_subtensor(rv_var[antimask_idx], observed_rv_var)
rv_var = Deterministic(name, rv_var, self, dims)
rv_var = Deterministic(name, joined_rv, self, dims)

else:
if sps.issparse(data):
Expand Down
17 changes: 11 additions & 6 deletions tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ def test_missing_data_model(self):
# See https://github.com/pymc-devs/pymc/issues/5255
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)

@pytest.mark.xfail(reason="Multivariate partial observed RVs not implemented for V4")
def test_mv_missing_data_model(self):
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)

Expand All @@ -361,19 +360,25 @@ def test_mv_missing_data_model(self):
mu = pm.Normal("mu", 0, 1, size=2)
sd_dist = pm.HalfNormal.dist(1.0, size=2)
# pylint: disable=unpacking-non-sequence
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist)
# pylint: enable=unpacking-non-sequence
with pytest.warns(ImputationWarning):
y = pm.MvNormal("y", mu=mu, chol=chol, observed=data)
inference_data = pm.sample(100, chains=2, return_inferencedata=True)
inference_data = pm.sample(
tune=100,
draws=100,
chains=2,
step=pm.Metropolis(),
idata_kwargs=dict(log_likelihood=True),
)

# make sure that data is really missing
assert isinstance(y.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1))
assert isinstance(y.owner.inputs[0].owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1))

test_dict = {
"posterior": ["mu", "chol_cov"],
"observed_data": ["y"],
"log_likelihood": ["y"],
"observed_data": ["y_observed"],
"log_likelihood": ["y_observed"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
Expand Down
Loading

0 comments on commit 48458a1

Please sign in to comment.