Skip to content

Commit

Permalink
Add derivatives for besselix, besseljx, and besselyx (#350)
Browse files Browse the repository at this point in the history
* Add derivatives for `besselix`, `besseljx`, and `besselyx`

* Bump version

* Apply suggestions from @sethaxen's review (muladd + optimizations)

* Improve `frule`s

* Simplify `frule`s

* Apply suggestions from code review

Co-authored-by: Seth Axen <[email protected]>

Co-authored-by: Seth Axen <[email protected]>
  • Loading branch information
devmotion and sethaxen authored Oct 18, 2021
1 parent 1b29b0b commit c4b0b83
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
90 changes: 90 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,93 @@ ChainRulesCore.@scalar_rule(
ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x)
ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x))
ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x)

# non-holomorphic functions
function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselix), ν::Number, x::Number)
# primal
Ω = besselix(ν, x)

# derivative
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
a = (besselix- 1, x) + besselix+ 1, x)) / 2
ΔΩ = if Δx isa Real
muladd(muladd(-sign(real(x)), Ω, a), Δx, ∂Ω_∂ν * Δν)
else
muladd(a, Δx, muladd(-sign(real(x)) * real(Δx), Ω, ∂Ω_∂ν * Δν))
end

return Ω, ΔΩ
end
function ChainRulesCore.rrule(::typeof(besselix), ν::Number, x::Number)
Ω = besselix(ν, x)
project_x = ChainRulesCore.ProjectTo(x)
function besselix_pullback(ΔΩ)
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
a = (besselix- 1, x) + besselix+ 1, x)) / 2
= project_x(muladd(conj(a), ΔΩ, - sign(real(x)) * real(conj(Ω) * ΔΩ)))
return ChainRulesCore.NoTangent(), ν̄, x̄
end
return Ω, besselix_pullback
end

function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besseljx), ν::Number, x::Number)
# primal
Ω = besseljx(ν, x)

# derivative
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
a = (besseljx- 1, x) - besseljx+ 1, x)) / 2
ΔΩ = if Δx isa Real
muladd(a, Δx, ∂Ω_∂ν * Δν)
else
muladd(a, Δx, muladd(-sign(imag(x)) * imag(Δx), Ω, ∂Ω_∂ν * Δν))
end

return Ω, ΔΩ
end
function ChainRulesCore.rrule(::typeof(besseljx), ν::Number, x::Number)
Ω = besseljx(ν, x)
project_x = ChainRulesCore.ProjectTo(x)
function besseljx_pullback(ΔΩ)
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
a = (besseljx- 1, x) - besseljx+ 1, x)) / 2
= if x isa Real
project_x(a * ΔΩ)
else
project_x(muladd(conj(a), ΔΩ, - sign(imag(x)) * real(conj(Ω) * ΔΩ) * im))
end
return ChainRulesCore.NoTangent(), ν̄, x̄
end
return Ω, besseljx_pullback
end

function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselyx), ν::Number, x::Number)
# primal
Ω = besselyx(ν, x)

# derivative
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
a = (besselyx- 1, x) - besselyx+ 1, x)) / 2
ΔΩ = if Δx isa Real
muladd(a, Δx, ∂Ω_∂ν * Δν)
else
muladd(a, Δx, muladd(-sign(imag(x)) * imag(Δx), Ω, ∂Ω_∂ν * Δν))
end

return Ω, ΔΩ
end
function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number)
Ω = besselyx(ν, x)
project_x = ChainRulesCore.ProjectTo(x)
function besselyx_pullback(ΔΩ)
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
a = (besselyx- 1, x) - besselyx+ 1, x)) / 2
= if x isa Real
project_x(a * ΔΩ)
else
project_x(muladd(conj(a), ΔΩ, - sign(imag(x)) * real(conj(Ω) * ΔΩ) * im))
end
return ChainRulesCore.NoTangent(), ν̄, x̄
end
return Ω, besselyx_pullback
end
9 changes: 9 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,15 @@
for nu in (-1.5, 2.2, 4.0)
test_frule(besseli, nu, x)
test_rrule(besseli, nu, x)
test_frule(besselix, nu, x) # derivative is `NotImplemented`
test_frule(besselix, nu NoTangent(), x) # derivative is a number
test_rrule(besselix, nu, x)

test_frule(besselj, nu, x)
test_rrule(besselj, nu, x)
test_frule(besseljx, nu, x) # derivative is `NotImplemented`
test_frule(besseljx, nu NoTangent(), x) # derivative is a number
test_rrule(besseljx, nu, x)

test_frule(besselk, nu, x)
test_rrule(besselk, nu, x)
Expand All @@ -64,6 +70,9 @@

test_frule(bessely, nu, x)
test_rrule(bessely, nu, x)
test_frule(besselyx, nu, x) # derivative is `NotImplemented`
test_frule(besselyx, nu NoTangent(), x) # derivative is a number
test_rrule(besselyx, nu, x)

test_frule(hankelh1, nu, x)
test_rrule(hankelh1, nu, x)
Expand Down

2 comments on commit c4b0b83

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/46962

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.8.0 -m "<description of version>" c4b0b83f2eeadc76dcc8825460b7311d034b60ce
git push origin v1.8.0

Please sign in to comment.