Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed May 30, 2024
1 parent 575f4ce commit 16ff129
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
30 changes: 19 additions & 11 deletions src/descent/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ end

@internal_caches HalleyDescentCache :lincache

function __internal_init(
prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; shared::Val{N} = Val(1),
pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing,
reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N}
function __internal_init(prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats,
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False,
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
timer = get_timer_output(), kwargs...) where {INV, N}
@bb δu = similar(u)
@bb b = similar(u)
@bb fu = similar(fu)
Expand All @@ -48,23 +48,27 @@ function __internal_init(
end
INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer)
lincache = LinearSolverCache(
alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...)
alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...)
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer)
end

function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV}
δu = get_du(cache, idx)
skip_solve && return δu, true, (;)
skip_solve && return DescentResult(; δu)
if INV
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`."
@bb δu = J × vec(fu)

Check warning on line 61 in src/descent/halley.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/halley.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
else
@static_timeit cache.timer "linear solve 1" begin

Check warning on line 63 in src/descent/halley.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/halley.jl#L63

Added line #L63 was not covered by tests
δu = cache.lincache(;
linres = cache.lincache(;
A = J, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
δu = _restructure(get_du(cache, idx), δu)
δu = _restructure(get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)

Check warning on line 70 in src/descent/halley.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/halley.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
end
end
end
b = cache.b
Expand All @@ -75,15 +79,19 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
@bb b = J × vec(hvvp)

Check warning on line 79 in src/descent/halley.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/halley.jl#L79

Added line #L79 was not covered by tests
else
@static_timeit cache.timer "linear solve 2" begin

Check warning on line 81 in src/descent/halley.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/halley.jl#L81

Added line #L81 was not covered by tests
b = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b),
linres = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b),
du = _vec(b), reuse_A_if_factorization = true)
b = _restructure(cache.b, b)
b = _restructure(cache.b, linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)

Check warning on line 87 in src/descent/halley.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/halley.jl#L86-L87

Added lines #L86 - L87 were not covered by tests
end
end
end
@bb @. δu = δu * δu / (b / 2 - δu)
set_du!(cache, δu, idx)
cache.b = b
return δu, true, (;)
return DescentResult(; δu)
end

function evaluate_hvvp(
Expand Down
2 changes: 1 addition & 1 deletion test/core/23_test_problems_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
test_on_library(problems, dicts, alg_ops, broken_tests)
end

@testitem "Halley" setup=[RobustnessTesting] begin
@testitem "Halley" setup=[RobustnessTesting] tags=[:core] begin
alg_ops = (Halley(),)

broken_tests = Dict(alg => Int[] for alg in alg_ops)
Expand Down

0 comments on commit 16ff129

Please sign in to comment.