-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add derivatives for besselix
, besseljx
, and besselyx
#350
Changes from 4 commits
89e2dc9
d1076fc
c26d35e
0af9568
30a020f
a9f9f52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,3 +193,108 @@ ChainRulesCore.@scalar_rule( | |
ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x) | ||
ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x)) | ||
ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x) | ||
|
||
# non-holomorphic functions | ||
function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselix), ν::Number, x::Number) | ||
# primal | ||
Ω = besselix(ν, x) | ||
|
||
# derivative | ||
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||
∂Ω_∂ν_Δν = ∂Ω_∂ν * Δν | ||
ΔΩ = if ∂Ω_∂ν_Δν isa ChainRulesCore.NotImplemented | ||
∂Ω_∂ν_Δν | ||
else | ||
a = (besselix(ν - 1, x) + besselix(ν + 1, x)) / 2 | ||
if Δx isa Real | ||
muladd(muladd(-sign(real(x)), Ω, a), Δx, ∂Ω_∂ν_Δν) | ||
else | ||
muladd(a, Δx, muladd(-sign(real(x)) * Ω, real(Δx), ∂Ω_∂ν_Δν)) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can't be tested currently and requires JuliaDiff/ChainRulesCore.jl#477 (currently this branch is not reached with neither the default nor the |
||
end | ||
|
||
return Ω, ΔΩ | ||
end | ||
function ChainRulesCore.rrule(::typeof(besselix), ν::Number, x::Number) | ||
Ω = besselix(ν, x) | ||
project_x = ChainRulesCore.ProjectTo(x) | ||
function besselix_pullback(ΔΩ) | ||
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||
a = (besselix(ν - 1, x) + besselix(ν + 1, x)) / 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One can use the recurrence relations to write this as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just reused the derivatives that are used for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be best to use the same relations for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right, and they're the same in DiffRules as well. If you like, you can keep them similar to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree! |
||
x̄ = project_x(muladd(conj(a), ΔΩ, - sign(real(x)) * real(conj(Ω) * ΔΩ))) | ||
return ChainRulesCore.NoTangent(), ν̄, x̄ | ||
end | ||
return Ω, besselix_pullback | ||
end | ||
|
||
function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besseljx), ν::Number, x::Number) | ||
# primal | ||
Ω = besseljx(ν, x) | ||
|
||
# derivative | ||
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||
∂Ω_∂ν_Δν = ∂Ω_∂ν * Δν | ||
ΔΩ = if ∂Ω_∂ν_Δν isa ChainRulesCore.NotImplemented | ||
∂Ω_∂ν_Δν | ||
else | ||
a = (besseljx(ν - 1, x) - besseljx(ν + 1, x)) / 2 | ||
if Δx isa Real | ||
a * Δx | ||
else | ||
muladd(a, Δx, muladd(-sign(imag(x)) * Ω, imag(Δx), ∂Ω_∂ν_Δν)) | ||
end | ||
end | ||
|
||
return Ω, ΔΩ | ||
end | ||
function ChainRulesCore.rrule(::typeof(besseljx), ν::Number, x::Number) | ||
Ω = besseljx(ν, x) | ||
project_x = ChainRulesCore.ProjectTo(x) | ||
function besseljx_pullback(ΔΩ) | ||
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||
a = (besseljx(ν - 1, x) - besseljx(ν + 1, x)) / 2 | ||
x̄ = if x isa Real | ||
project_x(a * ΔΩ) | ||
else | ||
project_x(muladd(conj(a), ΔΩ, - sign(imag(x)) * real(conj(Ω) * ΔΩ) * im)) | ||
end | ||
return ChainRulesCore.NoTangent(), ν̄, x̄ | ||
end | ||
return Ω, besseljx_pullback | ||
end | ||
|
||
function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselyx), ν::Number, x::Number) | ||
# primal | ||
Ω = besselyx(ν, x) | ||
|
||
# derivative | ||
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||
∂Ω_∂ν_Δν = ∂Ω_∂ν * Δν | ||
ΔΩ = if ∂Ω_∂ν_Δν isa ChainRulesCore.NotImplemented | ||
∂Ω_∂ν_Δν | ||
else | ||
a = (besselyx(ν - 1, x) - besselyx(ν + 1, x)) / 2 | ||
if Δx isa Real | ||
a * Δx | ||
else | ||
muladd(a, Δx, muladd(-sign(imag(x)) * Ω, imag(Δx), ∂Ω_∂ν_Δν)) | ||
end | ||
end | ||
|
||
return Ω, ΔΩ | ||
end | ||
function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number) | ||
Ω = besselyx(ν, x) | ||
project_x = ChainRulesCore.ProjectTo(x) | ||
function besselyx_pullback(ΔΩ) | ||
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) | ||
a = (besselyx(ν - 1, x) - besselyx(ν + 1, x)) / 2 | ||
x̄ = if x isa Real | ||
project_x(a * ΔΩ) | ||
else | ||
project_x(muladd(conj(a), ΔΩ, - sign(imag(x)) * real(conj(Ω) * ΔΩ) * im)) | ||
end | ||
return ChainRulesCore.NoTangent(), ν̄, x̄ | ||
end | ||
return Ω, besselyx_pullback | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this optimization is useful enough to justify the more verbose and less readable implementation. The optimization is also not performed when the derivative is defined with
@scalar_rule
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I don't think it's necessary. Odds are when this happens it will be because a user actually tried to differentiate wrt the order, and they will instantly realize that's not possible. Ideally the compiler would realize
a
is unused and not compute it, but that doesn't seem to be the case.Including it doesn't hurt though. You could always include in a comment the simpler expression so it's easier to follow. We do this in a few places in ChainRules.