Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PMF of Bernoullis #77

Open
FHoltorf opened this issue Apr 3, 2023 · 1 comment
Open

PMF of Bernoullis #77

FHoltorf opened this issue Apr 3, 2023 · 1 comment

Comments

@FHoltorf
Copy link

FHoltorf commented Apr 3, 2023

Hey!

I am trying to do VI with models involving discrete RVs. For that purpose it would be quite handy to get derivative estimators for PMFs of Bernoulli (and other discrete) RVs. Considering the following toy example

using StochasticAD, Distributions
function func(p) 
    x = rand(Bernoulli(p))
    pdf(Bernoulli(p), x)
end
function func_alt(p) 
    x = rand(Bernoulli(p))
    p^x*(1-p)^(1-x)
end

it seems that I can only get derivative estimators for func_alt. When trying to propagate derivative info through func it appears to fail because of the way pdf(::Bernoulli, ::Bool/Real) is implemented.

Now my questions:

  1. Am I using StochasticAD incorrectly?
  2. If not, would it be easy to accommodate propagation of stochastic triples through pdf(::Bernoulli, ::Bool/Real) (and perhaps the equivalent for other distributions with discrete support where I assume similar problems would arise)?

Thanks!
Flemming

@gaurav-arya
Copy link
Owner

gaurav-arya commented Apr 3, 2023

The issue is that stochastic triples unfortunately cannot propagate through the ternary operator in Distributions.jl's implementation of the Bernoulli PMF. You could fix this by overloading Distributions.pdf to catch stochastic triple inputs and feed these into the experimental (undocumented) StochasticAD.propagate interface:

using StochasticAD, Distributions

# Register an overload of the pdf
using Functors; @functor Bernoulli
Distributions.pdf(d::Bernoulli, x::StochasticAD.StochasticTriple) = StochasticAD.propagate(pdf, d, x; keep_deltas = Val{true}())

derivative_estimate(func, 0.7)

This should work on the about-to-be-released 0.1.14. (Performance may not be ideal; StochasticAD.propagate is still experimental functionality.)

Let me know if you have any questions! In any case, let's leave this issue open until this works out of the box.

(Edit: added keep_deltas = Val{true}() to the propagate call; the previous version was not correct.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants