-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Partially observed multivariate distributions not implemented in V4 #5260
Comments
Here's a small example illustrating how AePPL's support for import aesara.tensor as at
import numpy as np
import scipy.stats
from aeppl.joint_logprob import joint_logprob
data = np.array([[-11, 0, np.nan], [np.nan, 1, 5]])
missing_idx = np.isnan(data)
size = data.shape[0]
mu = np.r_[-10, 0, 10]
sigma = np.eye(3)
Y_base_rv = at.random.multivariate_normal(mu, sigma, size=size, name="Y")
missing_vals_var = at.vector("y_missing")
Y_rv = at.set_subtensor(Y_base_rv[missing_idx], missing_vals_var)
Y_rv_logp = joint_logprob({Y_rv: at.as_tensor(data)}, sum=False)
missing_vals = np.r_[-2, 2]
Y_rv_logp.eval({missing_vals_var: missing_vals})
# array([-75.2568156, -87.7568156])
full_obs = data.copy()
full_obs[missing_idx] = missing_vals
np.fromiter(
(scipy.stats.multivariate_normal(mu, sigma).logpdf(obs) for obs in full_obs),
dtype=np.float64,
)
# array([-75.2568156, -87.7568156]) |
That seems perfect for sub-partial observed values. On the other hand it is nice to have transforms working for imputed variables like we do now in V4. Perhaps we can do a forward transform pass on the observed data, so that all the input values can be on the transformed scale? Not sure if this would work already on the aeppl side but it would be the best of both worlds. The idea, if defensible, would be to allow: with pm.Model() as m:
x = pm.Dirichlet("x", a=np.ones(3), observed=[0.1, np.nan, np.nan]) While keeping the sampler happy by proposing the missing value(s) on the Simplex space |
Here is a snippet for how to do it manually until we provide a proper API: https://discourse.pymc.io/t/automatic-imputation-of-multivariate-models/11029/3?u=ricardov94 |
I added an explicit
NotImplementedError
and future test in #5245 - 022cefaPartial observed variables now rely on Aesara's
local_subtensor_rv_lift
, to separate the observed and unobserved portions of a RV into two separate RVs.https://github.com/aesara-devs/aesara/blob/4e1077210721deb007b4f556f08702cec5a74dfb/aesara/tensor/random/opt.py#L259
This rewrite needs to be updated to work for Multivariate distributions, there are a couple of comments in its code suggesting how to start approaching such task. If it does not already exist, we should open an issue in Aesara as those code changes would go there.
Importantly, even after this rewrite, we won't have the same flexibility we had in V3, where we might be missing some portions of a the base multidimensional value. Let's call these "sub-values", such as in
pm.MvNormal("x", np.ones(2), np.eye(2), observed=np.array([[1, 1], [1, np.nan]])
, because the missing portion of the distribution does not correspond to aMvNormal
distribution. In contrast,pm.MvNormal("x", np.ones(2), np.eye(2), observed=np.array([[1, np.nan], [1, np.nan]])
should work just fine, as we can split that into twoMvNormal
s.On the bright side, we would retain the ability to transform imputed values which was not the case with V3, and which would not be trivial if we were to somehow allow for "sub-value" imputation, because we don't know what are the constraints on those "sub-values".
One different issue, concerns partial observations for aeppl "derived" distributions such as in #5169. These won't work out of the box, because
local_subtensor_rv_lift
is meant to work only with pureRandomVariables
. In theory we could make such rewrite (or a copycat one) work with these "derived distributions", as long as we can provide the value and parameters support/mapping information that the rewrite requires, as well as the ability to re-create a node on the spot. In #5169 I had to implement most of these features anyway, but kept them accessible only via thecls
object. These can easily be made more accessible via class properties or dispatching.CC @brandonwillard
The text was updated successfully, but these errors were encountered: