Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: test Zygote-generated gradients using ChainRulesTestUtils test_rrule #987

Closed
wants to merge 13 commits into from

Conversation

mzgubic
Copy link
Collaborator

@mzgubic mzgubic commented Jun 3, 2021

A real rrule_f example to go with JuliaDiff/ChainRulesTestUtils.jl#166.

To do:

  • improve docstrings for rrule_via_ad
  • add an example where rrule_via_ad is used to test a gradient that is difficult to evaluate without finite differencing

Some benchmarking shows that test_rrule is about 40 times slower than gradtest. Do we understand why?

julia> @testset "time" begin @btime test_rrule((x, W, b) -> identity.(W*x .+ b), rand(5), rand(2,5), rand(2); rrule_f=rrule_via_ad) end
  2.186 ms (7482 allocations: 574.81 KiB)
Test Summary: |  Pass  Total
time          | 30536  30536
Test.DefaultTestSet("time", Any[Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false)  …  Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false), Test.DefaultTestSet("test_rrule: #144 on Vector{Float64},Matrix{Float64},Vector{Float64}", Any[], 8, false, false)], 0, false, false)

julia> @testset "time" begin @btime @test gradtest((x, W, b) -> identity.(W*x .+ b), 5, (2,5), 2) end
  46.975 μs (488 allocations: 20.05 KiB)
Test Summary: |  Pass  Total
time          | 96738  96738
Test.DefaultTestSet("time", Any[], 96738, false, false)

end

"""
zygote2differential(x)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is essentially a generalisation of wrap_chainrules_input that preserves the information about the primal type.

I am keeping it separate because it might eventually be moved to ZygoteRules when we support ChainRules types internally in Zygote.

As an aside: That PR is nearly complete (only higher order AD is broken) but 6 months old and would need a massive rebase. The hope is that after we have the possibility of rules for higher order functions by calling back into AD, the number of rules in Zygote would be much smaller and the PR would become easier.

@mzgubic
Copy link
Collaborator Author

mzgubic commented Jun 25, 2021

superseeded by #990

@mzgubic mzgubic closed this Jun 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant