Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an import from ChainRules macro #168

Merged
merged 3 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Tracker"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.33"
version = "0.2.34"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -27,6 +28,7 @@ TrackerPDMatsExt = "PDMats"

[compat]
Adapt = "3, 4"
ChainRulesCore = "1.23"
DiffRules = "1.4"
ForwardDiff = "0.10"
Functors = "0.3, 0.4"
Expand All @@ -45,8 +47,10 @@ Statistics = "1"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["PDMats", "Test"]
test = ["ChainRulesCore", "LinearAlgebra", "PDMats", "Test"]
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions src/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using MacroTools: @q, @forward

using DiffRules
using ForwardDiff
import ChainRulesCore as CRC
import LogExpFunctions
import NaNMath
import SpecialFunctions
Expand Down Expand Up @@ -71,6 +72,7 @@ end

include("idset.jl")
include("params.jl")
include("macros.jl")
include("lib/real.jl")
include("lib/array.jl")
include("back.jl")
Expand Down
60 changes: 9 additions & 51 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,59 +560,17 @@ dims)
return Y, dropout_back
end

depthwiseconv(x::TrackedArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)

@grad depthwiseconv(x, w, cdims::DepthwiseConvDims; kw...) =
depthwiseconv(data(x), data(w), cdims; kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, w))..., cdims; kw...),
NNlib.∇depthwiseconv_filter(data.((x, Δ))..., cdims; kw...),
nothing))

conv(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)
conv(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)
conv(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)

@grad conv(x, w, cdims::DenseConvDims; kw...) =
conv(data(x), data(w), cdims; kw...),
Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, w))..., cdims; kw...),
NNlib.∇conv_filter(data.((x, Δ))..., cdims; kw...),
nothing))

∇conv_data(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)

@grad function ∇conv_data(y, w, cdims::DenseConvDims; kw...)
return (
∇conv_data(data(y), data(w), cdims; kw...),
Δ -> begin
return nobacksies(:conv,
(NNlib.conv(data.((Δ, w))..., cdims; kw...),
NNlib.∇conv_filter(data.((Δ, y))..., cdims; kw...),
nothing)
)
end
)
end

maxpool(x::TrackedArray, pdims::PoolDims; kw...) = track(maxpool, x, pdims; kw...)

@grad function maxpool(x, pdims::PoolDims; kw...)
y = maxpool(data(x), pdims; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., pdims; kw...)), nothing)
for (xType, wType) in [(:TrackedArray, :TrackedArray), (:AbstractArray, :TrackedArray),
(:TrackedArray, :AbstractArray)]
@eval begin
@grad_from_chainrules depthwiseconv(::$xType, ::$wType, ::DepthwiseConvDims; kw...)
@grad_from_chainrules conv(::$xType, ::$wType, ::DenseConvDims; kw...)
@grad_from_chainrules ∇conv_data(::$xType, ::$wType, ::DenseConvDims; kw...)
end
end

meanpool(x::TrackedArray, pdims::PoolDims; kw...) = track(meanpool, x, pdims; kw...)


@grad function meanpool(x, pdims::PoolDims; kw...)
y = meanpool(data(x), pdims; kw...)
y, Δ -> (nobacksies(:meanpool, NNlib.∇meanpool(data.((Δ, y, x))..., pdims; kw...)), nothing)
end
@grad_from_chainrules maxpool(::TrackedArray, ::PoolDims; kw...)
@grad_from_chainrules meanpool(::TrackedArray, ::PoolDims; kw...)

# Broadcasting

Expand Down
91 changes: 91 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
@grad_from_chainrules f(args...; kwargs...)

