Skip to content

Commit

Permalink
Add step! for polyalgorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 16, 2024
1 parent c50c21a commit d879ed5
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.5.6"
version = "3.5.7"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/core/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args..
res = @static_timeit cache.timer "solve" begin
__step!(cache, args...; kwargs...)
end
cache.nsteps += 1

hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)

if timeit
cache.total_time += time() - time_start
Expand Down
68 changes: 57 additions & 11 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,35 @@ function Base.show(io::IO, alg::NonlinearSolvePolyAlgorithm{pType, N}) where {pT
end
end

@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} <:
AbstractNonlinearSolveCache{iip, false}
@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit} <:
AbstractNonlinearSolveCache{iip, timeit}
caches
alg
best::Int
current::Int
nsteps::Int
total_time::Float64
maxtime
retcode::ReturnCode.T
force_stop::Bool
end

function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches)
cache.current = 1
cache.nsteps = 0
cache.total_time = 0.0

Check warning on line 64 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
end

for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...;
kwargs...) where {N}
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N}(
map(solver -> SciMLBase.__init(prob,
solver, args...; kwargs...), alg.algs),
alg, 1)
maxtime = nothing, kwargs...) where {N}
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
alg.algs), alg, -1, 1, 0, 0.0, maxtime,
ReturnCode.Default, false)
end
end
end
Expand All @@ -91,7 +99,7 @@ end
u = $(sol_syms[i]).u
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u,
fu; retcode = ReturnCode.Success, stats,
fu; retcode = $(sol_syms[i]).retcode, stats,
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
end
cache.current = $(i + 1)
Expand All @@ -105,12 +113,11 @@ end
end
push!(calls,
quote
retcode = ReturnCode.MaxIters

fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.caches[1].internalnorm, fus)
stats = cache.caches[idx].stats
u = cache.caches[idx].u
u = get_u(cache.caches[idx])
retcode = cache.caches[idx].retcode

Check warning on line 120 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L119-L120

Added lines #L119 - L120 were not covered by tests

return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u,
fus[idx]; retcode, stats, cache.caches[idx].trace)
Expand All @@ -119,6 +126,45 @@ end
return Expr(:block, calls...)
end

@generated function __step!(
cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N}
calls = []
cache_syms = [gensym("cache") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
__step!($(cache_syms[i]), args...; kwargs...)
if !not_terminated($(cache_syms[i]))
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
cache.best = $(i)
cache.force_stop = true
cache.retcode = $(cache_syms[i]).retcode
else
cache.current = $(i + 1)

Check warning on line 145 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L145

Added line #L145 was not covered by tests
end
end
return
end
end)
end

push!(calls,
quote
if !(1 cache.current length(cache.caches))
minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches)
cache.best = idx
cache.retcode = cache.caches[cache.best].retcode
cache.force_stop = true
return

Check warning on line 160 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L155-L160

Added lines #L155 - L160 were not covered by tests
end
end
)

return Expr(:block, calls...)
end

for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
@inline __is_complex(::Type{Complex}) = true
@inline __is_complex(::Type{T}) where {T} = false

function __findmin_caches(f, caches)
return __findmin(f get_fu, caches)

Check warning on line 98 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L97-L98

Added lines #L97 - L98 were not covered by tests
end
function __findmin(f, x)
return findmin(x) do xᵢ
fx = f(xᵢ)
Expand Down
25 changes: 25 additions & 0 deletions test/misc/polyalg_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,31 @@
cache = init(probN, custom_polyalg; abstol = 1e-9)
solver = solve!(cache)
@test SciMLBase.successful_retcode(solver)

# Test the step interface
cache = init(probN; abstol = 1e-9)
for i in 1:10000
step!(cache)
cache.force_stop && break
end
@test SciMLBase.successful_retcode(cache.retcode)
cache = init(probN, RobustMultiNewton(); abstol = 1e-9)
for i in 1:10000
step!(cache)
cache.force_stop && break
end
@test SciMLBase.successful_retcode(cache.retcode)
cache = init(probN, FastShortcutNonlinearPolyalg(); abstol = 1e-9)
for i in 1:10000
step!(cache)
cache.force_stop && break
end
@test SciMLBase.successful_retcode(cache.retcode)
cache = init(probN, custom_polyalg; abstol = 1e-9)
for i in 1:10000
step!(cache)
cache.force_stop && break
end
end

@testitem "Testing #153 Singular Exception" begin
Expand Down

0 comments on commit d879ed5

Please sign in to comment.