From c8a4c28ffe7b6e4f8d5253e01cef091bb8d2f42c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 16 Dec 2021 10:35:14 +0100 Subject: [PATCH] Add `logcosh` (#32) * Add `logcosh` * Extend docstring Co-authored-by: Milan Bouchet-Valat Co-authored-by: Milan Bouchet-Valat --- Project.toml | 2 +- docs/src/index.md | 3 ++- src/LogExpFunctions.jl | 2 +- src/basicfuns.jl | 12 ++++++++++++ src/chainrules.jl | 1 + src/with_logabsdet_jacobian.jl | 8 ++++++++ test/basicfuns.jl | 16 ++++++++++++++++ test/chainrules.jl | 5 +++++ test/with_logabsdet_jacobian.jl | 3 +++ 9 files changed, 49 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 88dd896a..2b8bdde2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogExpFunctions" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" authors = ["StatsFun.jl contributors, Tamas K. Papp "] -version = "0.3.5" +version = "0.3.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index 6948a683..8035ea06 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,7 +4,7 @@ Various special functions based on `log` and `exp` moved from [StatsFuns.jl](htt The original authors of these functions are the StatsFuns.jl contributors. -LogExpFunctions supports [`InverseFunctions.inverse`](https://github.com/JuliaMath/InverseFunctions.jl) and [`ChangesOfVariables.test_with_logabsdet_jacobian`](https://github.com/JuliaMath/ChangesOfVariables.jl) for `log1mexp`, `log1pexp`, `log1pexp`, `log2mexp`, `log2mexp`, `logexpm1`, `logistic`, `logistic` and `logit`. +LogExpFunctions supports [`InverseFunctions.inverse`](https://github.com/JuliaMath/InverseFunctions.jl) and [`ChangesOfVariables.test_with_logabsdet_jacobian`](https://github.com/JuliaMath/ChangesOfVariables.jl) for `log1mexp`, `log1pexp`, `log1pexp`, `log2mexp`, `log2mexp`, `logexpm1`, `logistic`, `logistic`, `logit`, and `logcosh` (no inverse). ```@docs xlogx @@ -12,6 +12,7 @@ xlogy xlog1py logistic logit +logcosh log1psq log1pexp log1mexp diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index dd262e38..eaee068d 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -11,7 +11,7 @@ import LinearAlgebra export xlogx, xlogy, xlog1py, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax, - softmax! + softmax!, logcosh include("basicfuns.jl") include("logsumexp.jl") diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 77715f49..24301887 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -94,6 +94,18 @@ logit(x::Real) = log(x / (one(x) - x)) """ $(SIGNATURES) +Return `log(cosh(x))`, carefully evaluated without intermediate calculation of `cosh(x)`. + +The implementation ensures `logcosh(-x) = logcosh(x)`. +""" +function logcosh(x::Real) + abs_x = abs(x) + return abs_x + log1pexp(- 2 * abs_x) - IrrationalConstants.logtwo +end + +""" +$(SIGNATURES) + Return `log(1+x^2)` evaluated carefully for `abs(x)` very small or very large. """ log1psq(x::Real) = log1p(abs2(x)) diff --git a/src/chainrules.jl b/src/chainrules.jl index 21c8481a..2b97decd 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -4,6 +4,7 @@ ChainRulesCore.@scalar_rule(xlog1py(x::Real, y::Real), (log1p(y), x / (1 + y),)) ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),)) ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),)) +ChainRulesCore.@scalar_rule(logcosh(x::Real), tanh(x)) ChainRulesCore.@scalar_rule(log1psq(x::Real), (2 * x / (1 + x^2),)) ChainRulesCore.@scalar_rule(log1pexp(x::Real), (logistic(x),)) ChainRulesCore.@scalar_rule(log1mexp(x::Real), (-exp(x - Ω),)) diff --git a/src/with_logabsdet_jacobian.jl b/src/with_logabsdet_jacobian.jl index d26a766d..8b86c26d 100644 --- a/src/with_logabsdet_jacobian.jl +++ b/src/with_logabsdet_jacobian.jl @@ -27,3 +27,11 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logistic), x::Real) y = logistic(x) y, log(y * (1 - y)) end + +function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logcosh), x::Real) + abs_x = abs(x) + a = - 2 * abs_x + z = log1pexp(a) + y = abs_x + z - IrrationalConstants.logtwo + return y, log1mexp(a) - z +end diff --git a/test/basicfuns.jl b/test/basicfuns.jl index fa633a74..d6a7c81b 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -54,6 +54,22 @@ end @test logit(logistic(2)) ≈ 2.0 end +@testset "logcosh" begin + for x in (randn(), randn(Float32)) + @test @inferred(logcosh(x)) isa typeof(x) + @test logcosh(x) ≈ log(cosh(x)) + @test logcosh(-x) == logcosh(x) + end + + # special values + for x in (-Inf, Inf, -Inf32, Inf32) + @test @inferred(logcosh(x)) === oftype(x, Inf) + end + for x in (NaN, NaN32) + @test @inferred(logcosh(x)) === x + end +end + @testset "log1psq" begin @test iszero(log1psq(0.0)) @test log1psq(1.0) ≈ log1p(1.0) diff --git a/test/chainrules.jl b/test/chainrules.jl index 5760dc59..581f2982 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -31,6 +31,11 @@ test_rrule(logistic, x; rtol=1f-3, atol=1f-3) end + for x in (-randexp(), randexp()) + test_frule(logcosh, x) + test_rrule(logcosh, x) + end + # test all branches of `log1pexp` for x in (-20.9, 15.4, 41.5) test_frule(log1pexp, x) diff --git a/test/with_logabsdet_jacobian.jl b/test/with_logabsdet_jacobian.jl index 252fe17c..5ddf806d 100644 --- a/test/with_logabsdet_jacobian.jl +++ b/test/with_logabsdet_jacobian.jl @@ -16,4 +16,7 @@ ChangesOfVariables.test_with_logabsdet_jacobian(logistic, x, derivative) ChangesOfVariables.test_with_logabsdet_jacobian(logit, rand(), derivative) + + ChangesOfVariables.test_with_logabsdet_jacobian(logcosh, x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(logcosh, -x, derivative) end