From 756ad8cd8946f7de5069d036dbbc8f8cb16c9431 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 28 Oct 2023 15:22:13 -0400 Subject: [PATCH] Short circuit linesearch for now --- src/linesearch.jl | 26 ++++++++++++++++++-------- src/utils.jl | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/linesearch.jl b/src/linesearch.jl index 608bc6cff..0d37dc222 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -8,7 +8,7 @@ differentiation for fast Vector Jacobian Products. ### Arguments - - `method`: the line search algorithm to use. Defaults to `Static()`, which means that the + - `method`: the line search algorithm to use. Defaults to `nothing`, which means that the step size is fixed to the value of `alpha`. - `autodiff`: the automatic differentiation backend to use for the line search. Defaults to `AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP. @@ -22,7 +22,7 @@ differentiation for fast Vector Jacobian Products. α end -function LineSearch(; method = Static(), autodiff = AutoFiniteDiff(), alpha = true) +function LineSearch(; method = nothing, autodiff = AutoFiniteDiff(), alpha = true) return LineSearch(method, autodiff, alpha) end @@ -30,11 +30,23 @@ end return init_linesearch_cache(ls.method, ls, args...) end +@concrete struct NoLineSearchCache + α +end + +function init_linesearch_cache(::Nothing, ls, f::F, u, p, fu, iip) where {F} + return NoLineSearchCache(convert(eltype(u), ls.α)) +end + +perform_linesearch!(cache::NoLineSearchCache, u, du) = cache.α + # LineSearches.jl doesn't have a supertype so default to that -function init_linesearch_cache(_, ls, f::F, u, p, fu, iip) where {F <: Function} +function init_linesearch_cache(_, ls, f::F, u, p, fu, iip) where {F} return LineSearchesJLCache(ls, f, u, p, fu, iip) end +# FIXME: The closures lead to too many unnecessary runtime dispatches which leads to the +# massive increase in precompilation times. # Wrapper over LineSearches.jl algorithms @concrete mutable struct LineSearchesJLCache f @@ -45,8 +57,7 @@ end ls end -function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _, - ::Val{false}) where {F <: Function} +function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _, ::Val{false}) where {F} eval_f(u, du, α) = eval_f(u - α * du) eval_f(u) = f(u, p) @@ -87,8 +98,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _, return LineSearchesJLCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls) end -function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, - IIP::Val{iip}) where {iip, F <: Function} +function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) where {iip, F} fu = iip ? deepcopy(fu1) : nothing u_ = _mutable_zero(u) @@ -202,7 +212,7 @@ end end function init_linesearch_cache(alg::LiFukushimaLineSearch, ls::LineSearch, f::F, _u, p, _fu, - ::Val{iip}) where {iip, F <: Function} + ::Val{iip}) where {iip, F} fu = iip ? deepcopy(_fu) : nothing u = iip ? deepcopy(_u) : nothing return LiFukushimaLineSearchCache{iip}(f, p, u, fu, alg, ls.α) diff --git a/src/utils.jl b/src/utils.jl index 9349c2505..424163604 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -216,7 +216,7 @@ function __get_concrete_algorithm(alg, prob) use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff() else tag = NonlinearSolveTag() - use_sparse_ad ? AutoSparseForwardDiff(; tag) : AutoForwardDiff(; tag) + (use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(; tag) end return set_ad(alg, ad) end