From 5466fb8e81eed725151076dea15def84db78e319 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 9 Feb 2017 16:48:59 +0800 Subject: [PATCH 1/3] Improve type consistency of special functions --- src/gamma.jl | 163 ++++++++++++++++++++++++++++++----------------- test/runtests.jl | 32 ++++++++++ 2 files changed, 136 insertions(+), 59 deletions(-) diff --git a/src/gamma.jl b/src/gamma.jl index e66b41af..2607c834 100644 --- a/src/gamma.jl +++ b/src/gamma.jl @@ -3,6 +3,8 @@ using Base.Math: signflip, f16, f32, f64 using Base.MPFR: ROUNDING_MODE, big_ln2 +typealias ComplexOrReal{T} Union{T,Complex{T}} + # Bernoulli numbers B_{2k}, using tabulated numerators and denominators from # the online encyclopedia of integer sequences. (They actually have data # up to k=249, but we stop here at k=20.) Used for generating the polygamma @@ -16,7 +18,7 @@ using Base.MPFR: ROUNDING_MODE, big_ln2 Compute the digamma function of `x` (the logarithmic derivative of `gamma(x)`). """ -function digamma(z::Union{Float64,Complex{Float64}}) +function digamma(z::ComplexOrReal{Float64}) # Based on eq. (12), without looking at the accompanying source # code, of: K. S. Kölbig, "Programs for computing the logarithm of # the gamma function, and the digamma function, for complex @@ -56,7 +58,7 @@ end Compute the trigamma function of `x` (the logarithmic second derivative of `gamma(x)`). """ -function trigamma(z::Union{Float64,Complex{Float64}}) +function trigamma(z::ComplexOrReal{Float64}) # via the derivative of the Kölbig digamma formulation x = real(z) if x <= 0 # reflection formula @@ -213,8 +215,7 @@ this definition is equivalent to the Hurwitz zeta function ``\\sum_{k=0}^\\infty (k+z)^{-s}``. For ``z=1``, it yields the Riemann zeta function ``\\zeta(s)``. """ -function zeta(s::Union{Int,Float64,Complex{Float64}}, - z::Union{Float64,Complex{Float64}}) +function zeta(s::ComplexOrReal{Float64}, z::ComplexOrReal{Float64}) ζ = zero(promote_type(typeof(s), typeof(z))) (z == 1 || z == 0) && return oftype(ζ, zeta(s)) @@ -263,7 +264,8 @@ function zeta(s::Union{Int,Float64,Complex{Float64}}, minus_z = -z ζ += pow_oftype(ζ, minus_z, minus_s) # ν = 0 term if xf != z - ζ += pow_oftype(ζ, z - nx, minus_s) # real(z - nx) > 0, so use correct branch cut + ζ += pow_oftype(ζ, z - nx, minus_s) + # real(z - nx) > 0, so use correct branch cut # otherwise, if xf==z, then the definition skips this term end # do loop in different order, depending on the sign of s, @@ -316,10 +318,10 @@ end """ polygamma(m, x) -Compute the polygamma function of order `m` of argument `x` (the `(m+1)th` derivative of the +Compute the polygamma function of order `m` of argument `x` (the `(m+1)`th derivative of the logarithm of `gamma(x)`) """ -function polygamma(m::Integer, z::Union{Float64,Complex{Float64}}) +function polygamma(m::Integer, z::ComplexOrReal{Float64}) m == 0 && return digamma(z) m == 1 && return trigamma(z) @@ -337,7 +339,9 @@ function polygamma(m::Integer, z::Union{Float64,Complex{Float64}}) # constants. We throw a DomainError() since the definition is unclear. real(m) < 0 && throw(DomainError()) - s = m+1 + s = Float64(m+1) + # It is safe to convert any integer (including `BigInt`) to a float here + # as underflow occurs before precision issues. if real(z) <= 0 # reflection formula (zeta(s, 1-z) + signflip(m, cotderiv(m,z))) * (-gamma(s)) else @@ -345,32 +349,15 @@ function polygamma(m::Integer, z::Union{Float64,Complex{Float64}}) end end -# If we really cared about single precision, we could make a faster -# Float32 version by truncating the Stirling series at a smaller cutoff. -for (f,T) in ((:f32,Float32),(:f16,Float16)) - @eval begin - zeta(s::Integer, z::Union{$T,Complex{$T}}) = $f(zeta(Int(s), f64(z))) - zeta(s::Union{Float64,Complex128}, z::Union{$T,Complex{$T}}) = zeta(s, f64(z)) - zeta(s::Number, z::Union{$T,Complex{$T}}) = $f(zeta(f64(s), f64(z))) - polygamma(m::Integer, z::Union{$T,Complex{$T}}) = $f(polygamma(Int(m), f64(z))) - digamma(z::Union{$T,Complex{$T}}) = $f(digamma(f64(z))) - trigamma(z::Union{$T,Complex{$T}}) = $f(trigamma(f64(z))) - end -end - -zeta(s::Integer, z::Number) = zeta(Int(s), f64(z)) -zeta(s::Number, z::Number) = zeta(f64(s), f64(z)) -for f in (:digamma, :trigamma) - @eval begin - $f(z::Number) = $f(f64(z)) - end -end -polygamma(m::Integer, z::Number) = polygamma(m, f64(z)) +""" + invdigamma(x) -# Inverse digamma function: -# Implementation of fixed point algorithm described in -# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000 +Compute the inverse [`digamma`](@ref) function of `x`. +""" function invdigamma(y::Float64) + # Implementation of fixed point algorithm described in + # "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000 + # Closed form initial estimates if y >= -2.22 x_old = exp(y) + 0.5 @@ -392,18 +379,16 @@ function invdigamma(y::Float64) return x_new end -invdigamma(x::Float32) = Float32(invdigamma(Float64(x))) """ - invdigamma(x) + zeta(s) -Compute the inverse [`digamma`](@ref) function of `x`. +Riemann zeta function ``\\zeta(s)``. """ -invdigamma(x::Real) = invdigamma(Float64(x)) +function zeta(s::ComplexOrReal{Float64}) + # Riemann zeta function; algorithm is based on specializing the Hurwitz + # zeta function above for z==1. -# Riemann zeta function; algorithm is based on specializing the Hurwitz -# zeta function above for z==1. -function zeta(s::Union{Float64,Complex{Float64}}) # blows up to ±Inf, but get correct sign of imaginary zero s == 1 && return NaN + zero(s) * imag(s) @@ -448,23 +433,18 @@ function zeta(s::Union{Float64,Complex{Float64}}) return ζ end -zeta(x::Integer) = zeta(Float64(x)) -zeta(x::Real) = oftype(float(x),zeta(Float64(x))) - -""" - zeta(s) - -Riemann zeta function ``\\zeta(s)``. -""" -zeta(z::Complex) = oftype(float(z),zeta(Complex128(z))) - function zeta(x::BigFloat) z = BigFloat() ccall((:mpfr_zeta, :libmpfr), Int32, (Ptr{BigFloat}, Ptr{BigFloat}, Int32), &z, &x, ROUNDING_MODE[]) return z end -function eta(z::Union{Float64,Complex{Float64}}) +""" + eta(x) + +Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``. +""" +function eta(z::ComplexOrReal{Float64}) δz = 1 - z if abs(real(δz)) + abs(imag(δz)) < 7e-3 # Taylor expand around z==1 return 0.6931471805599453094172321214581765 * @@ -478,17 +458,82 @@ function eta(z::Union{Float64,Complex{Float64}}) return -zeta(z) * expm1(0.6931471805599453094172321214581765*δz) end end -eta(x::Integer) = eta(Float64(x)) -eta(x::Real) = oftype(float(x),eta(Float64(x))) - -""" - eta(x) - -Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``. -""" -eta(z::Complex) = oftype(float(z),eta(Complex128(z))) function eta(x::BigFloat) x == 1 && return big_ln2() return -zeta(x) * expm1(big_ln2()*(1-x)) end + +# Converting types that we can convert, and not ones we can not +# Float16, and Float32 and their Complex equivalents can be converted to Float64 +# and results converted back. +# Otherwise, we need to make things use their own `float` converting methods +# and in those cases, we do not convert back either as we assume +# they also implement their own versions of the functions, with the correct return types. +# This is the case for BitIntegers (which become `Float64` when `float`ed). +# Otherwise, if they do not implement their version of the functions we +# manually throw a `MethodError`. +# This case occurs, when calling `float` on a type does not change its type, +# as it is already a `float`, and would have hit own method, if one had existed. + + +# If we really cared about single precision, we could make a faster +# Float32 version by truncating the Stirling series at a smaller cutoff. +# and if we really cared about half precision, we could make a faster +# Float16 version, by using a precomputed table look-up. + + +for T in (Float16, Float32, Float64) + @eval f64(x::Complex{$T}) = Complex128(x) + @eval f64(x::$T) = Float64(x) +end + + +for f in (:digamma, :trigamma, :zeta, :eta, :invdigamma) + @eval begin + function $f(z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}}) + oftype(z, $f(f64(z))) + end + + function $f(z::Number) + x = float(z) + typeof(x) === typeof(z) && throw(MethodError($f, (z,))) + # There is nothing to fallback to, as this didn't change the argument types + $f(x) + end + end +end + + +for T1 in (Float16, Float32, Float64), T2 in (Float16, Float32, Float64) + (T1 == T2 == Float64) && continue # Avoid redefining base definition + + @eval function zeta(s::ComplexOrReal{$T1}, z::ComplexOrReal{$T2}) + ζ = zeta(f64(s), f64(z)) + convert(promote_type(typeof(s), typeof(z)), ζ) + end +end + + +function zeta(s::Number, z::Number) + t = float(s) + x = float(z) + if typeof(t) === typeof(s) && typeof(x) === typeof(z) + # There is nothing to fallback to, since this didn't work + throw(MethodError(zeta,(s,z))) + end + zeta(t, x) +end + + +function polygamma(m::Integer, z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}}) + oftype(z, polygamma(m, f64(z))) +end + + +function polygamma(m::Integer, z::Number) + x = float(z) + typeof(x) === typeof(z) && throw(MethodError(polygamma, (m,z))) + # There is nothing to fallback to, since this didn't work + polygamma(m, x) +end diff --git a/test/runtests.jl b/test/runtests.jl index a926a5a9..8c161254 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -454,3 +454,35 @@ end @test typeof(SF.erfc(a)) == BigFloat end end + +@testset "Base Julia issue #17474" begin + @test SF.f64(complex(1f0,1f0)) == complex(1.0, 1.0) + @test SF.f64(1f0) == 1.0 + + @test typeof(SF.eta(big"2")) == BigFloat + @test typeof(SF.zeta(big"2")) == BigFloat + @test typeof(SF.digamma(big"2")) == BigFloat + + @test_throws MethodError SF.trigamma(big"2") + @test_throws MethodError SF.trigamma(big"2.0") + @test_throws MethodError SF.invdigamma(big"2") + @test_throws MethodError SF.invdigamma(big"2.0") + + @test_throws MethodError SF.eta(Complex(big"2")) + @test_throws MethodError SF.eta(Complex(big"2.0")) + @test_throws MethodError SF.zeta(Complex(big"2")) + @test_throws MethodError SF.zeta(Complex(big"2.0")) + @test_throws MethodError SF.zeta(1.0,big"2") + @test_throws MethodError SF.zeta(1.0,big"2.0") + @test_throws MethodError SF.zeta(big"1.0",2.0) + @test_throws MethodError SF.zeta(big"1",2.0) + + + @test typeof(SF.polygamma(3, 0x2)) == Float64 + @test typeof(SF.polygamma(big"3", 2f0)) == Float32 + @test typeof(SF.zeta(1, 2.0)) == Float64 + @test typeof(SF.zeta(1, 2f0)) == Float64 # BitIntegers result in Float64 returns + @test typeof(SF.zeta(2f0, complex(2f0,0f0))) == Complex{Float32} + @test typeof(SF.zeta(complex(1,1), 2f0)) == Complex{Float64} + @test typeof(SF.zeta(complex(1), 2.0)) == Complex{Float64} +end From 2b04804a486be64cd59b883cb63aceb2276df0cc Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Thu, 9 Feb 2017 19:20:49 -0800 Subject: [PATCH 2/3] Add signflip from Base --- src/gamma.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gamma.jl b/src/gamma.jl index 2607c834..017f8724 100644 --- a/src/gamma.jl +++ b/src/gamma.jl @@ -1,6 +1,5 @@ # This file contains code that was formerly a part of Julia. License is MIT: http://julialang.org/license -using Base.Math: signflip, f16, f32, f64 using Base.MPFR: ROUNDING_MODE, big_ln2 typealias ComplexOrReal{T} Union{T,Complex{T}} @@ -81,6 +80,9 @@ function trigamma(z::ComplexOrReal{Float64}) ψ += t*w * @evalpoly(w,0.16666666666666666,-0.03333333333333333,0.023809523809523808,-0.03333333333333333,0.07575757575757576,-0.2531135531135531,1.1666666666666667,-7.092156862745098) end +signflip(m::Number, z) = (-1+0im)^m * z +signflip(m::Integer, z) = iseven(m) ? z : -z + # (-1)^m d^m/dz^m cot(z) = p_m(cot z), where p_m is a polynomial # that satisfies the recurrence p_{m+1}(x) = p_m′(x) * (1 + x^2). # Note that p_m(x) has only even powers of x if m is odd, and From a1ef500412ec6fdddaf0e63ab43c2b7430c076a3 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 10 Feb 2017 11:33:13 -0800 Subject: [PATCH 3/3] Test for type equality in f64 tests --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8c161254..e043141f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -456,8 +456,8 @@ end end @testset "Base Julia issue #17474" begin - @test SF.f64(complex(1f0,1f0)) == complex(1.0, 1.0) - @test SF.f64(1f0) == 1.0 + @test SF.f64(complex(1f0,1f0)) === complex(1.0, 1.0) + @test SF.f64(1f0) === 1.0 @test typeof(SF.eta(big"2")) == BigFloat @test typeof(SF.zeta(big"2")) == BigFloat