Skip to content

Commit

Permalink
Add more ChainRules derivatives (#348)
Browse files Browse the repository at this point in the history
* Add more ChainRules definitions

* Bump version

* Use irrational constants

* Convert irrational manually with `oftype`
  • Loading branch information
devmotion authored Sep 24, 2021
1 parent 1b7a377 commit 74fd16f
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 30 deletions.
46 changes: 38 additions & 8 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/321
"""

ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x))
ChainRulesCore.@scalar_rule(airyaix(x), airyaiprimex(x) + sqrt(x) * Ω)
ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x))
ChainRulesCore.@scalar_rule(airyaiprimex(x), x * airyaix(x) + sqrt(x) * Ω)
ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x))
ChainRulesCore.@scalar_rule(airybiprime(x), x * airybi(x))
ChainRulesCore.@scalar_rule(besselj0(x), -besselj1(x))
Expand All @@ -31,12 +33,18 @@ ChainRulesCore.@scalar_rule(
)
ChainRulesCore.@scalar_rule(dawson(x), 1 - (2 * x * Ω))
ChainRulesCore.@scalar_rule(digamma(x), trigamma(x))
ChainRulesCore.@scalar_rule(erf(x), (2 / sqrt(π)) * exp(-x * x))
ChainRulesCore.@scalar_rule(erfc(x), -(2 / sqrt(π)) * exp(-x * x))
ChainRulesCore.@scalar_rule(erfcinv(x), -(sqrt(π) / 2) * exp^2))
ChainRulesCore.@scalar_rule(erfcx(x), (2 * x * Ω) - (2 / sqrt(π)))
ChainRulesCore.@scalar_rule(erfi(x), (2 / sqrt(π)) * exp(x * x))
ChainRulesCore.@scalar_rule(erfinv(x), (sqrt(π) / 2) * exp^2))

# TODO: use `invsqrtπ` if it is added to IrrationalConstants
ChainRulesCore.@scalar_rule(erf(x), (2 * exp(-x^2)) / sqrtπ)
ChainRulesCore.@scalar_rule(erf(x, y), (- (2 * exp(-x^2)) / sqrtπ, (2 * exp(-y^2)) / sqrtπ))
ChainRulesCore.@scalar_rule(erfc(x), - (2 * exp(-x^2)) / sqrtπ)
ChainRulesCore.@scalar_rule(logerfc(x), - (2 * exp(-x^2 - Ω)) / sqrtπ)
ChainRulesCore.@scalar_rule(erfcinv(x), - (sqrtπ * (exp^2) / 2)))
ChainRulesCore.@scalar_rule(erfcx(x), 2 * (x * Ω - inv(oftype(Ω, sqrtπ))))
ChainRulesCore.@scalar_rule(logerfcx(x), 2 * (x - exp(-Ω) / sqrtπ))
ChainRulesCore.@scalar_rule(erfi(x), (2 * exp(x^2)) / sqrtπ)
ChainRulesCore.@scalar_rule(erfinv(x), sqrtπ * (exp^2) / 2))

ChainRulesCore.@scalar_rule(gamma(x), Ω * digamma(x))
ChainRulesCore.@scalar_rule(
gamma(a, x),
Expand Down Expand Up @@ -65,7 +73,7 @@ ChainRulesCore.@scalar_rule(
)
ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))

# binary
# Bessel functions
ChainRulesCore.@scalar_rule(
besselj(ν, x),
(
Expand Down Expand Up @@ -94,20 +102,42 @@ ChainRulesCore.@scalar_rule(
-(besselk- 1, x) + besselk+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
besselkx(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
-(besselkx- 1, x) + besselkx+ 1, x)) / 2 + Ω,
),
)
ChainRulesCore.@scalar_rule(
hankelh1(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh1- 1, x) - hankelh1+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
hankelh1x(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh1x- 1, x) - hankelh1x+ 1, x)) / 2 - im * Ω,
),
)
ChainRulesCore.@scalar_rule(
hankelh2(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh2- 1, x) - hankelh2+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
hankelh2x(ν, x),
(
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh2x- 1, x) - hankelh2x+ 1, x)) / 2 + im * Ω,
),
)

ChainRulesCore.@scalar_rule(
polygamma(m, x),
(
Expand Down Expand Up @@ -161,5 +191,5 @@ ChainRulesCore.@scalar_rule(
)
)
ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x)
ChainRulesCore.@scalar_rule(sinint(x), sinc(x / π))
ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x))
ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x)
73 changes: 51 additions & 22 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
test_scalar(erf, x)
test_scalar(erfc, x)
test_scalar(erfcx, x)
test_scalar(erfi, x)

test_scalar(airyai, x)
test_scalar(airyaiprime, x)
test_scalar(airybi, x)
test_scalar(airybiprime, x)

test_scalar(erfcx, x)
test_scalar(dawson, x)

if x isa Real
test_scalar(logerfc, x)
test_scalar(logerfcx, x)

test_scalar(invdigamma, x)
end

Expand All @@ -28,6 +31,11 @@
test_scalar(gamma, x)
test_scalar(digamma, x)
test_scalar(trigamma, x)

if x isa Real
test_scalar(airyaix, x)
test_scalar(airyaiprimex, x)
end
end
end
end
Expand All @@ -51,31 +59,38 @@

test_frule(besselk, nu, x)
test_rrule(besselk, nu, x)
test_frule(besselkx, nu, x)
test_rrule(besselkx, nu, x)

test_frule(bessely, nu, x)
test_rrule(bessely, nu, x)

# use complex numbers in `rrule` for FiniteDifferences
test_frule(hankelh1, nu, x)
test_rrule(hankelh1, nu, complex(x))
test_rrule(hankelh1, nu, x)
test_frule(hankelh1x, nu, x)
test_rrule(hankelh1x, nu, x)

# use complex numbers in `rrule` for FiniteDifferences
test_frule(hankelh2, nu, x)
test_rrule(hankelh2, nu, complex(x))
test_rrule(hankelh2, nu, x)
test_frule(hankelh2x, nu, x)
test_rrule(hankelh2x, nu, x)
end
end
end

@testset "beta and logbeta" begin
@testset "erf, beta, and logbeta" begin
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
for _x in test_points, _y in test_points
# ensure all complex if any complex for FiniteDifferences
x, y = promote(_x, _y)
for x in test_points, y in test_points
test_frule(beta, x, y)
test_rrule(beta, x, y)

test_frule(logbeta, x, y)
test_rrule(logbeta, x, y)

if x isa Real && y isa Real
test_frule(erf, x, y)
test_rrule(erf, x, y)
end
end
end

Expand All @@ -91,13 +106,11 @@
isreal(x) && x < 0 && continue
test_scalar(loggamma, x)
for a in test_points
# ensure all complex if any complex for FiniteDifferences
_a, _x = promote(a, x)
test_frule(gamma, _a, _x; rtol=1e-8)
test_rrule(gamma, _a, _x; rtol=1e-8)
test_frule(gamma, a, x; rtol=1e-8)
test_rrule(gamma, a, x; rtol=1e-8)

test_frule(loggamma, _a, _x)
test_rrule(loggamma, _a, _x)
test_frule(loggamma, a, x)
test_rrule(loggamma, a, x)
end

isreal(x) || continue
Expand All @@ -117,14 +130,11 @@
test_scalar(expintx, x)

for nu in (-1.5, 2.2, 4.0)
# ensure all complex if any complex for FiniteDifferences
_x, _nu = promote(x, nu)
test_frule(expint, nu, x)
test_rrule(expint, nu, x)

test_frule(expint, _nu, _x)
test_rrule(expint, _nu, _x)

test_frule(expintx, _nu, _x)
test_rrule(expintx, _nu, _x)
test_frule(expintx, nu, x)
test_rrule(expintx, nu, x)
end

isreal(x) || continue
Expand All @@ -133,4 +143,23 @@
test_scalar(cosint, x)
end
end

# https://github.com/JuliaMath/SpecialFunctions.jl/issues/307
@testset "promotions" begin
# one argument
for f in (erf, erfc, logerfc, erfcinv, erfcx, logerfcx, erfi, erfinv, sinint)
_, ẏ = frule((NoTangent(), 1f0), f, 1f0)
@testisa Float32
_, back = rrule(f, 1f0)
_, x̄ = back(1f0)
@testisa Float32
end

# two arguments
_, ẏ = frule((NoTangent(), 1f0, 1f0), erf, 1f0, 1f0)
@testisa Float32
_, back = rrule(erf, 1f0, 1f0)
_, x̄ = back(1f0)
@testisa Float32
end
end

0 comments on commit 74fd16f

Please sign in to comment.