-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
log1pexp #37
log1pexp #37
Conversation
Benchmarks. using BenchmarkTools, LogExpFunctions
log1pexp_old(x::Float64) = x < 18 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : oftype(exp(-x), x) # before this PR
log1pexp_new(x::Float64) = x ≤ -37 ? exp(x) : x ≤ 18 ? log1p(exp(x)) : x ≤ 33.3 ? x + exp(-x) : float(x) # this PR
julia> @benchmark log1pexp_old(x) setup=(x=(rand() - 0.5) * 100) samples=10^6
BenchmarkTools.Trial: 240558 samples with 998 evaluations.
Range (min … max): 9.675 ns … 136.391 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 19.892 ns ┊ GC (median): 0.00%
Time (mean ± σ): 19.025 ns ± 7.230 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▁█ ▂▄ ▅
▃███▆▃▁▁▃▁▃▄█▂▂▁▁▁▁▁▁▁▅▃▆▆▅██▁▃▁▁▁▁▁▁▁▂▁▃▁▄▂▁▆▁▁█▁▁▂▁▁▁▁▁▁▁▁ ▂
9.68 ns Histogram: frequency by time 35.2 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark log1pexp_new(x) setup=(x=(rand() - 0.5) * 100) samples=10^6
BenchmarkTools.Trial: 281295 samples with 1000 evaluations.
Range (min … max): 2.794 ns … 343.354 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 18.832 ns ┊ GC (median): 0.00%
Time (mean ± σ): 16.156 ns ± 8.600 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█ ▄ ▄
█▂▁▁▁▁▁▁▁▁▁▁▁▆▇▂▇▆▃▂▂▁▁▁▁▁▁▁▁▁▁▁█▂█▆▁▃▁▂▁▁▁▁▁▁▁▁█▁█▁▄▁▂▁▁▂▁▁ ▂
2.79 ns Histogram: frequency by time 32.1 ns <
Memory estimate: 0 bytes, allocs estimate: 0. |
Benchmark comparing to version with log1pexp_oftype(x::Float64) = x ≤ -37 ? exp(x) : x ≤ 18 ? log1p(exp(x)) : x ≤ 33.3 ? x + exp(-x) : oftype(exp(x), x)
julia> @benchmark log1pexp_new(x) setup=(x=(rand() - 0.5) * 100) samples=10^6
BenchmarkTools.Trial: 272085 samples with 1000 evaluations.
Range (min … max): 2.460 ns … 1.768 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 19.383 ns ┊ GC (median): 0.00%
Time (mean ± σ): 16.689 ns ± 11.719 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█ ▁ ▃▂
█▄▂▂▂▂▂▂▂▂▂▂▁▂▇▅▄█▇▅▄▂▂▂▂▂▂▂▂▂▂▂▂▆▃██▂▇▃▆▂▂▂▂▂▂▂▂▅▂█▂▇▂▅▃▂▄ ▃
2.46 ns Histogram: frequency by time 31.2 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark log1pexp_oftype(x) setup=(x=(rand() - 0.5) * 100) samples=10^6
BenchmarkTools.Trial: 233832 samples with 998 evaluations.
Range (min … max): 9.697 ns … 833.891 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 20.884 ns ┊ GC (median): 0.00%
Time (mean ± σ): 19.594 ns ± 8.968 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▅█▇▆ ▆ █ ▄ ▁ ▂
▂▅▇█████▃▁▁▁▁▁▁▁▁▁▁▂▄▆█▂█▅█▂▃▁▁▁▁▁▂▅▁█▁█▁▇▄▁▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
9.7 ns Histogram: frequency by time 39.7 ns <
Memory estimate: 0 bytes, allocs estimate: 0. The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the ChainRules tests as well? IIRC there we also test every branch.
My main concern is whether there is a licensing issue since this is based on (the vignette of) Rmpfr which uses a GPL. I would prefer if someone who's more familiar with licenses (I'm definitely not) could confirm that there are no problems here and we don't have use the GPL.
The ChainRules tests fail for Are we sure |
I unified the function log1pexp(x::Real)
t = log1p(exp(-abs(x)))
return x ≤ 0 ? t : t + x
end
xs = 0:-0.01:-100
for T in (Float16, Float32, Float64)
for x in xs
correct = T(log1pexp(big(x)))
if iszero(correct - exp(T(x)))
println("Found crossing with `exp(x)` at x = $x for $T, in the interval $(extrema(xs))")
break
end
end
end
xs = 0:0.01:100
for T in (Float16, Float32, Float64)
for x in xs
correct = T(log1pexp(big(x)))
if iszero(correct - (T(x) + exp(-T(x))))
println("Found crossing with `x + exp(-x)` at x = $x for $T, in the interval $(extrema(xs))")
break
end
end
end
xs = 0:0.01:100
for T in (Float16, Float32, Float64)
for x in xs
if iszero(T(log1pexp(big(x))) - T(x))
println("Found crossing with `x` at x = $x for $T, in the interval $(extrema(xs))")
break
end
end
end I removed the ChainRules tests for |
844564f
to
5328ccb
Compare
@devmotion Note that the current julia> log1pexp(Float16(16)) # current master
Inf16 julia> log1pexp(Float16(16)) # after this PR
Float16(16.0) This is because it is using branch bounds tailored for Float64, which are not right for Float16. Does this example persuade you? I agree this is a different issue, but I could try to fix it also here. I'd say that fixing wrong results takes priority over a (minor) performance regression. |
As to the performance issue, let's compare the current implementation on master with the generic using BenchmarkTools
# generic fallback in this PR
function log1pexp_new(x::Real)
t = log1p(exp(-abs(x)))
return x ≤ 0 ? t : t + x
end
# current master
log1pexp_old(x::Real) = x < 18.0 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : oftype(exp(-x), x)
julia> @benchmark log1pexp_old(x) setup=(x=(rand() - 0.5) * 100) samples=10^6
BenchmarkTools.Trial: 278000 samples with 999 evaluations.
Range (min … max): 9.042 ns … 56.144 ns ┊ GC (min … max): 0.00% … 0.00%
Time (median): 17.880 ns ┊ GC (median): 0.00%
Time (mean ± σ): 16.501 ns ± 6.132 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▆█▇▆▄▃▃ ▂▁▇▆▄▃▃▁▂▃▂▁ ▃▆█▅▇▄▅▁▃ ▂▂▁▃ ▁ ▃ ▅▇ ▆▂ ▅▁ ▃ ▂ ▄
████████████████████▇▇▆▇███████████████▇█▆▇▇█▆██▇██▇██▇██▆█ █
9.04 ns Histogram: log(frequency) by time 29.1 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark log1pexp_new(x) setup=(x=(rand() - 0.5) * 100) samples=10^6
BenchmarkTools.Trial: 218739 samples with 997 evaluations.
Range (min … max): 12.472 ns … 1.042 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 21.477 ns ┊ GC (median): 0.00%
Time (mean ± σ): 21.094 ns ± 8.230 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▇ █
▁▃▄▄▂▅▇▁▂▁▁▁▁▁▁▁▂▁▇▂▇▄▇▂█▁█▂▂▂▁▂▁▁▂▂▁▁▂▁▁▂▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁ ▂
12.5 ns Histogram: frequency by time 36 ns <
Memory estimate: 0 bytes, allocs estimate: 0. |
No, it doesn't convince me. The |
I simplified the implementation once more. Updated benchmarks:
whereas the current master takes about 9 secs minimum (see plot above). |
Unfortunately this version: @inline function _log1pexp_thresholds(x::Real)
prec = precision(x)
logtwo = oftype(x, IrrationalConstants.logtwo)
x0 = -prec * logtwo
x1 = (prec - 1) * logtwo / 2
x2 = -x0 - log(-x0) * (1 + 1 / x0) # approximate root of e^-x == x * ϵ/2 via asymptotics of Lambert's W function
return (x0, x1, x2)
end is not compiled away in Julia 1.0. Actually This supports having the hard-coded thresholds for |
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good, I have only a question about the Float64
and Float16
values left (and I think it would also be fine to make the hardcoded values consistent with the generic fallback definition, but I don't have a strong opinion there).
In general, now I think the PR should be an improvement for all inputs x
where float(x)
is a fixed-precision number since then (at least in recent Julia versions) the compiler can optimize away thresholds even for non-standard types. It's only a bit problematic for variable precision numbers - in contrast to the @generated
version the results will be correct but since the thresholds have to be recomputed every time it might cause performance regressions. IMO this is a bit annoying but much better than silently returning wrong results, and it can be fixed in the same way as done here for BigFloat
.
Can you update the version number as well?
@devmotion I attach here the calculation of the thresholds I am using in this PR. |
It seems the approximations in your notes are slightly different from the values in the PR? julia> log(eps(Float64))
-36.04365338911715
julia> -log(2*eps(Float64)) / 2
17.675253104278607 |
The epsilon in the notes is |
Merge? |
My understanding is that the current solution is broken without that PR, so for now I would suggest waiting for that. @cossio, in the meantime, can you please add a |
No, it's not broken, we don't need the PR: #37 (comment) Would still be good to add tests for it, I think. I was only waiting in case @tpapp has some additional comments. I summarized mine above, and think it's worth it even though it might cause performance regressions for variable precision number types. |
@devmotion: thanks for the clarification. In that case, please feel free to merge. |
Based on
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
Close #13