From 8ab833259c2af3c54e7f8f5258240877393c27e8 Mon Sep 17 00:00:00 2001 From: cossio Date: Sat, 5 Mar 2022 00:12:29 +0100 Subject: [PATCH] xexpx, xexpy --- docs/src/index.md | 2 ++ src/LogExpFunctions.jl | 2 +- src/basicfuns.jl | 25 +++++++++++++++++++++++++ src/chainrules.jl | 9 +++++++++ test/basicfuns.jl | 24 ++++++++++++++++++++++++ test/chainrules.jl | 9 +++++++++ 6 files changed, 70 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 8035ea06..5f57a99d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -10,6 +10,8 @@ LogExpFunctions supports [`InverseFunctions.inverse`](https://github.com/JuliaMa xlogx xlogy xlog1py +xexpx +xexpy logistic logit logcosh diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index eaee068d..7a26fba0 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -9,7 +9,7 @@ import InverseFunctions import IrrationalConstants import LinearAlgebra -export xlogx, xlogy, xlog1py, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, +export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax, softmax!, logcosh diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 24301887..8292e107 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -44,6 +44,31 @@ function xlog1py(x::Number, y::Number) return iszero(x) && !isnan(y) ? zero(result) : result end +""" +$(SIGNATURES) + +Return `x * exp(x)` for `x > -Inf`, or zero if `x == -Inf`. + +```jldoctest +julia> xexpx(-Inf) +0.0 +``` +""" +xexpx(x::Real) = isfinite(x) ? x * exp(x) : exp(x) + +""" +$(SIGNATURES) + +Return `x * exp(y)` for `y > -Inf`, or zero if `y == -Inf`. + +```jldoctest +julia> xexpy(1.0, -Inf) +0.0 +``` +""" +xexpy(x::T, y::T) where {T<:Real} = isnan(x) || isfinite(y) ? x * exp(y) : exp(y) +xexpy(x::Real, y::Real) = xexpy(promote(x, y)...) + # The following bounds are precomputed versions of the following abstract # function, but the implicit interface for AbstractFloat doesn't uniformly # enforce that all floating point types implement nextfloat and prevfloat. diff --git a/src/chainrules.jl b/src/chainrules.jl index 2b97decd..9ebcea58 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -2,6 +2,15 @@ ChainRulesCore.@scalar_rule(xlogx(x::Real), (1 + log(x),)) ChainRulesCore.@scalar_rule(xlogy(x::Real, y::Real), (log(y), x / y,)) ChainRulesCore.@scalar_rule(xlog1py(x::Real, y::Real), (log1p(y), x / (1 + y),)) +ChainRulesCore.@scalar_rule( + xexpx(x::Real), + (iszero(x) ? one(Ω) : isfinite(x) ? Ω / x * (1 + x) : Ω,) +) +ChainRulesCore.@scalar_rule( + xexpy(x::Real, y::Real), + (iszero(x) ? oftype(Ω, exp(y)) : Ω / x, Ω) +) + ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),)) ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),)) ChainRulesCore.@scalar_rule(logcosh(x::Real), tanh(x)) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index d6a7c81b..790246ef 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -44,6 +44,30 @@ @test iszero(xlog1py(0 + im * 0, -1 + im * Inf)) end +@testset "xexpx" begin + for x in (0, 0.0, 0f0, -Inf) + @test iszero(@inferred xexpx(x)) + end + @test isnan(@inferred xexpx(NaN)) + @test Inf == @inferred xexpx(Inf) + @test exp(1) ≈ @inferred xexpx(1) + @test 2exp(2) ≈ @inferred xexpx(2) +end + +@testset "xexpy" begin + @test iszero(@inferred xexpy(Inf, -Inf)) + @test isnan(@inferred xexpy(NaN, -Inf)) + @test isnan(@inferred xexpy(NaN, 1)) + @test isnan(@inferred xexpy(0, NaN)) + @test iszero(@inferred xexpy(0, -Inf)) + @test ℯ ≈ @inferred xexpy(1, 1) + @test iszero(@inferred xexpy(1., -Inf)) + @test 2exp(3) ≈ @inferred xexpy(2, 3) + for x = -10:10, y = -10:10 + @test x * exp(y) ≈ @inferred xexpy(x, y) + end +end + @testset "logistic & logit" begin @test logistic(2) ≈ 1.0 / (1.0 + exp(-2.0)) @test logistic(-750.0) === 0.0 diff --git a/test/chainrules.jl b/test/chainrules.jl index 581f2982..682a2432 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -13,6 +13,15 @@ end end + for x in -2.0:2.0 + test_frule(xexpx, x) + test_rrule(xexpx, x) + for y in -2.0:2.0 + test_frule(xexpy, x, y) + test_rrule(xexpy, x, y) + end + end + test_frule(logit, x) test_rrule(logit, x)