-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathautodiff.jl
32 lines (30 loc) · 1.2 KB
/
autodiff.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
function OnceDifferentiable(f!, x::AbstractArray, F::AbstractArray, autodiff::Union{Symbol, Bool} = :central, chunk = ForwardDiff.Chunk(x))
if autodiff == :central
central_cache = DiffEqDiffTools.JacobianCache(similar(x), similar(x), similar(x))
function fj!(F, J, x)
f!(F, x)
DiffEqDiffTools.finite_difference_jacobian!(J, f!, x, central_cache)
F
end
function j!(J, x)
F = similar(x)
fj!(F, J, x)
end
return OnceDifferentiable(f!, j!, fj!, x, x)
elseif autodiff == :forward || autodiff == true
jac_cfg = ForwardDiff.JacobianConfig(f!, x, x, chunk)
ForwardDiff.checktag(jac_cfg, f!, x)
F2 = copy(x)
function g!(J, x)
ForwardDiff.jacobian!(J, f!, F2, x, jac_cfg, Val{false}())
end
function fg!(F, J, x)
jac_res = DiffBase.DiffResult(F, J)
ForwardDiff.jacobian!(jac_res, f!, F2, x, jac_cfg, Val{false}())
DiffBase.value(jac_res)
end
return OnceDifferentiable(f!, g!, fg!, x, x)
else
error("The autodiff value $(autodiff) is not supported. Use :central or :forward.")
end
end