Skip to content

Commit

Permalink
Add an import from ChainRules macro
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 15, 2024
1 parent 596f1c5 commit ec33546
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 52 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Tracker"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.33"
version = "0.2.34"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -27,6 +28,7 @@ TrackerPDMatsExt = "PDMats"

[compat]
Adapt = "3, 4"
ChainRulesCore = "1.23"
DiffRules = "1.4"
ForwardDiff = "0.10"
Functors = "0.3, 0.4"
Expand Down
2 changes: 2 additions & 0 deletions src/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using MacroTools: @q, @forward

using DiffRules
using ForwardDiff
import ChainRulesCore as CRC
import LogExpFunctions
import NaNMath
import SpecialFunctions
Expand Down Expand Up @@ -71,6 +72,7 @@ end

include("idset.jl")
include("params.jl")
include("macros.jl")
include("lib/real.jl")
include("lib/array.jl")
include("back.jl")
Expand Down
60 changes: 9 additions & 51 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,59 +560,17 @@ dims)
return Y, dropout_back
end

depthwiseconv(x::TrackedArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::AbstractArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
depthwiseconv(x::TrackedArray, w::AbstractArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)

@grad depthwiseconv(x, w, cdims::DepthwiseConvDims; kw...) =
depthwiseconv(data(x), data(w), cdims; kw...),
Δ -> nobacksies(:depthwiseconv,
(NNlib.∇depthwiseconv_data(data.((Δ, w))..., cdims; kw...),
NNlib.∇depthwiseconv_filter(data.((x, Δ))..., cdims; kw...),
nothing))

conv(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)
conv(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)
conv(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...)

@grad conv(x, w, cdims::DenseConvDims; kw...) =
conv(data(x), data(w), cdims; kw...),
Δ -> nobacksies(:conv,
(NNlib.∇conv_data(data.((Δ, w))..., cdims; kw...),
NNlib.∇conv_filter(data.((x, Δ))..., cdims; kw...),
nothing))

∇conv_data(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...)

@grad function ∇conv_data(y, w, cdims::DenseConvDims; kw...)
return (
∇conv_data(data(y), data(w), cdims; kw...),
Δ -> begin
return nobacksies(:conv,
(NNlib.conv(data.((Δ, w))..., cdims; kw...),
NNlib.∇conv_filter(data.((Δ, y))..., cdims; kw...),
nothing)
)
end
)
end

maxpool(x::TrackedArray, pdims::PoolDims; kw...) = track(maxpool, x, pdims; kw...)

@grad function maxpool(x, pdims::PoolDims; kw...)
y = maxpool(data(x), pdims; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., pdims; kw...)), nothing)
for (xType, wType) in [(:TrackedArray, :TrackedArray), (:AbstractArray, :TrackedArray),
(:TrackedArray, :AbstractArray)]
@eval begin
@grad_from_chainrules depthwiseconv(::$xType, ::$wType, ::DepthwiseConvDims; kw...)
@grad_from_chainrules conv(::$xType, ::$wType, ::DenseConvDims; kw...)
@grad_from_chainrules ∇conv_data(::$xType, ::$wType, ::DenseConvDims; kw...)
end
end

meanpool(x::TrackedArray, pdims::PoolDims; kw...) = track(meanpool, x, pdims; kw...)


@grad function meanpool(x, pdims::PoolDims; kw...)
y = meanpool(data(x), pdims; kw...)
y, Δ -> (nobacksies(:meanpool, NNlib.∇meanpool(data.((Δ, y, x))..., pdims; kw...)), nothing)
end
@grad_from_chainrules maxpool(::TrackedArray, ::PoolDims; kw...)
@grad_from_chainrules meanpool(::TrackedArray, ::PoolDims; kw...)

# Broadcasting

Expand Down
84 changes: 84 additions & 0 deletions src/macros.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
@grad_from_chainrules f(args...; kwargs...)
The `@grad_from_chainrules` macro provides a way to import adjoints(rrule) defined in
ChainRules to Tracker. One must provide a method signature to import the corresponding
rrule. In the provided method signature, one should replace the types of arguments to which
one wants to take derivatives with respect with Tracker.TrackedReal and Tracker.TrackedArray
respectively. For example, we can import rrule of `f(x::Real, y::Array)`` like below:
Tracker.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
Tracker.@grad_from_chainrules f(x::TrackedReal, y::Array)
Tracker.@grad_from_chainrules f(x::Real, y::TrackedArray)
Acceptable type annotations are `TrackedReal`, `TrackedArray`, `TrackedVector`, and
`TrackedMatrix`. These can have parameters like `TrackedArray{Float32}`.
"""
macro grad_from_chainrules(fcall)
@assert isdefined(__module__, :Tracker) "Tracker not found in module $__module__. Please load `Tracker.jl`."
Meta.isexpr(fcall, :call) && length(fcall.args) 2 ||
error("`@grad_from_chainrules` has to be applied to a function signature")

f = fcall.args[1]
# Check if kwargs... splatting is present
kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] :
nothing
rem_args = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[3:end] :
fcall.args[2:end]
xs = map(rem_args) do x
Meta.isexpr(x, :(::)) || return x
length(x.args) == 1 && return :($(gensym())::$(x.args[1])) # ::T without var name
@assert length(x.args) == 2
return :($(x.args[1])::$(x.args[2])) # x::T
end
xs_untyped = map(xs) do x
Meta.isexpr(x, :(::)) || return x
return x.args[1]
end

untrack_args = map(enumerate(xs)) do (i, x)
Meta.isexpr(x, :(::)) || return (x, nothing)
name, type = x.args
Meta.isexpr(type, :curly) && (type = type.args[1]) # Strip parameters from types
type in (:TrackedArray, :TrackedVector, :TrackedMatrix, :TrackedReal) || return (name, nothing)
xdata = gensym(name)
return xdata, :($(xdata) = $(Tracker.data)($(name)))
end
untrack_calls = filter(Base.Fix2(!==, nothing), last.(untrack_args))
@assert length(untrack_calls) > 0 "No tracked arguments found."
var_names = first.(untrack_args)

f_sym = Meta.quot(Symbol(f))

if kws_var === nothing
return esc(quote
$(f)($(xs...)) = $(Tracker.track)($(f), $(xs_untyped...))
function Tracker._forward(::typeof($(f)), $(xs...))
$(untrack_calls...)
y, pb_f = $(CRC.rrule)($(f), $(var_names...))
∇internal_generated = let pb_f = pb_f # Avoid Boxing
Δ -> return Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end]))
end
return y, ∇internal_generated
end
end)
end
return esc(quote
function $(f)($(xs...); $(kws_var)...)
return Tracker.track($(f), $(xs_untyped...); $(kws_var)...)
end
function Tracker._forward(::typeof($(f)), $(xs...); $(kws_var)...)
$(untrack_calls...)
y, pb_f = $(CRC.rrule)($(f), $(var_names...); $(kws_var)...)
∇internal_generated = let pb_f = pb_f # Avoid Boxing
Δ -> Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end]))
end
return y, ∇internal_generated
end
end)
end

@inline __no_crctangent(::CRC.NoTangent) = nothing
@inline __no_crctangent(::CRC.ZeroTangent) = nothing
@inline __no_crctangent(x::CRC.AbstractThunk) = CRC.unthunk(x)
@inline __no_crctangent(x) = x

0 comments on commit ec33546

Please sign in to comment.