From f0a460ef01f5973aad55a3542e7237a382fbe359 Mon Sep 17 00:00:00 2001 From: Chirag Tyagi Date: Mon, 25 Mar 2024 18:33:59 +0530 Subject: [PATCH 1/4] remove scan --- src/back.jl | 93 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/src/back.jl b/src/back.jl index e0cc739..4bff2a6 100644 --- a/src/back.jl +++ b/src/back.jl @@ -1,11 +1,14 @@ # The AD generates fairly large backtraces that are unhelpful if you interrupt # while training; this just cleans that up. macro interrupts(ex) - :(try $(esc(ex)) + :( + try + $(esc(ex)) catch e e isa InterruptException || rethrow() throw(e) - end) + end + ) end # In-place gradients @@ -44,19 +47,39 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used") accum!(x, Δ) = x .+ Δ accum!(x::AbstractArray, Δ) = (x .+= Δ) +# function back(x::Tracked, Δ, once) +# x.isleaf && (x.grad = accum!(x.grad, Δ); return) +# ref = x.ref -= 1 +# grad = if isdefined(x, :grad) +# x.grad = accum!(x.grad, Δ) +# elseif ref > 0 +# x.grad = Δ +# else +# Δ +# end +# if ref == 0 +# back_(x.f, grad, once) +# once && !x.isleaf && (x.f = Call(missing, ())) +# end +# return +# end + + function back(x::Tracked, Δ, once) - x.isleaf && (x.grad = accum!(x.grad, Δ); return) - ref = x.ref -= 1 - grad = if isdefined(x, :grad) - x.grad = accum!(x.grad, Δ) - elseif ref > 0 - x.grad = Δ - else - Δ - end - if ref == 0 - back_(x.f, grad, once) - once && !x.isleaf && (x.f = Call(missing, ())) + if !x.isleaf # If x is not a leaf node + ref = getproperty(x, :ref, 0) # Get the ref count of x, default to 0 if not available + grad = getproperty(x, :grad, nothing) # Get the gradient of x, default to nothing if not available + + if isnothing(grad) || ref == 0 # If grad is not computed or x is not referenced elsewhere + x.grad = Δ # Set the gradient of x to Δ + else + x.grad = accum!(grad, Δ) # Accumulate Δ into the existing gradient of x + end + + if ref == 0 # If x is not referenced elsewhere + back_(x.f, x.grad, once) # Backpropagate through the function call of x with gradient x.grad + once && !x.isleaf && (x.f = Call(missing, ())) # If once is true and x is not a leaf, replace x.f with a missing function call + end end return end @@ -71,13 +94,19 @@ back(::Nothing, Δ, once) = return # Refcounts are also probably not safe in some situations (e.g. back called # from within a backpropagator) -function back!(x, Δ; once = true) - istracked(x) || return - scan(x) - back(tracker(x), Δ, once) + +function back!(x, Δ; once=true) + back(tracker(x), Δ, once) # Call the back function starting from the tracker of x return end +# function back!(x, Δ; once=true) +# istracked(x) || return +# scan(x) +# back(tracker(x), Δ, once) +# return +# end + function extract_grad!(x) x̄ = copy(grad(x)) x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄) @@ -161,7 +190,7 @@ function gradient_nested(f, args...) return back(1) end -gradient(f, xs...; nest = false) = +gradient(f, xs...; nest=false) = nest ? gradient_nested(f, xs...) : gradient_(f, xs...) # Jacobians and Hessians @@ -219,14 +248,14 @@ julia> withgradient(model, rand(Float32, 2)) do m, x ``` """ function withgradient(f, xs...) - pxs = fmap(param, xs; exclude = isnumeric, walk = _trainable_walk) - l = f(pxs...) - l1 = l isa Union{Tuple, NamedTuple} ? first(l) : l - val = l isa Union{Tuple, NamedTuple} ? fmap(data, l) : data(l) - losscheck(l1) - l1 isa TrackedReal || return (; val, grad = map(_ -> nothing, xs)) - @interrupts back!(l1) - (; val, grad = rec_grad(pxs)) + pxs = fmap(param, xs; exclude=isnumeric, walk=_trainable_walk) + l = f(pxs...) + l1 = l isa Union{Tuple,NamedTuple} ? first(l) : l + val = l isa Union{Tuple,NamedTuple} ? fmap(data, l) : data(l) + losscheck(l1) + l1 isa TrackedReal || return (; val, grad=map(_ -> nothing, xs)) + @interrupts back!(l1) + (; val, grad=rec_grad(pxs)) end function _trainable_walk(f, x) @@ -234,7 +263,7 @@ function _trainable_walk(f, x) isempty(func) && return x done = map(f, _trainable(x)) # recurse only into trainable fields, this contains `nothing` elsewhere map(func, merge(func, done)) do n, t - isnothing(t) ? n : t + isnothing(t) ? n : t end |> re # reconstruct the whole thing end _trainable_walk(f, x::Tuple) = map(f, x) @@ -247,9 +276,9 @@ rec_grad(x::Number) = nothing rec_grad(x::Union{Tuple,NamedTuple,AbstractArray}) = map(rec_grad, x) rec_grad(::Tuple{}) = nothing -rec_grad(::NamedTuple{(), Tuple{}}) = nothing +rec_grad(::NamedTuple{(),Tuple{}}) = nothing function rec_grad(x::T) where {T} - F = fieldnames(T) - isempty(F) && return nothing - map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F)) + F = fieldnames(T) + isempty(F) && return nothing + map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F)) end From 736771239980f92349802e40348e8a954b353856 Mon Sep 17 00:00:00 2001 From: Chirag Tyagi Date: Mon, 25 Mar 2024 18:34:51 +0530 Subject: [PATCH 2/4] remove scan --- src/back.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/back.jl b/src/back.jl index 4bff2a6..8b77cec 100644 --- a/src/back.jl +++ b/src/back.jl @@ -19,20 +19,20 @@ zero_grad!(x::AbstractArray) = (x .= 0) scan(c::Call) = foreach(scan, c.args) -function scan(x::Tracked) - x.isleaf && return - ref = x.ref += 1 - if ref == 1 - scan(x.f) - isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) - end - return -end +# function scan(x::Tracked) +# x.isleaf && return +# ref = x.ref += 1 +# if ref == 1 +# scan(x.f) +# isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) +# end +# return +# end -function scan(x) - istracked(x) && scan(tracker(x)) - return -end +# function scan(x) +# istracked(x) && scan(tracker(x)) +# return +# end function back_(c::Call, Δ, once) Δs = c.func(Δ) From 3f126ac6fe37b0592bd4c04715dab3ba428aa9ec Mon Sep 17 00:00:00 2001 From: Chirag Tyagi Date: Wed, 27 Mar 2024 22:26:01 +0530 Subject: [PATCH 3/4] handle using ref count --- src/back.jl | 68 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/src/back.jl b/src/back.jl index 8b77cec..1cb5674 100644 --- a/src/back.jl +++ b/src/back.jl @@ -17,7 +17,7 @@ init_grad(x) = zero(x) zero_grad!(x) = zero(x) zero_grad!(x::AbstractArray) = (x .= 0) -scan(c::Call) = foreach(scan, c.args) +# scan(c::Call) = foreach(scan, c.args) # function scan(x::Tracked) # x.isleaf && return @@ -34,15 +34,25 @@ scan(c::Call) = foreach(scan, c.args) # return # end -function back_(c::Call, Δ, once) +# function back_(c::Call, Δ, once) +# Δs = c.func(Δ) +# (Δs isa Tuple && length(Δs) >= length(c.args)) || +# error("Gradient is not a tuple of length $(length(c.args))") +# foreach((x, d) -> back(x, d, once), c.args, data.(Δs)) +# end + +function back_(c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, d) -> back(x, d, once), c.args, data.(Δs)) + foreach((x, d) -> back(x, d), c.args, data.(Δs)) end -back_(::Call{Nothing}, Δ, once) = nothing -back_(::Call{Missing}, Δ, once) = error("`back!` was already used") +# back_(::Call{Nothing}, Δ, once) = nothing +# back_(::Call{Missing}, Δ, once) = error("`back!` was already used") + +back_(::Call{Nothing}, Δ) = nothing +back_(::Call{Missing}, Δ) = error("`back!` was already used") accum!(x, Δ) = x .+ Δ accum!(x::AbstractArray, Δ) = (x .+= Δ) @@ -65,26 +75,28 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ) # end -function back(x::Tracked, Δ, once) - if !x.isleaf # If x is not a leaf node - ref = getproperty(x, :ref, 0) # Get the ref count of x, default to 0 if not available - grad = getproperty(x, :grad, nothing) # Get the gradient of x, default to nothing if not available +function back(x::Tracked, Δ) + # Increment the reference count + x.ref += 1 - if isnothing(grad) || ref == 0 # If grad is not computed or x is not referenced elsewhere - x.grad = Δ # Set the gradient of x to Δ - else - x.grad = accum!(grad, Δ) # Accumulate Δ into the existing gradient of x - end - - if ref == 0 # If x is not referenced elsewhere - back_(x.f, x.grad, once) # Backpropagate through the function call of x with gradient x.grad - once && !x.isleaf && (x.f = Call(missing, ())) # If once is true and x is not a leaf, replace x.f with a missing function call - end + # Handle gradient accumulation and backpropagation based on the reference count + if x.ref == 1 + # Node has no more references, perform backpropagation and reset gradient + x.grad = Δ + back_(x.f, Δ) + else + # Node already has additional references, accumulate gradient into the gradient buffer + x.grad = accum!(x.grad, Δ) end + + # Decrement the reference count + x.ref -= 1 + return end -back(::Nothing, Δ, once) = return +# back(::Nothing, Δ, once) = return +back(::Nothing, Δ) = return # Interface methods @@ -96,16 +108,18 @@ back(::Nothing, Δ, once) = return function back!(x, Δ; once=true) - back(tracker(x), Δ, once) # Call the back function starting from the tracker of x + # back(tracker(x), Δ, once) # Call the back function starting from the tracker of x + back(tracker(x), Δ) # Call the back function starting from the tracker of x return end -# function back!(x, Δ; once=true) -# istracked(x) || return -# scan(x) -# back(tracker(x), Δ, once) -# return -# end +function back!(x, Δ; once=true) + istracked(x) || return + # scan(x) + # back(tracker(x), Δ, once) + back(tracker(x), Δ) + return +end function extract_grad!(x) x̄ = copy(grad(x)) From 302b0c0ba9719bfbac9fd496821182585f391ffa Mon Sep 17 00:00:00 2001 From: Chirag Tyagi Date: Fri, 29 Mar 2024 02:14:00 +0530 Subject: [PATCH 4/4] modified approach --- src/back.jl | 65 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/src/back.jl b/src/back.jl index 1cb5674..23c8383 100644 --- a/src/back.jl +++ b/src/back.jl @@ -75,26 +75,53 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ) # end -function back(x::Tracked, Δ) - # Increment the reference count - x.ref += 1 +# function back(x::Tracked, Δ) +# # Increment the reference count +# x.ref += 1 - # Handle gradient accumulation and backpropagation based on the reference count - if x.ref == 1 - # Node has no more references, perform backpropagation and reset gradient - x.grad = Δ - back_(x.f, Δ) - else - # Node already has additional references, accumulate gradient into the gradient buffer +# # Handle gradient accumulation and backpropagation based on the reference count +# if x.ref == 1 +# # Node has no more references, perform backpropagation and reset gradient +# x.grad = Δ +# back_(x.f, Δ) +# else +# # Node already has additional references, accumulate gradient into the gradient buffer +# x.grad = accum!(x.grad, Δ) +# end + +# # Decrement the reference count +# x.ref -= 1 + +# return +# end + + + +function back(x::Tracked, Δ) + if x.isleaf x.grad = accum!(x.grad, Δ) + return end - # Decrement the reference count x.ref -= 1 + if isdefined(x, :grad) + x.grad = accum!(x.grad, Δ) + elseif x.ref > 0 + x.grad = Δ + else + x.grad = Δ + end - return + if x.ref == 0 + Δs = x.f(Δ) + for (arg, d) in zip(x.args, Δs) + back(arg, d) + end + end end + + # back(::Nothing, Δ, once) = return back(::Nothing, Δ) = return @@ -107,12 +134,18 @@ back(::Nothing, Δ) = return # from within a backpropagator) -function back!(x, Δ; once=true) - # back(tracker(x), Δ, once) # Call the back function starting from the tracker of x - back(tracker(x), Δ) # Call the back function starting from the tracker of x - return +function back!(x, Δ) + istracked(x) || return + back(tracker(x), Δ) end + +# function back!(x, Δ; once=true) +# # back(tracker(x), Δ, once) # Call the back function starting from the tracker of x +# back(tracker(x), Δ) # Call the back function starting from the tracker of x +# return +# end + function back!(x, Δ; once=true) istracked(x) || return # scan(x)