Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forward mode to line search #446

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions src/globalization/line_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
ϕdϕ
method
alpha
grad_op
deriv_op
u_cache
fu_cache
stats::NLStats
Expand All @@ -110,25 +110,59 @@ function __internal_init(
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
Detected $(autodiff). Falling back to AutoFiniteDiff."
end
grad_op = @closure (u, fu, p) -> last(__value_derivative(
autodiff, Base.Fix2(f, p), u)) * fu
deriv_op = @closure (du, u, fu, p) -> last(__value_derivative(
autodiff, Base.Fix2(f, p), u)) *
fu *
du
else
if SciMLBase.has_jvp(f)
# Both forward and reverse AD can be used for line-search.
# We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it.
# 1. If jvp is available, we use forward AD;
# 2. If vjp is available, we use reverse AD;
# 3. If reverse type is requested, we use reverse AD;
# 4. Finally, we use forward AD.
if alg.autodiff isa AutoFiniteDiff
deriv_op = nothing
elseif SciMLBase.has_jvp(f)
if isinplace(prob)
g_cache = __similar(u)
grad_op = @closure (u, fu, p) -> f.vjp(g_cache, fu, u, p)
jvp_cache = __similar(fu)
deriv_op = @closure (du, u, fu, p) -> begin
f.jvp(jvp_cache, du, u, p)
dot(fu, jvp_cache)
end
else
grad_op = @closure (u, fu, p) -> f.vjp(fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(du, u, p))
end
else
elseif SciMLBase.has_vjp(f)
if isinplace(prob)
vjp_cache = __similar(u)
deriv_op = @closure (du, u, fu, p) -> begin
f.vjp(vjp_cache, fu, u, p)
dot(du, vjp_cache)
end
else
deriv_op = @closure (du, u, fu, p) -> dot(du, f.vjp(fu, u, p))
end
elseif alg.autodiff !== nothing &&
ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode
autodiff = get_concrete_reverse_ad(
alg.autodiff, prob; check_reverse_mode = true)
vjp_op = VecJacOperator(prob, fu, u; autodiff)
if isinplace(prob)
g_cache = __similar(u)
grad_op = @closure (u, fu, p) -> vjp_op(g_cache, fu, u, p)
vjp_cache = __similar(u)
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(vjp_cache, fu, u, p))
else
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p))
end
else
autodiff = get_concrete_forward_ad(
alg.autodiff, prob; check_forward_mode = true)
jvp_op = JacVecOperator(prob, fu, u; autodiff)
if isinplace(prob)
jvp_cache = __similar(fu)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p))
else
grad_op = @closure (u, fu, p) -> vjp_op(fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p))
end
end
end
Expand All @@ -143,33 +177,37 @@ function __internal_init(
return @fastmath internalnorm(fu_cache)^2 / 2
end

= @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
= @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
@bb @. u_cache = u + α * du
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
stats.nf += 1
g₀ = grad_op(u_cache, fu_cache, p)
return dot(g₀, du)
return deriv_op(du, u_cache, fu_cache, p)
end

ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
@bb @. u_cache = u + α * du
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
stats.nf += 1
g₀ = grad_op(u_cache, fu_cache, p)
deriv = deriv_op(du, u_cache, fu_cache, p)
obj = @fastmath internalnorm(fu_cache)^2 / 2
return obj, dot(g₀, du)
return obj, deriv
end

return LineSearchesJLCache(f, p, ϕ, dϕ, ϕdϕ, alg.method, T(alg.initial_alpha),
grad_op, u_cache, fu_cache, stats)
deriv_op, u_cache, fu_cache, stats)
end

function __internal_solve!(cache::LineSearchesJLCache, u, du; kwargs...)
ϕ = @closure α -> cache.ϕ(cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache)
= @closure α -> cache.(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
ϕdϕ = @closure α -> cache.ϕdϕ(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
if cache.deriv_op !== nothing
= @closure α -> cache.(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
ϕdϕ = @closure α -> cache.ϕdϕ(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
else
= @closure α -> FiniteDiff.finite_difference_derivative(ϕ, α)
ϕdϕ = @closure α -> (ϕ(α), FiniteDiff.finite_difference_derivative(ϕ, α))
end

ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))

Expand Down
8 changes: 4 additions & 4 deletions test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
@testitem "NewtonRaphson" setup=[CoreRootfindTesting] tags=[:core] timeout=3600 begin
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote())
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff())

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down Expand Up @@ -466,7 +466,7 @@ end
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian) Update Rule: $(update_rule)" for lsmethod in (
Static(), StrongWolfe(), BackTracking(),
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
ad in (AutoFiniteDiff(), AutoZygote()),
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()),
init_jacobian in (Val(:identity), Val(:true_jacobian)),
update_rule in (Val(:good_broyden), Val(:bad_broyden), Val(:diagonal))

Expand Down Expand Up @@ -515,7 +515,7 @@ end
@testitem "Klement" setup=[CoreRootfindTesting] tags=[:core] skip=:(Sys.isapple()) timeout=3600 begin
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian)" for lsmethod in (
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote()),
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()),
init_jacobian in (Val(:identity), Val(:true_jacobian), Val(:true_jacobian_diagonal))

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
Expand Down Expand Up @@ -565,7 +565,7 @@ end
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
Static(), StrongWolfe(), BackTracking(),
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
ad in (AutoFiniteDiff(), AutoZygote())
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff())

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down
Loading