Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reactivate AD tests: mean functions #313

Merged
merged 30 commits into from
Apr 10, 2022
Merged

reactivate AD tests: mean functions #313

merged 30 commits into from
Apr 10, 2022

Conversation

st--
Copy link
Member

@st-- st-- commented Apr 6, 2022

Turns out there were a bunch of AD tests in the code. They were just commented out (and a bit broken). This is the work-in-progress attempt to reactivate them, and thereby resolve #311. Help welcome!

@st-- st-- added the help wanted Extra attention is needed label Apr 6, 2022
test/test_util.jl Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Apr 6, 2022

Codecov Report

Merging #313 (41a01da) into master (d99311e) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #313      +/-   ##
==========================================
- Coverage   97.64%   97.62%   -0.02%     
==========================================
  Files          10       10              
  Lines         382      379       -3     
==========================================
- Hits          373      370       -3     
  Misses          9        9              
Impacted Files Coverage Δ
src/mean_function.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d99311e...41a01da. Read the comment docs.

@@ -1,5 +1,7 @@
abstract type MeanFunction end

# (m::MeanFunction)(x::AbstractVector) = _map_meanfunction(m, x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we define something like this ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the motivation? Generally, I don't like implicit mapping or broadcasting. IIRC the function only exists to work around Zygote AD issues.

In any case, IMO it does not belong kn this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it doesn't belong in here, I just thought while people are thinking about mean functions we can consider it. will remove it before merging.

test/mean_function.jl Outdated Show resolved Hide resolved
# differentiable_mean_function_tests(f, randn(rng, N), x)
for x in [x1, xD, xD′]
@test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x)
differentiable_mean_function_tests(m, randn(rng, N), x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just call

Suggested change
differentiable_mean_function_tests(m, randn(rng, N), x)
differentiable_mean_function_tests(rng, m, x)

instead (and remove the y = ...), or otherwise remove that (currently unused) method definition of differentiable_mean_function_tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it would simplify things a bit, so I'm in favour.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which one? remove the (rng, m, x) method, or apply this suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be inclined to retain the method that just requires a rng, since we can now be confident that it will generate an appropriate tangent due to the call to collect.

@st-- st-- changed the title reactivate AD tests reactivate AD tests: mean functions Apr 6, 2022
@@ -1,4 +1,5 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tested package is usually not part of test/Project.toml as Pkg adds it automatically (https://pkgdocs.julialang.org/v1/creating-packages/#Test-specific-dependencies-in-Julia-1.2-and-above) and otherwise one has to add and update compat entries:

Suggested change
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I added it because I wanted to be able to run some of the tests locally, and when I used julia --project=test it complained about AbstractGPs not being in the project

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's not an intended workflow and not officially supported. Tests are supposed to be run with Pkg.test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what's the intended workflow for "I don't want to run all the tests because that takes a really long time, I just want to run the tests that I'm currently working on in this one file"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(who says what's intended, officially supported, and supposed to be done? where would I find out about that?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the test setup in eg KernelFunctions (and AbstractGPs?) with all imports and utilities in runtests.jl one has to run runtests.jl anyway, it seems, or load the packages manually. In general, also the first two options might be a bit misleading if eg other tests mutate the RNG.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@st-- I use TestEnv.jl (loaded at startup). In the working repo:

TestEnv.activate()

and then all tests modules are loaded as well as the current repo.
But I agree it's a mess

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, TestEnv seems to work, so I've reverted this change. How would I have been able to find this out by myself? Is this something we should/could document somewhere ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like a lot of things about Julia (e.g. the Pkg manager is so much better than in python land). But I have to say I miss pytest. It makes it so easy to select which subset of tests you want to run. E.g. just the slow ones, or just the fast ones, or the ones in these files only...

I think it'd help if instead of stuffing all imports into runtests.jl, having them at the top of each individual test file. Then it'd be easy to just include() one of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one learns about such things by eg asking, googling, or attending JuliaCon. IIRC a while ago TestEnv was also discussed in the Turing slack (probably during or after JuliaCon).

I don't think we should document anything here. It's not a common task, it's not JuliaGP specific, and there's not a single preferred approach in the Julia ecosystem. In particular we don't care about how people test their PRs locally, if they run all tests or only parts of it (this can be tricky even with separate files +TestEnv, see below) and if they use TestEnv or not.

One nice thing about tests is that it is very flexible since it's just one long script, possibly with includes. One can split it up or just use a single file, one can use multiple test sets or not, one can run some tests conditionally on environment variables... But of course the flexibility comes at a cost, eg. when trying to run a subset of tests only.

In general, running subsets of tests seems most convenient if they are 1) separated in different files and/or with different switches such as environment variables (used eg in SciML but also the Turing ecosystem and even AbstractGPs) and 2) put into separate modules, eg with SafeTestsets, with explicit imports, test utilities, and generally avoiding any other leakage eg from mutating the RNG. Without 2) running tests separately may fail and, if all are passing, does not guarantee that running all tests together is successful.

