Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an import from ChainRules macro #168

Merged
merged 3 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Tracker"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.33"
version = "0.2.34"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -27,6 +28,7 @@ TrackerPDMatsExt = "PDMats"

[compat]
Adapt = "3, 4"
ChainRulesCore = "1.23"
DiffRules = "1.4"
ForwardDiff = "0.10"
Functors = "0.3, 0.4"
Expand Down
2 changes: 2 additions & 0 deletions src/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using MacroTools: @q, @forward

using DiffRules
using ForwardDiff
import ChainRulesCore as CRC
import LogExpFunctions
import NaNMath
import SpecialFunctions
Expand Down Expand Up @@ -71,6 +72,7 @@ end

include("idset.jl")
include("params.jl")
include("macros.jl")
include("lib/real.jl")
include("lib/array.jl")
include("back.jl")
Expand Down
60 changes: 9 additions & 51 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,59 +560,17 @@ dims)
return Y, dropout_back
end

depthwiseconv(x::TrackedArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)

@grad depthwiseconv(x, w, cdims::DepthwiseConvDims; kw...) =
depthwiseconv(data(x), data(w), cdims; kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, w))..., cdims; kw...),
NNlib.∇depthwiseconv_filter(data.((x, Δ))..., cdims; kw...),
nothing))

conv(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)
conv(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)
conv(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)

@grad conv(x, w, cdims::DenseConvDims; kw...) =
conv(data(x), data(w), cdims; kw...),
Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, w))..., cdims; kw...),
NNlib.∇conv_filter(data.((x, Δ))..., cdims; kw...),
nothing))

∇conv_data(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)

@grad function ∇conv_data(y, w, cdims::DenseConvDims; kw...)
return (
∇conv_data(data(y), data(w), cdims; kw...),
Δ -> begin
return nobacksies(:conv,
(NNlib.conv(data.((Δ, w))..., cdims; kw...),
NNlib.∇conv_filter(data.((Δ, y))..., cdims; kw...),
nothing)
)
end
)
end

maxpool(x::TrackedArray, pdims::PoolDims; kw...) = track(maxpool, x, pdims; kw...)

@grad function maxpool(x, pdims::PoolDims; kw...)
y = maxpool(data(x), pdims; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., pdims; kw...)), nothing)
for (xType, wType) in [(:TrackedArray, :TrackedArray), (:AbstractArray, :TrackedArray),
(:TrackedArray, :AbstractArray)]
@eval begin
@grad_from_chainrules depthwiseconv(::$xType, ::$wType, ::DepthwiseConvDims; kw...)
@grad_from_chainrules conv(::$xType, ::$wType, ::DenseConvDims; kw...)
@grad_from_chainrules ∇conv_data(::$xType, ::$wType, ::DenseConvDims; kw...)
end
end

meanpool(x::TrackedArray, pdims::PoolDims; kw...) = track(meanpool, x, pdims; kw...)


@grad function meanpool(x, pdims::PoolDims; kw...)
y = meanpool(data(x), pdims; kw...)
y, Δ -> (nobacksies(:meanpool, NNlib.∇meanpool(data.((Δ, y, x))..., pdims; kw...)), nothing)
end
@grad_from_chainrules maxpool(::TrackedArray, ::PoolDims; kw...)
@grad_from_chainrules meanpool(::TrackedArray, ::PoolDims; kw...)

# Broadcasting

Expand Down
91 changes: 91 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
@grad_from_chainrules f(args...; kwargs...)

