diff --git a/Project.toml b/Project.toml index a5a1efa3..23b2fff2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogExpFunctions" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" authors = ["StatsFun.jl contributors, Tamas K. Papp "] -version = "0.3.7" +version = "0.3.8" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 39a5d68e..26997cf0 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -152,9 +152,59 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`. This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref). + +See: + * Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf) """ -log1pexp(x::Real) = x < 18.0 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : oftype(exp(-x), x) -log1pexp(x::Float32) = x < 9.0f0 ? log1p(exp(x)) : x < 16.0f0 ? x + exp(-x) : oftype(exp(-x), x) +log1pexp(x::Real) = _log1pexp(float(x)) # ensures that BigInt/BigFloat, Int/Float64 etc. dispatch to the same algorithm + +# Approximations based on Maechler (2012) +# Argument `x` is a floating point number due to the definition of `log1pexp` above +function _log1pexp(x::Real) + x0, x1, x2 = _log1pexp_thresholds(x) + if x < x0 + return exp(x) + elseif x < x1 + return log1p(exp(x)) + elseif x < x2 + return x + exp(-x) + else + return x + end +end + +#= The precision of BigFloat cannot be computed from the type only and computing +thresholds is slow. Therefore prefer version without thresholds in this case. =# +_log1pexp(x::BigFloat) = x > 0 ? x + log1p(exp(-x)) : log1p(exp(x)) + +#= +Returns thresholds x0, x1, x2 such that: + + * log1pexp(x) ≈ exp(x) for x ≤ x0 + * log1pexp(x) ≈ log1p(exp(x)) for x0 < x ≤ x1 + * log1pexp(x) ≈ x + exp(-x) for x1 < x ≤ x2 + * log1pexp(x) ≈ x for x > x2 + +where the tolerances of the approximations are on the order of eps(typeof(x)). +For types for which `precision(x)` depends only on the type of `x`, the compiler +should optimize away all computations done here. +=# +@inline function _log1pexp_thresholds(x::Real) + prec = precision(x) + logtwo = oftype(x, IrrationalConstants.logtwo) + x0 = -prec * logtwo + x1 = (prec - 1) * logtwo / 2 + x2 = -x0 - log(-x0) * (1 + 1 / x0) # approximate root of e^-x == x * ϵ/2 via asymptotics of Lambert's W function + return (x0, x1, x2) +end + +#= +For common types we hard-code the thresholds to make absolutely sure they are not recomputed +each time. Also, _log1pexp_thresholds is not elided by the compiler in Julia 1.0 / 1.6. +=# +@inline _log1pexp_thresholds(::Float64) = (-36.7368005696771, 18.021826694558577, 33.23111882352963) +@inline _log1pexp_thresholds(::Float32) = (-16.635532f0, 7.9711924f0, 13.993f0) +@inline _log1pexp_thresholds(::Float16) = (Float16(-7.625), Float16(3.467), Float16(5.86)) """ $(SIGNATURES) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 7842e596..c30a7ae5 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -110,15 +110,37 @@ end # log1pexp, log1mexp, log2mexp & logexpm1 @testset "log1pexp" begin - @test log1pexp(2.0) ≈ log(1.0 + exp(2.0)) - @test log1pexp(-2.0) ≈ log(1.0 + exp(-2.0)) - @test log1pexp(10000) ≈ 10000.0 - @test log1pexp(-10000) ≈ 0.0 - - @test log1pexp(2f0) ≈ log(1f0 + exp(2f0)) - @test log1pexp(-2f0) ≈ log(1f0 + exp(-2f0)) - @test log1pexp(10000f0) ≈ 10000f0 - @test log1pexp(-10000f0) ≈ 0f0 + for T in (Float16, Float32, Float64, BigFloat), x in 1:40 + @test (@inferred log1pexp(+log(T(x)))) ≈ T(log1p(big(x))) + @test (@inferred log1pexp(-log(T(x)))) ≈ T(log1p(1/big(x))) + end + + # special values + @test (@inferred log1pexp(0)) ≈ log(2) + @test (@inferred log1pexp(0f0)) ≈ log(2) + @test (@inferred log1pexp(big(0))) ≈ log(2) + @test (@inferred log1pexp(+1)) ≈ log1p(ℯ) + @test (@inferred log1pexp(-1)) ≈ log1p(ℯ) - 1 + + # large arguments + @test (@inferred log1pexp(1e4)) ≈ 1e4 + @test (@inferred log1pexp(1f4)) ≈ 1f4 + @test iszero(@inferred log1pexp(-1e4)) + @test iszero(@inferred log1pexp(-1f4)) + + # compare to accurate but slower implementation + correct_log1pexp(x::Real) = x > 0 ? x + log1p(exp(-x)) : log1p(exp(x)) + # large range needed to cover all branches, for all floats (from Float16 to BigFloat) + for T in (Int, Float16, Float32, Float64, BigInt, BigFloat), x in -300:300 + @test (@inferred log1pexp(T(x))) ≈ float(T)(correct_log1pexp(big(x))) + end + # test BigFloat with multiple precisions + for prec in (10, 20, 50, 100), x in -300:300 + setprecision(prec) do + y = big(float(x)) + @test @inferred(log1pexp(y)) ≈ correct_log1pexp(y) + end + end end @testset "log1mexp" begin diff --git a/test/chainrules.jl b/test/chainrules.jl index 3e16153d..18434d4a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -57,14 +57,11 @@ test_rrule(logcosh, x) end - # test all branches of `log1pexp` - for x in (-20.9, 15.4, 41.5) - test_frule(log1pexp, x) - test_rrule(log1pexp, x) - end - for x in (8.3f0, 12.5f0, 21.2f0) - test_frule(log1pexp, x; rtol=1f-3, atol=1f-3) - test_rrule(log1pexp, x; rtol=1f-3, atol=1f-3) + @testset "log1pexp" begin + for absx in (0, 1, 2, 10, 15, 20, 40), x in (-absx, absx) + test_scalar(log1pexp, Float64(x)) + test_scalar(log1pexp, Float32(x); rtol=1f-3, atol=1f-3) + end end for x in (-10.2, -3.3, -0.3)