Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the Type consistency of Special Functions (Fixes #17474) #18584

Closed
wants to merge 15 commits into from
182 changes: 112 additions & 70 deletions base/special/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license
typealias ComplexOrReal{T} Union{T,Complex{T}}

gamma(x::Float64) = nan_dom_err(ccall((:tgamma,libm), Float64, (Float64,), x), x)
gamma(x::Float32) = nan_dom_err(ccall((:tgammaf,libm), Float32, (Float32,), x), x)
Expand Down Expand Up @@ -72,7 +73,8 @@ function lgamma(z::Complex{Float64})
else
return Complex(NaN, NaN)
end
elseif x > 7 || yabs > 7 # use the Stirling asymptotic series for sufficiently large x or |y|
elseif x > 7 || yabs > 7
# use the Stirling asymptotic series for sufficiently large x or |y|
return lgamma_asymptotic(z)
elseif x < 0.1 # use reflection formula to transform to x > 0
if x == 0 && y == 0 # return Inf with the correct imaginary part for z == 0
Expand Down Expand Up @@ -106,7 +108,8 @@ function lgamma(z::Complex{Float64})
-2.2315475845357937976132853e-04,9.9457512781808533714662972e-05,
-4.4926236738133141700224489e-05,2.0507212775670691553131246e-05)
end
# use recurrence relation lgamma(z) = lgamma(z+1) - log(z) to shift to x > 7 for asymptotic series
# use recurrence relation lgamma(z) = lgamma(z+1) - log(z)
# to shift to x > 7 for asymptotic series
shiftprod = Complex(x,yabs)
x += 1
sb = false # == signbit(imag(shiftprod)) == signbit(yabs)
Expand Down Expand Up @@ -147,7 +150,7 @@ gamma(z::Complex) = exp(lgamma(z))

Compute the digamma function of `x` (the logarithmic derivative of `gamma(x)`)
"""
function digamma(z::Union{Float64,Complex{Float64}})
function digamma(z::ComplexOrReal{Float64})
# Based on eq. (12), without looking at the accompanying source
# code, of: K. S. Kölbig, "Programs for computing the logarithm of
# the gamma function, and the digamma function, for complex
Expand Down Expand Up @@ -181,7 +184,7 @@ end

Compute the trigamma function of `x` (the logarithmic second derivative of `gamma(x)`).
"""
function trigamma(z::Union{Float64,Complex{Float64}})
function trigamma(z::ComplexOrReal{Float64})
# via the derivative of the Kölbig digamma formulation
x = real(z)
if x <= 0 # reflection formula
Expand Down Expand Up @@ -341,10 +344,7 @@ this definition is equivalent to the Hurwitz zeta function
``\\sum_{k=0}^\\infty (k+z)^{-s}``. For ``z=1``, it yields
the Riemann zeta function ``\\zeta(s)``.
"""
zeta(s,z)

function zeta(s::Union{Int,Float64,Complex{Float64}},
z::Union{Float64,Complex{Float64}})
function zeta(s::ComplexOrReal{Float64}, z::ComplexOrReal{Float64})
ζ = zero(promote_type(typeof(s), typeof(z)))

(z == 1 || z == 0) && return oftype(ζ, zeta(s))
Expand Down Expand Up @@ -393,7 +393,8 @@ function zeta(s::Union{Int,Float64,Complex{Float64}},
minus_z = -z
ζ += pow_oftype(ζ, minus_z, minus_s) # ν = 0 term
if xf != z
ζ += pow_oftype(ζ, z - nx, minus_s) # real(z - nx) > 0, so use correct branch cut
ζ += pow_oftype(ζ, z - nx, minus_s)
# real(z - nx) > 0, so use correct branch cut
# otherwise, if xf==z, then the definition skips this term
end
# do loop in different order, depending on the sign of s,
Expand Down Expand Up @@ -446,10 +447,10 @@ end
"""
polygamma(m, x)

