Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step! for polyalgorithms #378

Merged
merged 4 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.7.0"
version = "3.7.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
21 changes: 19 additions & 2 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ not applicable. Else a boolean value is returned.
"""
concrete_jac(::AbstractNonlinearSolveAlgorithm) = nothing

function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm{name}) where {name}
__show_algorithm(io, alg, name, 0)
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
__show_algorithm(io, alg, get_name(alg), 0)
end

get_name(::AbstractNonlinearSolveAlgorithm{name}) where {name} = name
Expand Down Expand Up @@ -207,6 +207,23 @@ Abstract Type for all NonlinearSolve.jl Caches.
"""
abstract type AbstractNonlinearSolveCache{iip, timeit} end

function Base.show(io::IO, cache::AbstractNonlinearSolveCache)
__show_cache(io, cache, 0)
end

function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
println(io, "$(nameof(typeof(cache)))(")
__show_algorithm(io, cache.alg,
(" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4)
println(io, ",")
println(io, (" "^(indent + 4)) * "u = ", get_u(cache), ",")
println(io, (" "^(indent + 4)) * "residual = ", get_fu(cache), ",")
println(io, (" "^(indent + 4)) * "inf-norm(residual) = ", norm(get_fu(cache), Inf), ",")
println(io, " "^(indent + 4) * "nsteps = ", get_nsteps(cache), ",")
println(io, " "^(indent + 4) * "retcode = ", cache.retcode)
print(io, " "^(indent) * ")")
end

SciMLBase.isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

get_fu(cache::AbstractNonlinearSolveCache) = cache.fu
Expand Down
8 changes: 4 additions & 4 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ AutoSparseForwardDiff
Uses [`PolyesterForwardDiff.jl`](https://github.com/JuliaDiff/PolyesterForwardDiff.jl)
to compute the jacobian. This is essentially parallelized `ForwardDiff.jl`.

- Supports both inplace and out-of-place functions
- Supports both inplace and out-of-place functions

### Keyword Arguments

- `chunksize`: Count of dual numbers that can be propagated simultaneously. Setting
this number to a high value will lead to slowdowns. Use
[`NonlinearSolve.pickchunksize`](@ref) to get a proper value.
- `chunksize`: Count of dual numbers that can be propagated simultaneously. Setting
this number to a high value will lead to slowdowns. Use
[`NonlinearSolve.pickchunksize`](@ref) to get a proper value.
"""
AutoPolyesterForwardDiff

Expand Down
4 changes: 3 additions & 1 deletion src/core/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ Performs one step of the nonlinear solver.
"""
function SciMLBase.step!(cache::AbstractNonlinearSolveCache{iip, timeit},
args...; kwargs...) where {iip, timeit}
not_terminated(cache) || return
timeit && (time_start = time())
res = @static_timeit cache.timer "solve" begin
__step!(cache, args...; kwargs...)
end
cache.nsteps += 1

hasfield(typeof(cache), :nsteps) && (cache.nsteps += 1)

if timeit
cache.total_time += time() - time_start
Expand Down
90 changes: 79 additions & 11 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,56 @@
end
end

@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} <:
AbstractNonlinearSolveCache{iip, false}
@concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit} <:
AbstractNonlinearSolveCache{iip, timeit}
caches
alg
best::Int
current::Int
nsteps::Int
total_time::Float64
maxtime
retcode::ReturnCode.T
force_stop::Bool
maxiters::Int
end

function Base.show(

Check warning on line 61 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L61

Added line #L61 was not covered by tests
io::IO, cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N}
problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem")
println(io, "NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms")
best_alg = ifelse(cache.best == -1, "nothing", cache.best)
println(io, "Best algorithm: $(best_alg)")
println(io, "Current algorithm: $(cache.current)")
println(io, "nsteps: $(cache.nsteps)")
println(io, "retcode: $(cache.retcode)")
__show_cache(io, cache.caches[cache.current], 0)

Check warning on line 70 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L63-L70

Added lines #L63 - L70 were not covered by tests
end

function reinit_cache!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwargs...)
foreach(c -> reinit_cache!(c, args...; kwargs...), cache.caches)
cache.current = 1
cache.nsteps = 0
cache.total_time = 0.0

Check warning on line 77 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L76-L77

Added lines #L76 - L77 were not covered by tests
end

for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
function SciMLBase.__init(
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N}(
map(solver -> SciMLBase.__init(prob, solver, args...; kwargs...), alg.algs),
alg, 1)
function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...;
maxtime = nothing, maxiters = 1000, kwargs...) where {N}
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
alg.algs),
alg,
-1,
1,
0,
0.0,
maxtime,
ReturnCode.Default,
false,
maxiters)
end
end
end
Expand All @@ -89,7 +119,7 @@
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution(
$(sol_syms[i]).prob, cache.alg, u, fu;
retcode = ReturnCode.Success, stats,
retcode = $(sol_syms[i]).retcode, stats,
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
end
cache.current = $(i + 1)
Expand All @@ -103,12 +133,11 @@
end
push!(calls,
quote
retcode = ReturnCode.MaxIters

fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.caches[1].internalnorm, fus)
stats = cache.caches[idx].stats
u = cache.caches[idx].u
u = get_u(cache.caches[idx])
retcode = cache.caches[idx].retcode

Check warning on line 140 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L139-L140

Added lines #L139 - L140 were not covered by tests

return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
retcode, stats, cache.caches[idx].trace)
Expand All @@ -117,6 +146,45 @@
return Expr(:block, calls...)
end

@generated function __step!(
cache::NonlinearSolvePolyAlgorithmCache{iip, N}, args...; kwargs...) where {iip, N}
calls = []
cache_syms = [gensym("cache") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
__step!($(cache_syms[i]), args...; kwargs...)
$(cache_syms[i]).nsteps += 1
if !not_terminated($(cache_syms[i]))
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
cache.best = $(i)
cache.force_stop = true
cache.retcode = $(cache_syms[i]).retcode
else
cache.current = $(i + 1)

Check warning on line 166 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L166

Added line #L166 was not covered by tests
end
end
return
end
end)
end

push!(calls,
quote
if !(1 ≤ cache.current ≤ length(cache.caches))
minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches)
cache.best = idx
cache.retcode = cache.caches[cache.best].retcode
cache.force_stop = true
return

Check warning on line 181 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L176-L181

Added lines #L176 - L181 were not covered by tests
end
end)

return Expr(:block, calls...)
end

for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@
@inline __is_complex(::Type{Complex}) = true
@inline __is_complex(::Type{T}) where {T} = false

function __findmin_caches(f, caches)
return __findmin(f ∘ get_fu, caches)

Check warning on line 98 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L97-L98

Added lines #L97 - L98 were not covered by tests
end
function __findmin(f, x)
return findmin(x) do xᵢ
fx = f(xᵢ)
Expand Down
25 changes: 25 additions & 0 deletions test/misc/polyalg_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,31 @@
cache = init(probN, custom_polyalg; abstol = 1e-9)
solver = solve!(cache)
@test SciMLBase.successful_retcode(solver)

# Test the step interface
cache = init(probN; abstol = 1e-9)
for i in 1:10000
step!(cache)
cache.force_stop && break
end
@test SciMLBase.successful_retcode(cache.retcode)
cache = init(probN, RobustMultiNewton(); abstol = 1e-9)
for i in 1:10000
step!(cache)
cache.force_stop && break
end
@test SciMLBase.successful_retcode(cache.retcode)
cache = init(probN, FastShortcutNonlinearPolyalg(); abstol = 1e-9)
for i in 1:10000
step!(cache)
cache.force_stop && break
end
@test SciMLBase.successful_retcode(cache.retcode)
cache = init(probN, custom_polyalg; abstol = 1e-9)
for i in 1:10000
step!(cache)
cache.force_stop && break
end
end

@testitem "Testing #153 Singular Exception" begin
Expand Down
14 changes: 14 additions & 0 deletions testing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using NonlinearSolve

f(u, p) = u .* u .- 2

u0 = [1.0, 1.0]

prob = NonlinearProblem(f, u0)

nlcache = init(prob);

for i in 1:10
step!(nlcache)
@show nlcache.retcode
end
Loading