From d879ed548bfe935d2c7ed47727b5925edd372b63 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 16 Feb 2024 15:58:27 -0500 Subject: [PATCH] Add step! for polyalgorithms --- Project.toml | 2 +- src/core/generic.jl | 3 +- src/default.jl | 68 ++++++++++++++++++++++++++++++++------ src/utils.jl | 3 ++ test/misc/polyalg_tests.jl | 25 ++++++++++++++ 5 files changed, 88 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 59b33a84d..063a32081 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.5.6" +version = "3.5.7" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/core/generic.jl b/src/core/generic.jl index 849a259f1..70f7badbc 100644 --- a/src/core/generic.jl +++ b/src/core/generic.jl @@ -51,7 +51,8 @@ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args.. res = @static_timeit cache.timer "solve" begin __step!(cache, args...; kwargs...) end - cache.nsteps += 1 + + hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1) if timeit cache.total_time += time() - time_start diff --git a/src/default.jl b/src/default.jl index dafdf7902..0f9d920b1 100644 --- a/src/default.jl +++ b/src/default.jl @@ -44,27 +44,35 @@ function Base.show(io::IO, alg::NonlinearSolvePolyAlgorithm{pType, N}) where {pT end end -@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} <: - AbstractNonlinearSolveCache{iip, false} +@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit} <: + AbstractNonlinearSolveCache{iip, timeit} caches alg + best::Int current::Int + nsteps::Int + total_time::Float64 + maxtime + retcode::ReturnCode.T + force_stop::Bool end function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...) foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches) cache.current = 1 + cache.nsteps = 0 + cache.total_time = 0.0 end for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS)) algType = NonlinearSolvePolyAlgorithm{pType} @eval begin function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...; - kwargs...) where {N} - return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N}( - map(solver -> SciMLBase.__init(prob, - solver, args...; kwargs...), alg.algs), - alg, 1) + maxtime = nothing, kwargs...) where {N} + return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}( + map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...), + alg.algs), alg, -1, 1, 0, 0.0, maxtime, + ReturnCode.Default, false) end end end @@ -91,7 +99,7 @@ end u = $(sol_syms[i]).u fu = get_fu($(cache_syms[i])) return SciMLBase.build_solution($(sol_syms[i]).prob, cache.alg, u, - fu; retcode = ReturnCode.Success, stats, + fu; retcode = $(sol_syms[i]).retcode, stats, original = $(sol_syms[i]), trace = $(sol_syms[i]).trace) end cache.current = $(i + 1) @@ -105,12 +113,11 @@ end end push!(calls, quote - retcode = ReturnCode.MaxIters - fus = tuple($(Tuple(resids)...)) minfu, idx = __findmin(cache.caches[1].internalnorm, fus) stats = cache.caches[idx].stats - u = cache.caches[idx].u + u = get_u(cache.caches[idx]) + retcode = cache.caches[idx].retcode return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx]; retcode, stats, cache.caches[idx].trace) @@ -119,6 +126,45 @@ end return Expr(:block, calls...) end +@generated function __step!( + cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N} + calls = [] + cache_syms = [gensym("cache") for i in 1:N] + for i in 1:N + push!(calls, + quote + $(cache_syms[i]) = cache.caches[$(i)] + if $(i) == cache.current + __step!($(cache_syms[i]), args...; kwargs...) + if !not_terminated($(cache_syms[i])) + if SciMLBase.successful_retcode($(cache_syms[i]).retcode) + cache.best = $(i) + cache.force_stop = true + cache.retcode = $(cache_syms[i]).retcode + else + cache.current = $(i + 1) + end + end + return + end + end) + end + + push!(calls, + quote + if !(1 ≤ cache.current ≤ length(cache.caches)) + minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches) + cache.best = idx + cache.retcode = cache.caches[cache.best].retcode + cache.force_stop = true + return + end + end + ) + + return Expr(:block, calls...) +end + for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS)) algType = NonlinearSolvePolyAlgorithm{pType} @eval begin diff --git a/src/utils.jl b/src/utils.jl index 7f4c2c439..aec66bdd4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -94,6 +94,9 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x) @inline __is_complex(::Type{Complex}) = true @inline __is_complex(::Type{T}) where {T} = false +function __findmin_caches(f, caches) + return __findmin(f ∘ get_fu, caches) +end function __findmin(f, x) return findmin(x) do xᵢ fx = f(xᵢ) diff --git a/test/misc/polyalg_tests.jl b/test/misc/polyalg_tests.jl index 10df60429..1a80e3d38 100644 --- a/test/misc/polyalg_tests.jl +++ b/test/misc/polyalg_tests.jl @@ -28,6 +28,31 @@ cache = init(probN, custom_polyalg; abstol = 1e-9) solver = solve!(cache) @test SciMLBase.successful_retcode(solver) + + # Test the step interface + cache = init(probN; abstol = 1e-9) + for i in 1:10000 + step!(cache) + cache.force_stop && break + end + @test SciMLBase.successful_retcode(cache.retcode) + cache = init(probN, RobustMultiNewton(); abstol = 1e-9) + for i in 1:10000 + step!(cache) + cache.force_stop && break + end + @test SciMLBase.successful_retcode(cache.retcode) + cache = init(probN, FastShortcutNonlinearPolyalg(); abstol = 1e-9) + for i in 1:10000 + step!(cache) + cache.force_stop && break + end + @test SciMLBase.successful_retcode(cache.retcode) + cache = init(probN, custom_polyalg; abstol = 1e-9) + for i in 1:10000 + step!(cache) + cache.force_stop && break + end end @testitem "Testing #153 Singular Exception" begin