Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix division return type #233

Merged
merged 6 commits into from
Jul 10, 2020
Merged

Conversation

CarloLucibello
Copy link
Contributor

Fix FluxML/Zygote.jl#727.

Still have to add tests. Where are rrule_test, frule_test and scalar_test defined? Do they check types?

@nickrobinson251
Copy link
Contributor

Where are rrule_test, frule_test and scalar_test defined?

They are defined in ChainRulesTestUtils.jl.

(This is a separate package to keep dependencies separate)

Do they check types?

No, I think they only test via comparing to Finite Differencing

@nickrobinson251 nickrobinson251 requested a review from sethaxen July 8, 2020 09:06
@nickrobinson251 nickrobinson251 added the bug Something isn't working label Jul 8, 2020
@willtebbutt
Copy link
Member

Yeah, and FiniteDifferencing is still quite bad at preserving the types correctly, which is a bit of a shame.

@CarloLucibello in terms of testing this rule, the correctness tests should be fine as is, and it would be sufficient for you to add some basic regression tests on the type of the differential. Could you also please open an issue on ChainRulesTestUtils about this? It seems like the kind of thing that we should be testing for all of the time, and is probably part of a more general "testing that you've produced a valid differential type" in the rules.

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a minor comment and we also need the test. Thanks!
Oh, and go ahead and bump the version number when you're done.

x, Δx, x̄ = 10rand(T, 3)
y, Δy, ȳ = rand(T, 3)
Δz = randn(typeof(f(x, y)))

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
if T != Float32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to exclude this case because test tolerance is too strict for Float32

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you decrease the tolerance using the keyword arguments of the testers?
e.g. something like frule_test(f, (x, Δx), (y, Δy); atol=0, rtol = sqrt(eps(T)))

@CarloLucibello
Copy link
Contributor Author

I can't find a rule for \ (which may need a similar change, see JuliaDiff/DiffRules.jl#46)

@sethaxen
Copy link
Member

sethaxen commented Jul 9, 2020

I can't find a rule for \ (which may need a similar change, see JuliaDiff/DiffRules.jl#46)

Here it is: https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/base.jl#L141

) where {T<:Union{Real,Complex}}
x::T1,
y::T2,
) where {T1<:Union{Real,Complex}, T2<:Union{Real,Complex}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this signature just changed so that the tests in the loop pass? If so, maybe it's better to just take that test out of the loop. The signature we had here is the one that is FastMath compatible. There is a hypot that takes mixed arguments, but it promotes before calling this one, so currently we would just let the AD handle that. Same with the rrule.

x, Δx, x̄ = 10rand(T, 3)
y, Δy, ȳ = rand(T, 3)
Δz = randn(typeof(f(x, y)))

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
if T != Float32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you decrease the tolerance using the keyword arguments of the testers?
e.g. something like frule_test(f, (x, Δx), (y, Δy); atol=0, rtol = sqrt(eps(T)))

@CarloLucibello
Copy link
Contributor Author

zygote failure is unrelated

@CarloLucibello
Copy link
Contributor Author

again Zygote failure unrelated. This should be ready to go

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@CarloLucibello
Copy link
Contributor Author

merge and tag?

@willtebbutt
Copy link
Member

Please. Thanks for the contribution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

wrong gradient type when dividing by integer
6 participants