Skip to content

Commit

Permalink
Integrate the linear solve better into the algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 25, 2024
1 parent dd4a2fb commit 7450590
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 18 deletions.
21 changes: 21 additions & 0 deletions src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,27 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
cache.descent_cache, J, cache.fu, cache.u; new_jacobian, cache.kwargs...)
end
end

if !descent_result.linsolve_success
if new_jacobian && cache.steps_since_last_reset == 0

Check warning on line 296 in src/core/approximate_jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/core/approximate_jacobian.jl#L296

Added line #L296 was not covered by tests
# Extremely pathological case. Jacobian was just reset and linear solve
# failed. Should ideally never happen in practice unless true jacobian init
# is used.
cache.retcode = LinearSolveFailureCode
cache.force_stop = true
return

Check warning on line 302 in src/core/approximate_jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/core/approximate_jacobian.jl#L300-L302

Added lines #L300 - L302 were not covered by tests
else
# Force a reinit because the problem is currently un-solvable
if !haskey(cache.kwargs, :verbose) || cache.kwargs[:verbose]
@warn "Linear Solve Failed but Jacobian Information is not current. \

Check warning on line 306 in src/core/approximate_jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/core/approximate_jacobian.jl#L305-L306

Added lines #L305 - L306 were not covered by tests
Retrying with reinitialized Approximate Jacobian."
end
cache.force_reinit = true
__step!(cache; recompute_jacobian = true)
return

Check warning on line 311 in src/core/approximate_jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/core/approximate_jacobian.jl#L309-L311

Added lines #L309 - L311 were not covered by tests
end
end

δu, descent_intermediates = descent_result.δu, descent_result.extras

if descent_result.success
Expand Down
21 changes: 21 additions & 0 deletions src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,27 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
cache.descent_cache, J, cache.fu, cache.u; new_jacobian, cache.kwargs...)
end
end

if !descent_result.linsolve_success
if new_jacobian

Check warning on line 235 in src/core/generalized_first_order.jl

View check run for this annotation

Codecov / codecov/patch

src/core/generalized_first_order.jl#L235

Added line #L235 was not covered by tests
# Jacobian Information is current and linear solve failed terminate the solve
cache.retcode = LinearSolveFailureCode
cache.force_stop = true
return

Check warning on line 239 in src/core/generalized_first_order.jl

View check run for this annotation

Codecov / codecov/patch

src/core/generalized_first_order.jl#L237-L239

Added lines #L237 - L239 were not covered by tests
else
# Jacobian Information is not current and linear solve failed, recompute
# Jacobian
if !haskey(cache.kwargs, :verbose) || cache.kwargs[:verbose]
@warn "Linear Solve Failed but Jacobian Information is not current. \

Check warning on line 244 in src/core/generalized_first_order.jl

View check run for this annotation

Codecov / codecov/patch

src/core/generalized_first_order.jl#L243-L244

Added lines #L243 - L244 were not covered by tests
Retrying with updated Jacobian."
end
# In the 2nd call the `new_jacobian` is guaranteed to be `true`.
cache.make_new_jacobian = true
__step!(cache; recompute_jacobian = true, kwargs...)
return

Check warning on line 250 in src/core/generalized_first_order.jl

View check run for this annotation

Codecov / codecov/patch

src/core/generalized_first_order.jl#L248-L250

Added lines #L248 - L250 were not covered by tests
end
end

δu, descent_intermediates = descent_result.δu, descent_result.extras

if descent_result.success
Expand Down
10 changes: 7 additions & 3 deletions src/descent/common.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
DescentResult(; δu = missing, u = missing, success::Bool = true,
linsolve_success::Bool = true, extras = (;))
Construct a `DescentResult` object.
Expand All @@ -9,6 +10,7 @@ Construct a `DescentResult` object.
- `u`: The new iterate. This is provided only for multi-step methods currently.
- `success`: Certain Descent Algorithms can reject a descent direction for example
[`GeodesicAcceleration`](@ref).
- `linsolve_success`: Whether the line search was successful.
- `extras`: A named tuple containing intermediates computed during the solve.
For example, [`GeodesicAcceleration`](@ref) returns `NamedTuple{(:v, :a)}` containing
the "velocity" and "acceleration" terms.
Expand All @@ -17,10 +19,12 @@ Construct a `DescentResult` object.
δu
u
success::Bool
linsolve_success::Bool
extras
end

function DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
function DescentResult(; δu = missing, u = missing, success::Bool = true,
linsolve_success::Bool = true, extras = (;))
@assert δu !== missing || u !== missing
return DescentResult(δu, u, success, extras)
return DescentResult(δu, u, success, linsolve_success, extras)
end
8 changes: 6 additions & 2 deletions src/descent/damped_newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,14 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu,
end

