Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reactivate AD tests: mean functions #313

Merged
merged 30 commits into from
Apr 10, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1a99d74
reactivate mean function AD tests
st-- Apr 6, 2022
caeffdc
format
st-- Apr 6, 2022
9f6227f
fix test
st-- Apr 6, 2022
f13c902
revert FillArray
st-- Apr 6, 2022
6ee4c76
extend mean function tests to ColVecs/RowVecs
st-- Apr 6, 2022
9c85f6d
extend ZeroMean tests too
st-- Apr 6, 2022
83ccb64
bugfix
st-- Apr 6, 2022
2989293
add missing zero() definition for ColVecs/RowVecs
st-- Apr 6, 2022
8043454
patch bump
st-- Apr 6, 2022
940777a
revert Project.toml
st-- Apr 7, 2022
f3b736c
mean function rrules
st-- Apr 7, 2022
750ef77
revert runtests.jl
st-- Apr 7, 2022
bc0ee6a
Merge branch 'master' of github.com:JuliaGaussianProcesses/AbstractGP…
st-- Apr 7, 2022
a5365d6
clean up mean_function tests without reactivating AD
st-- Apr 7, 2022
75f95cd
Merge branch 'st/meanfunctiontest' into st/ad_tests
st-- Apr 7, 2022
1b09168
unify x...=
st-- Apr 7, 2022
80b7813
Merge branch 'st/meanfunctiontest' into st/ad_tests
st-- Apr 7, 2022
37aeb0a
rename
st-- Apr 7, 2022
a7a4d8c
remove code moved into KernelFunctions
st-- Apr 7, 2022
81841c1
Merge branch 'st/ad_tests' of github.com:JuliaGaussianProcesses/Abstr…
st-- Apr 7, 2022
77f6ebf
Merge branch 'st/meanfunctiontest' into st/ad_tests
st-- Apr 8, 2022
263a56c
remove no longer needed test
st-- Apr 8, 2022
fd69656
Merge branch 'master' into st/ad_tests
st-- Apr 8, 2022
97c5284
Update src/mean_function.jl
st-- Apr 8, 2022
9e93b93
Merge branch 'master' into st/ad_tests
st-- Apr 9, 2022
01a7ac0
Apply suggestions from code review
st-- Apr 9, 2022
4b0a683
Apply suggestions from code review
st-- Apr 9, 2022
f1df8b5
unify testcases
st-- Apr 9, 2022
145091d
remove rrules and ChainRulesCore
st-- Apr 9, 2022
41a01da
pass rng
st-- Apr 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/mean_function.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
abstract type MeanFunction end

# (m::MeanFunction)(x::AbstractVector) = _map_meanfunction(m, x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we define something like this ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the motivation? Generally, I don't like implicit mapping or broadcasting. IIRC the function only exists to work around Zygote AD issues.

In any case, IMO it does not belong kn this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it doesn't belong in here, I just thought while people are thinking about mean functions we can consider it. will remove it before merging.


st-- marked this conversation as resolved.
Show resolved Hide resolved
"""
ZeroMean{T<:Real} <: MeanFunction

Expand All @@ -10,7 +12,7 @@ struct ZeroMean{T<:Real} <: MeanFunction end
"""
This is an AbstractGPs-internal workaround for AD issues; ideally we would just extend Base.map
"""
_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x))
st-- marked this conversation as resolved.
Show resolved Hide resolved
_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = zeros(T, length(x))
st-- marked this conversation as resolved.
Show resolved Hide resolved

function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector)
map_ZeroMean_pullback(Δ) = (NoTangent(), NoTangent(), ZeroTangent())
Expand All @@ -28,7 +30,7 @@ struct ConstMean{T<:Real} <: MeanFunction
c::T
end

_map_meanfunction(m::ConstMean, x::AbstractVector) = Fill(m.c, length(x))
_map_meanfunction(m::ConstMean, x::AbstractVector) = fill(m.c, length(x))
st-- marked this conversation as resolved.
Show resolved Hide resolved
st-- marked this conversation as resolved.
Show resolved Hide resolved

"""
CustomMean{Tf} <: MeanFunction
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tested package is usually not part of test/Project.toml as Pkg adds it automatically (https://pkgdocs.julialang.org/v1/creating-packages/#Test-specific-dependencies-in-Julia-1.2-and-above) and otherwise one has to add and update compat entries:

Suggested change
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I added it because I wanted to be able to run some of the tests locally, and when I used julia --project=test it complained about AbstractGPs not being in the project

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's not an intended workflow and not officially supported. Tests are supposed to be run with Pkg.test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what's the intended workflow for "I don't want to run all the tests because that takes a really long time, I just want to run the tests that I'm currently working on in this one file"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(who says what's intended, officially supported, and supposed to be done? where would I find out about that?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the test setup in eg KernelFunctions (and AbstractGPs?) with all imports and utilities in runtests.jl one has to run runtests.jl anyway, it seems, or load the packages manually. In general, also the first two options might be a bit misleading if eg other tests mutate the RNG.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@st-- I use TestEnv.jl (loaded at startup). In the working repo:

TestEnv.activate()

and then all tests modules are loaded as well as the current repo.
But I agree it's a mess

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, TestEnv seems to work, so I've reverted this change. How would I have been able to find this out by myself? Is this something we should/could document somewhere ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like a lot of things about Julia (e.g. the Pkg manager is so much better than in python land). But I have to say I miss pytest. It makes it so easy to select which subset of tests you want to run. E.g. just the slow ones, or just the fast ones, or the ones in these files only...

I think it'd help if instead of stuffing all imports into runtests.jl, having them at the top of each individual test file. Then it'd be easy to just include() one of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one learns about such things by eg asking, googling, or attending JuliaCon. IIRC a while ago TestEnv was also discussed in the Turing slack (probably during or after JuliaCon).

I don't think we should document anything here. It's not a common task, it's not JuliaGP specific, and there's not a single preferred approach in the Julia ecosystem. In particular we don't care about how people test their PRs locally, if they run all tests or only parts of it (this can be tricky even with separate files +TestEnv, see below) and if they use TestEnv or not.

One nice thing about tests is that it is very flexible since it's just one long script, possibly with includes. One can split it up or just use a single file, one can use multiple test sets or not, one can run some tests conditionally on environment variables... But of course the flexibility comes at a cost, eg. when trying to run a subset of tests only.

In general, running subsets of tests seems most convenient if they are 1) separated in different files and/or with different switches such as environment variables (used eg in SciML but also the Turing ecosystem and even AbstractGPs) and 2) put into separate modules, eg with SafeTestsets, with explicit imports, test utilities, and generally avoiding any other leakage eg from mutating the RNG. Without 2) running tests separately may fail and, if all are passing, does not guarantee that running all tests together is successful.

ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand Down
41 changes: 23 additions & 18 deletions test/mean_function.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,52 @@
@testset "mean_functions" begin
@testset "ZeroMean" begin
P = 3
Q = 2
D = 4
# X = ColVecs(randn(rng, D, P))
x = randn(P)
x̄ = randn(P)
rng, D, N = MersenneTwister(123456), 5, 3
# X = ColVecs(randn(rng, D, N))
x = randn(rng, N)
x̄ = randn(rng, N)
f = ZeroMean{Float64}()

for x in [x]
@test AbstractGPs._map_meanfunction(f, x) == zeros(size(x))
# differentiable_mean_function_tests(f, randn(rng, P), x)
differentiable_mean_function_tests(f, randn(rng, N), x)
end

# Manually verify the ChainRule. Really, this should employ FiniteDifferences, but
# currently ChainRulesTestUtils isn't up to handling this, so this will have to do
# for now.
y, pb = rrule(AbstractGPs._map_meanfunction, f, x)
@test y == AbstractGPs._map_meanfunction(f, x)
Δmap, Δf, Δx = pb(randn(P))
Δmap, Δf, Δx = pb(randn(rng, N))
@test iszero(Δmap)
@test iszero(Δf)
@test iszero(Δx)
end
@testset "ConstMean" begin
rng, D, N = MersenneTwister(123456), 5, 3
# X = ColVecs(randn(rng, D, N))
x = randn(rng, N)
rng, N, D = MersenneTwister(123456), 5, 3
st-- marked this conversation as resolved.
Show resolved Hide resolved
x1 = randn(rng, N)
xD = ColVecs(randn(rng, D, N))
xD′ = RowVecs(randn(rng, N, D))

