-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix perturbation confusion #247
Conversation
👎 I think this may be too common of a use case. Is there an easy way to opt out for parallel usage? |
Well, it's not a common use case at the moment, because it won't work! There are a couple of options:
|
Oh wait, it only is when you parallel take the gradient inside of a function you're taking a gradient of? I read it as though all uses inside of |
It occurs when |
It's pretty crazy/cool that this works. When I tried implementing the global counter approach back in the days of #83, I got hung up on the insane metaprogramming/compile time regressions induced by the static tag selection logic; I never thought to use ForwardDiff's existing conversion methods to handle this! Good idea. We can merge this PR if you make the following tweaks:
# You'll have to move the Tag definition into dual.jl and adjust
# the TagMismatchError code to accommodate these methods.
@inline value(::Tag, x) = value(x)
@inline value(t::Tag, d::Dual) = throw(TagMismatchError(t, d))
@inline value(::T, d::Dual{T}) where {T<:Tag} = value(d)
@inline partials(::Tag, x, i...) = partials(x, i...)
@inline partials(t::Tag, d::Dual, i...) = throw(TagMismatchError(t, d))
@inline partials(::T, d::Dual{T}, i...) where {T<:Tag} = partials(d, i...) ...and then use those methods in the API implementation, instead of the unchecked versions. |
Sorry it's taken me a while to get around to this. Out of curiosity, why the |
Because otherwise, we'd introduce a large number of method ambiguities. Traditionally, this multiple-dispatch-centric problem is solved by promoting to single dispatch, but this promotion is too costly for the general case (which is one of the motivations behind Cassette, where the problem is solved via contextual dispatch). |
2261d04
to
5249226
Compare
This is an initial stab at fixing perturbation confusion. It relies on a global iterator which is incremented for each (function, signature) pair: since higher-order derivatives require earlier derivatives to be defined first, these should appear on the "outside" of any nested `Dual` objects. As I note in the comments, this could cause problems when using multiple processes, e.g. ``` ForwardDiff.gradient(x) do x @parallel for i = 1:n ... ForwardDiff.gradient(y) do y # something involving x end end end ```
Tags are process-local (will throw an error if you attempt to mix tags from different processes). Also changed how tags are generated (function is no longer part of the signature).
5249226
to
8f76a69
Compare
test/DualTest.jl
Outdated
# @test Dual{1}(FDNUM) / FDNUM2 === Dual{1}(FDNUM / FDNUM2) | ||
# @test FDNUM / Dual{1}(FDNUM2) === Dual{1}(FDNUM / FDNUM2) | ||
# @test Dual{1}(FDNUM / PRIMAL, FDNUM2 / PRIMAL) === Dual{1}(FDNUM, FDNUM2) / PRIMAL | ||
# end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to disable these tests since I got rid of some of the binary methods (the dispatch is now handled by the promotion machinery).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those tests are pretty important to keep working - they're one of the few places where we stress nested semantics directly. Assuming this PR is correct, it should be possible to modify these tests to pass without violating their original intent.
Okay, hopefully this should now work. I removed the function type from the |
Any suggestions about the appveyor failure? |
We still actually need the function type in the tag. This is to prevent people from accidentally reusing tagged |
Okay, this has been a bit of a reworking. Now the tag is parametrised by (function, eltype), and a generated function gives each tag a unique sequence number which is used for comparison. This has the side-benefit that it should now work across processes (since each process should safely generate its own sequence). Thoughts? |
We just ran into another perturbation confusion case (one part of JuliaRobotics/RigidBodyDynamics.jl#347) which is fixed by this PR (thanks!). I submitted a PR against sb/confused that adds a test case distilled from the RigidBodyDynamics issue: simonbyrne#1. |
I'm not sure what the bug actually is - I only skimmed the RigidBodyDynamics issue - but that test you filed doesn't involve any perturbation confusion AFAICT. The error message ForwardDiff is giving giving shows that it's an API problem; a config was "incorrectly" constructed/applied:
ForwardDiff should've worked on that case without this PR. In fact, it does work if you get rid of the I'm going to file a separate issue for figuring out what that bug actually is. Thanks for the test! |
@simonbyrne Is this ready to merge? It LGTM - awesome work! EDIT: Ah, I see, there's still the issue of how custom/opt-out tags promote... |
Yeah, basically if you use opt out tags you have to specify your own promotion. Not sure if there is much we can do about that. |
Fair enough; the easy way to see whether or not this breaks downstream code will be to merge it 😛 |
This is an initial stab at fixing perturbation confusion. It relies on a global iterator which is incremented for each (function, signature) pair: since higher-order derivatives require earlier derivatives to be defined first, these should appear on the "outside" of any nested
Dual
objects.As I note in the comments, this could cause problems when using multiple processes, e.g.