The `@grad_from_chainrules` macro provides a way to import adjoints(rrule) defined in
ChainRules to Tracker. One must provide a method signature to import the corresponding
rrule. In the provided method signature, one should replace the types of arguments to which
one wants to take derivatives with respect with Tracker.TrackedReal and Tracker.TrackedArray
respectively. For example, we can import rrule of `f(x::Real, y::Array)`` like below:

Tracker.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
Tracker.@grad_from_chainrules f(x::TrackedReal, y::Array)
Tracker.@grad_from_chainrules f(x::Real, y::TrackedArray)

Acceptable type annotations are `TrackedReal`, `TrackedArray`, `TrackedVector`, and
`TrackedMatrix`. These can have parameters like `TrackedArray{Float32}`.
"""
macro grad_from_chainrules(fcall)
@assert isdefined(__module__, :Tracker) "Tracker not found in module $__module__. Please load `Tracker.jl`."
Meta.isexpr(fcall, :call) && length(fcall.args) ≥ 2 ||
error("`@grad_from_chainrules` has to be applied to a function signature")

f = fcall.args[1]
# Check if kwargs... splatting is present
kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] :
nothing
rem_args = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[3:end] :
fcall.args[2:end]
xs = map(rem_args) do x
Meta.isexpr(x, :(::)) || return x
length(x.args) == 1 && return :($(gensym())::$(x.args[1])) # ::T without var name
@assert length(x.args) == 2
return :($(x.args[1])::$(x.args[2])) # x::T
end
xs_untyped = map(xs) do x
Meta.isexpr(x, :(::)) || return x
return x.args[1]
end

untrack_args = map(enumerate(xs)) do (i, x)
Meta.isexpr(x, :(::)) || return (x, nothing)
name, type = x.args
type = __strip_type(type)
type in (:TrackedArray, :TrackedVector, :TrackedMatrix, :TrackedReal) || return (name, nothing)
xdata = gensym(name)
return xdata, :($(xdata) = $(Tracker.data)($(name)))
end
untrack_calls = filter(Base.Fix2(!==, nothing), last.(untrack_args))
@assert length(untrack_calls) > 0 "No tracked arguments found."
var_names = first.(untrack_args)

f_sym = Meta.quot(Symbol(f))

if kws_var === nothing
return esc(quote
$(f)($(xs...)) = $(Tracker.track)($(f), $(xs_untyped...))
function Tracker._forward(::typeof($(f)), $(xs...))
$(untrack_calls...)
y, pb_f = $(CRC.rrule)($(f), $(var_names...))
∇internal_generated = let pb_f = pb_f # Avoid Boxing
Δ -> return Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end]))
end
return y, ∇internal_generated
end
end)
end
return esc(quote
function $(f)($(xs...); $(kws_var)...)
return Tracker.track($(f), $(xs_untyped...); $(kws_var)...)
end
function Tracker._forward(::typeof($(f)), $(xs...); $(kws_var)...)
$(untrack_calls...)
y, pb_f = $(CRC.rrule)($(f), $(var_names...); $(kws_var)...)
∇internal_generated = let pb_f = pb_f # Avoid Boxing
Δ -> Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end]))
end
return y, ∇internal_generated
end
end)
end

@inline __no_crctangent(::CRC.NoTangent) = nothing
@inline __no_crctangent(::CRC.ZeroTangent) = nothing
@inline __no_crctangent(x::CRC.AbstractThunk) = CRC.unthunk(x)
@inline __no_crctangent(x) = x

@inline function __strip_type(type)
Meta.isexpr(type, :curly) && (type = type.args[1]) # Strip parameters from types
Meta.isexpr(type, :(.)) && (type = type.args[2]) # Strip Tracker from Tracker.<...>
type isa QuoteNode && (type = type.value) # Unwrap a QuoteNode
return type
end
183 changes: 183 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Adapted from https://github.com/JuliaDiff/Tracker.jl/blob/master/test/ChainRulesTests.jl
module ChainRulesTest # Run in isolatex environment

using LinearAlgebra
using ChainRulesCore
using Tracker
using Test

struct MyStruct end
f(::MyStruct, x) = sum(4x .+ 1)
f(x, y::MyStruct) = sum(4x .+ 1)
f(x) = sum(4x .+ 1)

function ChainRulesCore.rrule(::typeof(f), x)
r = f(x)
function back(d)
#=
The proper derivative of `f` is 4, but in order to
check if `ChainRulesCore.rrule` had taken over the compuation,
we define a rrule that returns 3 as `f`'s derivative.

After importing this rrule into Tracker, if we get 3
rather than 4 when we compute the derivative of `f`, it means
the importing mechanism works.
=#
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
return NoTangent(), fill(3 * d, size(x))
end
return r, back
end
function ChainRulesCore.rrule(::typeof(f), ::MyStruct, x)
r = f(MyStruct(), x)
function back(d)
return NoTangent(), NoTangent(), fill(3 * d, size(x))
end
return r, back
end
function ChainRulesCore.rrule(::typeof(f), x, ::MyStruct)
r = f(x, MyStruct())
function back(d)
return NoTangent(), fill(3 * d, size(x)), NoTangent()
end
return r, back
end

Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray)
# test arg type hygiene
Tracker.@grad_from_chainrules f(::MyStruct, x::Tracker.TrackedArray)
Tracker.@grad_from_chainrules f(x::Tracker.TrackedArray, y::MyStruct)

g(x, y) = sum(4x .+ 4y)

function ChainRulesCore.rrule(::typeof(g), x, y)
r = g(x, y)
function back(d)
# same as above, use 3 and 5 as the derivatives
return NoTangent(), fill(3 * d, size(x)), fill(5 * d, size(x))
end
return r, back
end

Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y)
Tracker.@grad_from_chainrules g(x, y::Tracker.TrackedArray)
Tracker.@grad_from_chainrules g(x::Tracker.TrackedArray, y::Tracker.TrackedArray)

@testset "rrule in ChainRules and Tracker" begin
## ChainRules
# function f
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, input)
_, d = back(1)
@test output == f(input)
@test d == fill(3, size(input))
# function g
inputs = rand(3, 3), rand(3, 3)
output, back = ChainRulesCore.rrule(g, inputs...)
_, d1, d2 = back(1)
@test output == g(inputs...)
@test d1 == fill(3, size(inputs[1]))
@test d2 == fill(5, size(inputs[2]))
end

@testset "custom struct input" begin
input = rand(3, 3)
output, back = ChainRulesCore.rrule(f, MyStruct(), input)
_, _, d = back(1)
@test output == f(MyStruct(), input)
@test d == fill(3, size(input))

output, back = ChainRulesCore.rrule(f, input, MyStruct())
_, d, _ = back(1)
@test output == f(input, MyStruct())
@test d == fill(3, size(input))
end

### Functions with varargs and kwargs
# Varargs
f_vararg(x, args...) = sum(4x .+ sum(args))

function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
r = f_vararg(x, args...)
function back(d)
return (NoTangent(), fill(3 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
end
return r, back
end

Tracker.@grad_from_chainrules f_vararg(x::Tracker.TrackedArray, args...)

@testset "Function with Varargs" begin
grads = Tracker.gradient(x -> f_vararg(x, 1, 2, 3) + 2, rand(3, 3))

@test grads[1] == fill(3, (3, 3))
end

# Vargs and kwargs
f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))

function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
r = f_kw(x, args...; k=k, kwargs...)
function back(d)
return (NoTangent(), fill(3 * d, size(x)), ntuple(_ -> NoTangent(), length(args))...)
end
return r, back
end

Tracker.@grad_from_chainrules f_kw(x::Tracker.TrackedArray, args...; k=1, kwargs...)

@testset "Function with Varargs and kwargs" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f_kw(x, 1, 2, 3; k=2, j=3) + 2, inputs)

@test results[1] == fill(3, size(inputs))
end

### Mix @grad and @grad_from_chainrules

h(x) = 10x
h(x::Tracker.TrackedArray) = Tracker.track(h, x)
Tracker.@grad function h(x)
xv = Tracker.data(x)
return h(xv), Δ -> (Δ * 7,) # use 7 asits derivatives
end

@testset "Tracker and ChainRules Mixed" begin
t(x) = g(x, h(x))
inputs = rand(3, 3)
results = Tracker.gradient(t, inputs)
@test results[1] == fill(38, size(inputs)) # 38 = 3 + 5 * 7
end

### Isolated Scope
module IsolatedModuleForTestingScoping
using ChainRulesCore
using Tracker: Tracker, @grad_from_chainrules

f(x) = sum(4x .+ 1)

function ChainRulesCore.rrule(::typeof(f), x)
r = f(x)
function back(d)
# return a distinguishable but improper grad
return NoTangent(), fill(3 * d, size(x))
end
return r, back
end

@grad_from_chainrules f(x::Tracker.TrackedArray)

module SubModule
using Test
using Tracker: Tracker
using ..IsolatedModuleForTestingScoping: f
@testset "rrule in Isolated Scope" begin
inputs = rand(3, 3)
results = Tracker.gradient(x -> f(x) + 2, inputs)

@test results[1] == fill(3, size(inputs))
end

end # end of SubModule
end # end of IsolatedModuleForTestingScoping

end
Loading
Loading