-
-
Notifications
You must be signed in to change notification settings - Fork 54
/
LinearSolveEnzymeExt.jl
244 lines (212 loc) · 7.06 KB
/
LinearSolveEnzymeExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
module LinearSolveEnzymeExt
using LinearSolve
using LinearSolve.LinearAlgebra
using EnzymeCore
using EnzymeCore: EnzymeRules
function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@assert !(prob isa Const)
res = func.val(prob.val, alg.val; kwargs...)
if RT <: Const
if EnzymeRules.needs_primal(config)
return res
else
return nothing
end
end
dres = func.val(prob.dval, alg.val; kwargs...)
if dres.b == res.b
dres.b .= false
end
if dres.A == res.A
dres.A .= false
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(res, dres)
elseif EnzymeRules.needs_shadow(config)
return dres
elseif EnzymeRules.needs_primal(config)
return res
else
return nothing
end
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@assert !(linsolve isa Const)
res = func.val(linsolve.val; kwargs...)
if RT <: Const
if EnzymeRules.needs_primal(config)
return res
else
return nothing
end
end
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
res = deepcopy(res) # Without this copy, the next solve will end up mutating the result
b = linsolve.val.b
linsolve.val.b = linsolve.dval.b - linsolve.dval.A * res.u
dres = func.val(linsolve.val; kwargs...)
linsolve.val.b = b
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(res, dres)
elseif EnzymeRules.needs_shadow(config)
return dres
elseif EnzymeRules.needs_primal(config)
return res
else
return nothing
end
end
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.init)},
::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
func.val(prob.dval, alg.val; kwargs...)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
func.val(prob.dval[i], alg.val; kwargs...)
end
end
d_A = if EnzymeRules.width(config) == 1
dres.A
else
(dval.A for dval in dres)
end
d_b = if EnzymeRules.width(config) == 1
dres.b
else
(dval.b for dval in dres)
end
prob_d_A = if EnzymeRules.width(config) == 1
prob.dval.A
else
(dval.A for dval in prob.dval)
end
prob_d_b = if EnzymeRules.width(config) == 1
prob.dval.b
else
(dval.b for dval in prob.dval)
end
return EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
end
function EnzymeRules.reverse(
config, func::Const{typeof(LinearSolve.init)}, ::Type{RT},
cache, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b, prob_d_A, prob_d_b = cache
if EnzymeRules.width(config) == 1
if d_A !== prob_d_A
prob_d_A .+= d_A
d_A .= 0
end
if d_b !== prob_d_b
prob_d_b .+= d_b
d_b .= 0
end
else
for (_prob_d_A, _d_A, _prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
if _d_A !== _prob_d_A
_prob_d_A .+= _d_A
_d_A .= 0
end
if _d_b !== _prob_d_b
_prob_d_b .+= _d_b
_d_b .= 0
end
end
end
return (nothing, nothing)
end
# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
res = func.val(linsolve.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
deepcopy(res)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
deepcopy(res)
end
end
if EnzymeRules.width(config) == 1
dres.u .= 0
else
for dr in dres
dr.u .= 0
end
end
resvals = if EnzymeRules.width(config) == 1
dres.u
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
dres[i].u
end
end
dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].A
end
end
dbs = if EnzymeRules.width(config) == 1
(linsolve.dval.b,)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].b
end
end
cachesolve = deepcopy(linsolve.val)
cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
return EnzymeRules.AugmentedReturn(res, dres, cache)
end
function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve, dAs, dbs = cache
@assert !(linsolve isa Const)
@assert !(linsolve isa Active)
if EnzymeRules.width(config) == 1
dys = (dys,)
end
for (dA, db, dy) in zip(dAs, dbs, dys)
z = if _linsolve.cacheval isa Factorization
_linsolve.cacheval' \ dy
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization
_linsolve.cacheval[1]' \ dy
elseif _linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod
# Doesn't modify `A`, so it's safe to just reuse it
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
solve(invprob, _linearsolve.alg;
abstol = _linsolve.val.abstol,
reltol = _linsolve.val.reltol,
verbose = _linsolve.val.verbose)
elseif _linsolve.alg isa LinearSolve.DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(_linsolve, dy)
else
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
dA .-= z * transpose(y)
db .+= z
dy .= eltype(dy)(0)
end
return (nothing,)
end
end