Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid StackOverflowError with erf* functions #353

Merged
merged 2 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 89 additions & 86 deletions src/erf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,73 @@ using Base.Math: @horner
using Base.MPFR: ROUNDING_MODE

for f in (:erf, :erfc)
internalf = Symbol(:_, f)
libopenlibmf = QuoteNode(f)
libopenlibmf0 = QuoteNode(Symbol(f, :f))
openspecfunf = QuoteNode(Symbol(:Faddeeva_, f))
mpfrf = QuoteNode(Symbol(:mpfr_, f))
@eval begin
($f)(x::Float64) = ccall(($(string(f)),libopenlibm), Float64, (Float64,), x)
($f)(x::Float32) = ccall(($(string(f,"f")),libopenlibm), Float32, (Float32,), x)
($f)(x::Real) = ($f)(float(x))
($f)(a::Float16) = Float16($f(Float32(a)))
($f)(a::Complex{Float16}) = Complex{Float16}($f(Complex{Float32}(a)))
function ($f)(x::BigFloat)
$f(x::Number) = $internalf(float(x))

$internalf(x::Float64) = ccall(($libopenlibmf, libopenlibm), Float64, (Float64,), x)
$internalf(x::Float32) = ccall(($libopenlibmf0, libopenlibm), Float32, (Float32,), x)
$internalf(x::Float16) = Float16($internalf(Float32(x)))

$internalf(z::Complex{Float64}) = Complex{Float64}(ccall(($openspecfunf, libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), z, zero(Float64)))
$internalf(z::Complex{Float32}) = Complex{Float32}(ccall(($openspecfunf, libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), Complex{Float64}(z), Float64(eps(Float32))))
$internalf(z::Complex{Float16}) = Complex{Float16}($internalf(Complex{Float32}(z)))

function $internalf(x::BigFloat)
z = BigFloat()
ccall(($(string(:mpfr_,f)), :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Int32), z, x, ROUNDING_MODE[])
ccall(($mpfrf, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Int32), z, x, ROUNDING_MODE[])
return z
end
($f)(x::AbstractFloat) = error("not implemented for ", typeof(x))
end
end

for f in (:erf, :erfc, :erfcx, :erfi, :Dawson)
fname = (f === :Dawson) ? :dawson : f
for f in (:erfcx, :erfi, :dawson)
internalf = Symbol(:_, f)
openspecfunfsym = Symbol(:Faddeeva_, f === :dawson ? :Dawson : f)
openspecfunfF64 = QuoteNode(Symbol(openspecfunfsym, :_re))
openspecfunfCF64 = QuoteNode(openspecfunfsym)
@eval begin
($fname)(z::Complex{Float64}) = Complex{Float64}(ccall(($(string("Faddeeva_",f)),libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), z, zero(Float64)))
($fname)(z::Complex{Float32}) = Complex{Float32}(ccall(($(string("Faddeeva_",f)),libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), Complex{Float64}(z), Float64(eps(Float32))))
$f(x::Number) = $internalf(float(x))

$internalf(x::Float64) = ccall(($openspecfunfF64, libopenspecfun), Float64, (Float64,), x)
$internalf(x::Float32) = Float32($internalf(Float64(x)))
$internalf(x::Float16) = Float16($internalf(Float64(x)))
stevengj marked this conversation as resolved.
Show resolved Hide resolved

($fname)(z::Complex) = ($fname)(float(z))
($fname)(z::Complex{<:AbstractFloat}) = throw(MethodError($fname,(z,)))
$internalf(z::Complex{Float64}) = Complex{Float64}(ccall(($openspecfunfCF64, libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), z, zero(Float64)))
$internalf(z::Complex{Float32}) = Complex{Float32}(ccall(($openspecfunfCF64, libopenspecfun), Complex{Float64}, (Complex{Float64}, Float64), Complex{Float64}(z), Float64(eps(Float32))))
$internalf(z::Complex{Float16}) = Complex{Float16}($internalf(Complex{Float32}(z)))
stevengj marked this conversation as resolved.
Show resolved Hide resolved
end
end

