From a36c63ecc36bafcd86dfd414cecb315d5e79db23 Mon Sep 17 00:00:00 2001 From: cossio Date: Sat, 5 Mar 2022 00:12:29 +0100 Subject: [PATCH] xexpx, xexpy --- Project.toml | 2 +- docs/src/index.md | 2 ++ src/LogExpFunctions.jl | 2 +- src/basicfuns.jl | 30 ++++++++++++++++++++++++++++++ src/chainrules.jl | 36 ++++++++++++++++++++++++++++++++++++ test/basicfuns.jl | 32 ++++++++++++++++++++++++++++++++ test/chainrules.jl | 16 ++++++++++++++++ test/runtests.jl | 1 + 8 files changed, 119 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 2b8bdde2..a5a1efa3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogExpFunctions" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" authors = ["StatsFun.jl contributors, Tamas K. Papp "] -version = "0.3.6" +version = "0.3.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index 8035ea06..5f57a99d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -10,6 +10,8 @@ LogExpFunctions supports [`InverseFunctions.inverse`](https://github.com/JuliaMa xlogx xlogy xlog1py +xexpx +xexpy logistic logit logcosh diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index eaee068d..7a26fba0 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -9,7 +9,7 @@ import InverseFunctions import IrrationalConstants import LinearAlgebra -export xlogx, xlogy, xlog1py, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, +export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, softmax, softmax!, logcosh diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 24301887..e0ae2751 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -44,6 +44,36 @@ function xlog1py(x::Number, y::Number) return iszero(x) && !isnan(y) ? zero(result) : result end +""" +$(SIGNATURES) + +Return `x * exp(x)` for `x > -Inf`, or zero if `x == -Inf`. + +```jldoctest +julia> xexpx(-Inf) +0.0 +``` +""" +function xexpx(x::Real) + expx = exp(x) + return iszero(expx) ? expx : x * expx +end + +""" +$(SIGNATURES) + +Return `x * exp(y)` for `y > -Inf`, or zero if `y == -Inf`. + +```jldoctest +julia> xexpy(1.0, -Inf) +0.0 +``` +""" +function xexpy(x::Real, y::Real) + expy = exp(y) + return iszero(expy) && !isnan(x) ? zero(x * expy) : x * expy +end + # The following bounds are precomputed versions of the following abstract # function, but the implicit interface for AbstractFloat doesn't uniformly # enforce that all floating point types implement nextfloat and prevfloat. diff --git a/src/chainrules.jl b/src/chainrules.jl index 2b97decd..7c408188 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -2,6 +2,42 @@ ChainRulesCore.@scalar_rule(xlogx(x::Real), (1 + log(x),)) ChainRulesCore.@scalar_rule(xlogy(x::Real, y::Real), (log(y), x / y,)) ChainRulesCore.@scalar_rule(xlog1py(x::Real, y::Real), (log1p(y), x / (1 + y),)) +function ChainRulesCore.frule((_, Δx), ::typeof(xexpx), x::Real) + expx = exp(x) + if iszero(expx) + Ω = expx + ΔΩ = expx * Δx + else + Ω = x * expx + ΔΩ = (1 + x) * expx * Δx + end + return Ω, ΔΩ +end + +function ChainRulesCore.rrule(::typeof(xexpx), x::Real) + expx = exp(x) + Ω = iszero(expx) ? expx : x * expx + function xexpx_pullback(ΔΩ) + Δx = iszero(expx) ? expx * ΔΩ : (1 + x) * expx * ΔΩ + return (ChainRulesCore.NoTangent(), Δx) + end + return Ω, xexpx_pullback +end + +function ChainRulesCore.frule((_, Δx, Δy), ::typeof(xexpy), x::Real, y::Real) + expy = exp(y) + Ω = iszero(expy) && !isnan(x) ? zero(x * expy) : x * expy + ΔΩ = expy * Δx + Ω * Δy + return Ω, ΔΩ +end + +function ChainRulesCore.rrule(::typeof(xexpy), x::Real, y::Real) + expy = exp(y) + Ω = iszero(expy) && !isnan(x) ? zero(x * expy) : x * expy + xexpy_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), ΔΩ * expy, ΔΩ * Ω) + return Ω, xexpy_pullback +end + ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),)) ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),)) ChainRulesCore.@scalar_rule(logcosh(x::Real), tanh(x)) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index d6a7c81b..251f1e5f 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -44,6 +44,38 @@ @test iszero(xlog1py(0 + im * 0, -1 + im * Inf)) end +@testset "xexpx" begin + for x in (false, 0, 0.0, 0f0, -Inf, -Inf32) + @test (@inferred xexpx(x)) === zero(exp(x)) + end + for x in (NaN16, NaN32, NaN64, Inf16, Inf32, Inf64) + @test (@inferred xexpx(x)) === x + end + for x in (1, true, 1.0, 1f0) + @test (@inferred xexpx(x)) === exp(x) + end + for T in (Int, Float32, Float64), x in T.(-2:2) + @test (@inferred xexpx(x)) === x * exp(x) + end +end + +@testset "xexpy" begin + for x in (0, 1, 1.0, 1f0, Inf, Inf32), y in (-Inf, -Inf32) + @test (@inferred xexpy(x, y)) === zero(x * exp(y)) + end + for x in (0, 1, 1.0, 1f0, Inf, Inf32, -Inf, -Inf32, NaN, NaN32), nan in (NaN, NaN32) + @test (@inferred xexpy(x, nan)) === oftype(x * exp(nan), NaN) + @test (@inferred xexpy(nan, x)) === oftype(nan * exp(x), NaN) + end + Ts = (Int, Float32, Float64) + for Tx in Ts, Ty in Ts, x = -Tx(2):Tx(2), y = -Ty(2):Ty(2) + @test (@inferred xexpy(x, y)) ≈ x * exp(y) + end + for x in (randn(), randn(Float32)) + @test xexpy(x, x) ≈ xexpx(x) + end +end + @testset "logistic & logit" begin @test logistic(2) ≈ 1.0 / (1.0 + exp(-2.0)) @test logistic(-750.0) === 0.0 diff --git a/test/chainrules.jl b/test/chainrules.jl index 581f2982..509481dc 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -13,6 +13,22 @@ end end + @testset "xexpx, xexpy" begin + # regular branch + test_scalar(xexpx, randn()) + test_frule(xexpy, randn(), randn()) + test_rrule(xexpy, randn(), randn()) + # special cases (manually since FiniteDifferences/ChainRulesTestUtils fails at -Inf) + @test @inferred(frule((NoTangent(), 1), xexpx, -Inf)) === (0.0, 0.0) + Ω, back = @inferred(rrule(xexpx, -Inf)) + @test Ω === 0.0 + @test back(randn()) === (NoTangent(), 0.0) + @test @inferred(frule((NoTangent(), 1, 1), xexpy, x, -Inf)) === (0.0, 0.0) + Ω, back = @inferred(ChainRulesCore.rrule(xexpy, x, -Inf)) + @test Ω === 0.0 + @test back(randn()) === (NoTangent(), 0.0, 0.0) + end + test_frule(logit, x) test_rrule(logit, x) diff --git a/test/runtests.jl b/test/runtests.jl index bbbe4a95..f4c9ea3c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using LogExpFunctions using ChainRulesTestUtils +using ChainRulesCore using ChangesOfVariables using InverseFunctions using OffsetArrays