Skip to content

Commit

Permalink
Merge pull request #949 from devmotion/dw/logexpfunctions
Browse files Browse the repository at this point in the history
Define custom adjoints for LogExpFunctions instead of StatsFuns
  • Loading branch information
CarloLucibello authored May 2, 2021
2 parents 11d3c2d + 125f414 commit 5e2ba52
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 42 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ MacroTools = "0.5"
NaNMath = "0.3"
Requires = "1.1"
SpecialFunctions = "0.10, 1.0"
StatsFuns = "0.9.8"
ZygoteRules = "0.2.1"
julia = "1.3"

Expand All @@ -41,8 +42,9 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "StatsFuns", "Test"]
test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ include("lib/forward.jl")
include("lib/utils.jl")
include("lib/range.jl")
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl")
@init @require StatsFuns="4c63d2b9-4356-54db-8cca-17b64c39e42c" include("lib/statsfuns.jl")
@init @require LogExpFunctions="2ab3a3ac-af41-5b50-aa03-7779005ae688" include("lib/logexpfunctions.jl")

# we need to define this late, so that the genfuncs see lib.jl
# Move using statements out of this file to help with sysimage building
Expand Down
3 changes: 1 addition & 2 deletions src/lib/statsfuns.jl → src/lib/logexpfunctions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import .StatsFuns
using .StatsFuns: xlogx, xlogy, logistic, logit, log1psq, log1pexp,
using .LogExpFunctions: xlogx, xlogy, logistic, logit, log1psq, log1pexp,
logsumexp, logaddexp, logsubexp
using Base.Broadcast: broadcasted

Expand Down
76 changes: 38 additions & 38 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1196,44 +1196,44 @@ end
@test gradcheck(x -> muladd(x[1], x[2], x[3]), [2.0, 3.0, 5.0])
end

import StatsFuns
import LogExpFunctions

Zygote.refresh()

@testset "xlogx" begin
@test gradcheck(x->2.5 * StatsFuns.xlogx(x[1]), [1.0])
@test gradcheck(x->2.5 * StatsFuns.xlogx(x[1]), [2.45])
@test gradtest(x -> StatsFuns.xlogx.(x), (3,3))
@test gradcheck(x->2.5 * LogExpFunctions.xlogx(x[1]), [1.0])
@test gradcheck(x->2.5 * LogExpFunctions.xlogx(x[1]), [2.45])
@test gradtest(x -> LogExpFunctions.xlogx.(x), (3,3))
end

@testset "xlogy" begin
@test gradcheck(x -> StatsFuns.xlogy(x[1], x[2]), [1.0, 2.0])
@test gradcheck(x -> StatsFuns.xlogy(x[1], x[2]), [0.0, 2.0])
@test gradtest((x,y) -> StatsFuns.xlogy.(x,y), (3,3), (3,3))
@test gradcheck(x -> LogExpFunctions.xlogy(x[1], x[2]), [1.0, 2.0])
@test gradcheck(x -> LogExpFunctions.xlogy(x[1], x[2]), [0.0, 2.0])
@test gradtest((x,y) -> LogExpFunctions.xlogy.(x,y), (3,3), (3,3))
end

@testset "logistic" begin
@test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-5.0])
@test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-1.0])
@test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-eps()])
@test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [0.0])
@test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [eps()])
@test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [1.0])
@test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [5.0])
@test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-5.0])
@test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-1.0])
@test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-eps()])
@test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [0.0])
@test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [eps()])
@test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [1.0])
@test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [5.0])
end

@testset "logit" begin
@test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.1])
@test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.3])
@test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.5])
@test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.7])
@test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.9])
@test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.1])
@test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.3])
@test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.5])
@test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.7])
@test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.9])
end

function test_log1pexp(T, xs)
y = T(4.3)
for x in xs
@test gradcheck(x->y * StatsFuns.log1pexp(x[1]), [x])
@test gradcheck(x->y * LogExpFunctions.log1pexp(x[1]), [x])
end
end

Expand All @@ -1249,43 +1249,43 @@ end
test_log1pexp(Float64, [33.3, 33.3 + eps(), 100.0])
end
end
@test gradcheck(x->2.5 * StatsFuns.log1pexp(x[1]), [1.0])
@test gradcheck(x->2.5 * StatsFuns.log1pexp(x[1]), [2.45])
@test gradtest(x -> StatsFuns.log1pexp.(x), (3,3))
@test gradcheck(x->2.5 * LogExpFunctions.log1pexp(x[1]), [1.0])
@test gradcheck(x->2.5 * LogExpFunctions.log1pexp(x[1]), [2.45])
@test gradtest(x -> LogExpFunctions.log1pexp.(x), (3,3))
end

@testset "log1psq" begin
rng = MersenneTwister(123456)
@testset "Float64" begin
for x in [-10.0, -5.0, -1.0, -eps(), 0.0, eps(), 1.0, 5.0, 10.0]
@test gradcheck(x->5.1 * StatsFuns.log1psq(x[1]), [x])
@test gradcheck(x->5.1 * LogExpFunctions.log1psq(x[1]), [x])
end
end
end

@testset "logaddexp" begin
@test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [1.0, 2.0])
@test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [1.0, -1.0])
@test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [-2.0, -3.0])
@test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [5.0, 5.0])
@test gradtest((x,y) -> StatsFuns.logaddexp.(x,y), (3,3), (3,3))
@test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [1.0, 2.0])
@test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [1.0, -1.0])
@test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [-2.0, -3.0])
@test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [5.0, 5.0])
@test gradtest((x,y) -> LogExpFunctions.logaddexp.(x,y), (3,3), (3,3))
end

@testset "logsubexp" begin
@test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [1.0, 2.0])
@test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [1.0, -1.0])
@test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [-2.0, -3.0])
@test gradtest((x,y) -> StatsFuns.logsubexp.(x,y), (3,3), (3,3))
@test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [1.0, 2.0])
@test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [1.0, -1.0])
@test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [-2.0, -3.0])
@test gradtest((x,y) -> LogExpFunctions.logsubexp.(x,y), (3,3), (3,3))
end

@testset "logsumexp" begin
rng = MersenneTwister(123456)
@testset "Float64" begin
@test gradtest(StatsFuns.logsumexp, randn(rng, 1))
@test gradtest(StatsFuns.logsumexp, randn(rng, 1, 1))
@test gradtest(StatsFuns.logsumexp, randn(rng, 3))
@test gradtest(StatsFuns.logsumexp, randn(rng, 3, 4, 5))
@test gradtest(x -> sum(StatsFuns.logsumexp(x; dims=1)), randn(rng, 4, 4))
@test gradtest(LogExpFunctions.logsumexp, randn(rng, 1))
@test gradtest(LogExpFunctions.logsumexp, randn(rng, 1, 1))
@test gradtest(LogExpFunctions.logsumexp, randn(rng, 3))
@test gradtest(LogExpFunctions.logsumexp, randn(rng, 3, 4, 5))
@test gradtest(x -> sum(LogExpFunctions.logsumexp(x; dims=1)), randn(rng, 4, 4))
end
end

Expand Down

0 comments on commit 5e2ba52

Please sign in to comment.