c = randn(rng)
m = ConstMean(c)

for x in [x]
for x in [x1, xD, xD′]
@test AbstractGPs._map_meanfunction(m, x) == fill(c, N)
# differentiable_mean_function_tests(m, randn(rng, N), x)
differentiable_mean_function_tests(m, randn(rng, N), x)
end
end
@testset "CustomMean" begin
rng, N, D = MersenneTwister(123456), 11, 2
x = randn(rng, N)
rng, N, D = MersenneTwister(123456), 5, 3
x1 = randn(rng, N)
xD = ColVecs(randn(rng, D, N))
xD′ = RowVecs(randn(rng, N, D))

foo_mean = x -> sum(abs2, x)
f = CustomMean(foo_mean)
m = CustomMean(foo_mean)
st-- marked this conversation as resolved.
Show resolved Hide resolved

@test AbstractGPs._map_meanfunction(f, x) == map(foo_mean, x)
# differentiable_mean_function_tests(f, randn(rng, N), x)
for x in [x1, xD, xD′]
@test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x)
differentiable_mean_function_tests(m, randn(rng, N), x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just call

Suggested change
differentiable_mean_function_tests(m, randn(rng, N), x)
differentiable_mean_function_tests(rng, m, x)

instead (and remove the y = ...), or otherwise remove that (currently unused) method definition of differentiable_mean_function_tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it would simplify things a bit, so I'm in favour.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which one? remove the (rng, m, x) method, or apply this suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be inclined to retain the method that just requires a rng, since we can now be confident that it will generate an appropriate tangent due to the call to collect.

end
end
end
28 changes: 6 additions & 22 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ end
Test _very_ basic consistency properties of the mean function `m`.
"""
function mean_function_tests(m::MeanFunction, x::AbstractVector)
@test AbstractGPs._map_meanfunction(m, x) isa AbstractVector
@test length(ew(m, x)) == length(x)
mean = AbstractGPs._map_meanfunction(m, x)
@test mean isa AbstractVector
@test length(mean) == length(x)
end

"""
Expand All @@ -87,34 +88,17 @@ end
Ensure that the gradient w.r.t. the inputs of `MeanFunction` `m` are approximately correct.
"""
function differentiable_mean_function_tests(
m::MeanFunction,
ȳ::AbstractVector{<:Real},
x::AbstractVector{<:Real};
rtol=_rtol,
atol=_atol,
m::MeanFunction, ȳ::AbstractVector, x::AbstractVector; rtol=_rtol, atol=_atol
)
# Run forward tests.
mean_function_tests(m, x)

# Check adjoint.
@assert length(ȳ) == length(x)
return adjoint_test(x -> ew(m, x), ȳ, x; rtol=rtol, atol=atol)
adjoint_test(x -> AbstractGPs._map_meanfunction(m, x), ȳ, x; rtol=rtol, atol=atol)
st-- marked this conversation as resolved.
Show resolved Hide resolved
return nothing
end

# function differentiable_mean_function_tests(
# m::MeanFunction,
# ȳ::AbstractVector{<:Real},
# x::ColVecs{<:Real};
# rtol=_rtol,
# atol=_atol,
# )
# # Run forward tests.
# mean_function_tests(m, x)

# @assert length(ȳ) == length(x)
# adjoint_test(X->ew(m, ColVecs(X)), ȳ, x.X; rtol=rtol, atol=atol)
# end

function differentiable_mean_function_tests(
rng::AbstractRNG, m::MeanFunction, x::AbstractVector; rtol=_rtol, atol=_atol
)
Expand Down