Skip to content

Commit

Permalink
=Work on JuliaLang#17474 squashme
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 4, 2016
1 parent 0254e89 commit 4b188ae
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 90 deletions.
181 changes: 94 additions & 87 deletions base/special/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# This file is a part of Julia. License is MIT: http://julialang.org/license
typealias ComplexOrReal{T} Union{T,Complex{T}}

gamma(x::Float64) = nan_dom_err(ccall((:tgamma,libm), Float64, (Float64,), x), x)
gamma(x::Float32) = nan_dom_err(ccall((:tgammaf,libm), Float32, (Float32,), x), x)
Expand Down Expand Up @@ -147,7 +148,7 @@ gamma(z::Complex) = exp(lgamma(z))
Compute the digamma function of `x` (the logarithmic derivative of `gamma(x)`)
"""
function digamma(z::Union{Float64,Complex{Float64}})
function digamma(z::ComplexOrReal{Float64})
# Based on eq. (12), without looking at the accompanying source
# code, of: K. S. Kölbig, "Programs for computing the logarithm of
# the gamma function, and the digamma function, for complex
Expand Down Expand Up @@ -181,7 +182,7 @@ end
Compute the trigamma function of `x` (the logarithmic second derivative of `gamma(x)`).
"""
function trigamma(z::Union{Float64,Complex{Float64}})
function trigamma(z::ComplexOrReal{Float64})
# via the derivative of the Kölbig digamma formulation
x = real(z)
if x <= 0 # reflection formula
Expand Down Expand Up @@ -343,8 +344,8 @@ the Riemann zeta function ``\\zeta(s)``.
"""
zeta(s,z)

function zeta(s::Union{Int,Float64,Complex{Float64}},
z::Union{Float64,Complex{Float64}})
function zeta(s::Union{Int, ComplexOrReal{Float64}},
z::ComplexOrReal{Float64})
ζ = zero(promote_type(typeof(s), typeof(z)))

(z == 1 || z == 0) && return oftype(ζ, zeta(s))
Expand Down Expand Up @@ -449,7 +450,7 @@ end
Compute the polygamma function of order `m` of argument `x` (the `(m+1)th` derivative of the
logarithm of `gamma(x)`)
"""
function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
function polygamma(m::Integer, z::ComplexOrReal{Float64})
m == 0 && return digamma(z)
m == 1 && return trigamma(z)

Expand All @@ -475,94 +476,17 @@ function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
end
end

# TODO: better way to do this
f64(x::Real) = Float64(x)
f64(z::Complex) = Complex128(z)
f32(x::Real) = Float32(x)
f32(z::Complex) = Complex64(z)
f16(x::Real) = Float16(x)
f16(z::Complex) = Complex32(z)

# If we really cared about single precision, we could make a faster
# Float32 version by truncating the Stirling series at a smaller cutoff.
# and if we really cared about half precision, we could make a faster
# Float16 version, by using a precomputed table look-up.
for (f,T) in ((:f32,Float32),(:f16,Float16))
@eval begin
zeta(s::Integer, z::Union{$T,Complex{$T}}) = $f(zeta(Int(s), f64(z)))
zeta(s::Union{Float64,Complex128}, z::Union{$T,Complex{$T}}) = zeta(s, f64(z))
zeta(s::Number, z::Union{$T,Complex{$T}}) = $f(zeta(f64(s), f64(z)))
polygamma(m::Integer, z::Union{$T,Complex{$T}}) = $f(polygamma(Int(m), f64(z)))
digamma(z::Union{$T,Complex{$T}}) = $f(digamma(f64(z)))
trigamma(z::Union{$T,Complex{$T}}) = $f(trigamma(f64(z)))
end
end


function zeta(s::Integer, z::Number)
x = float(z)
t = Int(s)
if typeof(x) === typeof(z) && typeof(t) === typeof(s)
throw MethodError(zeta,(s,t))
end
oftype(x, zeta(t, x))
end

function zeta(s::Number, z::Number)
x = float(z)
t = float(s)
if typeof(x) === typeof(z) && typeof(t) === typeof(s)
throw MethodError(zeta,(s,t))
end
oftype(x, zeta(t, x))
end


