-
-
Notifications
You must be signed in to change notification settings - Fork 117
/
Copy pathDiffEqBaseEnzymeExt.jl
61 lines (54 loc) · 1.85 KB
/
DiffEqBaseEnzymeExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
module DiffEqBaseEnzymeExt
using DiffEqBase
import DiffEqBase: value
using Enzyme
import Enzyme: Const
using ChainRulesCore
function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1},
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
u0, p, args...; kwargs...) where {RT}
@inline function copy_or_reuse(val, idx)
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
return deepcopy(val)
else
return val
end
end
@inline function arg_copy(i)
copy_or_reuse(args[i].val, i + 5)
end
res = DiffEqBase._solve_adjoint(
copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3),
copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5),
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
kwargs...)
dres = deepcopy(res[1])::RT
for v in dres.u
v .= 0
end
tup = (dres, res[2])
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
end
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
u0, p, args...; kwargs...) where {RT}
dres, clos = tape
dres = dres::RT
dargs = clos(dres)
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
if ptr isa Enzyme.Const
continue
end
if darg == ChainRulesCore.NoTangent()
continue
end
ptr.dval .+= darg
end
for v in dres.u
v .= 0
end
return ntuple(_ -> nothing, Val(length(args) + 4))
end
end