Skip to content

Commit

Permalink
Update tests to use a global counter instead of incorrect gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 18, 2024
1 parent be6f8d2 commit dd29f0f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 47 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ Statistics = "1"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesCore", "LinearAlgebra", "PDMats", "Test"]
test = ["PDMats", "Test"]
98 changes: 54 additions & 44 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,26 @@ f(::MyStruct, x) = sum(4x .+ 1)
f(x, y::MyStruct) = sum(4x .+ 1)
f(x) = sum(4x .+ 1)

rrule_f_singleargs = Ref(0)
rrule_f_mystruct_x = Ref(0)
rrule_f_x_mystruct = Ref(0)

function ChainRulesCore.rrule(::typeof(f), x)
rrule_f_singleargs[] += 1
r = f(x)
function back(d)
#=
The proper derivative of `f` is 4, but in order to
check if `ChainRulesCore.rrule` had taken over the compuation,
we define a rrule that returns 3 as `f`'s derivative.
After importing this rrule into Tracker, if we get 3
rather than 4 when we compute the derivative of `f`, it means
the importing mechanism works.
=#
return NoTangent(), fill(3 * d, size(x))
end
back(d) = NoTangent(), fill(4 * d, size(x))
return r, back
end
function ChainRulesCore.rrule(::typeof(f), ::MyStruct, x)
rrule_f_mystruct_x[] += 1
r = f(MyStruct(), x)
function back(d)
return NoTangent(), NoTangent(), fill(3 * d, size(x))
end
back(d) = NoTangent(), NoTangent(), fill(4 * d, size(x))
return r, back
end
function ChainRulesCore.rrule(::typeof(f), x, ::MyStruct)
rrule_f_x_mystruct[] += 1
r = f(x, MyStruct())
function back(d)
return NoTangent(), fill(3 * d, size(x)), NoTangent()
end
back(d) = NoTangent(), fill(4 * d, size(x)), NoTangent()
return r, back
end

Expand All @@ -49,12 +41,12 @@ Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray, y::MyStruct)

g(x, y) = sum(4x .+ 4y)

rrule_g_x_y = Ref(0)

function ChainRulesCore.rrule(::typeof(g), x, y)
rrule_g_x_y[] += 1
r = g(x, y)
function back(d)
# same as above, use 3 and 5 as the derivatives
return NoTangent(), fill(3 * d, size(x)), fill(5 * d, size(x))
end
back(d) = NoTangent(), fill(4 * d, size(x)), fill(4 * d, size(x))
return r, back
end

Expand All @@ -69,38 +61,43 @@ Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y::Tracker.TrackedArray
output, back = ChainRulesCore.rrule(f, input)
_, d = back(1)
@test output == f(input)
@test d == fill(3, size(input))
@test d == fill(4, size(input))
@test rrule_f_singleargs[] == 1
# function g
inputs = rand(3, 3), rand(3, 3)
output, back = ChainRulesCore.rrule(g, inputs...)
_, d1, d2 = back(1)
@test output == g(inputs...)
@test d1 == fill(3, size(inputs[1]))
@test d2 == fill(5, size(inputs[2]))
@test d1 == fill(4, size(inputs[1]))
@test d2 == fill(4, size(inputs[2]))
@test rrule_g_x_y[] == 1
end

@testset "custom struct input" begin
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, MyStruct(), input)
_, _, d = back(1)
@test output == f(MyStruct(), input)
@test d == fill(3, size(input))
@test d == fill(4, size(input))
@test rrule_f_mystruct_x[] == 1

output, back = ChainRulesCore.rrule(f, input, MyStruct())
_, d, _ = back(1)
@test output == f(input, MyStruct())
@test d == fill(3, size(input))
@test d == fill(4, size(input))
@test rrule_f_x_mystruct[] == 1
end

### Functions with varargs and kwargs
# Varargs
f_vararg(x, args...) = sum(4x .+ sum(args))

rrule_f_vararg = Ref(0)

function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
rrule_f_vararg[] += 1
r = f_vararg(x, args...)
function back(d)
return (NoTangent(), fill(3 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
end
back(d) = (NoTangent(), fill(4 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
return r, back
end

Expand All @@ -109,17 +106,19 @@ Tracker.@grad_from_chainrules f_vararg(x::Tracker.TrackedArray, args...)
@testset "Function with Varargs" begin
grads = Tracker.gradient(x -> f_vararg(x, 1, 2, 3) + 2, rand(3, 3))

@test grads[1] == fill(3, (3, 3))
@test grads[1] == fill(4, (3, 3))
@test rrule_f_vararg[] == 1
end

# Vargs and kwargs
f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))

rrule_f_kw = Ref(0)

function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
rrule_f_kw[] += 1
r = f_kw(x, args...; k=k, kwargs...)
function back(d)
return (NoTangent(), fill(3 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
end
back(d) = (NoTangent(), fill(4 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
return r, back
end

Expand All @@ -129,38 +128,46 @@ Tracker.@grad_from_chainrules f_kw(x::Tracker.TrackedArray, args...; k=1, kwargs
inputs = rand(3, 3)
results = Tracker.gradient(x -> f_kw(x, 1, 2, 3; k=2, j=3) + 2, inputs)

@test results[1] == fill(3, size(inputs))
@test results[1] == fill(4, size(inputs))
@test rrule_f_kw[] == 1
end

### Mix @grad and @grad_from_chainrules

h(x) = 10x
h(x::Tracker.TrackedArray) = Tracker.track(h, x)

grad_hcalls = Ref(0)

Tracker.@grad function h(x)
grad_hcalls[] += 1
xv = Tracker.data(x)
return h(xv), Δ ->* 7,) # use 7 asits derivatives
return h(xv), Δ ->* 10,) # use 7 asits derivatives
end

@testset "Tracker and ChainRules Mixed" begin
t(x) = g(x, h(x))
inputs = rand(3, 3)
results = Tracker.gradient(t, inputs)
@test results[1] == fill(38, size(inputs)) # 38 = 3 + 5 * 7
@test results[1] == fill(44, size(inputs)) # 44 = 4 + 4 * 10
@test rrule_g_x_y[] == 2
@test grad_hcalls[] == 1
end

### Isolated Scope
module IsolatedModuleForTestingScoping
using ChainRulesCore

using ChainRulesCore, Test
using Tracker: Tracker, @grad_from_chainrules

f(x) = sum(4x .+ 1)

rrule_f_singleargs = Ref(0)

function ChainRulesCore.rrule(::typeof(f), x)
rrule_f_singleargs[] += 1
r = f(x)
function back(d)
# return a distinguishable but improper grad
return NoTangent(), fill(3 * d, size(x))
end
back(d) = NoTangent(), fill(4 * d, size(x))
return r, back
end

Expand All @@ -169,15 +176,18 @@ end
module SubModule
using Test
using Tracker: Tracker
using ..IsolatedModuleForTestingScoping: f
using ..IsolatedModuleForTestingScoping: f, rrule_f_singleargs

@testset "rrule in Isolated Scope" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f(x) + 2, inputs)

@test results[1] == fill(3, size(inputs))
@test results[1] == fill(4, size(inputs))
@test rrule_f_singleargs[] == 1
end

end # end of SubModule

end # end of IsolatedModuleForTestingScoping

end

2 comments on commit dd29f0f

@ChrisRackauckas
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/105239

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.34 -m "<description of version>" dd29f0f8924a91c6e34e4a1cfbac8912aed1396a
git push origin v0.2.34

Please sign in to comment.