function polygamma(m::Integer, z::Number)
x = float(z)
typeof(x) == typeof(z) && throw(MethodError(polygamma, (m,z))
oftype(x,polygamma(m, x))
end


for f in (:digamma, :trigamma, :zeta, :eta, invdigamma)
@eval begin
$f(z::Base.BitInteger) = $f(Float64(z))
$f(z::Float32) = Float32($f(Float64(z)))
$f(z::Float16) = Float16($f(Float64(z)))

function $f(z::Number)
x = float(z)
typeof(x) == typeof(z) && throw(MethodError($f, (z,)))
oftype(x, $f(x))
end
end
end

for f in (:zeta, :eta)
@eval begin
$f{T<:Union{Base.BitInteger,Float32,Float16}}(x::t) = oftype(float(x), $f(Complex128(z)))

function $f(z::Complex)
x = float(z)
typeof(x) == typeof(z) && throw(MethodError($f, (z,)))
oftype(x, $f(Complex(x))
end
end
end



"""
invdigamma(x)
Compute the inverse [`digamma`](:func:`digamma`) function of `x`.
"""
function invdigamma(y::Float64)
# Implementation of fixed point algorithm described in
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000
# Implementation of fixed point algorithm described in
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000

# Closed form initial estimates
# Closed form initial estimates
if y >= -2.22
x_old = exp(y) + 0.5
x_new = x_old
Expand Down Expand Up @@ -610,7 +534,7 @@ lbeta(x::Number, w::Number) = lgamma(x)+lgamma(w)-lgamma(x+w)
Riemann zeta function ``\\zeta(s)``.
"""
function zeta(s::Union{Float64,Complex{Float64}})
function zeta(s::ComplexOrReal{Float64})
# Riemann zeta function; algorithm is based on specializing the Hurwitz
# zeta function above for z==1.

Expand Down Expand Up @@ -665,7 +589,7 @@ end
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
"""
function eta(z::Union{Float64,Complex{Float64}})
function eta(z::ComplexOrReal{Float64})
δz = 1 - z
if abs(real(δz)) + abs(imag(δz)) < 7e-3 # Taylor expand around z==1
return 0.6931471805599453094172321214581765 *
Expand All @@ -680,3 +604,86 @@ function eta(z::Union{Float64,Complex{Float64}})
end
end


# Converting types that we can convert, and not ones we can not
# If we really cared about single precision, we could make a faster
# Float32 version by truncating the Stirling series at a smaller cutoff.
# and if we really cared about half precision, we could make a faster
# Float16 version, by using a precomputed table look-up.

f64(x::Union{Base.BitInteger, Float16, Float32, Float64}) = Float64(x)
f64(x::Complex{Union{Base.BitInteger, Float16, Float32, Float64}}) = Complex128(x)
ofpromotedtype(as, c) = oftype(promote(as...), c)

for T in (Float32, Float16)
@eval begin
polygamma(m::Integer, z::ComplexOrReal{T}) = ofpromotedtype((m,z), polygamma(Int(m), f64(z)))
digamma(z::ComplexOrReal{$T}) = oftype(z, digamma(f64(z)))
trigamma(z::ComplexOrReal{$T}) = oftype(z, trigamma(f64(z)))
zeta(s::Integer, z::ComplexOrReal{$T}}) = ofpromotedtype((s,z), zeta(Int(s), f64(z)))
end
end

function zeta(s::ComplexOrReal{Union{Float16, Float32}},
z::ComplexOrReal{Union{Float16, Float32, Float64, Base.BitInteger})
ofpromotedtype((s, z), zeta(f64(s), f64(z)))
end


function zeta(s::Integer, z::Number)
x = float(z)
t = Int(s)
if typeof(x) === typeof(z) && typeof(t) === typeof(s)
# There is nothing to fallback to, since this didn't work
throw(MethodError(zeta,(s,t)))
end
ofpromotedtype((x,y), zeta(t, x))
end

function zeta(s::Number, z::Number)
x = float(z)
t = float(s)
if typeof(x) === typeof(z) && typeof(t) === typeof(s)
# There is nothing to fallback to, since this didn't work
throw(MethodError(zeta,(s,t)))
end
ofpromotedtype((x,t), zeta(t, x))
end


function polygamma(m::Integer, z::Number)
x = float(z)
typeof(x) == typeof(z) && throw(MethodError(polygamma, (m,z)))
# There is nothing to fallback to, since this didn't work
oftype(x, polygamma(m, x))
end


for f in (:digamma, :trigamma, :zeta, :eta, :invdigamma)
@eval begin
$f(z::Base.BitInteger) = $f(Float64(z))
$f(z::Float32) = Float32($f(Float64(z)))
$f(z::Float16) = Float16($f(Float64(z)))

function $f(z::Number)
x = float(z)
typeof(x) == typeof(z) && throw(MethodError($f, (z,)))
# There is nothing to fallback to, since this didn't work
oftype(x, $f(x))
end
end
end

for f in (:zeta, :eta)
@eval begin
$f{T<:Union{Base.BitInteger,Float32,Float16}}(z::Complex{T}) = oftype(float(z), $f(Complex128(z)))

function $f(z::Complex)
x = float(z)
typeof(x) == typeof(z) && throw(MethodError($f, (z,)))
# There is nothing to fallback to, since this didn't work
oftype(x, $f(x))
end
end
end

2 changes: 1 addition & 1 deletion base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ function __init__()
end

include = include_from_node1
include("precompile.jl")
#include("precompile.jl") #Don't commit me. Speed up testing l

end # baremodule Base

Expand Down
7 changes: 5 additions & 2 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,5 +885,8 @@ end
@test_throws MethodError trigamma(big"2")
@test_throws MethodError trigamma(big"2.0")
@test_throws MethodError invdigamma(big"2")
@test_throws MethodError invdiamma(big"2.0")

@test_throws MethodError invdigamma(big"2.0")
@test_throws MethodError eta(Complex(big"2"))
@test_throws MethodError eta(Complex(big"2.0"))
@test_throws MethodError zeta(Complex(big"2"))
@test_throws MethodError zeta(Complex(big"2.0"))

0 comments on commit 4b188ae

Please sign in to comment.