From b7c54f3844bb413ac0a35e2da9628b04e803d9a2 Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Fri, 7 Jun 2024 11:59:44 -0400 Subject: [PATCH] Add forward mode to line search --- src/globalization/line_search.jl | 82 +++++++++++++++++++++++--------- test/core/rootfind_tests.jl | 8 ++-- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/src/globalization/line_search.jl b/src/globalization/line_search.jl index e09a8d188..4fca295e7 100644 --- a/src/globalization/line_search.jl +++ b/src/globalization/line_search.jl @@ -90,7 +90,7 @@ end ϕdϕ method alpha - grad_op + deriv_op u_cache fu_cache stats::NLStats @@ -110,25 +110,59 @@ function __internal_init( @warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \ Detected $(autodiff). Falling back to AutoFiniteDiff." end - grad_op = @closure (u, fu, p) -> last(__value_derivative( - autodiff, Base.Fix2(f, p), u)) * fu + deriv_op = @closure (du, u, fu, p) -> last(__value_derivative( + autodiff, Base.Fix2(f, p), u)) * + fu * + du else - if SciMLBase.has_jvp(f) + # Both forward and reverse AD can be used for line-search. + # We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it. + # 1. If jvp is available, we use forward AD; + # 2. If vjp is available, we use reverse AD; + # 3. If reverse type is requested, we use reverse AD; + # 4. Finally, we use forward AD. + if alg.autodiff isa AutoFiniteDiff + deriv_op = nothing + elseif SciMLBase.has_jvp(f) if isinplace(prob) - g_cache = __similar(u) - grad_op = @closure (u, fu, p) -> f.vjp(g_cache, fu, u, p) + jvp_cache = __similar(fu) + deriv_op = @closure (du, u, fu, p) -> begin + f.jvp(jvp_cache, du, u, p) + dot(fu, jvp_cache) + end else - grad_op = @closure (u, fu, p) -> f.vjp(fu, u, p) + deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(du, u, p)) end - else + elseif SciMLBase.has_vjp(f) + if isinplace(prob) + vjp_cache = __similar(u) + deriv_op = @closure (du, u, fu, p) -> begin + f.vjp(vjp_cache, fu, u, p) + dot(du, vjp_cache) + end + else + deriv_op = @closure (du, u, fu, p) -> dot(du, f.vjp(fu, u, p)) + end + elseif alg.autodiff !== nothing && + ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode autodiff = get_concrete_reverse_ad( alg.autodiff, prob; check_reverse_mode = true) vjp_op = VecJacOperator(prob, fu, u; autodiff) if isinplace(prob) - g_cache = __similar(u) - grad_op = @closure (u, fu, p) -> vjp_op(g_cache, fu, u, p) + vjp_cache = __similar(u) + deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(vjp_cache, fu, u, p)) + else + deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p)) + end + else + autodiff = get_concrete_forward_ad( + alg.autodiff, prob; check_forward_mode = true) + jvp_op = JacVecOperator(prob, fu, u; autodiff) + if isinplace(prob) + jvp_cache = __similar(fu) + deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p)) else - grad_op = @closure (u, fu, p) -> vjp_op(fu, u, p) + deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p)) end end end @@ -143,33 +177,37 @@ function __internal_init( return @fastmath internalnorm(fu_cache)^2 / 2 end - dϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin + dϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin @bb @. u_cache = u + α * du fu_cache = evaluate_f!!(f, fu_cache, u_cache, p) stats.nf += 1 - g₀ = grad_op(u_cache, fu_cache, p) - return dot(g₀, du) + return deriv_op(du, u_cache, fu_cache, p) end - ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin + ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin @bb @. u_cache = u + α * du fu_cache = evaluate_f!!(f, fu_cache, u_cache, p) stats.nf += 1 - g₀ = grad_op(u_cache, fu_cache, p) + deriv = deriv_op(du, u_cache, fu_cache, p) obj = @fastmath internalnorm(fu_cache)^2 / 2 - return obj, dot(g₀, du) + return obj, deriv end return LineSearchesJLCache(f, p, ϕ, dϕ, ϕdϕ, alg.method, T(alg.initial_alpha), - grad_op, u_cache, fu_cache, stats) + deriv_op, u_cache, fu_cache, stats) end function __internal_solve!(cache::LineSearchesJLCache, u, du; kwargs...) ϕ = @closure α -> cache.ϕ(cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache) - dϕ = @closure α -> cache.dϕ( - cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op) - ϕdϕ = @closure α -> cache.ϕdϕ( - cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op) + if cache.deriv_op !== nothing + dϕ = @closure α -> cache.dϕ( + cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op) + ϕdϕ = @closure α -> cache.ϕdϕ( + cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op) + else + dϕ = @closure α -> FiniteDiff.finite_difference_derivative(ϕ, α) + ϕdϕ = @closure α -> (ϕ(α), FiniteDiff.finite_difference_derivative(ϕ, α)) + end ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u))) diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 3ed50a2c7..880b34ff1 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -55,7 +55,7 @@ end @testitem "NewtonRaphson" setup=[CoreRootfindTesting] tags=[:core] timeout=3600 begin @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in ( Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()), - ad in (AutoFiniteDiff(), AutoZygote()) + ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()) linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad) u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) @@ -466,7 +466,7 @@ end @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian) Update Rule: $(update_rule)" for lsmethod in ( Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente(), LiFukushimaLineSearch()), - ad in (AutoFiniteDiff(), AutoZygote()), + ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()), init_jacobian in (Val(:identity), Val(:true_jacobian)), update_rule in (Val(:good_broyden), Val(:bad_broyden), Val(:diagonal)) @@ -515,7 +515,7 @@ end @testitem "Klement" setup=[CoreRootfindTesting] tags=[:core] skip=:(Sys.isapple()) timeout=3600 begin @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian)" for lsmethod in ( Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()), - ad in (AutoFiniteDiff(), AutoZygote()), + ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()), init_jacobian in (Val(:identity), Val(:true_jacobian), Val(:true_jacobian_diagonal)) linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad) @@ -565,7 +565,7 @@ end @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in ( Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente(), LiFukushimaLineSearch()), - ad in (AutoFiniteDiff(), AutoZygote()) + ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()) linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad) u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)