From cd9a6647bdf88d330acdf496f75d287af7e8c4e6 Mon Sep 17 00:00:00 2001 From: cossio Date: Mon, 7 Mar 2022 16:32:36 +0100 Subject: [PATCH] generic log1pexp --- src/basicfuns.jl | 36 ++++++++++++++++++++++++++++++++++-- test/basicfuns.jl | 17 +++++++++++++---- test/chainrules.jl | 13 +++++-------- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 4ac992a4..513bf62b 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -152,9 +152,41 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`. This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref). + +See: + * Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf) + +Note: different than Maechler (2012), also uses bounds specific to Float32 and Float16. """ -log1pexp(x::Real) = x ≤ -37 ? exp(x) : x ≤ 18 ? log1p(exp(x)) : x ≤ 33.3 ? x + exp(-x) : float(x) -log1pexp(x::Float32) = x < 9f0 ? log1p(exp(x)) : x < 16f0 ? x + exp(-x) : oftype(exp(-x), x) +function log1pexp(x::Real) + t = log1p(exp(-abs(x))) + return x ≤ 0 ? t : t + x +end + +function log1pexp(x::Union{Float16,Float32,Float64}) + a, b, c = _log1pexp_branch_bounds(x) + if x ≤ a + return exp(x) + elseif x ≤ b + return log1p(exp(x)) + elseif x ≤ c + return x + exp(-x) + else + return x + end +end + +#= +Given the `approx` used in a branch of log1pexp(x) above, we find the first `x` (from above +or below) that is a root of + + T(log1pexp(big(x))) - approx(T(x)) + +This determines the branch bounds below. +=# +_log1pexp_branch_bounds(::Float64) = (-37.0, 18.0, 33.3) +_log1pexp_branch_bounds(::Float32) = (-15f0, 9f0, 14.5f0) +_log1pexp_branch_bounds(::Float16) = (Float16(-7), Float16(3), Float16(5.7)) """ $(SIGNATURES) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index de8ab7a0..3166248d 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -110,10 +110,19 @@ end # log1pexp, log1mexp, log2mexp & logexpm1 @testset "log1pexp" begin - # test every branch - for x in (0, 1, 2, 10, 20, 40), T in (Int, Float32, Float64) - @test (@inferred log1pexp(-T(x))) ≈ log1p(exp(big(-x))) - @test (@inferred log1pexp(+T(x))) ≈ log1p(exp(big(+x))) + # generic method + @test (@inferred log1pexp(big(0))) ≈ log(big(2)) + for x in 1:10, s in (-1, 1) + @test (@inferred log1pexp(log(big(x)))) ≈ log(big(1 + x)) + @test (@inferred log1pexp(-log(big(x)))) ≈ log(big(1 + 1//x)) + @test (@inferred log1pexp(big(x))) ≈ log(1 + exp(big(x))) + @test (@inferred log1pexp(-big(x))) ≈ log(1 + exp(-big(x))) + end + + # test branches of specialized approximations + for x in (0, 1, 2, 10, 15, 20, 40), T in (Float16, Float32, Float64) + @test (@inferred log1pexp(-T(x))) ≈ T(log1pexp(big(-x))) + @test (@inferred log1pexp(+T(x))) ≈ T(log1pexp(big(+x))) end # large arguments diff --git a/test/chainrules.jl b/test/chainrules.jl index 3e16153d..acd04599 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -57,14 +57,11 @@ test_rrule(logcosh, x) end - # test all branches of `log1pexp` - for x in (-20.9, 15.4, 41.5) - test_frule(log1pexp, x) - test_rrule(log1pexp, x) - end - for x in (8.3f0, 12.5f0, 21.2f0) - test_frule(log1pexp, x; rtol=1f-3, atol=1f-3) - test_rrule(log1pexp, x; rtol=1f-3, atol=1f-3) + @testset "log1pexp" begin + for x in (0, 1, 2, 10, 15, 20, 40), s in (-1, 1) + test_scalar(log1pexp, Float64(s * x)) + test_scalar(log1pexp, Float32(s * x); rtol=1f-3, atol=1f-3) + end end for x in (-10.2, -3.3, -0.3)