Skip to content

Commit

Permalink
Short circuit linesearch for now
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 28, 2023
1 parent 411a649 commit 756ad8c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
26 changes: 18 additions & 8 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ differentiation for fast Vector Jacobian Products.
### Arguments
- `method`: the line search algorithm to use. Defaults to `Static()`, which means that the
- `method`: the line search algorithm to use. Defaults to `nothing`, which means that the
step size is fixed to the value of `alpha`.
- `autodiff`: the automatic differentiation backend to use for the line search. Defaults to
`AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP.
Expand All @@ -22,19 +22,31 @@ differentiation for fast Vector Jacobian Products.
α
end

function LineSearch(; method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
function LineSearch(; method = nothing, autodiff = AutoFiniteDiff(), alpha = true)
return LineSearch(method, autodiff, alpha)
end

@inline function init_linesearch_cache(ls::LineSearch, args...)
return init_linesearch_cache(ls.method, ls, args...)
end

@concrete struct NoLineSearchCache
α
end

function init_linesearch_cache(::Nothing, ls, f::F, u, p, fu, iip) where {F}
return NoLineSearchCache(convert(eltype(u), ls.α))
end

perform_linesearch!(cache::NoLineSearchCache, u, du) = cache.α

# LineSearches.jl doesn't have a supertype so default to that
function init_linesearch_cache(_, ls, f::F, u, p, fu, iip) where {F <: Function}
function init_linesearch_cache(_, ls, f::F, u, p, fu, iip) where {F}
return LineSearchesJLCache(ls, f, u, p, fu, iip)

Check warning on line 45 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
end

# FIXME: The closures lead to too many unnecessary runtime dispatches which leads to the
# massive increase in precompilation times.
# Wrapper over LineSearches.jl algorithms
@concrete mutable struct LineSearchesJLCache
f
Expand All @@ -45,8 +57,7 @@ end
ls
end

function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _,
::Val{false}) where {F <: Function}
function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _, ::Val{false}) where {F}

Check warning on line 60 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L60

Added line #L60 was not covered by tests
eval_f(u, du, α) = eval_f(u - α * du)
eval_f(u) = f(u, p)

Expand Down Expand Up @@ -87,8 +98,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u::Number, p, _,
return LineSearchesJLCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
end

function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1,
IIP::Val{iip}) where {iip, F <: Function}
function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) where {iip, F}

Check warning on line 101 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L101

Added line #L101 was not covered by tests
fu = iip ? deepcopy(fu1) : nothing
u_ = _mutable_zero(u)

Expand Down Expand Up @@ -202,7 +212,7 @@ end
end

function init_linesearch_cache(alg::LiFukushimaLineSearch, ls::LineSearch, f::F, _u, p, _fu,

Check warning on line 214 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L214

Added line #L214 was not covered by tests
::Val{iip}) where {iip, F <: Function}
::Val{iip}) where {iip, F}
fu = iip ? deepcopy(_fu) : nothing
u = iip ? deepcopy(_u) : nothing
return LiFukushimaLineSearchCache{iip}(f, p, u, fu, alg, ls.α)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ function __get_concrete_algorithm(alg, prob)
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
else
tag = NonlinearSolveTag()
use_sparse_ad ? AutoSparseForwardDiff(; tag) : AutoForwardDiff(; tag)
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(; tag)
end
return set_ad(alg, ad)
end
Expand Down

0 comments on commit 756ad8c

Please sign in to comment.