Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite for Mixture when comp_dists can be "fused" #6803

Open
Tracked by #7053
ricardoV94 opened this issue Jun 29, 2023 · 0 comments
Open
Tracked by #7053

Add rewrite for Mixture when comp_dists can be "fused" #6803

ricardoV94 opened this issue Jun 29, 2023 · 0 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 29, 2023

Description

The following distributions are equivalent:

import pymc as pm

pm.Mixture.dist(w=[0.5, 0.5], comp_dists=[pm.Normal.dist(-1), pm.Normal.dist(1)])
pm.Mixture.dist(w=[0.5, 0.5], comp_dists=pm.Normal.dist([-1, 1]))

But the second one is more efficient, because the logp is vectorized among a single batched Normal.

We could add a rewrite in the logprob_rewrites to convert the former to the second, so that users are not penalized from using the first form (which may be more intuitive for some).

Actually that sort of rewrite stack([rv1, rv2]) -> rv3 could be useful in many places in the logprob submodule

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant