diff --git a/Project.toml b/Project.toml index e0d9275d..205cf836 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SpecialFunctions" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.7.0" +version = "1.8.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/chainrules.jl b/src/chainrules.jl index fa7b5dd3..ad79dd03 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -193,3 +193,93 @@ ChainRulesCore.@scalar_rule( ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x) ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x)) ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x) + +# non-holomorphic functions +function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselix), ν::Number, x::Number) + # primal + Ω = besselix(ν, x) + + # derivative + ∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) + a = (besselix(ν - 1, x) + besselix(ν + 1, x)) / 2 + ΔΩ = if Δx isa Real + muladd(muladd(-sign(real(x)), Ω, a), Δx, ∂Ω_∂ν * Δν) + else + muladd(a, Δx, muladd(-sign(real(x)) * real(Δx), Ω, ∂Ω_∂ν * Δν)) + end + + return Ω, ΔΩ +end +function ChainRulesCore.rrule(::typeof(besselix), ν::Number, x::Number) + Ω = besselix(ν, x) + project_x = ChainRulesCore.ProjectTo(x) + function besselix_pullback(ΔΩ) + ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) + a = (besselix(ν - 1, x) + besselix(ν + 1, x)) / 2 + x̄ = project_x(muladd(conj(a), ΔΩ, - sign(real(x)) * real(conj(Ω) * ΔΩ))) + return ChainRulesCore.NoTangent(), ν̄, x̄ + end + return Ω, besselix_pullback +end + +function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besseljx), ν::Number, x::Number) + # primal + Ω = besseljx(ν, x) + + # derivative + ∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) + a = (besseljx(ν - 1, x) - besseljx(ν + 1, x)) / 2 + ΔΩ = if Δx isa Real + muladd(a, Δx, ∂Ω_∂ν * Δν) + else + muladd(a, Δx, muladd(-sign(imag(x)) * imag(Δx), Ω, ∂Ω_∂ν * Δν)) + end + + return Ω, ΔΩ +end +function ChainRulesCore.rrule(::typeof(besseljx), ν::Number, x::Number) + Ω = besseljx(ν, x) + project_x = ChainRulesCore.ProjectTo(x) + function besseljx_pullback(ΔΩ) + ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) + a = (besseljx(ν - 1, x) - besseljx(ν + 1, x)) / 2 + x̄ = if x isa Real + project_x(a * ΔΩ) + else + project_x(muladd(conj(a), ΔΩ, - sign(imag(x)) * real(conj(Ω) * ΔΩ) * im)) + end + return ChainRulesCore.NoTangent(), ν̄, x̄ + end + return Ω, besseljx_pullback +end + +function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselyx), ν::Number, x::Number) + # primal + Ω = besselyx(ν, x) + + # derivative + ∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) + a = (besselyx(ν - 1, x) - besselyx(ν + 1, x)) / 2 + ΔΩ = if Δx isa Real + muladd(a, Δx, ∂Ω_∂ν * Δν) + else + muladd(a, Δx, muladd(-sign(imag(x)) * imag(Δx), Ω, ∂Ω_∂ν * Δν)) + end + + return Ω, ΔΩ +end +function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number) + Ω = besselyx(ν, x) + project_x = ChainRulesCore.ProjectTo(x) + function besselyx_pullback(ΔΩ) + ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO) + a = (besselyx(ν - 1, x) - besselyx(ν + 1, x)) / 2 + x̄ = if x isa Real + project_x(a * ΔΩ) + else + project_x(muladd(conj(a), ΔΩ, - sign(imag(x)) * real(conj(Ω) * ΔΩ) * im)) + end + return ChainRulesCore.NoTangent(), ν̄, x̄ + end + return Ω, besselyx_pullback +end diff --git a/test/chainrules.jl b/test/chainrules.jl index d4be8285..506b4b48 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -53,9 +53,15 @@ for nu in (-1.5, 2.2, 4.0) test_frule(besseli, nu, x) test_rrule(besseli, nu, x) + test_frule(besselix, nu, x) # derivative is `NotImplemented` + test_frule(besselix, nu ⊢ NoTangent(), x) # derivative is a number + test_rrule(besselix, nu, x) test_frule(besselj, nu, x) test_rrule(besselj, nu, x) + test_frule(besseljx, nu, x) # derivative is `NotImplemented` + test_frule(besseljx, nu ⊢ NoTangent(), x) # derivative is a number + test_rrule(besseljx, nu, x) test_frule(besselk, nu, x) test_rrule(besselk, nu, x) @@ -64,6 +70,9 @@ test_frule(bessely, nu, x) test_rrule(bessely, nu, x) + test_frule(besselyx, nu, x) # derivative is `NotImplemented` + test_frule(besselyx, nu ⊢ NoTangent(), x) # derivative is a number + test_rrule(besselyx, nu, x) test_frule(hankelh1, nu, x) test_rrule(hankelh1, nu, x)