Compute the polygamma function of order `m` of argument `x` (the `(m+1)th` derivative of the
logarithm of `gamma(x)`)
Compute the polygamma function of order `m` of argument `x`
(the `(m+1)th` derivative of the logarithm of `gamma(x)`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"th" shouldn't be code highlighted

"""
function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
function polygamma(m::Integer, z::ComplexOrReal{Float64})
m == 0 && return digamma(z)
m == 1 && return trigamma(z)

Expand All @@ -467,48 +468,26 @@ function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
# constants. We throw a DomainError() since the definition is unclear.
real(m) < 0 && throw(DomainError())

s = m+1
s = Float64(m+1)
# It is safe to convert any integer (including `BigInt`) to a float here
# as underflow occurs before precision issues.
if real(z) <= 0 # reflection formula
(zeta(s, 1-z) + signflip(m, cotderiv(m,z))) * (-gamma(s))
else
signflip(m, zeta(s,z) * (-gamma(s)))
end
end

# TODO: better way to do this
f64(x::Real) = Float64(x)
f64(z::Complex) = Complex128(z)
f32(x::Real) = Float32(x)
f32(z::Complex) = Complex64(z)
f16(x::Real) = Float16(x)
f16(z::Complex) = Complex32(z)

# If we really cared about single precision, we could make a faster
# Float32 version by truncating the Stirling series at a smaller cutoff.
for (f,T) in ((:f32,Float32),(:f16,Float16))
@eval begin
zeta(s::Integer, z::Union{$T,Complex{$T}}) = $f(zeta(Int(s), f64(z)))
zeta(s::Union{Float64,Complex128}, z::Union{$T,Complex{$T}}) = zeta(s, f64(z))
zeta(s::Number, z::Union{$T,Complex{$T}}) = $f(zeta(f64(s), f64(z)))
polygamma(m::Integer, z::Union{$T,Complex{$T}}) = $f(polygamma(Int(m), f64(z)))
digamma(z::Union{$T,Complex{$T}}) = $f(digamma(f64(z)))
trigamma(z::Union{$T,Complex{$T}}) = $f(trigamma(f64(z)))
end
end

zeta(s::Integer, z::Number) = zeta(Int(s), f64(z))
zeta(s::Number, z::Number) = zeta(f64(s), f64(z))
for f in (:digamma, :trigamma)
@eval begin
$f(z::Number) = $f(f64(z))
end
end
polygamma(m::Integer, z::Number) = polygamma(m, f64(z))
"""
invdigamma(x)

# Inverse digamma function:
# Implementation of fixed point algorithm described in
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000
Compute the inverse [`digamma`](:func:`digamma`) function of `x`.
"""
function invdigamma(y::Float64)
# Implementation of fixed point algorithm described in
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000

# Closed form initial estimates
if y >= -2.22
x_old = exp(y) + 0.5
Expand All @@ -530,19 +509,12 @@ function invdigamma(y::Float64)

return x_new
end
invdigamma(x::Float32) = Float32(invdigamma(Float64(x)))

"""
invdigamma(x)

Compute the inverse [`digamma`](:func:`digamma`) function of `x`.
"""
invdigamma(x::Real) = invdigamma(Float64(x))

"""
beta(x, y)

Euler integral of the first kind ``\\operatorname{B}(x,y) = \\Gamma(x)\\Gamma(y)/\\Gamma(x+y)``.
Euler integral of the first kind
``\\operatorname{B}(x,y) = \\Gamma(x)\\Gamma(y)/\\Gamma(x+y)``.
"""
function beta(x::Number, w::Number)
yx, sx = lgamma_r(x)
Expand All @@ -559,9 +531,16 @@ function ``\\log(|\\operatorname{B}(x,y)|)``.
"""
lbeta(x::Number, w::Number) = lgamma(x)+lgamma(w)-lgamma(x+w)

# Riemann zeta function; algorithm is based on specializing the Hurwitz
# zeta function above for z==1.
function zeta(s::Union{Float64,Complex{Float64}})

"""
zeta(s)

Riemann zeta function ``\\zeta(s)``.
"""
function zeta(s::ComplexOrReal{Float64})
# Riemann zeta function; algorithm is based on specializing the Hurwitz
# zeta function above for z==1.

# blows up to ±Inf, but get correct sign of imaginary zero
s == 1 && return NaN + zero(s) * imag(s)

Expand Down Expand Up @@ -606,17 +585,14 @@ function zeta(s::Union{Float64,Complex{Float64}})
return ζ
end

zeta(x::Integer) = zeta(Float64(x))
zeta(x::Real) = oftype(float(x),zeta(Float64(x)))

"""
zeta(s)

Riemann zeta function ``\\zeta(s)``.
"""
zeta(z::Complex) = oftype(float(z),zeta(Complex128(z)))
eta(x)

function eta(z::Union{Float64,Complex{Float64}})
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
"""
function eta(z::ComplexOrReal{Float64})
δz = 1 - z
if abs(real(δz)) + abs(imag(δz)) < 7e-3 # Taylor expand around z==1
return 0.6931471805599453094172321214581765 *
Expand All @@ -630,12 +606,78 @@ function eta(z::Union{Float64,Complex{Float64}})
return -zeta(z) * expm1(0.6931471805599453094172321214581765*δz)
end
end
eta(x::Integer) = eta(Float64(x))
eta(x::Real) = oftype(float(x),eta(Float64(x)))

"""
eta(x)

Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
"""
eta(z::Complex) = oftype(float(z),eta(Complex128(z)))
# Converting types that we can convert, and not ones we can not
# Float16, and Float32 and their Complex equivalents can be converted to Float64
# and results converted back.
# Otherwise, we need to make things use their own `float` converting methods
# and in those cases, we do not convert back either as we assume
# they also implement their own versions of the functions, with the correct return types.
# This is the case for BitIntegers (which become `Float64` when `float`ed).
# Otherwise, if they do not implement their version of the functions we
# manually throw a `MethodError`.
# This case occurs, when calling `float` on a type does not change its type,
# as it is already a `float`, and would have hit own method, if one had existed.


# If we really cared about single precision, we could make a faster
# Float32 version by truncating the Stirling series at a smaller cutoff.
# and if we really cared about half precision, we could make a faster
# Float16 version, by using a precomputed table look-up.


for T in (Float16, Float32, Float64)
@eval f64(x::Complex{$T}) = Complex128(x)
@eval f64(x::$T) = Float64(x)
end


for f in (:digamma, :trigamma, :zeta, :eta, :invdigamma)
@eval begin
function $f(z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}})
oftype(z, $f(f64(z)))
end

function $f(z::Number)
x = float(z)
typeof(x) === typeof(z) && throw(MethodError($f, (z,)))
# There is nothing to fallback to, as this didn't change the argument types
$f(x)
end
end
end


for T1 in (Float16, Float32, Float64), T2 in (Float16, Float32, Float64)
(T1 == T2 == Float64) && continue # Avoid redefining base definition

@eval function zeta(s::ComplexOrReal{$T1}, z::ComplexOrReal{$T2})
ζ = zeta(f64(s), f64(z))
convert(promote_type(typeof(s), typeof(z)), ζ)
end
end


function zeta(s::Number, z::Number)
t = float(s)
x = float(z)
if typeof(t) === typeof(s) && typeof(x) === typeof(z)
# There is nothing to fallback to, since this didn't work
throw(MethodError(zeta,(s,z)))
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should have a second fallback to zeta(::Number, Number) maybe.
Or possibly to prevent looping just try and convert s to a float and see how that goes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my other comment about the zeta.

zeta(t, x)
end


function polygamma(m::Integer, z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}})
oftype(z, polygamma(m, f64(z)))
end


function polygamma(m::Integer, z::Number)
x = float(z)
typeof(x) === typeof(z) && throw(MethodError(polygamma, (m,z)))
# There is nothing to fallback to, since this didn't work
polygamma(m, x)
end
33 changes: 31 additions & 2 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -936,8 +936,37 @@ end
end
end

@test Base.Math.f32(complex(1.0,1.0)) == complex(Float32(1.),Float32(1.))
@test Base.Math.f16(complex(1.0,1.0)) == complex(Float16(1.),Float16(1.))

@test Base.Math.f64(complex(1f0,1f0)) == complex(1.0, 1.0)
@test Base.Math.f64(1f0) == 1.0

# no domain error is thrown for negative values
@test invoke(cbrt, Tuple{AbstractFloat}, -1.0) == -1.0

# issue #17474
@test typeof(eta(big"2")) == BigFloat
@test typeof(zeta(big"2")) == BigFloat
@test typeof(digamma(big"2")) == BigFloat

@test_throws MethodError trigamma(big"2")
@test_throws MethodError trigamma(big"2.0")
@test_throws MethodError invdigamma(big"2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these not available in mpfr/gmp?

Copy link
Contributor Author

@oxinabox oxinabox Sep 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. They are not available, or if they are, then we are not importing them.
Let me check mpfr docs

Edit:
They are not available in MPFR:
http://www.mpfr.org/mpfr-current/mpfr.html#index-Special-functions
GMP doesn't have it since (as I understand it) these functions almost never return integers.

@test_throws MethodError invdigamma(big"2.0")

@test_throws MethodError eta(Complex(big"2"))
@test_throws MethodError eta(Complex(big"2.0"))
@test_throws MethodError zeta(Complex(big"2"))
@test_throws MethodError zeta(Complex(big"2.0"))
@test_throws MethodError zeta(1.0,big"2")
@test_throws MethodError zeta(1.0,big"2.0")
@test_throws MethodError zeta(big"1.0",2.0)
@test_throws MethodError zeta(big"1",2.0)


@test typeof(polygamma(3, 0x2)) == Float64
@test typeof(polygamma(big"3", 2f0)) == Float32
@test typeof(zeta(1, 2.0)) == Float64
@test typeof(zeta(1, 2f0)) == Float64 # BitIntegers result in Float64 returns
@test typeof(zeta(2f0, complex(2f0,0f0))) == Complex{Float32}
@test typeof(zeta(complex(1,1), 2f0)) == Complex{Float64}
@test typeof(zeta(complex(1), 2.0)) == Complex{Float64}