Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reactivate AD tests: mean functions #313

Merged
merged 30 commits into from
Apr 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1a99d74
reactivate mean function AD tests
st-- Apr 6, 2022
caeffdc
format
st-- Apr 6, 2022
9f6227f
fix test
st-- Apr 6, 2022
f13c902
revert FillArray
st-- Apr 6, 2022
6ee4c76
extend mean function tests to ColVecs/RowVecs
st-- Apr 6, 2022
9c85f6d
extend ZeroMean tests too
st-- Apr 6, 2022
83ccb64
bugfix
st-- Apr 6, 2022
2989293
add missing zero() definition for ColVecs/RowVecs
st-- Apr 6, 2022
8043454
patch bump
st-- Apr 6, 2022
940777a
revert Project.toml
st-- Apr 7, 2022
f3b736c
mean function rrules
st-- Apr 7, 2022
750ef77
revert runtests.jl
st-- Apr 7, 2022
bc0ee6a
Merge branch 'master' of github.com:JuliaGaussianProcesses/AbstractGP…
st-- Apr 7, 2022
a5365d6
clean up mean_function tests without reactivating AD
st-- Apr 7, 2022
75f95cd
Merge branch 'st/meanfunctiontest' into st/ad_tests
st-- Apr 7, 2022
1b09168
unify x...=
st-- Apr 7, 2022
80b7813
Merge branch 'st/meanfunctiontest' into st/ad_tests
st-- Apr 7, 2022
37aeb0a
rename
st-- Apr 7, 2022
a7a4d8c
remove code moved into KernelFunctions
st-- Apr 7, 2022
81841c1
Merge branch 'st/ad_tests' of github.com:JuliaGaussianProcesses/Abstr…
st-- Apr 7, 2022
77f6ebf
Merge branch 'st/meanfunctiontest' into st/ad_tests
st-- Apr 8, 2022
263a56c
remove no longer needed test
st-- Apr 8, 2022
fd69656
Merge branch 'master' into st/ad_tests
st-- Apr 8, 2022
97c5284
Update src/mean_function.jl
st-- Apr 8, 2022
9e93b93
Merge branch 'master' into st/ad_tests
st-- Apr 9, 2022
01a7ac0
Apply suggestions from code review
st-- Apr 9, 2022
4b0a683
Apply suggestions from code review
st-- Apr 9, 2022
f1df8b5
unify testcases
st-- Apr 9, 2022
145091d
remove rrules and ChainRulesCore
st-- Apr 9, 2022
41a01da
pass rng
st-- Apr 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
name = "AbstractGPs"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
authors = ["JuliaGaussianProcesses Team"]
version = "0.5.11"
version = "0.5.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand All @@ -19,7 +18,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "1"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
IrrationalConstants = "0.1"
Expand Down
1 change: 0 additions & 1 deletion src/AbstractGPs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module AbstractGPs

using ChainRulesCore
using Distributions
using FillArrays
using LinearAlgebra
Expand Down
7 changes: 1 addition & 6 deletions src/mean_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ This is an AbstractGPs-internal workaround for AD issues; ideally we would just
"""
_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x))
st-- marked this conversation as resolved.
Show resolved Hide resolved

function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector)
map_ZeroMean_pullback(Δ) = (NoTangent(), NoTangent(), ZeroTangent())
return _map_meanfunction(m, x), map_ZeroMean_pullback
end

ZeroMean() = ZeroMean{Float64}()

"""
Expand All @@ -40,4 +35,4 @@ struct CustomMean{Tf} <: MeanFunction
f::Tf
end

_map_meanfunction(f::CustomMean, x::AbstractVector) = map(f.f, x)
_map_meanfunction(m::CustomMean, x::AbstractVector) = map(m.f, x)
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -14,7 +13,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "1"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
Documenter = "0.24, 0.25, 0.26, 0.27"
FillArrays = "0.11, 0.12, 0.13"
Expand Down
45 changes: 13 additions & 32 deletions test/mean_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,23 @@
xD_colvecs = ColVecs(randn(rng, D, N))
xD_rowvecs = RowVecs(randn(rng, N, D))

