diff --git a/Project.toml b/Project.toml index 68d7366a9..32849f933 100644 --- a/Project.toml +++ b/Project.toml @@ -33,6 +33,7 @@ MacroTools = "0.5" NaNMath = "0.3" Requires = "1.1" SpecialFunctions = "0.10, 1.0" +StatsFuns = "0.9.8" ZygoteRules = "0.2.1" julia = "1.3" @@ -41,8 +42,9 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "StatsFuns", "Test"] +test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"] diff --git a/src/Zygote.jl b/src/Zygote.jl index 614cd9a53..5c0d743fd 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -40,7 +40,7 @@ include("lib/forward.jl") include("lib/utils.jl") include("lib/range.jl") @init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl") -@init @require StatsFuns="4c63d2b9-4356-54db-8cca-17b64c39e42c" include("lib/statsfuns.jl") +@init @require LogExpFunctions="2ab3a3ac-af41-5b50-aa03-7779005ae688" include("lib/logexpfunctions.jl") # we need to define this late, so that the genfuncs see lib.jl # Move using statements out of this file to help with sysimage building diff --git a/src/lib/statsfuns.jl b/src/lib/logexpfunctions.jl similarity index 97% rename from src/lib/statsfuns.jl rename to src/lib/logexpfunctions.jl index 85916cae8..1e5e4c0b6 100644 --- a/src/lib/statsfuns.jl +++ b/src/lib/logexpfunctions.jl @@ -1,5 +1,4 @@ -import .StatsFuns -using .StatsFuns: xlogx, xlogy, logistic, logit, log1psq, log1pexp, +using .LogExpFunctions: xlogx, xlogy, logistic, logit, log1psq, log1pexp, logsumexp, logaddexp, logsubexp using Base.Broadcast: broadcasted diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 6f1e5a996..1b509d159 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1196,44 +1196,44 @@ end @test gradcheck(x -> muladd(x[1], x[2], x[3]), [2.0, 3.0, 5.0]) end -import StatsFuns +import LogExpFunctions Zygote.refresh() @testset "xlogx" begin - @test gradcheck(x->2.5 * StatsFuns.xlogx(x[1]), [1.0]) - @test gradcheck(x->2.5 * StatsFuns.xlogx(x[1]), [2.45]) - @test gradtest(x -> StatsFuns.xlogx.(x), (3,3)) + @test gradcheck(x->2.5 * LogExpFunctions.xlogx(x[1]), [1.0]) + @test gradcheck(x->2.5 * LogExpFunctions.xlogx(x[1]), [2.45]) + @test gradtest(x -> LogExpFunctions.xlogx.(x), (3,3)) end @testset "xlogy" begin - @test gradcheck(x -> StatsFuns.xlogy(x[1], x[2]), [1.0, 2.0]) - @test gradcheck(x -> StatsFuns.xlogy(x[1], x[2]), [0.0, 2.0]) - @test gradtest((x,y) -> StatsFuns.xlogy.(x,y), (3,3), (3,3)) + @test gradcheck(x -> LogExpFunctions.xlogy(x[1], x[2]), [1.0, 2.0]) + @test gradcheck(x -> LogExpFunctions.xlogy(x[1], x[2]), [0.0, 2.0]) + @test gradtest((x,y) -> LogExpFunctions.xlogy.(x,y), (3,3), (3,3)) end @testset "logistic" begin - @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-5.0]) - @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-1.0]) - @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-eps()]) - @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [0.0]) - @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [eps()]) - @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [1.0]) - @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [5.0]) + @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-5.0]) + @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-1.0]) + @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-eps()]) + @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [0.0]) + @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [eps()]) + @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [1.0]) + @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [5.0]) end @testset "logit" begin - @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.1]) - @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.3]) - @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.5]) - @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.7]) - @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.9]) + @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.1]) + @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.3]) + @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.5]) + @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.7]) + @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.9]) end function test_log1pexp(T, xs) y = T(4.3) for x in xs - @test gradcheck(x->y * StatsFuns.log1pexp(x[1]), [x]) + @test gradcheck(x->y * LogExpFunctions.log1pexp(x[1]), [x]) end end @@ -1249,43 +1249,43 @@ end test_log1pexp(Float64, [33.3, 33.3 + eps(), 100.0]) end end - @test gradcheck(x->2.5 * StatsFuns.log1pexp(x[1]), [1.0]) - @test gradcheck(x->2.5 * StatsFuns.log1pexp(x[1]), [2.45]) - @test gradtest(x -> StatsFuns.log1pexp.(x), (3,3)) + @test gradcheck(x->2.5 * LogExpFunctions.log1pexp(x[1]), [1.0]) + @test gradcheck(x->2.5 * LogExpFunctions.log1pexp(x[1]), [2.45]) + @test gradtest(x -> LogExpFunctions.log1pexp.(x), (3,3)) end @testset "log1psq" begin rng = MersenneTwister(123456) @testset "Float64" begin for x in [-10.0, -5.0, -1.0, -eps(), 0.0, eps(), 1.0, 5.0, 10.0] - @test gradcheck(x->5.1 * StatsFuns.log1psq(x[1]), [x]) + @test gradcheck(x->5.1 * LogExpFunctions.log1psq(x[1]), [x]) end end end @testset "logaddexp" begin - @test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [1.0, 2.0]) - @test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [1.0, -1.0]) - @test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [-2.0, -3.0]) - @test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [5.0, 5.0]) - @test gradtest((x,y) -> StatsFuns.logaddexp.(x,y), (3,3), (3,3)) + @test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [1.0, 2.0]) + @test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [1.0, -1.0]) + @test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [-2.0, -3.0]) + @test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [5.0, 5.0]) + @test gradtest((x,y) -> LogExpFunctions.logaddexp.(x,y), (3,3), (3,3)) end @testset "logsubexp" begin - @test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [1.0, 2.0]) - @test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [1.0, -1.0]) - @test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [-2.0, -3.0]) - @test gradtest((x,y) -> StatsFuns.logsubexp.(x,y), (3,3), (3,3)) + @test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [1.0, 2.0]) + @test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [1.0, -1.0]) + @test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [-2.0, -3.0]) + @test gradtest((x,y) -> LogExpFunctions.logsubexp.(x,y), (3,3), (3,3)) end @testset "logsumexp" begin rng = MersenneTwister(123456) @testset "Float64" begin - @test gradtest(StatsFuns.logsumexp, randn(rng, 1)) - @test gradtest(StatsFuns.logsumexp, randn(rng, 1, 1)) - @test gradtest(StatsFuns.logsumexp, randn(rng, 3)) - @test gradtest(StatsFuns.logsumexp, randn(rng, 3, 4, 5)) - @test gradtest(x -> sum(StatsFuns.logsumexp(x; dims=1)), randn(rng, 4, 4)) + @test gradtest(LogExpFunctions.logsumexp, randn(rng, 1)) + @test gradtest(LogExpFunctions.logsumexp, randn(rng, 1, 1)) + @test gradtest(LogExpFunctions.logsumexp, randn(rng, 3)) + @test gradtest(LogExpFunctions.logsumexp, randn(rng, 3, 4, 5)) + @test gradtest(x -> sum(LogExpFunctions.logsumexp(x; dims=1)), randn(rng, 4, 4)) end end