Skip to content

Commit

Permalink
Add forward mode to line search
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Jun 14, 2024
1 parent a74c321 commit b7c54f3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 26 deletions.
82 changes: 60 additions & 22 deletions src/globalization/line_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
ϕdϕ
method
alpha
grad_op
deriv_op
u_cache
fu_cache
stats::NLStats
Expand All @@ -110,25 +110,59 @@ function __internal_init(
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
Detected $(autodiff). Falling back to AutoFiniteDiff."
end
grad_op = @closure (u, fu, p) -> last(__value_derivative(
autodiff, Base.Fix2(f, p), u)) * fu
deriv_op = @closure (du, u, fu, p) -> last(__value_derivative(
autodiff, Base.Fix2(f, p), u)) *
fu *
du
else
if SciMLBase.has_jvp(f)
# Both forward and reverse AD can be used for line-search.
# We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it.
# 1. If jvp is available, we use forward AD;
# 2. If vjp is available, we use reverse AD;
# 3. If reverse type is requested, we use reverse AD;
# 4. Finally, we use forward AD.
if alg.autodiff isa AutoFiniteDiff
deriv_op = nothing
elseif SciMLBase.has_jvp(f)
if isinplace(prob)
g_cache = __similar(u)
grad_op = @closure (u, fu, p) -> f.vjp(g_cache, fu, u, p)
jvp_cache = __similar(fu)
deriv_op = @closure (du, u, fu, p) -> begin
f.jvp(jvp_cache, du, u, p)
dot(fu, jvp_cache)
end
else
grad_op = @closure (u, fu, p) -> f.vjp(fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(du, u, p))
end
else
elseif SciMLBase.has_vjp(f)
if isinplace(prob)
vjp_cache = __similar(u)
deriv_op = @closure (du, u, fu, p) -> begin
f.vjp(vjp_cache, fu, u, p)
dot(du, vjp_cache)
end
else
deriv_op = @closure (du, u, fu, p) -> dot(du, f.vjp(fu, u, p))
end
elseif alg.autodiff !== nothing &&
ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode
autodiff = get_concrete_reverse_ad(
alg.autodiff, prob; check_reverse_mode = true)
vjp_op = VecJacOperator(prob, fu, u; autodiff)
if isinplace(prob)
g_cache = __similar(u)
grad_op = @closure (u, fu, p) -> vjp_op(g_cache, fu, u, p)
vjp_cache = __similar(u)
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(vjp_cache, fu, u, p))
else
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p))
end
else
autodiff = get_concrete_forward_ad(
alg.autodiff, prob; check_forward_mode = true)
jvp_op = JacVecOperator(prob, fu, u; autodiff)
if isinplace(prob)
jvp_cache = __similar(fu)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p))
else
grad_op = @closure (u, fu, p) -> vjp_op(fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p))
end
end
end
Expand All @@ -143,33 +177,37 @@ function __internal_init(
return @fastmath internalnorm(fu_cache)^2 / 2
end

= @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
= @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
@bb @. u_cache = u + α * du
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
stats.nf += 1
g₀ = grad_op(u_cache, fu_cache, p)
return dot(g₀, du)
return deriv_op(du, u_cache, fu_cache, p)
end

ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
@bb @. u_cache = u + α * du
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
stats.nf += 1
g₀ = grad_op(u_cache, fu_cache, p)
deriv = deriv_op(du, u_cache, fu_cache, p)
obj = @fastmath internalnorm(fu_cache)^2 / 2
return obj, dot(g₀, du)
return obj, deriv
end

return LineSearchesJLCache(f, p, ϕ, dϕ, ϕdϕ, alg.method, T(alg.initial_alpha),
grad_op, u_cache, fu_cache, stats)
deriv_op, u_cache, fu_cache, stats)
end

function __internal_solve!(cache::LineSearchesJLCache, u, du; kwargs...)
ϕ = @closure α -> cache.ϕ(cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache)
= @closure α -> cache.(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
ϕdϕ = @closure α -> cache.ϕdϕ(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
if cache.deriv_op !== nothing
= @closure α -> cache.(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
ϕdϕ = @closure α -> cache.ϕdϕ(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
else
= @closure α -> FiniteDiff.finite_difference_derivative(ϕ, α)
ϕdϕ = @closure α -> (ϕ(α), FiniteDiff.finite_difference_derivative(ϕ, α))
end

ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))

Expand Down
8 changes: 4 additions & 4 deletions test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
@testitem "NewtonRaphson" setup=[CoreRootfindTesting] tags=[:core] timeout=3600 begin
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote())
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff())

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down Expand Up @@ -466,7 +466,7 @@ end
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian) Update Rule: $(update_rule)" for lsmethod in (
Static(), StrongWolfe(), BackTracking(),
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
ad in (AutoFiniteDiff(), AutoZygote()),
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()),
init_jacobian in (Val(:identity), Val(:true_jacobian)),
update_rule in (Val(:good_broyden), Val(:bad_broyden), Val(:diagonal))

Expand Down Expand Up @@ -515,7 +515,7 @@ end
@testitem "Klement" setup=[CoreRootfindTesting] tags=[:core] skip=:(Sys.isapple()) timeout=3600 begin
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian)" for lsmethod in (
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote()),
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()),
init_jacobian in (Val(:identity), Val(:true_jacobian), Val(:true_jacobian_diagonal))

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
Expand Down Expand Up @@ -565,7 +565,7 @@ end
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
Static(), StrongWolfe(), BackTracking(),
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
ad in (AutoFiniteDiff(), AutoZygote())
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff())

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down

0 comments on commit b7c54f3

Please sign in to comment.