This repository has been archived by the owner on Oct 31, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 26
/
ad.jl
78 lines (72 loc) · 2.98 KB
/
ad.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
f = prob.f
p = value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = value(prob.tspan)
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
else
u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
end
sol = solve(newprob, alg, args...; kwargs...)
uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
end
f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
pp = prob.p
sumfun = let f_x′ = -f_x
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
end
partials = sum(sumfun, zip(f_p, pp))
return sol, partials
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:Dual{T, V, P}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
end
# avoid ambiguities
for Alg in [Bisection]
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:Dual{T, V, P}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:AbstractArray{
<:Dual{T,
V,
P},
}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
end