The `@grad_from_chainrules` macro provides a way to import adjoints(rrule) defined in
ChainRules to Tracker. One must provide a method signature to import the corresponding
rrule. In the provided method signature, one should replace the types of arguments to which
one wants to take derivatives with respect with Tracker.TrackedReal and Tracker.TrackedArray
respectively. For example, we can import rrule of `f(x::Real, y::Array)`` like below:

Tracker.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
Tracker.@grad_from_chainrules f(x::TrackedReal, y::Array)
Tracker.@grad_from_chainrules f(x::Real, y::TrackedArray)

Acceptable type annotations are `TrackedReal`, `TrackedArray`, `TrackedVector`, and
`TrackedMatrix`. These can have parameters like `TrackedArray{Float32}`.
"""
macro grad_from_chainrules(fcall)
@assert isdefined(__module__, :Tracker) "Tracker not found in module $__module__. Please load `Tracker.jl`."
Meta.isexpr(fcall, :call) && length(fcall.args) ≥ 2 ||
error("`@grad_from_chainrules` has to be applied to a function signature")

f = fcall.args[1]
# Check if kwargs... splatting is present
kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] :
nothing
rem_args = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[3:end] :
fcall.args[2:end]
xs = map(rem_args) do x
Meta.isexpr(x, :(::)) || return x
length(x.args) == 1 && return :($(gensym())::$(x.args[1])) # ::T without var name
@assert length(x.args) == 2
return :($(x.args[1])::$(x.args[2])) # x::T
end
xs_untyped = map(xs) do x
Meta.isexpr(x, :(::)) || return x
return x.args[1]
end

untrack_args = map(enumerate(xs)) do (i, x)
Meta.isexpr(x, :(::)) || return (x, nothing)
name, type = x.args
type = __strip_type(type)
type in (:TrackedArray, :TrackedVector, :TrackedMatrix, :TrackedReal) || return (name, nothing)
xdata = gensym(name)
return xdata, :($(xdata) = $(Tracker.data)($(name)))
end
untrack_calls = filter(Base.Fix2(!==, nothing), last.(untrack_args))
@assert length(untrack_calls) > 0 "No tracked arguments found."
var_names = first.(untrack_args)

f_sym = Meta.quot(Symbol(f))

if kws_var === nothing
return esc(quote
$(f)($(xs...)) = $(Tracker.track)($(f), $(xs_untyped...))
function Tracker._forward(::typeof($(f)), $(xs...))
$(untrack_calls...)
y, pb_f = $(CRC.rrule)($(f), $(var_names...))
∇internal_generated = let pb_f = pb_f # Avoid Boxing
Δ -> return Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end]))
end
return y, ∇internal_generated
end
end)
end
return esc(quote
function $(f)($(xs...); $(kws_var)...)
return Tracker.track($(f), $(xs_untyped...); $(kws_var)...)
end
function Tracker._forward(::typeof($(f)), $(xs...); $(kws_var)...)
$(untrack_calls...)
y, pb_f = $(CRC.rrule)($(f), $(var_names...); $(kws_var)...)
∇internal_generated = let pb_f = pb_f # Avoid Boxing
Δ -> Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end]))
end
return y, ∇internal_generated
end
end)
end

@inline __no_crctangent(::CRC.NoTangent) = nothing
@inline __no_crctangent(::CRC.ZeroTangent) = nothing
@inline __no_crctangent(x::CRC.AbstractThunk) = CRC.unthunk(x)
@inline __no_crctangent(x) = x

@inline function __strip_type(type)
Meta.isexpr(type, :curly) && (type = type.args[1]) # Strip parameters from types
Meta.isexpr(type, :(.)) && (type = type.args[2]) # Strip Tracker from Tracker.<...>
type isa QuoteNode && (type = type.value) # Unwrap a QuoteNode
return type
end
193 changes: 193 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Adapted from https://github.com/JuliaDiff/Tracker.jl/blob/master/test/ChainRulesTests.jl
module ChainRulesTest # Run in isolatex environment

using LinearAlgebra
using ChainRulesCore
using Tracker
using Test

struct MyStruct end
f(::MyStruct, x) = sum(4x .+ 1)
f(x, y::MyStruct) = sum(4x .+ 1)
f(x) = sum(4x .+ 1)

rrule_f_singleargs = Ref(0)
rrule_f_mystruct_x = Ref(0)
rrule_f_x_mystruct = Ref(0)

function ChainRulesCore.rrule(::typeof(f), x)
rrule_f_singleargs[] += 1
r = f(x)
back(d) = NoTangent(), fill(4 * d, size(x))
return r, back
end
function ChainRulesCore.rrule(::typeof(f), ::MyStruct, x)
rrule_f_mystruct_x[] += 1
r = f(MyStruct(), x)
back(d) = NoTangent(), NoTangent(), fill(4 * d, size(x))
return r, back
end
function ChainRulesCore.rrule(::typeof(f), x, ::MyStruct)
rrule_f_x_mystruct[] += 1
r = f(x, MyStruct())
back(d) = NoTangent(), fill(4 * d, size(x)), NoTangent()
return r, back
end

Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray)
# test arg type hygiene
Tracker.@grad_from_chainrules f(::MyStruct, x::Tracker.TrackedArray)
Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray, y::MyStruct)

g(x, y) = sum(4x .+ 4y)

rrule_g_x_y = Ref(0)

function ChainRulesCore.rrule(::typeof(g), x, y)
rrule_g_x_y[] += 1
r = g(x, y)
back(d) = NoTangent(), fill(4 * d, size(x)), fill(4 * d, size(x))
return r, back
end

Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y)
Tracker.@grad_from_chainrules g(x, y::Tracker.TrackedArray)
Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y::Tracker.TrackedArray)

@testset "rrule in ChainRules and Tracker" begin
## ChainRules
# function f
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, input)
_, d = back(1)
@test output == f(input)
@test d == fill(4, size(input))
@test rrule_f_singleargs[] == 1
# function g
inputs = rand(3, 3), rand(3, 3)
output, back = ChainRulesCore.rrule(g, inputs...)
_, d1, d2 = back(1)
@test output == g(inputs...)
@test d1 == fill(4, size(inputs[1]))
@test d2 == fill(4, size(inputs[2]))
@test rrule_g_x_y[] == 1
end

@testset "custom struct input" begin
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, MyStruct(), input)
_, _, d = back(1)
@test output == f(MyStruct(), input)
@test d == fill(4, size(input))
@test rrule_f_mystruct_x[] == 1

output, back = ChainRulesCore.rrule(f, input, MyStruct())
_, d, _ = back(1)
@test output == f(input, MyStruct())
@test d == fill(4, size(input))
@test rrule_f_x_mystruct[] == 1
end

### Functions with varargs and kwargs
# Varargs
f_vararg(x, args...) = sum(4x .+ sum(args))

rrule_f_vararg = Ref(0)

function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
rrule_f_vararg[] += 1
r = f_vararg(x, args...)
back(d) = (NoTangent(), fill(4 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
return r, back
end

Tracker.@grad_from_chainrules f_vararg(x::Tracker.TrackedArray, args...)

@testset "Function with Varargs" begin
grads = Tracker.gradient(x -> f_vararg(x, 1, 2, 3) + 2, rand(3, 3))

@test grads[1] == fill(4, (3, 3))
@test rrule_f_vararg[] == 1
end

# Vargs and kwargs
f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))

rrule_f_kw = Ref(0)

function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
rrule_f_kw[] += 1
r = f_kw(x, args...; k=k, kwargs...)
back(d) = (NoTangent(), fill(4 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
return r, back
end

Tracker.@grad_from_chainrules f_kw(x::Tracker.TrackedArray, args...; k=1, kwargs...)

@testset "Function with Varargs and kwargs" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f_kw(x, 1, 2, 3; k=2, j=3) + 2, inputs)

@test results[1] == fill(4, size(inputs))
@test rrule_f_kw[] == 1
end

### Mix @grad and @grad_from_chainrules

h(x) = 10x
h(x::Tracker.TrackedArray) = Tracker.track(h, x)

grad_hcalls = Ref(0)

Tracker.@grad function h(x)
grad_hcalls[] += 1
xv = Tracker.data(x)
return h(xv), Δ -> (Δ * 10,) # use 7 asits derivatives
end

@testset "Tracker and ChainRules Mixed" begin
t(x) = g(x, h(x))
inputs = rand(3, 3)
results = Tracker.gradient(t, inputs)
@test results[1] == fill(44, size(inputs)) # 44 = 4 + 4 * 10
@test rrule_g_x_y[] == 2
@test grad_hcalls[] == 1
end

### Isolated Scope
module IsolatedModuleForTestingScoping

using ChainRulesCore, Test
using Tracker: Tracker, @grad_from_chainrules

f(x) = sum(4x .+ 1)

rrule_f_singleargs = Ref(0)

function ChainRulesCore.rrule(::typeof(f), x)
rrule_f_singleargs[] += 1
r = f(x)
back(d) = NoTangent(), fill(4 * d, size(x))
return r, back
end

@grad_from_chainrules f(x::Tracker.TrackedArray)

module SubModule
using Test
using Tracker: Tracker
using ..IsolatedModuleForTestingScoping: f, rrule_f_singleargs

@testset "rrule in Isolated Scope" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f(x) + 2, inputs)

@test results[1] == fill(4, size(inputs))
@test rrule_f_singleargs[] == 1
end

end # end of SubModule

end # end of IsolatedModuleForTestingScoping

end
Loading
Loading