for f in (:erfcx, :erfi, :Dawson)
fname = (f === :Dawson) ? :dawson : f
@eval begin
($fname)(x::Float64) = ccall(($(string("Faddeeva_",f,"_re")),libopenspecfun), Float64, (Float64,), x)
($fname)(x::Float32) = Float32(ccall(($(string("Faddeeva_",f,"_re")),libopenspecfun), Float64, (Float64,), Float64(x)))

($fname)(x::Real) = ($fname)(float(x))
($fname)(x::AbstractFloat) = throw(MethodError($fname,(x,)))
# MPFR has an open TODO item for this function
# until then, we use [DLMF 7.12.1](https://dlmf.nist.gov/7.12.1) for the tail
function _erfcx(x::BigFloat)
if x <= (Clong == Int32 ? 0x1p15 : 0x1p30)
# any larger gives internal overflow
return exp(x^2)*erfc(x)
elseif !isfinite(x)
return 1/x
else
# asymptotic series
# starts to diverge at iteration i = 2^30 or 2^60
# final term will be < Γ(2*i+1)/(2^i * Γ(i+1)) / (2^(i+1))
# so good to (lgamma(2*i+1) - lgamma(i+1))/log(2) - 2*i - 1
# ≈ 3.07e10 or 6.75e19 bits
# which is larger than the memory of the respective machines
ϵ = eps(BigFloat)/4
v = 1/(2*x*x)
k = 1
s = w = -k*v
while abs(w) > ϵ
k += 2
w *= -k*v
s += w
end
return (1+s)/(x*sqrtπ)
end
end

