Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle polyalgorithm aliasing correctly #392

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.8.0"
version = "3.8.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf, SciMLBase,
SimpleNonlinearSolve, SparseArrays, SparseDiffTools

import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing
import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing,
ismutable
import DiffEqBase: AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode,
Expand Down
123 changes: 102 additions & 21 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
force_stop::Bool
maxiters::Int
internalnorm
u0
u0_aliased
alias_u0::Bool
end

function Base.show(
Expand All @@ -91,11 +94,24 @@
@eval begin
function SciMLBase.__init(
prob::$probType, alg::$algType{N}, args...; maxtime = nothing,
maxiters = 1000, internalnorm = DEFAULT_NORM, kwargs...) where {N}
maxiters = 1000, internalnorm = DEFAULT_NORM,
alias_u0 = false, verbose = true, kwargs...) where {N}
if (alias_u0 && !ismutable(prob.u0))
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \

Check warning on line 100 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L100

Added line #L100 was not covered by tests
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing

Check warning on line 102 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L102

Added line #L102 was not covered by tests
end
u0 = prob.u0
if alias_u0
u0_aliased = copy(u0)
else
u0_aliased = u0 # Irrelevant
end
alias_u0 && (prob = remake(prob; u0 = u0_aliased))
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(
solver -> SciMLBase.__init(
prob, solver, args...; maxtime, internalnorm, kwargs...),
solver -> SciMLBase.__init(prob, solver, args...; maxtime,
internalnorm, alias_u0, verbose, kwargs...),
alg.algs),
alg,
-1,
Expand All @@ -106,7 +122,10 @@
ReturnCode.Default,
false,
maxiters,
internalnorm)
internalnorm,
u0,
u0_aliased,
alias_u0)
end
end
end
Expand All @@ -120,20 +139,30 @@

cache_syms = [gensym("cache") for i in 1:N]
sol_syms = [gensym("sol") for i in 1:N]
u_result_syms = [gensym("u_result") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0)
$(sol_syms[i]) = SciMLBase.solve!($(cache_syms[i]))
if SciMLBase.successful_retcode($(sol_syms[i]))
stats = $(sol_syms[i]).stats
u = $(sol_syms[i]).u
if cache.alias_u0
copyto!(cache.u0, $(sol_syms[i]).u)
$(u_result_syms[i]) = cache.u0
else
$(u_result_syms[i]) = $(sol_syms[i]).u
end
fu = get_fu($(cache_syms[i]))
return SciMLBase.build_solution(
$(sol_syms[i]).prob, cache.alg, u, fu;
retcode = $(sol_syms[i]).retcode, stats,
$(sol_syms[i]).prob, cache.alg, $(u_result_syms[i]),
fu; retcode = $(sol_syms[i]).retcode, stats,
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace)
elseif cache.alias_u0
# For safety we need to maintain a copy of the solution
$(u_result_syms[i]) = copy($(sol_syms[i]).u)
end
cache.current = $(i + 1)
end
Expand All @@ -144,14 +173,29 @@
for (sym, resid) in zip(cache_syms, resids)
push!(calls, :($(resid) = @isdefined($(sym)) ? get_fu($(sym)) : nothing))
end
push!(calls, quote
fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.internalnorm, fus)
stats = __compile_stats(cache.caches[idx])

Check warning on line 179 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L177-L179

Added lines #L177 - L179 were not covered by tests
end)
for i in 1:N
push!(calls, quote
if idx == $(i)
if cache.alias_u0
u = $(u_result_syms[i])

Check warning on line 185 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L183-L185

Added lines #L183 - L185 were not covered by tests
else
u = get_u(cache.caches[$i])

Check warning on line 187 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L187

Added line #L187 was not covered by tests
end
end
end)
end
push!(calls,
quote
fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.internalnorm, fus)
stats = __compile_stats(cache.caches[idx])
u = get_u(cache.caches[idx])
retcode = cache.caches[idx].retcode

if cache.alias_u0
copyto!(cache.u0, u)
u = cache.u0

Check warning on line 197 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L195-L197

Added lines #L195 - L197 were not covered by tests
end
return SciMLBase.build_solution(cache.caches[idx].prob, cache.alg, u, fus[idx];
retcode, stats, cache.caches[idx].trace)
end)
Expand Down Expand Up @@ -200,22 +244,52 @@
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
@generated function SciMLBase.__solve(
prob::$probType, alg::$algType{N}, args...; kwargs...) where {N}
calls = [:(current = alg.start_index)]
@generated function SciMLBase.__solve(prob::$probType, alg::$algType{N}, args...;
alias_u0 = false, verbose = true, kwargs...) where {N}
sol_syms = [gensym("sol") for _ in 1:N]
prob_syms = [gensym("prob") for _ in 1:N]
u_result_syms = [gensym("u_result") for _ in 1:N]
calls = [quote
current = alg.start_index
if (alias_u0 && !ismutable(prob.u0))
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \

Check warning on line 255 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L255

Added line #L255 was not covered by tests
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing

Check warning on line 257 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L257

Added line #L257 was not covered by tests
end
u0 = prob.u0
if alias_u0
u0_aliased = similar(u0)
else
u0_aliased = u0 # Irrelevant
end
end]
for i in 1:N
cur_sol = sol_syms[i]
push!(calls,
quote
if current == $i
$(cur_sol) = SciMLBase.__solve(
prob, alg.algs[$(i)], args...; kwargs...)
if alias_u0
copyto!(u0_aliased, u0)
$(prob_syms[i]) = remake(prob; u0 = u0_aliased)
else
$(prob_syms[i]) = prob
end
$(cur_sol) = SciMLBase.__solve($(prob_syms[i]), alg.algs[$(i)],
args...; alias_u0, verbose, kwargs...)
if SciMLBase.successful_retcode($(cur_sol))
if alias_u0
copyto!(u0, $(cur_sol).u)
$(u_result_syms[i]) = u0
else
$(u_result_syms[i]) = $(cur_sol).u
end
return SciMLBase.build_solution(
prob, alg, $(cur_sol).u, $(cur_sol).resid;
prob, alg, $(u_result_syms[i]), $(cur_sol).resid;
$(cur_sol).retcode, $(cur_sol).stats,
original = $(cur_sol), trace = $(cur_sol).trace)
elseif alias_u0
# For safety we need to maintain a copy of the solution
$(u_result_syms[i]) = copy($(cur_sol).u)
end
current = $(i + 1)
end
Expand All @@ -236,9 +310,16 @@
push!(calls,
quote
if idx == $i
return SciMLBase.build_solution(prob, alg, $(sol_syms[i]).u,
$(sol_syms[i]).resid; $(sol_syms[i]).retcode,
$(sol_syms[i]).stats, $(sol_syms[i]).trace)
if alias_u0
copyto!(u0, $(u_result_syms[i]))
$(u_result_syms[i]) = u0

Check warning on line 315 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L314-L315

Added lines #L314 - L315 were not covered by tests
else
$(u_result_syms[i]) = $(sol_syms[i]).u
end
return SciMLBase.build_solution(
prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid;
$(sol_syms[i]).retcode, $(sol_syms[i]).stats,
$(sol_syms[i]).trace, original = $(sol_syms[i]))
end
end)
end
Expand Down
13 changes: 7 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,16 @@
@inline __is_complex(::Type{Complex}) = true
@inline __is_complex(::Type{T}) where {T} = false

function __findmin_caches(f, caches)
return __findmin(f ∘ get_fu, caches)
end
function __findmin(f, x)
return findmin(x) do xᵢ
@inline __findmin_caches(f, caches) = __findmin(f ∘ get_fu, caches)

Check warning on line 97 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L97

Added line #L97 was not covered by tests
# FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`)
@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x)
@inline function __findmin(f, x)
fmin = @closure xᵢ -> begin
xᵢ === nothing && return Inf
fx = f(xᵢ)
return isnan(fx) ? Inf : fx
return ifelse(isnan(fx), Inf, fx)
end
return findmin(fmin, x)
end

@inline __can_setindex(x) = can_setindex(x)
Expand Down
25 changes: 25 additions & 0 deletions test/misc/aliasing_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testitem "PolyAlgorithm Aliasing" begin
using NonlinearProblemLibrary

# Use a problem that the initial solvers cannot solve and cause the initial value to
# diverge. If we don't alias correctly, all the subsequent algorithms will also fail.
prob = NonlinearProblemLibrary.nlprob_23_testcases["Generalized Rosenbrock function"].prob
u0 = copy(prob.u0)
prob = remake(prob; u0 = copy(u0))

# If aliasing is not handled properly this will diverge
sol = solve(prob; abstol = 1e-6, alias_u0 = true,
termination_condition = AbsNormTerminationMode())

@test sol.u === prob.u0
@test SciMLBase.successful_retcode(sol.retcode)

prob = remake(prob; u0 = copy(u0))

cache = init(prob; abstol = 1e-6, alias_u0 = true,
termination_condition = AbsNormTerminationMode())
sol = solve!(cache)

@test sol.u === prob.u0
@test SciMLBase.successful_retcode(sol.retcode)
end
Loading