Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test against Enzyme #318

Merged
merged 22 commits into from
Jul 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
const AD = get(ENV, "AD", "All")

function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
for b in broken
if !(b in (:ForwardDiff, :Zygote, :ReverseDiff, :Enzyme, :EnzymeForward, :EnzymeReverse))
mhauru marked this conversation as resolved.
Show resolved Hide resolved
error("Unknown broken AD backend: $b")
end
end

finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1]
et = eltype(finitediff)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Shouldn't Enzyme return the correct types automatically?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In forward mode it returns tuples, and if the gradient is empty, the result is a Tuple{}. This resulted in comparing an empty Float64[] to an empty Union{}[], which failed. See EnzymeAD/Enzyme.jl#1584

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the gradient should never be empty? Such a test would be quite useless, so I assume we don't run into this special case here? So maybe a simple collect (without specifying the element type) would be sufficient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a corner case, but we ran into it here, when d==1:

for d in [1, 2, 5]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest not testing AD in the case d = 1 - or just checking that the gradient is empty. We already handle this case in a special way in e.g.

test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false)
test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false)
.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passes now though, and I don't really see a downside to testing it? Good to know for instance that nothing crashes even if you hit this corner case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I also would prefer to not remove the test completely (even though I think it's of very limited use)

just checking that the gradient is empty

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what the current test is effectively doing, because when d = 1 finitediff returns an empty array. The eltype thing just makes sure that the check becomes Float64[] == Float64[], rather than Union{}[] == Float64[]. We could put in a specific case for d == 1 in the test file, but this seems like more work to me, because you need to make it cater to different AD backends and specify manually that the result should be empty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant something differently - removing the eltype/collect completely and only add a special case to this weird test of the CorrBijector since we already have special cases for d = 1 there anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's how I understood you, but that would require something like adding another argument to test_ad called expect_empty that would skip comparing to finitediff and instead check that the gradient has length 1 (for all AD backends) and setting that argument to d == 1, which to me seems more complicated, with a bunch of if statements, compared adding the one-liner enforcing eltype. I can do it if you prefer it, I just don't see the benefit.


Expand Down