Expand Down Expand Up @@ -204,7 +237,9 @@ Using the rational approximants tabulated in:
> <http://www.jstor.org/stable/2005402>
combined with Newton iterations for `BigFloat`.
"""
function erfinv(x::Float64)
erfinv(x::Real) = _erfinv(float(x))

function _erfinv(x::Float64)
a = abs(x)
if a >= 1.0
if x == 1.0
Expand Down Expand Up @@ -272,7 +307,7 @@ function erfinv(x::Float64)
end
end

function erfinv(x::Float32)
function _erfinv(x::Float32)
a = abs(x)
if a >= 1.0f0
if x == 1.0f0
Expand Down Expand Up @@ -315,7 +350,25 @@ function erfinv(x::Float32)
end
end

erfinv(x::Union{Integer,Rational}) = erfinv(float(x))
function _erfinv(y::BigFloat)
xfloat = erfinv(Float64(y))
if isfinite(xfloat)
x = BigFloat(xfloat)
else
# Float64 overflowed, use asymptotic estimate instead
# from erfc(x) ≈ exp(-x²)/x√π ≈ y ⟹ -log(yπ) ≈ x² + log(x) ≈ x²
x = copysign(sqrt(-log((1-abs(y))*sqrtπ)), y)
isfinite(x) || return x
end
sqrtπhalf = sqrtπ * big(0.5)
tol = 2eps(abs(x))
while true # Newton iterations
Δx = sqrtπhalf * (erf(x) - y) * exp(x^2)
x -= Δx
abs(Δx) < tol && break
end
return x
end

@doc raw"""
erfcinv(x)
Expand All @@ -341,7 +394,9 @@ Using the rational approximants tabulated in:
> <http://www.jstor.org/stable/2005402>
combined with Newton iterations for `BigFloat`.
"""
function erfcinv(y::Float64)
erfcinv(x::Real) = _erfcinv(float(x))

function _erfcinv(y::Float64)
if y > 0.0625
return erfinv(1.0 - y)
elseif y <= 0.0
Expand Down Expand Up @@ -393,7 +448,7 @@ function erfcinv(y::Float64)
end
end

function erfcinv(y::Float32)
function _erfcinv(y::Float32)
if y > 0.0625f0
return erfinv(1.0f0 - y)
elseif y <= 0.0f0
Expand All @@ -415,27 +470,7 @@ function erfcinv(y::Float32)
end
end

function erfinv(y::BigFloat)
xfloat = erfinv(Float64(y))
if isfinite(xfloat)
x = BigFloat(xfloat)
else
# Float64 overflowed, use asymptotic estimate instead
# from erfc(x) ≈ exp(-x²)/x√π ≈ y ⟹ -log(yπ) ≈ x² + log(x) ≈ x²
x = copysign(sqrt(-log((1-abs(y))*sqrtπ)), y)
isfinite(x) || return x
end
sqrtπhalf = sqrtπ * big(0.5)
tol = 2eps(abs(x))
while true # Newton iterations
Δx = sqrtπhalf * (erf(x) - y) * exp(x^2)
x -= Δx
abs(Δx) < tol && break
end
return x
end

function erfcinv(y::BigFloat)
function _erfcinv(y::BigFloat)
yfloat = Float64(y)
xfloat = erfcinv(yfloat)
if isfinite(xfloat)
Expand All @@ -461,36 +496,6 @@ function erfcinv(y::BigFloat)
return x
end

erfcinv(x::Union{Integer,Rational}) = erfcinv(float(x))

# MPFR has an open TODO item for this function
# until then, we use [DLMF 7.12.1](https://dlmf.nist.gov/7.12.1) for the tail
function erfcx(x::BigFloat)
if x <= (Clong == Int32 ? 0x1p15 : 0x1p30)
# any larger gives internal overflow
return exp(x^2)*erfc(x)
elseif !isfinite(x)
return 1/x
else
# asymptotic series
# starts to diverge at iteration i = 2^30 or 2^60
# final term will be < Γ(2*i+1)/(2^i * Γ(i+1)) / (2^(i+1))
# so good to (lgamma(2*i+1) - lgamma(i+1))/log(2) - 2*i - 1
# ≈ 3.07e10 or 6.75e19 bits
# which is larger than the memory of the respective machines
ϵ = eps(BigFloat)/4
v = 1/(2*x*x)
k = 1
s = w = -k*v
while abs(w) > ϵ
k += 2
w *= -k*v
s += w
end
return (1+s)/(x*sqrtπ)
end
end

@doc raw"""
logerfc(x)

Expand All @@ -511,7 +516,9 @@ See also: [`erfcx(x)`](@ref erfcx).
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
"""
function logerfc(x::Union{Float32, Float64, BigFloat})
logerfc(x::Real) = _logerfc(float(x))

function _logerfc(x::Union{Float32, Float64, BigFloat})
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
if x > 0.0
return log(erfcx(x)) - x^2
Expand All @@ -520,9 +527,6 @@ function logerfc(x::Union{Float32, Float64, BigFloat})
end
end

logerfc(x::Real) = logerfc(float(x))
logerfc(x::AbstractFloat) = throw(MethodError(logerfc, x))

@doc raw"""
logerfcx(x)

Expand All @@ -543,7 +547,9 @@ See also: [`erfcx(x)`](@ref erfcx).
Based on the [`erfc(x)`](@ref erfc) and [`erfcx(x)`](@ref erfcx) functions.
Currently only implemented for `Float32`, `Float64`, and `BigFloat`.
"""
function logerfcx(x::Union{Float32, Float64, BigFloat})
logerfcx(x::Real) = _logerfcx(float(x))

function _logerfcx(x::Union{Float32, Float64, BigFloat})
# Don't include Float16 in the Union, otherwise logerfc would currently work for x <= 0.0, but not x > 0.0
if x < 0.0
return log(erfc(x)) + x^2
Expand All @@ -552,9 +558,6 @@ function logerfcx(x::Union{Float32, Float64, BigFloat})
end
end

logerfcx(x::Real) = logerfcx(float(x))
logerfcx(x::AbstractFloat) = throw(MethodError(logerfcx, x))

@doc raw"""
logerf(x, y)

