Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support automatic imputation for multivariate and symbolic distributions #6797

Merged
merged 3 commits into from
Jun 30, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 27, 2023

Closes #5260
Related to #6626
Related to #5255
Related to #6645

This PR creates a new Op: PartialObservedRV that splits the sample space according to a boolean mask and allows separate variables/values for these. This enables automatic imputation for cases not supported before. For multivariate cases it is not always possible to "attribute" the logp to one variable or the other, so they are all associated with the observed variable. This means the values will show up exclusively in the model log_likelihood, even if some (or all) entries were imputed. Directly related to: #5255

There is some logic to avoid the use of the PartialObservedRV for pure Multivariate variables when it's safe to do so (i.e., there is no mixed indexing across the support dims). This has the benefit that automatic transforms will still apply. It also avoids the logp issue mentioned above. This behavior is the same that existed until now for univariate pure RandomVariables, which keep behaving as before.

The new PartialObservedRV allows for symbolic (constant or mutable) mask, but this is not accessible to users using the current API based on MaskedArray or nan entries. Similarly, one can't use ConstantData/MutableData because those try to convert such arrays to MaskedArray which aren't supported in PyTensor: pymc-devs/pytensor#259. I added an early error for that.

An alternative would be to provide a separate API for users, something like pm.Imputed(name, dist, obs_mask, obs_data) where both data and the mask could be tensors wrapped in Data.


📚 Documentation preview 📚: https://pymc--6797.org.readthedocs.build/en/6797/

@@ -423,6 +423,11 @@ def Data(
# `convert_observed_data` takes care of parameter `value` and
# transforms it to something digestible for PyTensor.
arr = convert_observed_data(value)
if isinstance(arr, np.ma.MaskedArray):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have failed in the call to as_tensor_variable or shared anyway

@codecov
Copy link

codecov bot commented Jun 27, 2023

Codecov Report

Merging #6797 (4f1f11c) into main (7b08fc1) will increase coverage by 0.09%.
The diff coverage is 98.48%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6797      +/-   ##
==========================================
+ Coverage   91.92%   92.02%   +0.09%     
==========================================
  Files          95       95              
  Lines       16197    16261      +64     
==========================================
+ Hits        14889    14964      +75     
+ Misses       1308     1297      -11     
Impacted Files Coverage Δ
pymc/model.py 90.94% <80.00%> (+1.02%) ⬆️
pymc/data.py 89.47% <100.00%> (+0.12%) ⬆️
pymc/distributions/distribution.py 97.06% <100.00%> (+0.48%) ⬆️

... and 9 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the partially_observed_rv branch 3 times, most recently from 41c5681 to 48458a1 Compare June 27, 2023 13:03
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool stuff!!

pymc/distributions/distribution.py Outdated Show resolved Hide resolved
pymc/model.py Outdated

(missing_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
# Register FreeRV corresponding to unobserved components
self.register_rv(unobserved_rv, f"{name}_missing", transform=transform)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.register_rv(unobserved_rv, f"{name}_missing", transform=transform)
self.register_rv(unobserved_rv, f"{name}_unobserved", transform=transform)

this would match the _observed for the other and maybe avoid confusion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could count as a breaking change, that's why I didn't change it. Worth the trouble?

Copy link
Member Author

@ricardoV94 ricardoV94 Jun 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think unobserved is a better name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it's okay to break. Imputation isn't that popular, and the breaking change isn't dangerously silent or sth like that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It changes the variable name, that's all

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

tests/backends/test_arviz.py Outdated Show resolved Hide resolved
tests/distributions/test_distribution.py Outdated Show resolved Hide resolved
Comment on lines +850 to +848
# Test that we can update a shared mask
mask.set_value(np.array([False]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯

@ricardoV94 ricardoV94 force-pushed the partially_observed_rv branch from 48458a1 to e1fd175 Compare June 29, 2023 13:13
@ricardoV94 ricardoV94 added the major Include in major changes release notes section label Jun 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements major Include in major changes release notes section request discussion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Partially observed multivariate distributions not implemented in V4
2 participants