@static_timeit cache.timer "linear solve" begin
δu = cache.lincache(;
linres = cache.lincache(;
A, b, reuse_A_if_factorization = !new_jacobian && !recompute_A,
kwargs..., linu = _vec(δu))
δu = _restructure(get_du(cache, idx), δu)
δu = _restructure(get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)

Check warning on line 208 in src/descent/damped_newton.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/damped_newton.jl#L208

Added line #L208 was not covered by tests
return DescentResult(; δu, success = false, linsolve_success = false)
end
end

@bb @. δu *= -1
Expand Down
17 changes: 13 additions & 4 deletions src/descent/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,14 @@ function __internal_solve!(
@bb δu = J × vec(fu)
else
@static_timeit cache.timer "linear solve" begin
δu = cache.lincache(;
linres = cache.lincache(;
A = J, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
δu = _restructure(get_du(cache, idx), δu)
δu = _restructure(get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)

Check warning on line 89 in src/descent/newton.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/newton.jl#L89

Added line #L89 was not covered by tests
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
@bb @. δu *= -1
Expand All @@ -102,10 +106,15 @@ function __internal_solve!(
end
@bb cache.Jᵀfu_cache = transpose(J) × vec(fu)
@static_timeit cache.timer "linear solve" begin
δu = cache.lincache(; A = __maybe_symmetric(cache.JᵀJ_cache), b = cache.Jᵀfu_cache,
linres = cache.lincache(;
A = __maybe_symmetric(cache.JᵀJ_cache), b = cache.Jᵀfu_cache,
kwargs..., linu = _vec(δu), du = _vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
δu = _restructure(get_du(cache, idx), δu)
δu = _restructure(get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)

Check warning on line 115 in src/descent/newton.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/newton.jl#L115

Added line #L115 was not covered by tests
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
@bb @. δu *= -1
set_du!(cache, δu, idx)
Expand Down
8 changes: 6 additions & 2 deletions src/descent/steepest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ function __internal_solve!(cache::SteepestDescentCache{INV}, J, fu, u, idx::Val
if INV
A = J === nothing ? nothing : transpose(J)
@static_timeit cache.timer "linear solve" begin
δu = cache.lincache(;
linres = cache.lincache(;

Check warning on line 57 in src/descent/steepest.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/steepest.jl#L57

Added line #L57 was not covered by tests
A, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu),
reuse_A_if_factorization = !new_jacobian || idx !== Val(1))
δu = _restructure(get_du(cache, idx), δu)
δu = _restructure(get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)

Check warning on line 63 in src/descent/steepest.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/steepest.jl#L60-L63

Added lines #L60 - L63 were not covered by tests
end
end
else
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(false)`."
Expand Down
18 changes: 11 additions & 7 deletions src/internal/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ function LinearSolverCache(alg, linsolve, A, b, u; kwargs...)
return LinearSolverCache(lincache, linsolve, nothing, nothing, nothing, precs, 0, 0)
end

@kwdef @concrete struct LinearSolveResult
u
success::Bool = true
end

# Direct Linear Solve Case without Caching
function (cache::LinearSolverCache{Nothing})(;
A = nothing, b = nothing, linu = nothing, kwargs...)
Expand All @@ -119,7 +124,7 @@ function (cache::LinearSolverCache{Nothing})(;
else
res = cache.A \ cache.b
end
return res
return LinearSolveResult(; u = res)
end

# Use LinearSolve.jl
Expand Down Expand Up @@ -154,11 +159,7 @@ function (cache::LinearSolverCache)(;
cache.lincache.Pr = Pr
end

# display(A)

linres = solve!(cache.lincache)
# @show cache.lincache.cacheval
# @show LinearAlgebra.issuccess(cache.lincache.cacheval)
cache.lincache = linres.cache
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling
if linres.retcode === ReturnCode.Failure
Expand All @@ -185,11 +186,14 @@ function (cache::LinearSolverCache)(;
end
linres = solve!(cache.additional_lincache)
cache.additional_lincache = linres.cache
return linres.u
linres.retcode === ReturnCode.Failure &&

Check warning on line 189 in src/internal/linear_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/linear_solve.jl#L187-L189

Added lines #L187 - L189 were not covered by tests
return LinearSolveResult(; u = linres.u, success = false)
return LinearSolveResult(; u = linres.u)

Check warning on line 191 in src/internal/linear_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/linear_solve.jl#L191

Added line #L191 was not covered by tests
end
return LinearSolveResult(; u = linres.u, success = false)

Check warning on line 193 in src/internal/linear_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/linear_solve.jl#L193

Added line #L193 was not covered by tests
end

return linres.u
return LinearSolveResult(; u = linres.u)
end

@inline __update_A!(cache::LinearSolverCache, ::Nothing, reuse) = cache
Expand Down

0 comments on commit 7450590

Please sign in to comment.