-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathLuxZygoteExt.jl
141 lines (120 loc) · 5.43 KB
/
LuxZygoteExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
module LuxZygoteExt
using ADTypes: AutoZygote
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
using Lux: Lux, DISABLE_AUTOMATIC_NESTED_AD_SWITCH
using Setfield: @set!
using Zygote: Zygote
const CRC = ChainRulesCore
function Lux.Experimental.compute_gradients(::AutoZygote, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
(loss, st, stats), back = Zygote.pullback(
objective_function, ts.model, ts.parameters, ts.states, data)
grads = back((one(loss), nothing, nothing))[2]
@set! ts.states = st
return grads, loss, stats, ts
end
# Nested AD Handling
## Zygote.gradient call
@inline function __internal_gradient_capture(f::F, x, args...) where {F}
return Zygote.gradient(@closure(x->f(x, args...)), x)
end
@inline function Zygote.gradient(
f::Base.ComposedFunction{<:Lux.StatefulLuxLayer, F}, x::AbstractArray) where {F}
return __internal_gradient_capture(
@closure((x, ps)->f.outer(f.inner(x), ps)), x, f.outer.ps)
end
@inline function Zygote.gradient(
f::Base.ComposedFunction{F, <:Lux.StatefulLuxLayer}, x::AbstractArray) where {F}
return __internal_gradient_capture(f, x, f.inner.ps)
end
@inline function Zygote.gradient(f::Lux.StatefulLuxLayer, x::AbstractArray)
return __internal_gradient_capture(f, x, f.ps)
end
function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(__internal_gradient_capture), f::F, x::AbstractArray, ps) where {F}
if !Lux._is_extension_loaded(Val(:ForwardDiff)) || DISABLE_AUTOMATIC_NESTED_AD_SWITCH
if !DISABLE_AUTOMATIC_NESTED_AD_SWITCH
@warn "Load ForwardDiff.jl for better nested AD handling." maxlog=1
end
# Use the AD itself for whatever reason
y, pb_f = CRC.rrule_via_ad(cfg, Zygote.gradient, f, x, ps)
return (first(y),), pb_f
end
y = __internal_gradient_capture(f, x, ps)
∇internal_gradient_capture = @closure Δ -> begin
(Δ isa CRC.NoTangent || Δ isa CRC.ZeroTangent) &&
return ntuple(Returns(CRC.NoTangent()), 4)
Δ_ = reshape(CRC.unthunk(only(Δ)), size(x))
∂x, ∂ps = Lux.__forwarddiff_jvp(
@closure((x, ps)->Zygote.gradient(f, x, ps)), x, Δ_, ps)
return CRC.NoTangent(), CRC.NoTangent(), ∂x, ∂ps
end
return y, ∇internal_gradient_capture
end
## Zygote.jacobian call
@inline function __internal_jacobian_capture(f::F, x, args...) where {F}
return Zygote.jacobian(@closure(x->f(x, args...)), x)
end
@inline function Zygote.jacobian(
f::Base.ComposedFunction{<:Lux.StatefulLuxLayer, F}, x::AbstractArray) where {F}
return __internal_jacobian_capture(@closure((x, ps)->f.outer(f.inner(x), ps)), x, ps)
end
@inline function Zygote.jacobian(
f::Base.ComposedFunction{F, <:Lux.StatefulLuxLayer}, x::AbstractArray) where {F}
return __internal_jacobian_capture(f, x, f.inner.ps)
end
@inline function Zygote.jacobian(f::Lux.StatefulLuxLayer, x::AbstractArray)
return __internal_jacobian_capture(f, x, f.ps)
end
function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(__internal_jacobian_capture), f::F, x::AbstractArray, ps) where {F}
if !Lux._is_extension_loaded(Val(:ForwardDiff)) || DISABLE_AUTOMATIC_NESTED_AD_SWITCH
if !DISABLE_AUTOMATIC_NESTED_AD_SWITCH
@warn "Load ForwardDiff.jl for better nested AD handling." maxlog=1
end
# Use the AD itself for whatever reason. This will fail most likely!
y, pb_f = CRC.rrule_via_ad(cfg, Zygote.jacobian, f, x, ps)
return y, pb_f
end
J = __internal_jacobian_capture(f, x, ps)
∇internal_jacobian_capture = Δ_ -> begin
(Δ_ isa CRC.NoTangent || Δ_ isa CRC.ZeroTangent) &&
return ntuple(Returns(CRC.NoTangent()), 4)
Δ = Lux.__compactify_if_structured_matrix(only(J), CRC.unthunk(only(Δ_)))
∂x, ∂ps = mapreduce(Lux.__internal_add, enumerate(eachrow(Δ))) do (i, Δᵢ)
__f = (x, p) -> sum(vec(f(x, p))[i:i])
∂xᵢ, ∂psᵢ = Lux.__forwarddiff_jvp(
@closure((x, ps)->Zygote.gradient(__f, x, ps)), x, reshape(Δᵢ, size(x)), ps)
return ∂xᵢ, ∂psᵢ
end
return CRC.NoTangent(), CRC.NoTangent(), ∂x, ∂ps
end
return J, ∇internal_jacobian_capture
end
# Handle Weird Zygote shit
## Hope this doesn't get moved into extensions then we will have to create another file
@static if isdefined(Zygote, :ForwardDiff)
using Zygote: ForwardDiff
# Forward to a function that doesn't have this _pullback defined so that it triggers the
# rrule
function Zygote._pullback(cx::Zygote.AContext,
::typeof(ForwardDiff.jacobian),
f::Union{Base.ComposedFunction{<:Any, <:Lux.StatefulLuxLayer},
Base.ComposedFunction{<:Lux.StatefulLuxLayer, <:Any},
Lux.StatefulLuxLayer},
x::AbstractArray)
return Zygote._pullback(
cx, ForwardDiff.jacobian, f, x, ForwardDiff.JacobianConfig(f, x), Val(true))
end
function Zygote._pullback(cx::Zygote.AContext,
::typeof(ForwardDiff.gradient),
f::Union{Base.ComposedFunction{<:Any, <:Lux.StatefulLuxLayer},
Base.ComposedFunction{<:Lux.StatefulLuxLayer, <:Any},
Lux.StatefulLuxLayer},
x::AbstractArray)
return Zygote._pullback(
cx, ForwardDiff.gradient, f, x, ForwardDiff.GradientConfig(f, x), Val(true))
end
end
end