Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chain rules for certain functions does not respect numerical precision #307

Closed
torfjelde opened this issue Apr 21, 2021 · 0 comments · Fixed by #348
Closed

Chain rules for certain functions does not respect numerical precision #307

torfjelde opened this issue Apr 21, 2021 · 0 comments · Fixed by #348

Comments

@torfjelde
Copy link

torfjelde commented Apr 21, 2021

Due to the usage of irrational numbers, some of the functions have adjoints which will mistakenly promote the numerical precision of the derivative/gradient. In particular this occurs because certain impls will first call act on the irrational number which often by default ends up converting the irrational number to Float64. E.g. for erfc we will first call sqrt(π) which results in Float64, and instead of promoting Irrational to what we expected the output-type to be, we end up promoting the output-type to Float64 (if we're using floats with lower precision):

julia> using SpecialFunctions, ChainRulesCore

julia> y, ȳ = ChainRulesCore.frule((ChainRulesCore.NO_FIELDS, 1f0), SpecialFunctions.erfc, 1f0)
(0.1572992f0, -0.41510750774498784)

julia> typeof(y), typeof(ȳ)
(Float32, Float64)

This is essentially the same issue as in DiffRules (JuliaDiff/DiffRules.jl#55).

Anyone got a better idea on what to do here, or should I just make a similar PR to SpecialFunctions.jl?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant