diff --git a/Project.toml b/Project.toml index 20dd03f68..09059a431 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.7.0" +version = "3.7.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/abstract_types.jl b/src/abstract_types.jl index a08f53317..86dcb963c 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -173,8 +173,8 @@ not applicable. Else a boolean value is returned. """ concrete_jac(::AbstractNonlinearSolveAlgorithm) = nothing -function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm{name}) where {name} - __show_algorithm(io, alg, name, 0) +function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm) + __show_algorithm(io, alg, get_name(alg), 0) end get_name(::AbstractNonlinearSolveAlgorithm{name}) where {name} = name @@ -207,6 +207,23 @@ Abstract Type for all NonlinearSolve.jl Caches. """ abstract type AbstractNonlinearSolveCache{iip, timeit} end +function Base.show(io::IO, cache::AbstractNonlinearSolveCache) + __show_cache(io, cache, 0) +end + +function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0) + println(io, "$(nameof(typeof(cache)))(") + __show_algorithm(io, cache.alg, + (" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4) + println(io, ",") + println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",") + println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",") + println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",") + println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",") + println(io, " "^(indent + 4) * "retcode = ", cache.retcode) + print(io, " "^(indent) * ")") +end + SciMLBase.isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip get_fu(cache::AbstractNonlinearSolveCache) = cache.fu diff --git a/src/adtypes.jl b/src/adtypes.jl index 45507ecdf..9ed3107c5 100644 --- a/src/adtypes.jl +++ b/src/adtypes.jl @@ -83,13 +83,13 @@ AutoSparseForwardDiff Uses [`PolyesterForwardDiff.jl`](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) to compute the jacobian. This is essentially parallelized `ForwardDiff.jl`. - - Supports both inplace and out-of-place functions + - Supports both inplace and out-of-place functions ### Keyword Arguments - - `chunksize`: Count of dual numbers that can be propagated simultaneously. Setting - this number to a high value will lead to slowdowns. Use - [`NonlinearSolve.pickchunksize`](@ref) to get a proper value. + - `chunksize`: Count of dual numbers that can be propagated simultaneously. Setting + this number to a high value will lead to slowdowns. Use + [`NonlinearSolve.pickchunksize`](@ref) to get a proper value. """ AutoPolyesterForwardDiff diff --git a/src/core/generic.jl b/src/core/generic.jl index 22fc3e9d0..9aafba49b 100644 --- a/src/core/generic.jl +++ b/src/core/generic.jl @@ -47,11 +47,13 @@ Performs one step of the nonlinear solver. """ function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit}, args...; kwargs...) where {iip, timeit} + not_terminated(cache) || return timeit && (time_start = time()) res = @static_timeit cache.timer "solve" begin __step!(cache, args...; kwargs...) end - cache.nsteps += 1 + + hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1) if timeit cache.total_time += time() - time_start diff --git a/src/default.jl b/src/default.jl index 8b17a9393..da2420656 100644 --- a/src/default.jl +++ b/src/default.jl @@ -44,26 +44,56 @@ function Base.show(io::IO, alg::NonlinearSolvePolyAlgorithm{pType, N}) where {pT end end -@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} <: - AbstractNonlinearSolveCache{iip, false} +@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit} <: + AbstractNonlinearSolveCache{iip, timeit} caches alg + best::Int current::Int + nsteps::Int + total_time::Float64 + maxtime + retcode::ReturnCode.T + force_stop::Bool + maxiters::Int +end + +function Base.show( + io::IO, cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N} + problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem") + println(io, "NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms") + best_alg = ifelse(cache.best == -1, "nothing", cache.best) + println(io, "Best algorithm: $(best_alg)") + println(io, "Current algorithm: $(cache.current)") + println(io, "nsteps: $(cache.nsteps)") + println(io, "retcode: $(cache.retcode)") + __show_cache(io, cache.caches[cache.current], 0) end function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...) foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches) cache.current = 1 + cache.nsteps = 0 + cache.total_time = 0.0 end for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS)) algType = NonlinearSolvePolyAlgorithm{pType} @eval begin - function SciMLBase.__init( - prob::$probType, alg::$algType{N}, args...; kwargs...) where {N} - return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N}( - map(solver -> SciMLBase.__init(prob, solver, args...; kwargs...), alg.algs), - alg, 1) + function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...; + maxtime = nothing, maxiters = 1000, kwargs...) where {N} + return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}( + map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...), + alg.algs), + alg, + -1, + 1, + 0, + 0.0, + maxtime, + ReturnCode.Default, + false, + maxiters) end end end @@ -89,7 +119,7 @@ end fu = get_fu($(cache_syms[i])) return SciMLBase.build_solution( $(sol_syms[i]).prob, cache.alg, u, fu; - retcode = ReturnCode.Success, stats, + retcode = $(sol_syms[i]).retcode, stats, original = $(sol_syms[i]), trace = $(sol_syms[i]).trace) end cache.current = $(i + 1) @@ -103,12 +133,11 @@ end end push!(calls, quote - retcode = ReturnCode.MaxIters - fus = tuple($(Tuple(resids)...)) minfu, idx = __findmin(cache.caches[1].internalnorm, fus) stats = cache.caches[idx].stats - u = cache.caches[idx].u + u = get_u(cache.caches[idx]) + retcode = cache.caches[idx].retcode return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx]; retcode, stats, cache.caches[idx].trace) @@ -117,6 +146,45 @@ end return Expr(:block, calls...) end +@generated function __step!( + cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N} + calls = [] + cache_syms = [gensym("cache") for i in 1:N] + for i in 1:N + push!(calls, + quote + $(cache_syms[i]) = cache.caches[$(i)] + if $(i) == cache.current + __step!($(cache_syms[i]), args...; kwargs...) + $(cache_syms[i]).nsteps += 1 + if !not_terminated($(cache_syms[i])) + if SciMLBase.successful_retcode($(cache_syms[i]).retcode) + cache.best = $(i) + cache.force_stop = true + cache.retcode = $(cache_syms[i]).retcode + else + cache.current = $(i + 1) + end + end + return + end + end) + end + + push!(calls, + quote + if !(1 ≤ cache.current ≤ length(cache.caches)) + minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches) + cache.best = idx + cache.retcode = cache.caches[cache.best].retcode + cache.force_stop = true + return + end + end) + + return Expr(:block, calls...) +end + for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS)) algType = NonlinearSolvePolyAlgorithm{pType} @eval begin diff --git a/src/utils.jl b/src/utils.jl index 7f4c2c439..aec66bdd4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -94,6 +94,9 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x) @inline __is_complex(::Type{Complex}) = true @inline __is_complex(::Type{T}) where {T} = false +function __findmin_caches(f, caches) + return __findmin(f ∘ get_fu, caches) +end function __findmin(f, x) return findmin(x) do xᵢ fx = f(xᵢ) diff --git a/test/misc/polyalg_tests.jl b/test/misc/polyalg_tests.jl index d9433d494..6c1f17639 100644 --- a/test/misc/polyalg_tests.jl +++ b/test/misc/polyalg_tests.jl @@ -28,6 +28,31 @@ cache = init(probN, custom_polyalg; abstol = 1e-9) solver = solve!(cache) @test SciMLBase.successful_retcode(solver) + + # Test the step interface + cache = init(probN; abstol = 1e-9) + for i in 1:10000 + step!(cache) + cache.force_stop && break + end + @test SciMLBase.successful_retcode(cache.retcode) + cache = init(probN, RobustMultiNewton(); abstol = 1e-9) + for i in 1:10000 + step!(cache) + cache.force_stop && break + end + @test SciMLBase.successful_retcode(cache.retcode) + cache = init(probN, FastShortcutNonlinearPolyalg(); abstol = 1e-9) + for i in 1:10000 + step!(cache) + cache.force_stop && break + end + @test SciMLBase.successful_retcode(cache.retcode) + cache = init(probN, custom_polyalg; abstol = 1e-9) + for i in 1:10000 + step!(cache) + cache.force_stop && break + end end @testitem "Testing #153 Singular Exception" begin diff --git a/testing.jl b/testing.jl new file mode 100644 index 000000000..9a3110f40 --- /dev/null +++ b/testing.jl @@ -0,0 +1,14 @@ +using NonlinearSolve + +f(u, p) = u .* u .- 2 + +u0 = [1.0, 1.0] + +prob = NonlinearProblem(f, u0) + +nlcache = init(prob); + +for i in 1:10 + step!(nlcache) + @show nlcache.retcode +end