From f2ccad81b957f295ed209f973e61fd8b13882c5d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 26 Jan 2022 14:00:37 +0100 Subject: [PATCH] Add derivatives for `ellipk` and `ellipe` (#370) * Add derivatives for `ellipk` and `ellipe` * Bump version * Fix derivatives at 0 --- Project.toml | 2 +- src/chainrules.jl | 10 ++++++++++ src/ellip.jl | 18 +++++++++--------- test/chainrules.jl | 5 +++++ 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index d1799f40..776718c3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SpecialFunctions" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.0.0" +version = "2.1.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/chainrules.jl b/src/chainrules.jl index ad79dd03..76a8f07d 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -194,6 +194,16 @@ ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x) ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x)) ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x) +# elliptic integrals +ChainRulesCore.@scalar_rule( + ellipk(m), + iszero(m) ? oftype(Ω, π) / 8 : (ellipe(m) / (1 - m) - Ω) / (2 * m), +) +ChainRulesCore.@scalar_rule( + ellipe(m), + iszero(m) ? -oftype(Ω, π) / 8 : (Ω - ellipk(m)) / (2 * m), +) + # non-holomorphic functions function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselix), ν::Number, x::Number) # primal diff --git a/src/ellip.jl b/src/ellip.jl index a9f7ef25..5065f60d 100644 --- a/src/ellip.jl +++ b/src/ellip.jl @@ -39,7 +39,9 @@ For ``m<0``, followed by > As suggested in this paper, the domain is restricted to ``(-\infty,1]``. """ -function ellipk(m::Float64) +ellipk(m::Real) = _ellipk(float(m)) + +function _ellipk(m::Float64) flag_is_m_neg = false if m < 0.0 x = m / (m-1) #dealing with negative args @@ -214,7 +216,9 @@ For ``m<0``, followed by > As suggested in this paper, the domain is restricted to ``(-\infty,1]``. """ -function ellipe(m::Float64) +ellipe(m::Real) = _ellipe(float(m)) + +function _ellipe(m::Float64) flag_is_m_neg = false if m < 0.0 x = m / (m-1) #dealing with negative args @@ -346,11 +350,7 @@ function ellipe(m::Float64) end end -for f in (:ellipk,:ellipe) - @eval begin - ($f)(x::Float16) = Float16(($f)(Float64(x))) - ($f)(x::Float32) = Float32(($f)(Float64(x))) - ($f)(x::Real) = ($f)(float(x)) - ($f)(x::AbstractFloat) = throw(MethodError($f, (x, ""))) - end +# Support for Float32 and Float16 +for internalf in (:_ellipk, :_ellipe), T in (:Float16, :Float32) + @eval $internalf(x::$T) = $T($internalf(Float64(x))) end diff --git a/test/chainrules.jl b/test/chainrules.jl index 506b4b48..1754d591 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -37,6 +37,11 @@ test_scalar(airyaiprimex, x) end end + + if x isa Real && x < 1 + test_scalar(ellipk, x) + test_scalar(ellipe, x) + end end end