You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to do VI with models involving discrete RVs. For that purpose it would be quite handy to get derivative estimators for PMFs of Bernoulli (and other discrete) RVs. Considering the following toy example
using StochasticAD, Distributions
function func(p)
x = rand(Bernoulli(p))
pdf(Bernoulli(p), x)
end
function func_alt(p)
x = rand(Bernoulli(p))
p^x*(1-p)^(1-x)
end
it seems that I can only get derivative estimators for func_alt. When trying to propagate derivative info through func it appears to fail because of the way pdf(::Bernoulli, ::Bool/Real) is implemented.
Now my questions:
Am I using StochasticAD incorrectly?
If not, would it be easy to accommodate propagation of stochastic triples through pdf(::Bernoulli, ::Bool/Real) (and perhaps the equivalent for other distributions with discrete support where I assume similar problems would arise)?
Thanks!
Flemming
The text was updated successfully, but these errors were encountered:
The issue is that stochastic triples unfortunately cannot propagate through the ternary operator in Distributions.jl's implementation of the Bernoulli PMF. You could fix this by overloading Distributions.pdf to catch stochastic triple inputs and feed these into the experimental (undocumented) StochasticAD.propagate interface:
using StochasticAD, Distributions
# Register an overload of the pdfusing Functors; @functor Bernoulli
Distributions.pdf(d::Bernoulli, x::StochasticAD.StochasticTriple) = StochasticAD.propagate(pdf, d, x; keep_deltas =Val{true}())
derivative_estimate(func, 0.7)
This should work on the about-to-be-released 0.1.14. (Performance may not be ideal; StochasticAD.propagate is still experimental functionality.)
Let me know if you have any questions! In any case, let's leave this issue open until this works out of the box.
(Edit: added keep_deltas = Val{true}() to the propagate call; the previous version was not correct.)
Hey!
I am trying to do VI with models involving discrete RVs. For that purpose it would be quite handy to get derivative estimators for PMFs of Bernoulli (and other discrete) RVs. Considering the following toy example
it seems that I can only get derivative estimators for
func_alt
. When trying to propagate derivative info throughfunc
it appears to fail because of the waypdf(::Bernoulli, ::Bool/Real)
is implemented.Now my questions:
pdf(::Bernoulli, ::Bool/Real)
(and perhaps the equivalent for other distributions with discrete support where I assume similar problems would arise)?Thanks!
Flemming
The text was updated successfully, but these errors were encountered: