Skip to content

Commit

Permalink
generic log1pexp
Browse files Browse the repository at this point in the history
  • Loading branch information
cossio committed Mar 7, 2022
1 parent 9e6fc95 commit 844564f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
34 changes: 32 additions & 2 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,39 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`.
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref).
Uses a fast implementation for floats based on Mächler (2012), available at:
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf.
"""
log1pexp(x::Real) = x -37 ? exp(x) : x 18 ? log1p(exp(x)) : x 33.3 ? x + exp(-x) : float(x)
log1pexp(x::Float32) = x < 9f0 ? log1p(exp(x)) : x < 16f0 ? x + exp(-x) : oftype(exp(-x), x)
function log1pexp(x::Real)
t = log1p(exp(-abs(x)))
return x 0 ? t : t + x
end

function log1pexp(x::Union{Float16,Float32,Float64})
a, b, c = _log1pexp_branch_bounds(x)
if x a
return exp(x)
elseif x b
return log1p(exp(x))
elseif x c
return x + exp(-x)
else
return x
end
end

#=
Given the `approx` used in a branch of log1pexp(x) above, we find the first `x` (from above
or below) that is a root of
T(log1pexp(big(x))) - approx(T(x))
This determines the branch bounds below.
=#
_log1pexp_branch_bounds(::Float64) = (-37, 18, 33.3)
_log1pexp_branch_bounds(::Float32) = (-15, 9, 14.5)
_log1pexp_branch_bounds(::Float16) = (-7, 3, 5.7)

"""
$(SIGNATURES)
Expand Down
17 changes: 13 additions & 4 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,19 @@ end
# log1pexp, log1mexp, log2mexp & logexpm1

@testset "log1pexp" begin
# test every branch
for x in (0, 1, 2, 10, 20, 40), T in (Int, Float32, Float64)
@test (@inferred log1pexp(-T(x))) log1p(exp(big(-x)))
@test (@inferred log1pexp(+T(x))) log1p(exp(big(+x)))
# generic method
@test (@inferred log1pexp(big(0))) log(big(2))
for x in 1:10, s in (-1, 1)
@test (@inferred log1pexp(log(big(x)))) log(big(1 + x))
@test (@inferred log1pexp(-log(big(x)))) log(big(1 + 1//x))
@test (@inferred log1pexp(big(x))) log(1 + exp(big(x)))
@test (@inferred log1pexp(-big(x))) log(1 + exp(-big(x)))
end

# test branches of specialized approximations
for x in (0, 1, 2, 10, 15, 20, 40), T in (Float16, Float32, Float64)
@test (@inferred log1pexp(-T(x))) T(log1pexp(big(-x)))
@test (@inferred log1pexp(+T(x))) T(log1pexp(big(+x)))
end

# large arguments
Expand Down
13 changes: 5 additions & 8 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,11 @@
test_rrule(logcosh, x)
end

# test all branches of `log1pexp`
for x in (-20.9, 15.4, 41.5)
test_frule(log1pexp, x)
test_rrule(log1pexp, x)
end
for x in (8.3f0, 12.5f0, 21.2f0)
test_frule(log1pexp, x; rtol=1f-3, atol=1f-3)
test_rrule(log1pexp, x; rtol=1f-3, atol=1f-3)
@testset "log1pexp" begin
for x in (0, 1, 2, 10, 15, 20, 40), s in (-1, 1)
test_scalar(log1pexp, Float64(s * x))
test_scalar(log1pexp, Float32(s * x); rtol=1f-3, atol=1f-3)
end
end

for x in (-10.2, -3.3, -0.3)
Expand Down

0 comments on commit 844564f

Please sign in to comment.