src/mean_function.jl Outdated Show resolved Hide resolved
src/mean_function.jl Show resolved Hide resolved
src/mean_function.jl Outdated Show resolved Hide resolved
test/mean_function.jl Outdated Show resolved Hide resolved
test/mean_function.jl Outdated Show resolved Hide resolved
test/mean_function.jl Outdated Show resolved Hide resolved
test/mean_function.jl Outdated Show resolved Hide resolved
test/mean_function.jl Outdated Show resolved Hide resolved
Comment on lines 30 to 33
# TODO should move into KernelFunctions.jl
Base.zero(x::ColVecs) = ColVecs(zero(x.X))
Base.zero(x::RowVecs) = RowVecs(zero(x.X))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this should really not exist here.

Suggested change
# TODO should move into KernelFunctions.jl
Base.zero(x::ColVecs) = ColVecs(zero(x.X))
Base.zero(x::RowVecs) = RowVecs(zero(x.X))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I'll open a PR in KernelFunctions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update it here once JuliaGaussianProcesses/KernelFunctions.jl#444 is merged

Copy link
Member Author

@st-- st-- Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though note that this is just for the tests. Above, it also defines

Base.zero(d::Dict) = Dict([(key, zero(val)) for (key, val) in d])
Base.zero(x::Array) = zero.(x)

(The zero(::Array) seems to actually be implemented in Base, so not sure why that's here; I'm assuming only the Dict() one is actually needed)

@st--
Copy link
Member Author

st-- commented Apr 7, 2022

@devmotion you requested changes, but it's not clear to me what changes you're actually requesting, could you please clarify?
regarding the definition of zero(::ColVecs), see my reply #313 (comment) though I also opened JuliaGaussianProcesses/KernelFunctions.jl#444

@devmotion
Copy link
Member

devmotion commented Apr 7, 2022

I would like my comments to be addressed. Ie, no new commented out methods, no unrelated and unneeded changes, no new type piracy (doesn't matter if it's for tests or in the package), and understanding what causes the AD issues before reverting the FillArray changes (since possibly one can fix them in a better way).

@devmotion
Copy link
Member

I unsubscribed, please find another reviewer.

@st-- st-- requested review from willtebbutt and theogf April 8, 2022 13:28
Base automatically changed from st/meanfunctiontest to master April 8, 2022 16:23
@st-- st-- dismissed devmotion’s stale review April 8, 2022 20:04

unsubscribed

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this broadly looks good -- I've given some thoughts on how we might sort out the AD testing related problems. I would like to re-review once we've sorted out the AD testing issues, but I think it's basically there.

test/test_util.jl Outdated Show resolved Hide resolved
src/mean_function.jl Outdated Show resolved Hide resolved
# differentiable_mean_function_tests(f, randn(rng, N), x)
for x in [x1, xD, xD′]
@test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x)
differentiable_mean_function_tests(m, randn(rng, N), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it would simplify things a bit, so I'm in favour.

src/mean_function.jl Show resolved Hide resolved
Comment on lines 33 to 36
function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ConstMean, x::AbstractVector)
map_ConstMean_pullback(Δ) = (NoTangent(), Tangent{ConstMean}(; c=sum(Δ)), ZeroTangent())
return _map_meanfunction(m, x), map_ConstMean_pullback
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, was this introduced to attempt to deal with the AD testing problems?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that was my attempt at working around the lack of FillArray rrules/projectors by doing it directly on the mean function call... do you think it should get removed again ?

src/mean_function.jl Outdated Show resolved Hide resolved
src/mean_function.jl Outdated Show resolved Hide resolved
& revert FillArray changes

Co-authored-by: willtebbutt <[email protected]>
src/mean_function.jl Outdated Show resolved Hide resolved
test/test_util.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@st--
Copy link
Member Author

st-- commented Apr 9, 2022

@willtebbutt thanks, looks like your collect() actually fixed the issues, and now it's all happy even with Zero and Fill. Also, I checked and neither of the two rrules defined in mean_function.jl for ZeroMean and ConstMean, respectively, are actually needed for the AD test to pass. Should we keep (/add) or remove them ?

Re #313 (comment) please let me know which way around you prefer.

@willtebbutt
Copy link
Member

@willtebbutt thanks, looks like your collect() actually fixed the issues, and now it's all happy even with Zero and Fill. Also, I checked and neither of the two rrules defined in mean_function.jl for ZeroMean and ConstMean, respectively, are actually needed for the AD test to pass. Should we keep (/add) or remove them ?

Excellent. Please remove the rules since they're now not needed.

@st--
Copy link
Member Author

st-- commented Apr 9, 2022

Will do.

What do you think of f1df8b5 ?

@st--
Copy link
Member Author

st-- commented Apr 9, 2022

Please remove the rules since they're now not needed.

I'm wondering if they might make it a bit faster though (as they're shortcutting some backprop)?

@willtebbutt
Copy link
Member

I'm wondering if they might make it a bit faster though (as they're shortcutting some backprop)?

They really shouldn't make stuff faster in this case -- it's simple enough that Zygote ought to be able to infer this properly. In my view, the additional code complexity probably isn't worth it.

@willtebbutt
Copy link
Member

What do you think of f1df8b5 ?

I like it. Nice to remove some code duplication.

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@st-- st-- merged commit cd8f069 into master Apr 10, 2022
@st-- st-- deleted the st/ad_tests branch April 10, 2022 09:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AD tests
4 participants