-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ZygoteAD bug with do_observe #235
Conversation
This fixes [this issue](TuringLang/Turing.jl#1595) where Zygote fails to compile `dot_observe` because of an exception clause.
I see that it might be needed to fix the Zygote issue but I wonder if there exists a Zygote compatible alternative that does not allocate. It's always sad to make a correct implementation worse because of bugs in some AD backend (or, more generally, some dependency). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to add a Turing-independent test of this bug. Probably it would be best to just check if one can compute the gradient of something like x -> sum(DynamicPPL.dot_observe(SampleFromPrior(), [Normal(), Normal()], x, VarInfo())
with the different AD backends.
return sum(zip(dists, value)) do (d, v) | ||
Distributions.loglikelihood(d, v) | ||
end | ||
return sum(Distributions.loglikelihood.(dists, value)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if Zygote is happy with
return sum(Distributions.loglikelihood.(dists, value)) | |
return mapreduce(Distributions.loglikelihood, +, dists, value) |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be even a bit simpler than the original expression 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@devmotion Unfortunately, no. Zygote doesn't seem to like that. I get
ERROR: Can't differentiate loopinfo expression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems to be the sum
bug that I mentioned in the original issue: FluxML/Zygote.jl#897 Just yesterday two PRs were opened that fix this problem, can you see if it works with these fixes? The relevant PR should be FluxML/Zygote.jl#956 but not completely sure, maybe it's the other one 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, it seems like both do not solve the problem probably because mapreduce
has it's own implementation that does not rely on sum
(pardon me if I'm wrong on this; see reduce.jl
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So to summarize, Zygote is a bit annoying here... Maybe we should just rewrite the primal for Zygote in https://github.com/TuringLang/DynamicPPL.jl/blob/master/src/compat/ad.jl (we can use ZygoteRules and don't have to depend on Zygote). This would not impact performance and allocations in regular executions and with other AD backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@devmotion This is quite down the rabbit hole. I'm really sorry to say this, but I currently don't have time to look deeper into this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries, I'll try to add a fix later today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opened a PR with the changes that I had in mind: #236
This PR fixes TuringLang/Turing.jl#1595. It is an alternative to #235 that does not require us to rewrite the primal less efficiently which would affect regular execution and other AD backends. Co-authored-by: David Widmann <[email protected]>
Fixed by #236. |
This fixes this issue where Zygote fails to compile
dot_observe
because of an exception clause.