@testset "ZeroMean" begin
m = ZeroMean{Float64}()
zero_mean_testcase = (; mean_function=ZeroMean(), calc_expected=_ -> zeros(N))

for x in [x1, xD_colvecs, xD_rowvecs]
@test AbstractGPs._map_meanfunction(m, x) == zeros(N)
#differentiable_mean_function_tests(m, randn(rng, N), x)

# Manually verify the ChainRule. Really, this should employ FiniteDifferences, but
# currently ChainRulesTestUtils isn't up to handling this, so this will have to do
# for now.
y, pb = rrule(AbstractGPs._map_meanfunction, m, x)
@test y == AbstractGPs._map_meanfunction(m, x)
Δmap, Δf, Δx = pb(randn(rng, N))
@test iszero(Δmap)
@test iszero(Δf)
@test iszero(Δx)
end
end

@testset "ConstMean" begin
c = randn(rng)
m = ConstMean(c)

for x in [x1, xD_colvecs, xD_rowvecs]
@test AbstractGPs._map_meanfunction(m, x) == fill(c, N)
#differentiable_mean_function_tests(m, randn(rng, N), x)
end
end
c = randn(rng)
const_mean_testcase = (; mean_function=ConstMean(c), calc_expected=_ -> fill(c, N))

@testset "CustomMean" begin
foo_mean = x -> sum(abs2, x)
m = CustomMean(foo_mean)
foo_mean = x -> sum(abs2, x)
custom_mean_testcase = (;
mean_function=CustomMean(foo_mean), calc_expected=x -> map(foo_mean, x)
)

@testset "$(typeof(testcase.mean_function))" for testcase in [
zero_mean_testcase, const_mean_testcase, custom_mean_testcase
]
for x in [x1, xD_colvecs, xD_rowvecs]
@test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x)
#differentiable_mean_function_tests(m, randn(rng, N), x)
m = testcase.mean_function
@test AbstractGPs._map_meanfunction(m, x) == testcase.calc_expected(x)
differentiable_mean_function_tests(rng, m, x)
end
end
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ using AbstractGPs:
TestUtils

using Documenter
using ChainRulesCore
using Distributions: MvNormal, PDMat, loglikelihood, Distributions
using FillArrays
using FiniteDifferences
Expand Down
30 changes: 8 additions & 22 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ end
Test _very_ basic consistency properties of the mean function `m`.
"""
function mean_function_tests(m::MeanFunction, x::AbstractVector)
@test AbstractGPs._map_meanfunction(m, x) isa AbstractVector
@test length(ew(m, x)) == length(x)
mean = AbstractGPs._map_meanfunction(m, x)
@test mean isa AbstractVector
@test length(mean) == length(x)
end

"""
Expand All @@ -87,34 +88,19 @@ end
Ensure that the gradient w.r.t. the inputs of `MeanFunction` `m` are approximately correct.
"""
function differentiable_mean_function_tests(
m::MeanFunction,
ȳ::AbstractVector{<:Real},
x::AbstractVector{<:Real};
rtol=_rtol,
atol=_atol,
m::MeanFunction, ȳ::AbstractVector, x::AbstractVector; rtol=_rtol, atol=_atol
)
# Run forward tests.
mean_function_tests(m, x)

# Check adjoint.
@assert length(ȳ) == length(x)
return adjoint_test(x -> ew(m, x), ȳ, x; rtol=rtol, atol=atol)
adjoint_test(
x -> collect(AbstractGPs._map_meanfunction(m, x)), ȳ, x; rtol=rtol, atol=atol
)
return nothing
end

# function differentiable_mean_function_tests(
# m::MeanFunction,
# ȳ::AbstractVector{<:Real},
# x::ColVecs{<:Real};
# rtol=_rtol,
# atol=_atol,
# )
# # Run forward tests.
# mean_function_tests(m, x)

# @assert length(ȳ) == length(x)
# adjoint_test(X->ew(m, ColVecs(X)), ȳ, x.X; rtol=rtol, atol=atol)
# end

function differentiable_mean_function_tests(
rng::AbstractRNG, m::MeanFunction, x::AbstractVector; rtol=_rtol, atol=_atol
)
Expand Down