-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve the Type consistency of Special Functions (Fixes #17474) #18584
Changes from 12 commits
b0b58a8
ba62160
bf783c9
63fc253
f62b049
dca8256
c96f578
706773b
5a11d1d
f057335
6334dda
e81314f
0f62f12
9fd1a32
734ca54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# This file is a part of Julia. License is MIT: http://julialang.org/license | ||
typealias ComplexOrReal{T} Union{T,Complex{T}} | ||
|
||
gamma(x::Float64) = nan_dom_err(ccall((:tgamma,libm), Float64, (Float64,), x), x) | ||
gamma(x::Float32) = nan_dom_err(ccall((:tgammaf,libm), Float32, (Float32,), x), x) | ||
|
@@ -72,7 +73,8 @@ function lgamma(z::Complex{Float64}) | |
else | ||
return Complex(NaN, NaN) | ||
end | ||
elseif x > 7 || yabs > 7 # use the Stirling asymptotic series for sufficiently large x or |y| | ||
elseif x > 7 || yabs > 7 | ||
# use the Stirling asymptotic series for sufficiently large x or |y| | ||
return lgamma_asymptotic(z) | ||
elseif x < 0.1 # use reflection formula to transform to x > 0 | ||
if x == 0 && y == 0 # return Inf with the correct imaginary part for z == 0 | ||
|
@@ -106,7 +108,8 @@ function lgamma(z::Complex{Float64}) | |
-2.2315475845357937976132853e-04,9.9457512781808533714662972e-05, | ||
-4.4926236738133141700224489e-05,2.0507212775670691553131246e-05) | ||
end | ||
# use recurrence relation lgamma(z) = lgamma(z+1) - log(z) to shift to x > 7 for asymptotic series | ||
# use recurrence relation lgamma(z) = lgamma(z+1) - log(z) | ||
# to shift to x > 7 for asymptotic series | ||
shiftprod = Complex(x,yabs) | ||
x += 1 | ||
sb = false # == signbit(imag(shiftprod)) == signbit(yabs) | ||
|
@@ -147,7 +150,7 @@ gamma(z::Complex) = exp(lgamma(z)) | |
|
||
Compute the digamma function of `x` (the logarithmic derivative of `gamma(x)`) | ||
""" | ||
function digamma(z::Union{Float64,Complex{Float64}}) | ||
function digamma(z::ComplexOrReal{Float64}) | ||
# Based on eq. (12), without looking at the accompanying source | ||
# code, of: K. S. Kölbig, "Programs for computing the logarithm of | ||
# the gamma function, and the digamma function, for complex | ||
|
@@ -181,7 +184,7 @@ end | |
|
||
Compute the trigamma function of `x` (the logarithmic second derivative of `gamma(x)`). | ||
""" | ||
function trigamma(z::Union{Float64,Complex{Float64}}) | ||
function trigamma(z::ComplexOrReal{Float64}) | ||
# via the derivative of the Kölbig digamma formulation | ||
x = real(z) | ||
if x <= 0 # reflection formula | ||
|
@@ -341,10 +344,7 @@ this definition is equivalent to the Hurwitz zeta function | |
``\\sum_{k=0}^\\infty (k+z)^{-s}``. For ``z=1``, it yields | ||
the Riemann zeta function ``\\zeta(s)``. | ||
""" | ||
zeta(s,z) | ||
|
||
function zeta(s::Union{Int,Float64,Complex{Float64}}, | ||
z::Union{Float64,Complex{Float64}}) | ||
function zeta(s::ComplexOrReal{Float64}, z::ComplexOrReal{Float64}) | ||
ζ = zero(promote_type(typeof(s), typeof(z))) | ||
|
||
(z == 1 || z == 0) && return oftype(ζ, zeta(s)) | ||
|
@@ -393,7 +393,8 @@ function zeta(s::Union{Int,Float64,Complex{Float64}}, | |
minus_z = -z | ||
ζ += pow_oftype(ζ, minus_z, minus_s) # ν = 0 term | ||
if xf != z | ||
ζ += pow_oftype(ζ, z - nx, minus_s) # real(z - nx) > 0, so use correct branch cut | ||
ζ += pow_oftype(ζ, z - nx, minus_s) | ||
# real(z - nx) > 0, so use correct branch cut | ||
# otherwise, if xf==z, then the definition skips this term | ||
end | ||
# do loop in different order, depending on the sign of s, | ||
|
@@ -446,10 +447,10 @@ end | |
""" | ||
polygamma(m, x) | ||
|
||
Compute the polygamma function of order `m` of argument `x` (the `(m+1)th` derivative of the | ||
logarithm of `gamma(x)`) | ||
Compute the polygamma function of order `m` of argument `x` | ||
(the `(m+1)th` derivative of the logarithm of `gamma(x)`) | ||
""" | ||
function polygamma(m::Integer, z::Union{Float64,Complex{Float64}}) | ||
function polygamma(m::Integer, z::ComplexOrReal{Float64}) | ||
m == 0 && return digamma(z) | ||
m == 1 && return trigamma(z) | ||
|
||
|
@@ -467,48 +468,26 @@ function polygamma(m::Integer, z::Union{Float64,Complex{Float64}}) | |
# constants. We throw a DomainError() since the definition is unclear. | ||
real(m) < 0 && throw(DomainError()) | ||
|
||
s = m+1 | ||
s = Float64(m+1) | ||
# It is safe to convert any integer (including `BigInt`) to a float here | ||
# as underflow occurs before precision issues. | ||
if real(z) <= 0 # reflection formula | ||
(zeta(s, 1-z) + signflip(m, cotderiv(m,z))) * (-gamma(s)) | ||
else | ||
signflip(m, zeta(s,z) * (-gamma(s))) | ||
end | ||
end | ||
|
||
# TODO: better way to do this | ||
f64(x::Real) = Float64(x) | ||
f64(z::Complex) = Complex128(z) | ||
f32(x::Real) = Float32(x) | ||
f32(z::Complex) = Complex64(z) | ||
f16(x::Real) = Float16(x) | ||
f16(z::Complex) = Complex32(z) | ||
|
||
# If we really cared about single precision, we could make a faster | ||
# Float32 version by truncating the Stirling series at a smaller cutoff. | ||
for (f,T) in ((:f32,Float32),(:f16,Float16)) | ||
@eval begin | ||
zeta(s::Integer, z::Union{$T,Complex{$T}}) = $f(zeta(Int(s), f64(z))) | ||
zeta(s::Union{Float64,Complex128}, z::Union{$T,Complex{$T}}) = zeta(s, f64(z)) | ||
zeta(s::Number, z::Union{$T,Complex{$T}}) = $f(zeta(f64(s), f64(z))) | ||
polygamma(m::Integer, z::Union{$T,Complex{$T}}) = $f(polygamma(Int(m), f64(z))) | ||
digamma(z::Union{$T,Complex{$T}}) = $f(digamma(f64(z))) | ||
trigamma(z::Union{$T,Complex{$T}}) = $f(trigamma(f64(z))) | ||
end | ||
end | ||
|
||
zeta(s::Integer, z::Number) = zeta(Int(s), f64(z)) | ||
zeta(s::Number, z::Number) = zeta(f64(s), f64(z)) | ||
for f in (:digamma, :trigamma) | ||
@eval begin | ||
$f(z::Number) = $f(f64(z)) | ||
end | ||
end | ||
polygamma(m::Integer, z::Number) = polygamma(m, f64(z)) | ||
""" | ||
invdigamma(x) | ||
|
||
# Inverse digamma function: | ||
# Implementation of fixed point algorithm described in | ||
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000 | ||
Compute the inverse [`digamma`](:func:`digamma`) function of `x`. | ||
""" | ||
function invdigamma(y::Float64) | ||
# Implementation of fixed point algorithm described in | ||
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000 | ||
|
||
# Closed form initial estimates | ||
if y >= -2.22 | ||
x_old = exp(y) + 0.5 | ||
|
@@ -530,19 +509,12 @@ function invdigamma(y::Float64) | |
|
||
return x_new | ||
end | ||
invdigamma(x::Float32) = Float32(invdigamma(Float64(x))) | ||
|
||
""" | ||
invdigamma(x) | ||
|
||
Compute the inverse [`digamma`](:func:`digamma`) function of `x`. | ||
""" | ||
invdigamma(x::Real) = invdigamma(Float64(x)) | ||
|
||
""" | ||
beta(x, y) | ||
|
||
Euler integral of the first kind ``\\operatorname{B}(x,y) = \\Gamma(x)\\Gamma(y)/\\Gamma(x+y)``. | ||
Euler integral of the first kind | ||
``\\operatorname{B}(x,y) = \\Gamma(x)\\Gamma(y)/\\Gamma(x+y)``. | ||
""" | ||
function beta(x::Number, w::Number) | ||
yx, sx = lgamma_r(x) | ||
|
@@ -559,9 +531,16 @@ function ``\\log(|\\operatorname{B}(x,y)|)``. | |
""" | ||
lbeta(x::Number, w::Number) = lgamma(x)+lgamma(w)-lgamma(x+w) | ||
|
||
# Riemann zeta function; algorithm is based on specializing the Hurwitz | ||
# zeta function above for z==1. | ||
function zeta(s::Union{Float64,Complex{Float64}}) | ||
|
||
""" | ||
zeta(s) | ||
|
||
Riemann zeta function ``\\zeta(s)``. | ||
""" | ||
function zeta(s::ComplexOrReal{Float64}) | ||
# Riemann zeta function; algorithm is based on specializing the Hurwitz | ||
# zeta function above for z==1. | ||
|
||
# blows up to ±Inf, but get correct sign of imaginary zero | ||
s == 1 && return NaN + zero(s) * imag(s) | ||
|
||
|
@@ -606,17 +585,14 @@ function zeta(s::Union{Float64,Complex{Float64}}) | |
return ζ | ||
end | ||
|
||
zeta(x::Integer) = zeta(Float64(x)) | ||
zeta(x::Real) = oftype(float(x),zeta(Float64(x))) | ||
|
||
""" | ||
zeta(s) | ||
|
||
Riemann zeta function ``\\zeta(s)``. | ||
""" | ||
zeta(z::Complex) = oftype(float(z),zeta(Complex128(z))) | ||
eta(x) | ||
|
||
function eta(z::Union{Float64,Complex{Float64}}) | ||
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``. | ||
""" | ||
function eta(z::ComplexOrReal{Float64}) | ||
δz = 1 - z | ||
if abs(real(δz)) + abs(imag(δz)) < 7e-3 # Taylor expand around z==1 | ||
return 0.6931471805599453094172321214581765 * | ||
|
@@ -630,12 +606,78 @@ function eta(z::Union{Float64,Complex{Float64}}) | |
return -zeta(z) * expm1(0.6931471805599453094172321214581765*δz) | ||
end | ||
end | ||
eta(x::Integer) = eta(Float64(x)) | ||
eta(x::Real) = oftype(float(x),eta(Float64(x))) | ||
|
||
""" | ||
eta(x) | ||
|
||
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``. | ||
""" | ||
eta(z::Complex) = oftype(float(z),eta(Complex128(z))) | ||
# Converting types that we can convert, and not ones we can not | ||
# Float16, and Float32 and their Complex equivalents can be converted to Float64 | ||
# and results converted back. | ||
# Otherwise, we need to make things use their own `float` converting methods | ||
# and in those cases, we do not convert back either as we assume | ||
# they also implement their own versions of the functions, with the correct return types. | ||
# This is the case for BitIntegers (which become `Float64` when `float`ed). | ||
# Otherwise, if they do not implement their version of the functions we | ||
# manually throw a `MethodError`. | ||
# This case occurs, when calling `float` on a type does not change its type, | ||
# as it is already a `float`, and would have hit own method, if one had existed. | ||
|
||
|
||
# If we really cared about single precision, we could make a faster | ||
# Float32 version by truncating the Stirling series at a smaller cutoff. | ||
# and if we really cared about half precision, we could make a faster | ||
# Float16 version, by using a precomputed table look-up. | ||
|
||
|
||
for T in (Float16, Float32, Float64) | ||
@eval f64(x::Complex{$T}) = Complex128(x) | ||
@eval f64(x::$T) = Float64(x) | ||
end | ||
|
||
|
||
for f in (:digamma, :trigamma, :zeta, :eta, :invdigamma) | ||
@eval begin | ||
function $f(z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}}) | ||
oftype(z, $f(f64(z))) | ||
end | ||
|
||
function $f(z::Number) | ||
x = float(z) | ||
typeof(x) === typeof(z) && throw(MethodError($f, (z,))) | ||
# There is nothing to fallback to, as this didn't change the argument types | ||
$f(x) | ||
end | ||
end | ||
end | ||
|
||
|
||
for T1 in (Float16, Float32, Float64), T2 in (Float16, Float32, Float64) | ||
(T1 == T2 == Float64) && continue # Avoid redefining base definition | ||
|
||
@eval function zeta(s::ComplexOrReal{$T1}, z::ComplexOrReal{$T2}) | ||
ζ = zeta(f64(s), f64(z)) | ||
convert(promote_type(typeof(s), typeof(z)), ζ) | ||
end | ||
end | ||
|
||
|
||
function zeta(s::Number, z::Number) | ||
t = float(s) | ||
x = float(z) | ||
if typeof(t) === typeof(s) && typeof(x) === typeof(z) | ||
# There is nothing to fallback to, since this didn't work | ||
throw(MethodError(zeta,(s,z))) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should have a second fallback to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my other comment about the |
||
zeta(t, x) | ||
end | ||
|
||
|
||
function polygamma(m::Integer, z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}}) | ||
oftype(z, polygamma(m, f64(z))) | ||
end | ||
|
||
|
||
function polygamma(m::Integer, z::Number) | ||
x = float(z) | ||
typeof(x) === typeof(z) && throw(MethodError(polygamma, (m,z))) | ||
# There is nothing to fallback to, since this didn't work | ||
polygamma(m, x) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -936,8 +936,37 @@ end | |
end | ||
end | ||
|
||
@test Base.Math.f32(complex(1.0,1.0)) == complex(Float32(1.),Float32(1.)) | ||
@test Base.Math.f16(complex(1.0,1.0)) == complex(Float16(1.),Float16(1.)) | ||
|
||
@test Base.Math.f64(complex(1f0,1f0)) == complex(1.0, 1.0) | ||
@test Base.Math.f64(1f0) == 1.0 | ||
|
||
# no domain error is thrown for negative values | ||
@test invoke(cbrt, Tuple{AbstractFloat}, -1.0) == -1.0 | ||
|
||
# issue #17474 | ||
@test typeof(eta(big"2")) == BigFloat | ||
@test typeof(zeta(big"2")) == BigFloat | ||
@test typeof(digamma(big"2")) == BigFloat | ||
|
||
@test_throws MethodError trigamma(big"2") | ||
@test_throws MethodError trigamma(big"2.0") | ||
@test_throws MethodError invdigamma(big"2") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these not available in mpfr/gmp? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. They are not available, or if they are, then we are not importing them. Edit: |
||
@test_throws MethodError invdigamma(big"2.0") | ||
|
||
@test_throws MethodError eta(Complex(big"2")) | ||
@test_throws MethodError eta(Complex(big"2.0")) | ||
@test_throws MethodError zeta(Complex(big"2")) | ||
@test_throws MethodError zeta(Complex(big"2.0")) | ||
@test_throws MethodError zeta(1.0,big"2") | ||
@test_throws MethodError zeta(1.0,big"2.0") | ||
@test_throws MethodError zeta(big"1.0",2.0) | ||
@test_throws MethodError zeta(big"1",2.0) | ||
|
||
|
||
@test typeof(polygamma(3, 0x2)) == Float64 | ||
@test typeof(polygamma(big"3", 2f0)) == Float32 | ||
@test typeof(zeta(1, 2.0)) == Float64 | ||
@test typeof(zeta(1, 2f0)) == Float64 # BitIntegers result in Float64 returns | ||
@test typeof(zeta(2f0, complex(2f0,0f0))) == Complex{Float32} | ||
@test typeof(zeta(complex(1,1), 2f0)) == Complex{Float64} | ||
@test typeof(zeta(complex(1), 2.0)) == Complex{Float64} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"th" shouldn't be code highlighted