-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add derivatives for besselix
, besseljx
, and besselyx
#350
Changes from 2 commits
89e2dc9
d1076fc
c26d35e
0af9568
30a020f
a9f9f52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -193,3 +193,66 @@ ChainRulesCore.@scalar_rule( | |||||
ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x) | ||||||
ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x)) | ||||||
ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x) | ||||||
|
||||||
# non-holomorphic functions | ||||||
function ChainRulesCore.frule((_, _, _), ::typeof(besselix), ν::Number, x::Number) | ||||||
Ω = besselix(ν, x) | ||||||
ΔΩ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||||||
return Ω, ΔΩ | ||||||
end | ||||||
function ChainRulesCore.rrule(::typeof(besselix), ν::Number, x::Number) | ||||||
Ω = besselix(ν, x) | ||||||
project_x = ChainRulesCore.ProjectTo(x) | ||||||
function besselix_pullback(ΔΩ) | ||||||
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||||||
a = (besselix(ν - 1, x) + besselix(ν + 1, x)) / 2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One can use the recurrence relations to write this as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just reused the derivatives that are used for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be best to use the same relations for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right, and they're the same in DiffRules as well. If you like, you can keep them similar to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree! |
||||||
b = - sign(real(x)) * Ω | ||||||
x̄ = project_x(conj(a) * ΔΩ + real(conj(b) * ΔΩ)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's more efficient to compute The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have never measured |
||||||
return ChainRulesCore.NoTangent(), ν̄, x̄ | ||||||
end | ||||||
return Ω, besselix_pullback | ||||||
end | ||||||
|
||||||
function ChainRulesCore.frule((_, _, _), ::typeof(besseljx), ν::Number, x::Number) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since everything should be identical for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could but so far the style in this file is to implement everything explicitly (eg. also derivatives There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. |
||||||
Ω = besseljx(ν, x) | ||||||
ΔΩ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||||||
return Ω, ΔΩ | ||||||
end | ||||||
function ChainRulesCore.rrule(::typeof(besseljx), ν::Number, x::Number) | ||||||
Ω = besseljx(ν, x) | ||||||
project_x = ChainRulesCore.ProjectTo(x) | ||||||
function besseljx_pullback(ΔΩ) | ||||||
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||||||
a = (besseljx(ν - 1, x) - besseljx(ν + 1, x)) / 2 | ||||||
x̄ = if x isa Real | ||||||
project_x(conj(a) * ΔΩ) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Suggested change
|
||||||
else | ||||||
b = -sign(imag(x)) * Ω | ||||||
project_x(conj(a) * ΔΩ + real(conj(b) * ΔΩ) * im) | ||||||
end | ||||||
return ChainRulesCore.NoTangent(), ν̄, x̄ | ||||||
end | ||||||
return Ω, besseljx_pullback | ||||||
end | ||||||
|
||||||
function ChainRulesCore.frule((_, _, _), ::typeof(besselyx), ν::Number, x::Number) | ||||||
Ω = besselyx(ν, x) | ||||||
ΔΩ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||||||
return Ω, ΔΩ | ||||||
end | ||||||
function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number) | ||||||
Ω = besselyx(ν, x) | ||||||
project_x = ChainRulesCore.ProjectTo(x) | ||||||
function besselyx_pullback(ΔΩ) | ||||||
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||||||
a = (besselyx(ν - 1, x) - besselyx(ν + 1, x)) / 2 | ||||||
x̄ = if x isa Real | ||||||
project_x(conj(a) * ΔΩ) | ||||||
else | ||||||
b = -sign(imag(x)) * Ω | ||||||
project_x(conj(a) * ΔΩ + real(conj(b) * ΔΩ) * im) | ||||||
end | ||||||
return ChainRulesCore.NoTangent(), ν̄, x̄ | ||||||
end | ||||||
return Ω, besselyx_pullback | ||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My earlier comment might have gotten lost, but since
ZeroTangent() * ::NotImplemented
is aZeroTangent
, you can implement thefrule
in such a way that if the AD providesZeroTangent()
forΔν
, then thefrule
does not return aNotImplemented
:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, I completely forgot that
ZeroTangent() * ::NotImplemented = ZeroTangent()
. I'll fix the forward rules!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that currently it is not possible to test such definitions: JuliaDiff/ChainRulesCore.jl#477