diff --git a/Project.toml b/Project.toml index 4722385..dd704ac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "Tracker" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.33" +version = "0.2.34" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -27,6 +28,7 @@ TrackerPDMatsExt = "PDMats" [compat] Adapt = "3, 4" +ChainRulesCore = "1.23" DiffRules = "1.4" ForwardDiff = "0.10" Functors = "0.3, 0.4" diff --git a/src/Tracker.jl b/src/Tracker.jl index 0a927ac..8ec1d1d 100644 --- a/src/Tracker.jl +++ b/src/Tracker.jl @@ -5,6 +5,7 @@ using MacroTools: @q, @forward using DiffRules using ForwardDiff +import ChainRulesCore as CRC import LogExpFunctions import NaNMath import SpecialFunctions @@ -71,6 +72,7 @@ end include("idset.jl") include("params.jl") +include("macros.jl") include("lib/real.jl") include("lib/array.jl") include("back.jl") diff --git a/src/lib/array.jl b/src/lib/array.jl index a359c95..2022ff4 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -560,59 +560,17 @@ dims) return Y, dropout_back end -depthwiseconv(x::TrackedArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...) -depthwiseconv(x::AbstractArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...) -depthwiseconv(x::TrackedArray, w::AbstractArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...) - -@grad depthwiseconv(x, w, cdims::DepthwiseConvDims; kw...) = - depthwiseconv(data(x), data(w), cdims; kw...), - Δ -> nobacksies(:depthwiseconv, - (NNlib.∇depthwiseconv_data(data.((Δ, w))..., cdims; kw...), - NNlib.∇depthwiseconv_filter(data.((x, Δ))..., cdims; kw...), - nothing)) - -conv(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...) -conv(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...) -conv(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(conv, x, w, cdims; kw...) - -@grad conv(x, w, cdims::DenseConvDims; kw...) = - conv(data(x), data(w), cdims; kw...), - Δ -> nobacksies(:conv, - (NNlib.∇conv_data(data.((Δ, w))..., cdims; kw...), - NNlib.∇conv_filter(data.((x, Δ))..., cdims; kw...), - nothing)) - -∇conv_data(x::TrackedArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...) -∇conv_data(x::AbstractArray, w::TrackedArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...) -∇conv_data(x::TrackedArray, w::AbstractArray, cdims::DenseConvDims; kw...) = track(∇conv_data, x, w, cdims; kw...) - -@grad function ∇conv_data(y, w, cdims::DenseConvDims; kw...) - return ( - ∇conv_data(data(y), data(w), cdims; kw...), - Δ -> begin - return nobacksies(:conv, - (NNlib.conv(data.((Δ, w))..., cdims; kw...), - NNlib.∇conv_filter(data.((Δ, y))..., cdims; kw...), - nothing) - ) - end - ) -end - -maxpool(x::TrackedArray, pdims::PoolDims; kw...) = track(maxpool, x, pdims; kw...) - -@grad function maxpool(x, pdims::PoolDims; kw...) - y = maxpool(data(x), pdims; kw...) - y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., pdims; kw...)), nothing) +for (xType, wType) in [(:TrackedArray, :TrackedArray), (:AbstractArray, :TrackedArray), + (:TrackedArray, :AbstractArray)] + @eval begin + @grad_from_chainrules depthwiseconv(::$xType, ::$wType, ::DepthwiseConvDims; kw...) + @grad_from_chainrules conv(::$xType, ::$wType, ::DenseConvDims; kw...) + @grad_from_chainrules ∇conv_data(::$xType, ::$wType, ::DenseConvDims; kw...) + end end -meanpool(x::TrackedArray, pdims::PoolDims; kw...) = track(meanpool, x, pdims; kw...) - - -@grad function meanpool(x, pdims::PoolDims; kw...) - y = meanpool(data(x), pdims; kw...) - y, Δ -> (nobacksies(:meanpool, NNlib.∇meanpool(data.((Δ, y, x))..., pdims; kw...)), nothing) -end +@grad_from_chainrules maxpool(::TrackedArray, ::PoolDims; kw...) +@grad_from_chainrules meanpool(::TrackedArray, ::PoolDims; kw...) # Broadcasting diff --git a/src/macros.jl b/src/macros.jl new file mode 100644 index 0000000..af1d0e2 --- /dev/null +++ b/src/macros.jl @@ -0,0 +1,84 @@ +""" + @grad_from_chainrules f(args...; kwargs...) + +The `@grad_from_chainrules` macro provides a way to import adjoints(rrule) defined in +ChainRules to Tracker. One must provide a method signature to import the corresponding +rrule. In the provided method signature, one should replace the types of arguments to which +one wants to take derivatives with respect with Tracker.TrackedReal and Tracker.TrackedArray +respectively. For example, we can import rrule of `f(x::Real, y::Array)`` like below: + + Tracker.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray) + Tracker.@grad_from_chainrules f(x::TrackedReal, y::Array) + Tracker.@grad_from_chainrules f(x::Real, y::TrackedArray) + +Acceptable type annotations are `TrackedReal`, `TrackedArray`, `TrackedVector`, and +`TrackedMatrix`. These can have parameters like `TrackedArray{Float32}`. +""" +macro grad_from_chainrules(fcall) + @assert isdefined(__module__, :Tracker) "Tracker not found in module $__module__. Please load `Tracker.jl`." + Meta.isexpr(fcall, :call) && length(fcall.args) ≥ 2 || + error("`@grad_from_chainrules` has to be applied to a function signature") + + f = fcall.args[1] + # Check if kwargs... splatting is present + kws_var = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[2].args[1].args[1] : + nothing + rem_args = Meta.isexpr(fcall.args[2], :parameters) ? fcall.args[3:end] : + fcall.args[2:end] + xs = map(rem_args) do x + Meta.isexpr(x, :(::)) || return x + length(x.args) == 1 && return :($(gensym())::$(x.args[1])) # ::T without var name + @assert length(x.args) == 2 + return :($(x.args[1])::$(x.args[2])) # x::T + end + xs_untyped = map(xs) do x + Meta.isexpr(x, :(::)) || return x + return x.args[1] + end + + untrack_args = map(enumerate(xs)) do (i, x) + Meta.isexpr(x, :(::)) || return (x, nothing) + name, type = x.args + Meta.isexpr(type, :curly) && (type = type.args[1]) # Strip parameters from types + type in (:TrackedArray, :TrackedVector, :TrackedMatrix, :TrackedReal) || return (name, nothing) + xdata = gensym(name) + return xdata, :($(xdata) = $(Tracker.data)($(name))) + end + untrack_calls = filter(Base.Fix2(!==, nothing), last.(untrack_args)) + @assert length(untrack_calls) > 0 "No tracked arguments found." + var_names = first.(untrack_args) + + f_sym = Meta.quot(Symbol(f)) + + if kws_var === nothing + return esc(quote + $(f)($(xs...)) = $(Tracker.track)($(f), $(xs_untyped...)) + function Tracker._forward(::typeof($(f)), $(xs...)) + $(untrack_calls...) + y, pb_f = $(CRC.rrule)($(f), $(var_names...)) + ∇internal_generated = let pb_f = pb_f # Avoid Boxing + Δ -> return Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end])) + end + return y, ∇internal_generated + end + end) + end + return esc(quote + function $(f)($(xs...); $(kws_var)...) + return Tracker.track($(f), $(xs_untyped...); $(kws_var)...) + end + function Tracker._forward(::typeof($(f)), $(xs...); $(kws_var)...) + $(untrack_calls...) + y, pb_f = $(CRC.rrule)($(f), $(var_names...); $(kws_var)...) + ∇internal_generated = let pb_f = pb_f # Avoid Boxing + Δ -> Tracker.nobacksies($(f_sym), $(__no_crctangent).(pb_f($(data)(Δ))[2:end])) + end + return y, ∇internal_generated + end + end) +end + +@inline __no_crctangent(::CRC.NoTangent) = nothing +@inline __no_crctangent(::CRC.ZeroTangent) = nothing +@inline __no_crctangent(x::CRC.AbstractThunk) = CRC.unthunk(x) +@inline __no_crctangent(x) = x