diff --git a/src/globalization/line_search.jl b/src/globalization/line_search.jl index e09a8d188..4ca263ca2 100644 --- a/src/globalization/line_search.jl +++ b/src/globalization/line_search.jl @@ -90,7 +90,7 @@ end ϕdϕ method alpha - grad_op + deriv_op u_cache fu_cache stats::NLStats @@ -110,25 +110,39 @@ function __internal_init( @warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \ Detected $(autodiff). Falling back to AutoFiniteDiff." end - grad_op = @closure (u, fu, p) -> last(__value_derivative( - autodiff, Base.Fix2(f, p), u)) * fu + deriv_op = @closure (du, u, fu, p) -> last(__value_derivative( + autodiff, Base.Fix2(f, p), u)) * fu * du else + # Both forward and reverse AD can be used for line-search. + # We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it. + # 1. If jvp is available, we use forward AD; + # 2. If reverse type is requested, we use reverse AD; + # 3. Otherwise, we use forward AD. if SciMLBase.has_jvp(f) if isinplace(prob) - g_cache = __similar(u) - grad_op = @closure (u, fu, p) -> f.vjp(g_cache, fu, u, p) + jvp_cache = __similar(fu) + deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(jvp_cache, du, u, p)) else - grad_op = @closure (u, fu, p) -> f.vjp(fu, u, p) + deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(du, u, p)) end - else + elseif alg.autodiff !== nothing && ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode autodiff = get_concrete_reverse_ad( alg.autodiff, prob; check_reverse_mode = true) vjp_op = VecJacOperator(prob, fu, u; autodiff) if isinplace(prob) g_cache = __similar(u) - grad_op = @closure (u, fu, p) -> vjp_op(g_cache, fu, u, p) + deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(g_cache, fu, u, p)) + else + deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p)) + end + else + autodiff = get_concrete_forward_ad(alg.autodiff, prob; check_forward_mode = true) + jvp_op = JacVecOperator(prob, fu, u; autodiff) + if isinplace(prob) + jvp_cache = __similar(fu) + deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p)) else - grad_op = @closure (u, fu, p) -> vjp_op(fu, u, p) + deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p)) end end end @@ -143,33 +157,32 @@ function __internal_init( return @fastmath internalnorm(fu_cache)^2 / 2 end - dϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin + dϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin @bb @. u_cache = u + α * du fu_cache = evaluate_f!!(f, fu_cache, u_cache, p) stats.nf += 1 - g₀ = grad_op(u_cache, fu_cache, p) - return dot(g₀, du) + return deriv_op(du, u_cache, fu_cache, p) end - ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin + ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin @bb @. u_cache = u + α * du fu_cache = evaluate_f!!(f, fu_cache, u_cache, p) stats.nf += 1 - g₀ = grad_op(u_cache, fu_cache, p) + deriv = deriv_op(du, u_cache, fu_cache, p) obj = @fastmath internalnorm(fu_cache)^2 / 2 - return obj, dot(g₀, du) + return obj, deriv end return LineSearchesJLCache(f, p, ϕ, dϕ, ϕdϕ, alg.method, T(alg.initial_alpha), - grad_op, u_cache, fu_cache, stats) + deriv_op, u_cache, fu_cache, stats) end function __internal_solve!(cache::LineSearchesJLCache, u, du; kwargs...) ϕ = @closure α -> cache.ϕ(cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache) dϕ = @closure α -> cache.dϕ( - cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op) + cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op) ϕdϕ = @closure α -> cache.ϕdϕ( - cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op) + cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op) ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u))) diff --git a/test/core/nlls_tests.jl b/test/core/nlls_tests.jl index 483107f69..f368ec897 100644 --- a/test/core/nlls_tests.jl +++ b/test/core/nlls_tests.jl @@ -27,12 +27,12 @@ const θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1 solvers = [] for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES(), KrylovJL_LSMR()] - vjp_autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoFiniteDiff()] : + autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoForwardDiff()] : [nothing] for linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()], - vjp_autodiff in vjp_autodiffs + autodiff in autodiffs - push!(solvers, GaussNewton(; linsolve, linesearch, vjp_autodiff)) + push!(solvers, GaussNewton(; linsolve, linesearch, autodiff)) end end append!(solvers, diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 3ed50a2c7..9d93e9680 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -55,7 +55,7 @@ end @testitem "NewtonRaphson" setup=[CoreRootfindTesting] tags=[:core] timeout=3600 begin @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in ( Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()), - ad in (AutoFiniteDiff(), AutoZygote()) + ad in (AutoForwardDiff(), AutoZygote()) linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad) u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) @@ -466,7 +466,7 @@ end @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian) Update Rule: $(update_rule)" for lsmethod in ( Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente(), LiFukushimaLineSearch()), - ad in (AutoFiniteDiff(), AutoZygote()), + ad in (AutoForwardDiff(), AutoZygote()), init_jacobian in (Val(:identity), Val(:true_jacobian)), update_rule in (Val(:good_broyden), Val(:bad_broyden), Val(:diagonal)) @@ -515,7 +515,7 @@ end @testitem "Klement" setup=[CoreRootfindTesting] tags=[:core] skip=:(Sys.isapple()) timeout=3600 begin @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian)" for lsmethod in ( Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()), - ad in (AutoFiniteDiff(), AutoZygote()), + ad in (AutoForwardDiff(), AutoZygote()), init_jacobian in (Val(:identity), Val(:true_jacobian), Val(:true_jacobian_diagonal)) linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad) @@ -565,7 +565,7 @@ end @testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in ( Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente(), LiFukushimaLineSearch()), - ad in (AutoFiniteDiff(), AutoZygote()) + ad in (AutoForwardDiff(), AutoZygote()) linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad) u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)