Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ZygoteAD bug with do_observe #235

Closed
wants to merge 1 commit into from

Conversation

Red-Portal
Copy link
Member

This fixes this issue where Zygote fails to compile dot_observe because of an exception clause.

This fixes [this issue](TuringLang/Turing.jl#1595) where Zygote fails to compile `dot_observe` because of an exception clause.
@devmotion
Copy link
Member

I see that it might be needed to fix the Zygote issue but I wonder if there exists a Zygote compatible alternative that does not allocate. It's always sad to make a correct implementation worse because of bugs in some AD backend (or, more generally, some dependency).

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add a Turing-independent test of this bug. Probably it would be best to just check if one can compute the gradient of something like x -> sum(DynamicPPL.dot_observe(SampleFromPrior(), [Normal(), Normal()], x, VarInfo()) with the different AD backends.

return sum(zip(dists, value)) do (d, v)
Distributions.loglikelihood(d, v)
end
return sum(Distributions.loglikelihood.(dists, value))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if Zygote is happy with

Suggested change
return sum(Distributions.loglikelihood.(dists, value))
return mapreduce(Distributions.loglikelihood, +, dists, value)

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be even a bit simpler than the original expression 🙂

Copy link
Member Author

@Red-Portal Red-Portal Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion Unfortunately, no. Zygote doesn't seem to like that. I get

ERROR: Can't differentiate loopinfo expression

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to be the sum bug that I mentioned in the original issue: FluxML/Zygote.jl#897 Just yesterday two PRs were opened that fix this problem, can you see if it works with these fixes? The relevant PR should be FluxML/Zygote.jl#956 but not completely sure, maybe it's the other one 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, it seems like both do not solve the problem probably because mapreduce has it's own implementation that does not rely on sum (pardon me if I'm wrong on this; see reduce.jl).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to summarize, Zygote is a bit annoying here... Maybe we should just rewrite the primal for Zygote in https://github.com/TuringLang/DynamicPPL.jl/blob/master/src/compat/ad.jl (we can use ZygoteRules and don't have to depend on Zygote). This would not impact performance and allocations in regular executions and with other AD backends.

Copy link
Member Author

@Red-Portal Red-Portal Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion This is quite down the rabbit hole. I'm really sorry to say this, but I currently don't have time to look deeper into this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, I'll try to add a fix later today.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened a PR with the changes that I had in mind: #236

bors bot pushed a commit that referenced this pull request Apr 27, 2021
This PR fixes TuringLang/Turing.jl#1595.

It is an alternative to #235 that does not require us to rewrite the primal less efficiently which would affect regular execution and other AD backends.

Co-authored-by: David Widmann <[email protected]>
@devmotion
Copy link
Member

Fixed by #236.

@devmotion devmotion closed this Apr 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ZygoteAD dot operator fails again
2 participants