From 1a99d7457d594225f23647d90a76cc0f28d31760 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 09:40:17 +0300 Subject: [PATCH 01/23] reactivate mean function AD tests --- src/mean_function.jl | 2 ++ test/Project.toml | 1 + test/mean_function.jl | 6 +++--- test/test_util.jl | 26 +++++++------------------- 4 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/mean_function.jl b/src/mean_function.jl index 6eacec36..f5389d8a 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -1,5 +1,7 @@ abstract type MeanFunction end +# (m::MeanFunction)(x::AbstractVector) = _map_meanfunction(m, x) + """ ZeroMean{T<:Real} <: MeanFunction diff --git a/test/Project.toml b/test/Project.toml index 57222240..9ac8a7c3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/test/mean_function.jl b/test/mean_function.jl index 4f0ebb1d..1e83fce0 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -10,7 +10,7 @@ for x in [x] @test AbstractGPs._map_meanfunction(f, x) == zeros(size(x)) - # differentiable_mean_function_tests(f, randn(rng, P), x) + differentiable_mean_function_tests(f, randn(rng, P), x) end # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but @@ -32,7 +32,7 @@ for x in [x] @test AbstractGPs._map_meanfunction(m, x) == fill(c, N) - # differentiable_mean_function_tests(m, randn(rng, N), x) + differentiable_mean_function_tests(m, randn(rng, N), x) end end @testset "CustomMean" begin @@ -42,6 +42,6 @@ f = CustomMean(foo_mean) @test AbstractGPs._map_meanfunction(f, x) == map(foo_mean, x) - # differentiable_mean_function_tests(f, randn(rng, N), x) + differentiable_mean_function_tests(f, randn(rng, N), x) end end diff --git a/test/test_util.jl b/test/test_util.jl index e2aa2c83..92fe8217 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -73,8 +73,9 @@ end Test _very_ basic consistency properties of the mean function `m`. """ function mean_function_tests(m::MeanFunction, x::AbstractVector) - @test AbstractGPs._map_meanfunction(m, x) isa AbstractVector - @test length(ew(m, x)) == length(x) + mean = AbstractGPs._map_meanfunction(m, x) + @test mean isa AbstractVector + @test length(mean) == length(x) end """ @@ -88,8 +89,8 @@ Ensure that the gradient w.r.t. the inputs of `MeanFunction` `m` are approximate """ function differentiable_mean_function_tests( m::MeanFunction, - ȳ::AbstractVector{<:Real}, - x::AbstractVector{<:Real}; + ȳ::AbstractVector, + x::AbstractVector; rtol=_rtol, atol=_atol, ) @@ -98,23 +99,10 @@ function differentiable_mean_function_tests( # Check adjoint. @assert length(ȳ) == length(x) - return adjoint_test(x -> ew(m, x), ȳ, x; rtol=rtol, atol=atol) + adjoint_test(x -> AbstractGPs._map_meanfunction(m, x), ȳ, x; rtol=rtol, atol=atol) + return nothing end -# function differentiable_mean_function_tests( -# m::MeanFunction, -# ȳ::AbstractVector{<:Real}, -# x::ColVecs{<:Real}; -# rtol=_rtol, -# atol=_atol, -# ) -# # Run forward tests. -# mean_function_tests(m, x) - -# @assert length(ȳ) == length(x) -# adjoint_test(X->ew(m, ColVecs(X)), ȳ, x.X; rtol=rtol, atol=atol) -# end - function differentiable_mean_function_tests( rng::AbstractRNG, m::MeanFunction, x::AbstractVector; rtol=_rtol, atol=_atol ) From caeffdc6115a6af9f6d973c279bcf422f9ed05b9 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 09:43:34 +0300 Subject: [PATCH 02/23] format --- test/test_util.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/test_util.jl b/test/test_util.jl index 92fe8217..0fc4f62e 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -88,11 +88,7 @@ end Ensure that the gradient w.r.t. the inputs of `MeanFunction` `m` are approximately correct. """ function differentiable_mean_function_tests( - m::MeanFunction, - ȳ::AbstractVector, - x::AbstractVector; - rtol=_rtol, - atol=_atol, + m::MeanFunction, ȳ::AbstractVector, x::AbstractVector; rtol=_rtol, atol=_atol ) # Run forward tests. mean_function_tests(m, x) From 9f6227fe50ebee76262e677c756384d66226398b Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 09:58:10 +0300 Subject: [PATCH 03/23] fix test --- test/mean_function.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 1e83fce0..7a32c3cb 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -1,16 +1,14 @@ @testset "mean_functions" begin @testset "ZeroMean" begin - P = 3 - Q = 2 - D = 4 - # X = ColVecs(randn(rng, D, P)) - x = randn(P) - x̄ = randn(P) + rng, D, N = MersenneTwister(123456), 5, 3 + # X = ColVecs(randn(rng, D, N)) + x = randn(rng, N) + x̄ = randn(rng, N) f = ZeroMean{Float64}() for x in [x] @test AbstractGPs._map_meanfunction(f, x) == zeros(size(x)) - differentiable_mean_function_tests(f, randn(rng, P), x) + differentiable_mean_function_tests(f, randn(rng, N), x) end # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but @@ -18,7 +16,7 @@ # for now. y, pb = rrule(AbstractGPs._map_meanfunction, f, x) @test y == AbstractGPs._map_meanfunction(f, x) - Δmap, Δf, Δx = pb(randn(P)) + Δmap, Δf, Δx = pb(randn(rng, N)) @test iszero(Δmap) @test iszero(Δf) @test iszero(Δx) From f13c902c2e33a2ca2f4b6526a498edb8454b1510 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 09:58:17 +0300 Subject: [PATCH 04/23] revert FillArray --- src/mean_function.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mean_function.jl b/src/mean_function.jl index f5389d8a..7979984b 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -12,7 +12,7 @@ struct ZeroMean{T<:Real} <: MeanFunction end """ This is an AbstractGPs-internal workaround for AD issues; ideally we would just extend Base.map """ -_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x)) +_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = zeros(T, length(x)) function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector) map_ZeroMean_pullback(Δ) = (NoTangent(), NoTangent(), ZeroTangent()) @@ -30,7 +30,7 @@ struct ConstMean{T<:Real} <: MeanFunction c::T end -_map_meanfunction(m::ConstMean, x::AbstractVector) = Fill(m.c, length(x)) +_map_meanfunction(m::ConstMean, x::AbstractVector) = fill(m.c, length(x)) """ CustomMean{Tf} <: MeanFunction From 6ee4c760705deccb75441c4d2d213e998643917c Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 10:07:32 +0300 Subject: [PATCH 05/23] extend mean function tests to ColVecs/RowVecs --- test/mean_function.jl | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 7a32c3cb..5c9dc60a 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -22,24 +22,31 @@ @test iszero(Δx) end @testset "ConstMean" begin - rng, D, N = MersenneTwister(123456), 5, 3 - # X = ColVecs(randn(rng, D, N)) - x = randn(rng, N) + rng, N, D = MersenneTwister(123456), 5, 3 + x1 = randn(rng, N) + xD = ColVecs(randn(rng, D, N)) + xD′ = RowVecs(randn(rng, N, D)) + c = randn(rng) m = ConstMean(c) - for x in [x] + for x in [x1, xD, xD′] @test AbstractGPs._map_meanfunction(m, x) == fill(c, N) differentiable_mean_function_tests(m, randn(rng, N), x) end end @testset "CustomMean" begin - rng, N, D = MersenneTwister(123456), 11, 2 - x = randn(rng, N) + rng, N, D = MersenneTwister(123456), 5, 3 + x1 = randn(rng, N) + xD = ColVecs(randn(rng, D, N)) + xD′ = RowVecs(randn(rng, N, D)) + foo_mean = x -> sum(abs2, x) - f = CustomMean(foo_mean) + m = CustomMean(foo_mean) - @test AbstractGPs._map_meanfunction(f, x) == map(foo_mean, x) - differentiable_mean_function_tests(f, randn(rng, N), x) + for x in [x1, xD, xD′] + @test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x) + differentiable_mean_function_tests(m, randn(rng, N), x) + end end end From 9c85f6d4faaea1671f28a2cf393230d0a932b8c6 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 10:10:22 +0300 Subject: [PATCH 06/23] extend ZeroMean tests too --- test/mean_function.jl | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 5c9dc60a..2fe16f3f 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -1,25 +1,26 @@ @testset "mean_functions" begin @testset "ZeroMean" begin - rng, D, N = MersenneTwister(123456), 5, 3 - # X = ColVecs(randn(rng, D, N)) - x = randn(rng, N) - x̄ = randn(rng, N) - f = ZeroMean{Float64}() + rng, N, D = MersenneTwister(123456), 5, 3 + x1 = randn(rng, N) + xD = ColVecs(randn(rng, D, N)) + xD′ = RowVecs(randn(rng, N, D)) - for x in [x] - @test AbstractGPs._map_meanfunction(f, x) == zeros(size(x)) - differentiable_mean_function_tests(f, randn(rng, N), x) - end + m = ZeroMean{Float64}() - # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but - # currently ChainRulesTestUtils isn't up to handling this, so this will have to do - # for now. - y, pb = rrule(AbstractGPs._map_meanfunction, f, x) - @test y == AbstractGPs._map_meanfunction(f, x) - Δmap, Δf, Δx = pb(randn(rng, N)) - @test iszero(Δmap) - @test iszero(Δf) - @test iszero(Δx) + for x in [x1, xD, xD′] + @test AbstractGPs._map_meanfunction(m, x) == zeros(size(x)) + differentiable_mean_function_tests(m, randn(rng, N), x) + + # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but + # currently ChainRulesTestUtils isn't up to handling this, so this will have to do + # for now. + y, pb = rrule(AbstractGPs._map_meanfunction, f, x) + @test y == AbstractGPs._map_meanfunction(f, x) + Δmap, Δf, Δx = pb(randn(rng, N)) + @test iszero(Δmap) + @test iszero(Δf) + @test iszero(Δx) + end end @testset "ConstMean" begin rng, N, D = MersenneTwister(123456), 5, 3 From 83ccb64347f3ab38786be1da3cb1923de6412c45 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 10:17:16 +0300 Subject: [PATCH 07/23] bugfix --- test/mean_function.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 2fe16f3f..3bda5334 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -14,8 +14,8 @@ # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but # currently ChainRulesTestUtils isn't up to handling this, so this will have to do # for now. - y, pb = rrule(AbstractGPs._map_meanfunction, f, x) - @test y == AbstractGPs._map_meanfunction(f, x) + y, pb = rrule(AbstractGPs._map_meanfunction, m, x) + @test y == AbstractGPs._map_meanfunction(m, x) Δmap, Δf, Δx = pb(randn(rng, N)) @test iszero(Δmap) @test iszero(Δf) From 29892935695cb2a999524d99216442e8b071acb1 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 10:17:32 +0300 Subject: [PATCH 08/23] add missing zero() definition for ColVecs/RowVecs --- test/test_util.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_util.jl b/test/test_util.jl index 0fc4f62e..76956ba8 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -27,6 +27,10 @@ end Base.zero(d::Dict) = Dict([(key, zero(val)) for (key, val) in d]) Base.zero(x::Array) = zero.(x) +# TODO should move into KernelFunctions.jl +Base.zero(x::ColVecs) = ColVecs(zero(x.X)) +Base.zero(x::RowVecs) = RowVecs(zero(x.X)) + # My version of isapprox function fd_isapprox(x_ad::Nothing, x_fd, rtol, atol) return fd_isapprox(x_fd, zero(x_fd), rtol, atol) From 80434548d72089c08a3a89105a3d5a78f99f81af Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 6 Apr 2022 10:30:12 +0300 Subject: [PATCH 09/23] patch bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 403e508b..d49580d7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AbstractGPs" uuid = "99985d1d-32ba-4be9-9821-2ec096f28918" authors = ["JuliaGaussianProcesses Team"] -version = "0.5.11" +version = "0.5.12" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 940777af0175a384b0833be3732d2948b2f1efed Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Apr 2022 09:49:30 +0300 Subject: [PATCH 10/23] revert Project.toml --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 9ac8a7c3..57222240 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,4 @@ [deps] -AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" From f3b736c20117628b9a82571e464f9e6bef41227a Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Apr 2022 13:44:33 +0300 Subject: [PATCH 11/23] mean function rrules --- src/mean_function.jl | 9 +++-- test/runtests.jl | 80 ++++++++++++++++++++++---------------------- 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/src/mean_function.jl b/src/mean_function.jl index 7979984b..ebf2c03c 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -15,7 +15,7 @@ This is an AbstractGPs-internal workaround for AD issues; ideally we would just _map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = zeros(T, length(x)) function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector) - map_ZeroMean_pullback(Δ) = (NoTangent(), NoTangent(), ZeroTangent()) + map_ZeroMean_pullback(Δ) = (NoTangent(), ZeroTangent(), ZeroTangent()) return _map_meanfunction(m, x), map_ZeroMean_pullback end @@ -32,6 +32,11 @@ end _map_meanfunction(m::ConstMean, x::AbstractVector) = fill(m.c, length(x)) +function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ConstMean, x::AbstractVector) + map_ConstMean_pullback(Δ) = (NoTangent(), Tangent{ConstMean}(; c=sum(Δ)), ZeroTangent()) + return _map_meanfunction(m, x), map_ConstMean_pullback +end + """ CustomMean{Tf} <: MeanFunction @@ -42,4 +47,4 @@ struct CustomMean{Tf} <: MeanFunction f::Tf end -_map_meanfunction(f::CustomMean, x::AbstractVector) = map(f.f, x) +_map_meanfunction(m::CustomMean, x::AbstractVector) = map(m.f, x) diff --git a/test/runtests.jl b/test/runtests.jl index 4166a913..060be5ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,19 +45,19 @@ include("test_util.jl") @testset "AbstractGPs" begin if GROUP == "All" || GROUP == "AbstractGPs" - @testset "util" begin - include("util/common_covmat_ops.jl") - include("util/plotting.jl") - end - println(" ") - @info "Ran util tests" + #@testset "util" begin + # include("util/common_covmat_ops.jl") + # include("util/plotting.jl") + #end + #println(" ") + #@info "Ran util tests" - @testset "abstract_gp" begin - include("abstract_gp.jl") - include("finite_gp_projection.jl") - end - println(" ") - @info "Ran abstract_gp tests" + #@testset "abstract_gp" begin + # include("abstract_gp.jl") + # include("finite_gp_projection.jl") + #end + #println(" ") + #@info "Ran abstract_gp tests" @testset "gp" begin include("mean_function.jl") @@ -66,37 +66,37 @@ include("test_util.jl") println(" ") @info "Ran gp tests" - @testset "posterior_gp" begin - include("exact_gpr_posterior.jl") - include("sparse_approximations.jl") - end - println(" ") - @info "Ran posterior_gp tests" + #@testset "posterior_gp" begin + # include("exact_gpr_posterior.jl") + # include("sparse_approximations.jl") + #end + #println(" ") + #@info "Ran posterior_gp tests" - include("latent_gp.jl") - println(" ") - @info "Ran latent_gp tests" + #include("latent_gp.jl") + #println(" ") + #@info "Ran latent_gp tests" - include("deprecations.jl") - println(" ") - @info "Ran deprecation tests" + #include("deprecations.jl") + #println(" ") + #@info "Ran deprecation tests" - @testset "doctests" begin - DocMeta.setdocmeta!( - AbstractGPs, - :DocTestSetup, - :(using AbstractGPs, Random, Documenter, LinearAlgebra); - recursive=true, - ) - doctest( - AbstractGPs; - doctestfilters=[ - r"{([a-zA-Z0-9]+,\s?)+[a-zA-Z0-9]+}", - r"(Array{[a-zA-Z0-9]+,\s?1}|\s?Vector{[a-zA-Z0-9]+})", - r"(Array{[a-zA-Z0-9]+,\s?2}|\s?Matrix{[a-zA-Z0-9]+})", - ], - ) - end + #@testset "doctests" begin + # DocMeta.setdocmeta!( + # AbstractGPs, + # :DocTestSetup, + # :(using AbstractGPs, Random, Documenter, LinearAlgebra); + # recursive=true, + # ) + # doctest( + # AbstractGPs; + # doctestfilters=[ + # r"{([a-zA-Z0-9]+,\s?)+[a-zA-Z0-9]+}", + # r"(Array{[a-zA-Z0-9]+,\s?1}|\s?Vector{[a-zA-Z0-9]+})", + # r"(Array{[a-zA-Z0-9]+,\s?2}|\s?Matrix{[a-zA-Z0-9]+})", + # ], + # ) + #end end if (GROUP == "All" || GROUP == "PPL") && VERSION >= v"1.5" From 750ef77685df0b279265a83380b46b0eea6e11c4 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Apr 2022 15:34:10 +0300 Subject: [PATCH 12/23] revert runtests.jl --- test/runtests.jl | 80 ++++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 060be5ee..4166a913 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,19 +45,19 @@ include("test_util.jl") @testset "AbstractGPs" begin if GROUP == "All" || GROUP == "AbstractGPs" - #@testset "util" begin - # include("util/common_covmat_ops.jl") - # include("util/plotting.jl") - #end - #println(" ") - #@info "Ran util tests" + @testset "util" begin + include("util/common_covmat_ops.jl") + include("util/plotting.jl") + end + println(" ") + @info "Ran util tests" - #@testset "abstract_gp" begin - # include("abstract_gp.jl") - # include("finite_gp_projection.jl") - #end - #println(" ") - #@info "Ran abstract_gp tests" + @testset "abstract_gp" begin + include("abstract_gp.jl") + include("finite_gp_projection.jl") + end + println(" ") + @info "Ran abstract_gp tests" @testset "gp" begin include("mean_function.jl") @@ -66,37 +66,37 @@ include("test_util.jl") println(" ") @info "Ran gp tests" - #@testset "posterior_gp" begin - # include("exact_gpr_posterior.jl") - # include("sparse_approximations.jl") - #end - #println(" ") - #@info "Ran posterior_gp tests" + @testset "posterior_gp" begin + include("exact_gpr_posterior.jl") + include("sparse_approximations.jl") + end + println(" ") + @info "Ran posterior_gp tests" - #include("latent_gp.jl") - #println(" ") - #@info "Ran latent_gp tests" + include("latent_gp.jl") + println(" ") + @info "Ran latent_gp tests" - #include("deprecations.jl") - #println(" ") - #@info "Ran deprecation tests" + include("deprecations.jl") + println(" ") + @info "Ran deprecation tests" - #@testset "doctests" begin - # DocMeta.setdocmeta!( - # AbstractGPs, - # :DocTestSetup, - # :(using AbstractGPs, Random, Documenter, LinearAlgebra); - # recursive=true, - # ) - # doctest( - # AbstractGPs; - # doctestfilters=[ - # r"{([a-zA-Z0-9]+,\s?)+[a-zA-Z0-9]+}", - # r"(Array{[a-zA-Z0-9]+,\s?1}|\s?Vector{[a-zA-Z0-9]+})", - # r"(Array{[a-zA-Z0-9]+,\s?2}|\s?Matrix{[a-zA-Z0-9]+})", - # ], - # ) - #end + @testset "doctests" begin + DocMeta.setdocmeta!( + AbstractGPs, + :DocTestSetup, + :(using AbstractGPs, Random, Documenter, LinearAlgebra); + recursive=true, + ) + doctest( + AbstractGPs; + doctestfilters=[ + r"{([a-zA-Z0-9]+,\s?)+[a-zA-Z0-9]+}", + r"(Array{[a-zA-Z0-9]+,\s?1}|\s?Vector{[a-zA-Z0-9]+})", + r"(Array{[a-zA-Z0-9]+,\s?2}|\s?Matrix{[a-zA-Z0-9]+})", + ], + ) + end end if (GROUP == "All" || GROUP == "PPL") && VERSION >= v"1.5" From a5365d688d79907a1894cca0b2b1f4cf009f951e Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Apr 2022 15:51:06 +0300 Subject: [PATCH 13/23] clean up mean_function tests without reactivating AD --- test/mean_function.jl | 66 +++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 4f0ebb1d..cd271c0d 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -1,47 +1,53 @@ @testset "mean_functions" begin @testset "ZeroMean" begin - P = 3 - Q = 2 - D = 4 - # X = ColVecs(randn(rng, D, P)) - x = randn(P) - x̄ = randn(P) - f = ZeroMean{Float64}() + rng, N, D = MersenneTwister(123456), 5, 3 + x1 = randn(rng, N) + xD = ColVecs(randn(rng, D, N)) + xD′ = RowVecs(randn(rng, N, D)) - for x in [x] - @test AbstractGPs._map_meanfunction(f, x) == zeros(size(x)) - # differentiable_mean_function_tests(f, randn(rng, P), x) - end + m = ZeroMean{Float64}() + + for x in [x1, xD, xD′] + @test AbstractGPs._map_meanfunction(m, x) == zeros(size(x)) + #differentiable_mean_function_tests(m, randn(rng, N), x) - # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but - # currently ChainRulesTestUtils isn't up to handling this, so this will have to do - # for now. - y, pb = rrule(AbstractGPs._map_meanfunction, f, x) - @test y == AbstractGPs._map_meanfunction(f, x) - Δmap, Δf, Δx = pb(randn(P)) - @test iszero(Δmap) - @test iszero(Δf) - @test iszero(Δx) + # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but + # currently ChainRulesTestUtils isn't up to handling this, so this will have to do + # for now. + y, pb = rrule(AbstractGPs._map_meanfunction, m, x) + @test y == AbstractGPs._map_meanfunction(m, x) + Δmap, Δf, Δx = pb(randn(rng, N)) + @test iszero(Δmap) + @test iszero(Δf) + @test iszero(Δx) + end end @testset "ConstMean" begin - rng, D, N = MersenneTwister(123456), 5, 3 - # X = ColVecs(randn(rng, D, N)) - x = randn(rng, N) + rng, N, D = MersenneTwister(123456), 5, 3 + x1 = randn(rng, N) + xD = ColVecs(randn(rng, D, N)) + xD′ = RowVecs(randn(rng, N, D)) + c = randn(rng) m = ConstMean(c) - for x in [x] + for x in [x1, xD, xD′] @test AbstractGPs._map_meanfunction(m, x) == fill(c, N) - # differentiable_mean_function_tests(m, randn(rng, N), x) + #differentiable_mean_function_tests(m, randn(rng, N), x) end end @testset "CustomMean" begin - rng, N, D = MersenneTwister(123456), 11, 2 - x = randn(rng, N) + rng, N, D = MersenneTwister(123456), 5, 3 + x1 = randn(rng, N) + xD = ColVecs(randn(rng, D, N)) + xD′ = RowVecs(randn(rng, N, D)) + foo_mean = x -> sum(abs2, x) - f = CustomMean(foo_mean) + m = CustomMean(foo_mean) - @test AbstractGPs._map_meanfunction(f, x) == map(foo_mean, x) - # differentiable_mean_function_tests(f, randn(rng, N), x) + for x in [x1, xD, xD′] + @test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x) + #differentiable_mean_function_tests(m, randn(rng, N), x) + end end end From 1b0916863ea81b71a10feef0a0734716fb075575 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Apr 2022 16:58:28 +0300 Subject: [PATCH 14/23] unify x...= --- test/mean_function.jl | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index cd271c0d..10940ebf 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -1,14 +1,15 @@ @testset "mean_functions" begin - @testset "ZeroMean" begin - rng, N, D = MersenneTwister(123456), 5, 3 - x1 = randn(rng, N) - xD = ColVecs(randn(rng, D, N)) - xD′ = RowVecs(randn(rng, N, D)) + rng = MersenneTwister(123456) + N, D = 5, 3 + x1 = randn(rng, N) + xD = ColVecs(randn(rng, D, N)) + xD′ = RowVecs(randn(rng, N, D)) + @testset "ZeroMean" begin m = ZeroMean{Float64}() for x in [x1, xD, xD′] - @test AbstractGPs._map_meanfunction(m, x) == zeros(size(x)) + @test AbstractGPs._map_meanfunction(m, x) == zeros(N) #differentiable_mean_function_tests(m, randn(rng, N), x) # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but @@ -22,12 +23,8 @@ @test iszero(Δx) end end - @testset "ConstMean" begin - rng, N, D = MersenneTwister(123456), 5, 3 - x1 = randn(rng, N) - xD = ColVecs(randn(rng, D, N)) - xD′ = RowVecs(randn(rng, N, D)) + @testset "ConstMean" begin c = randn(rng) m = ConstMean(c) @@ -36,12 +33,8 @@ #differentiable_mean_function_tests(m, randn(rng, N), x) end end - @testset "CustomMean" begin - rng, N, D = MersenneTwister(123456), 5, 3 - x1 = randn(rng, N) - xD = ColVecs(randn(rng, D, N)) - xD′ = RowVecs(randn(rng, N, D)) + @testset "CustomMean" begin foo_mean = x -> sum(abs2, x) m = CustomMean(foo_mean) From 37aeb0a52690db7fd7b12455076eab0aa6cdc64f Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Apr 2022 17:17:58 +0300 Subject: [PATCH 15/23] rename --- test/mean_function.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 10940ebf..a699ac29 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -2,13 +2,13 @@ rng = MersenneTwister(123456) N, D = 5, 3 x1 = randn(rng, N) - xD = ColVecs(randn(rng, D, N)) - xD′ = RowVecs(randn(rng, N, D)) + xD_colvecs = ColVecs(randn(rng, D, N)) + xD_rowvecs = RowVecs(randn(rng, N, D)) @testset "ZeroMean" begin m = ZeroMean{Float64}() - for x in [x1, xD, xD′] + for x in [x1, xD_colvecs, xD_rowvecs] @test AbstractGPs._map_meanfunction(m, x) == zeros(N) #differentiable_mean_function_tests(m, randn(rng, N), x) @@ -28,7 +28,7 @@ c = randn(rng) m = ConstMean(c) - for x in [x1, xD, xD′] + for x in [x1, xD_colvecs, xD_rowvecs] @test AbstractGPs._map_meanfunction(m, x) == fill(c, N) #differentiable_mean_function_tests(m, randn(rng, N), x) end @@ -38,7 +38,7 @@ foo_mean = x -> sum(abs2, x) m = CustomMean(foo_mean) - for x in [x1, xD, xD′] + for x in [x1, xD_colvecs, xD_rowvecs] @test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x) #differentiable_mean_function_tests(m, randn(rng, N), x) end From a7a4d8c7cd9a64a395c1287fef626ee22bf7d0c1 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 7 Apr 2022 17:39:33 +0300 Subject: [PATCH 16/23] remove code moved into KernelFunctions --- test/test_util.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_util.jl b/test/test_util.jl index 76956ba8..0fc4f62e 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -27,10 +27,6 @@ end Base.zero(d::Dict) = Dict([(key, zero(val)) for (key, val) in d]) Base.zero(x::Array) = zero.(x) -# TODO should move into KernelFunctions.jl -Base.zero(x::ColVecs) = ColVecs(zero(x.X)) -Base.zero(x::RowVecs) = RowVecs(zero(x.X)) - # My version of isapprox function fd_isapprox(x_ad::Nothing, x_fd, rtol, atol) return fd_isapprox(x_fd, zero(x_fd), rtol, atol) From 263a56c755c6153cabb02ab783c778a38c9989bf Mon Sep 17 00:00:00 2001 From: ST John Date: Fri, 8 Apr 2022 09:38:59 +0300 Subject: [PATCH 17/23] remove no longer needed test --- test/mean_function.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 6e3e1e8d..853b193e 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -11,16 +11,6 @@ for x in [x1, xD_colvecs, xD_rowvecs] @test AbstractGPs._map_meanfunction(m, x) == zeros(N) differentiable_mean_function_tests(m, randn(rng, N), x) - - # Manually verify the ChainRule. Really, this should employ FiniteDifferences, but - # currently ChainRulesTestUtils isn't up to handling this, so this will have to do - # for now. - y, pb = rrule(AbstractGPs._map_meanfunction, m, x) - @test y == AbstractGPs._map_meanfunction(m, x) - Δmap, Δf, Δx = pb(randn(rng, N)) - @test iszero(Δmap) - @test iszero(Δf) - @test iszero(Δx) end end From 97c52842a2f46b0089e30b086a77e1749b624533 Mon Sep 17 00:00:00 2001 From: st-- Date: Fri, 8 Apr 2022 23:04:56 +0300 Subject: [PATCH 18/23] Update src/mean_function.jl Co-authored-by: David Widmann --- src/mean_function.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mean_function.jl b/src/mean_function.jl index ebf2c03c..74efbef8 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -1,7 +1,5 @@ abstract type MeanFunction end -# (m::MeanFunction)(x::AbstractVector) = _map_meanfunction(m, x) - """ ZeroMean{T<:Real} <: MeanFunction From 01a7ac0f94c9c885e7fb92b9ea0f2e7f8c4df2cf Mon Sep 17 00:00:00 2001 From: st-- Date: Sat, 9 Apr 2022 22:34:52 +0300 Subject: [PATCH 19/23] Apply suggestions from code review & revert FillArray changes Co-authored-by: willtebbutt --- src/mean_function.jl | 4 ++-- test/test_util.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mean_function.jl b/src/mean_function.jl index 74efbef8..25efec00 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -10,7 +10,7 @@ struct ZeroMean{T<:Real} <: MeanFunction end """ This is an AbstractGPs-internal workaround for AD issues; ideally we would just extend Base.map """ -_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = zeros(T, length(x)) + _map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x)) function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector) map_ZeroMean_pullback(Δ) = (NoTangent(), ZeroTangent(), ZeroTangent()) @@ -28,7 +28,7 @@ struct ConstMean{T<:Real} <: MeanFunction c::T end -_map_meanfunction(m::ConstMean, x::AbstractVector) = fill(m.c, length(x)) +_map_meanfunction(m::ConstMean, x::AbstractVector) = Fill(m.c, length(x)) function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ConstMean, x::AbstractVector) map_ConstMean_pullback(Δ) = (NoTangent(), Tangent{ConstMean}(; c=sum(Δ)), ZeroTangent()) diff --git a/test/test_util.jl b/test/test_util.jl index 0fc4f62e..e32911a5 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -95,7 +95,7 @@ function differentiable_mean_function_tests( # Check adjoint. @assert length(ȳ) == length(x) - adjoint_test(x -> AbstractGPs._map_meanfunction(m, x), ȳ, x; rtol=rtol, atol=atol) + adjoint_test(x -> collect(AbstractGPs._map_meanfunction(m, x)), ȳ, x; rtol=rtol, atol=atol) return nothing end From 4b0a68372affed9611e7fa4ba82779a683df7581 Mon Sep 17 00:00:00 2001 From: st-- Date: Sat, 9 Apr 2022 22:36:26 +0300 Subject: [PATCH 20/23] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/mean_function.jl | 2 +- test/test_util.jl | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mean_function.jl b/src/mean_function.jl index 25efec00..143b2be1 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -10,7 +10,7 @@ struct ZeroMean{T<:Real} <: MeanFunction end """ This is an AbstractGPs-internal workaround for AD issues; ideally we would just extend Base.map """ - _map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x)) +_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x)) function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector) map_ZeroMean_pullback(Δ) = (NoTangent(), ZeroTangent(), ZeroTangent()) diff --git a/test/test_util.jl b/test/test_util.jl index e32911a5..e525b1fa 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -95,7 +95,9 @@ function differentiable_mean_function_tests( # Check adjoint. @assert length(ȳ) == length(x) - adjoint_test(x -> collect(AbstractGPs._map_meanfunction(m, x)), ȳ, x; rtol=rtol, atol=atol) + adjoint_test( + x -> collect(AbstractGPs._map_meanfunction(m, x)), ȳ, x; rtol=rtol, atol=atol + ) return nothing end From f1df8b501c0c0c624518b53f7fe0eb89a6839cc2 Mon Sep 17 00:00:00 2001 From: ST John Date: Sat, 9 Apr 2022 23:16:59 +0300 Subject: [PATCH 21/23] unify testcases --- test/mean_function.jl | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 853b193e..5e80a206 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -5,31 +5,22 @@ xD_colvecs = ColVecs(randn(rng, D, N)) xD_rowvecs = RowVecs(randn(rng, N, D)) - @testset "ZeroMean" begin - m = ZeroMean{Float64}() + zero_mean_testcase = (; mean_function=ZeroMean(), calc_expected=_ -> zeros(N)) - for x in [x1, xD_colvecs, xD_rowvecs] - @test AbstractGPs._map_meanfunction(m, x) == zeros(N) - differentiable_mean_function_tests(m, randn(rng, N), x) - end - end - - @testset "ConstMean" begin - c = randn(rng) - m = ConstMean(c) - - for x in [x1, xD_colvecs, xD_rowvecs] - @test AbstractGPs._map_meanfunction(m, x) == fill(c, N) - differentiable_mean_function_tests(m, randn(rng, N), x) - end - end + c = randn(rng) + const_mean_testcase = (; mean_function=ConstMean(c), calc_expected=_ -> fill(c, N)) - @testset "CustomMean" begin - foo_mean = x -> sum(abs2, x) - m = CustomMean(foo_mean) + foo_mean = x -> sum(abs2, x) + custom_mean_testcase = (; + mean_function=CustomMean(foo_mean), calc_expected=x -> map(foo_mean, x) + ) + @testset "$(typeof(testcase.mean_function))" for testcase in [ + zero_mean_testcase, const_mean_testcase, custom_mean_testcase + ] for x in [x1, xD_colvecs, xD_rowvecs] - @test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x) + m = testcase.mean_function + @test AbstractGPs._map_meanfunction(m, x) == testcase.calc_expected(x) differentiable_mean_function_tests(m, randn(rng, N), x) end end From 145091df2000e483e0c9328394661407a2f735ff Mon Sep 17 00:00:00 2001 From: ST John Date: Sat, 9 Apr 2022 23:20:45 +0300 Subject: [PATCH 22/23] remove rrules and ChainRulesCore --- Project.toml | 2 -- src/AbstractGPs.jl | 1 - src/mean_function.jl | 10 ---------- test/Project.toml | 2 -- test/runtests.jl | 1 - 5 files changed, 16 deletions(-) diff --git a/Project.toml b/Project.toml index d49580d7..7897c5e1 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["JuliaGaussianProcesses Team"] version = "0.5.12" [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" @@ -19,7 +18,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "1" Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13" IrrationalConstants = "0.1" diff --git a/src/AbstractGPs.jl b/src/AbstractGPs.jl index 9792cbea..acb3c577 100644 --- a/src/AbstractGPs.jl +++ b/src/AbstractGPs.jl @@ -1,6 +1,5 @@ module AbstractGPs -using ChainRulesCore using Distributions using FillArrays using LinearAlgebra diff --git a/src/mean_function.jl b/src/mean_function.jl index 143b2be1..e691b154 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -12,11 +12,6 @@ This is an AbstractGPs-internal workaround for AD issues; ideally we would just """ _map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x)) -function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector) - map_ZeroMean_pullback(Δ) = (NoTangent(), ZeroTangent(), ZeroTangent()) - return _map_meanfunction(m, x), map_ZeroMean_pullback -end - ZeroMean() = ZeroMean{Float64}() """ @@ -30,11 +25,6 @@ end _map_meanfunction(m::ConstMean, x::AbstractVector) = Fill(m.c, length(x)) -function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ConstMean, x::AbstractVector) - map_ConstMean_pullback(Δ) = (NoTangent(), Tangent{ConstMean}(; c=sum(Δ)), ZeroTangent()) - return _map_meanfunction(m, x), map_ConstMean_pullback -end - """ CustomMean{Tf} <: MeanFunction diff --git a/test/Project.toml b/test/Project.toml index 57222240..5051d678 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,4 @@ [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -14,7 +13,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ChainRulesCore = "1" Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" Documenter = "0.24, 0.25, 0.26, 0.27" FillArrays = "0.11, 0.12, 0.13" diff --git a/test/runtests.jl b/test/runtests.jl index 4166a913..415c0032 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,7 +23,6 @@ using AbstractGPs: TestUtils using Documenter -using ChainRulesCore using Distributions: MvNormal, PDMat, loglikelihood, Distributions using FillArrays using FiniteDifferences From 41a01da4bb9429276c8209079c34a5036d08e62f Mon Sep 17 00:00:00 2001 From: ST John Date: Sat, 9 Apr 2022 23:34:24 +0300 Subject: [PATCH 23/23] pass rng --- test/mean_function.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mean_function.jl b/test/mean_function.jl index 5e80a206..511ea687 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -21,7 +21,7 @@ for x in [x1, xD_colvecs, xD_rowvecs] m = testcase.mean_function @test AbstractGPs._map_meanfunction(m, x) == testcase.calc_expected(x) - differentiable_mean_function_tests(m, randn(rng, N), x) + differentiable_mean_function_tests(rng, m, x) end end end