Expand Down
24 changes: 18 additions & 6 deletions test/erf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@test erfc(Float32(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float32)
@test erfc(Float64(1)) ≈ 0.15729920705028513066 rtol=2*eps(Float64)

@test_throws MethodError erfcx(Float16(1))
@test erfcx(Float16(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float16)
@test erfcx(Float32(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float32)
@test erfcx(Float64(1)) ≈ 0.42758357615580700442 rtol=2*eps(Float64)

Expand Down Expand Up @@ -38,7 +38,7 @@
@test logerfcx(Float32(1000)) ≈ -7.48012072190621214066734919080 rtol=2eps(Float32)
@test logerfcx(Float64(1000)) ≈ -7.48012072190621214066734919080 rtol=2eps(Float64)

@test_throws MethodError erfi(Float16(1))
@test erfi(Float16(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float16)
@test erfi(Float32(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float32)
@test erfi(Float64(1)) ≈ 1.6504257587975428760 rtol=2*eps(Float64)

Expand All @@ -52,7 +52,7 @@
@test erfcinv(Float32(0.15729920705028513066)) ≈ 1 rtol=2*eps(Float32)
@test erfcinv(Float64(0.15729920705028513066)) ≈ 1 rtol=2*eps(Float64)

@test_throws MethodError dawson(Float16(1))
@test dawson(Float16(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float16)
@test dawson(Float32(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float32)
@test dawson(Float64(1)) ≈ 0.53807950691276841914 rtol=2*eps(Float64)
end
Expand All @@ -66,19 +66,19 @@
@test erfc(ComplexF32(1+2im)) ≈ 1.5366435657785650340+5.0491437034470346695im
@test erfc(ComplexF64(1+2im)) ≈ 1.5366435657785650340+5.0491437034470346695im

@test_throws MethodError erfcx(ComplexF16(1))
@test erfcx(ComplexF16(1+2im)) ≈ 0.14023958136627794370-0.22221344017989910261im
@test erfcx(ComplexF32(1+2im)) ≈ 0.14023958136627794370-0.22221344017989910261im
@test erfcx(ComplexF64(1+2im)) ≈ 0.14023958136627794370-0.22221344017989910261im

@test_throws MethodError erfi(ComplexF16(1))
@test erfi(ComplexF16(1+2im)) ≈ -0.011259006028815025076+1.0036063427256517509im
@test erfi(ComplexF32(1+2im)) ≈ -0.011259006028815025076+1.0036063427256517509im
@test erfi(ComplexF64(1+2im)) ≈ -0.011259006028815025076+1.0036063427256517509im

@test_throws MethodError erfinv(Complex(1))

@test_throws MethodError erfcinv(Complex(1))

@test_throws MethodError dawson(ComplexF16(1))
@test dawson(ComplexF16(1+2im)) ≈ -13.388927316482919244-11.828715103889593303im
@test dawson(ComplexF32(1+2im)) ≈ -13.388927316482919244-11.828715103889593303im
@test dawson(ComplexF64(1+2im)) ≈ -13.388927316482919244-11.828715103889593303im
end
Expand Down Expand Up @@ -116,6 +116,18 @@
end
end

@testset "Other float types" begin
struct NotAFloat <: AbstractFloat end

@test_throws MethodError erf(NotAFloat())
@test_throws MethodError erfc(NotAFloat())
@test_throws MethodError erfcx(NotAFloat())
@test_throws MethodError erfi(NotAFloat())
@test_throws MethodError erfinv(NotAFloat())
@test_throws MethodError erfcinv(NotAFloat())
@test_throws MethodError dawson(NotAFloat())
end

@testset "inverse" begin
for elty in [Float32,Float64]
for x in exp10.(range(-200, stop=-0.01, length=50))
Expand Down