Skip to content

Commit

Permalink
Add forward mode to line search
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Jun 7, 2024
1 parent a74c321 commit 82944d0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
49 changes: 31 additions & 18 deletions src/globalization/line_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
ϕdϕ
method
alpha
grad_op
deriv_op
u_cache
fu_cache
stats::NLStats
Expand All @@ -110,25 +110,39 @@ function __internal_init(
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
Detected $(autodiff). Falling back to AutoFiniteDiff."
end
grad_op = @closure (u, fu, p) -> last(__value_derivative(
autodiff, Base.Fix2(f, p), u)) * fu
deriv_op = @closure (du, u, fu, p) -> last(__value_derivative(
autodiff, Base.Fix2(f, p), u)) * fu * du
else
# Both forward and reverse AD can be used for line-search.
# We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it.
# 1. If jvp is available, we use forward AD;
# 2. If reverse type is requested, we use reverse AD;
# 3. Otherwise, we use forward AD.
if SciMLBase.has_jvp(f)
if isinplace(prob)
g_cache = __similar(u)
grad_op = @closure (u, fu, p) -> f.vjp(g_cache, fu, u, p)
jvp_cache = __similar(fu)
deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(jvp_cache, du, u, p))

Check warning on line 124 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L123-L124

Added lines #L123 - L124 were not covered by tests
else
grad_op = @closure (u, fu, p) -> f.vjp(fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(du, u, p))

Check warning on line 126 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L126

Added line #L126 was not covered by tests
end
else
elseif alg.autodiff !== nothing && ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode
autodiff = get_concrete_reverse_ad(
alg.autodiff, prob; check_reverse_mode = true)
vjp_op = VecJacOperator(prob, fu, u; autodiff)
if isinplace(prob)
g_cache = __similar(u)
grad_op = @closure (u, fu, p) -> vjp_op(g_cache, fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(g_cache, fu, u, p))

Check warning on line 134 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L134

Added line #L134 was not covered by tests
else
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p))
end
else
autodiff = get_concrete_forward_ad(alg.autodiff, prob; check_forward_mode = true)
jvp_op = JacVecOperator(prob, fu, u; autodiff)
if isinplace(prob)
jvp_cache = __similar(fu)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p))
else
grad_op = @closure (u, fu, p) -> vjp_op(fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p))
end
end
end
Expand All @@ -143,33 +157,32 @@ function __internal_init(
return @fastmath internalnorm(fu_cache)^2 / 2
end

= @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
= @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
@bb @. u_cache = u + α * du
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
stats.nf += 1
g₀ = grad_op(u_cache, fu_cache, p)
return dot(g₀, du)
return deriv_op(du, u_cache, fu_cache, p)
end

ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
@bb @. u_cache = u + α * du
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
stats.nf += 1
g₀ = grad_op(u_cache, fu_cache, p)
deriv = deriv_op(du, u_cache, fu_cache, p)
obj = @fastmath internalnorm(fu_cache)^2 / 2
return obj, dot(g₀, du)
return obj, deriv
end

return LineSearchesJLCache(f, p, ϕ, dϕ, ϕdϕ, alg.method, T(alg.initial_alpha),
grad_op, u_cache, fu_cache, stats)
deriv_op, u_cache, fu_cache, stats)
end

function __internal_solve!(cache::LineSearchesJLCache, u, du; kwargs...)
ϕ = @closure α -> cache.ϕ(cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache)
= @closure α -> cache.(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
ϕdϕ = @closure α -> cache.ϕdϕ(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)

ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))

Expand Down
6 changes: 3 additions & 3 deletions test/core/nlls_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ const θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1

solvers = []
for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES(), KrylovJL_LSMR()]
vjp_autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoFiniteDiff()] :
autodiffs = linsolve isa KrylovJL ? [nothing, AutoZygote(), AutoForwardDiff()] :
[nothing]
for linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()],
vjp_autodiff in vjp_autodiffs
autodiff in autodiffs

push!(solvers, GaussNewton(; linsolve, linesearch, vjp_autodiff))
push!(solvers, GaussNewton(; linsolve, linesearch, autodiff))
end
end
append!(solvers,
Expand Down
8 changes: 4 additions & 4 deletions test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
@testitem "NewtonRaphson" setup=[CoreRootfindTesting] tags=[:core] timeout=3600 begin
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote())
ad in (AutoForwardDiff(), AutoZygote())

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down Expand Up @@ -466,7 +466,7 @@ end
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian) Update Rule: $(update_rule)" for lsmethod in (
Static(), StrongWolfe(), BackTracking(),
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
ad in (AutoFiniteDiff(), AutoZygote()),
ad in (AutoForwardDiff(), AutoZygote()),
init_jacobian in (Val(:identity), Val(:true_jacobian)),
update_rule in (Val(:good_broyden), Val(:bad_broyden), Val(:diagonal))

Expand Down Expand Up @@ -515,7 +515,7 @@ end
@testitem "Klement" setup=[CoreRootfindTesting] tags=[:core] skip=:(Sys.isapple()) timeout=3600 begin
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian)" for lsmethod in (
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote()),
ad in (AutoForwardDiff(), AutoZygote()),
init_jacobian in (Val(:identity), Val(:true_jacobian), Val(:true_jacobian_diagonal))

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
Expand Down Expand Up @@ -565,7 +565,7 @@ end
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
Static(), StrongWolfe(), BackTracking(),
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
ad in (AutoFiniteDiff(), AutoZygote())
ad in (AutoForwardDiff(), AutoZygote())

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down

0 comments on commit 82944d0

Please sign in to comment.