Line | Exclusive | Inclusive | Code |
---|---|---|---|
1 | module NonlinearSolve | ||
2 | |||
3 | if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods")) | ||
4 | @eval Base.Experimental.@max_methods 1 | ||
5 | end | ||
6 | |||
7 | import Reexport: @reexport | ||
8 | import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload | ||
9 | |||
10 | @recompile_invalidations begin | ||
11 | using ADTypes, DiffEqBase, LazyArrays, LineSearches, LinearAlgebra, LinearSolve, Printf, | ||
12 | SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, StaticArrays | ||
13 | |||
14 | import ADTypes: AbstractFiniteDifferencesMode | ||
15 | import ArrayInterface: undefmatrix, restructure, can_setindex, | ||
16 | matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing | ||
17 | import ConcreteStructs: @concrete | ||
18 | import EnumX: @enumx | ||
19 | import FastBroadcast: @.. | ||
20 | import FiniteDiff | ||
21 | import ForwardDiff | ||
22 | import ForwardDiff: Dual | ||
23 | import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A | ||
24 | import MaybeInplace: setindex_trait, @bb, CanSetindex, CannotSetindex | ||
25 | import RecursiveArrayTools: ArrayPartition, | ||
26 | AbstractVectorOfArray, recursivecopy!, recursivefill! | ||
27 | import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace | ||
28 | import SciMLOperators: FunctionOperator | ||
29 | import StaticArrays: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix | ||
30 | import UnPack: @unpack | ||
31 | end | ||
32 | |||
33 | @reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve | ||
34 | import DiffEqBase: AbstractNonlinearTerminationMode, | ||
35 | AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode, | ||
36 | NonlinearSafeTerminationReturnCode, get_termination_mode | ||
37 | |||
38 | const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences, | ||
39 | ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode} | ||
40 | |||
41 | # Type-Inference Friendly Check for Extension Loading | ||
42 | is_extension_loaded(::Val) = false | ||
43 | |||
44 | abstract type AbstractNonlinearSolveLineSearchAlgorithm end | ||
45 | |||
46 | abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end | ||
47 | abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end | ||
48 | |||
49 | abstract type AbstractNonlinearSolveCache{iip} end | ||
50 | |||
51 | isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip | ||
52 | |||
53 | function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(cache); | ||
54 | p = cache.p, abstol = cache.abstol, reltol = cache.reltol, | ||
55 | maxiters = cache.maxiters, alias_u0 = false, termination_condition = missing, | ||
56 | kwargs...) where {iip} | ||
57 | cache.p = p | ||
58 | if iip | ||
59 | recursivecopy!(get_u(cache), u0) | ||
60 | cache.f(get_fu(cache), get_u(cache), p) | ||
61 | else | ||
62 | cache.u = __maybe_unaliased(u0, alias_u0) | ||
63 | set_fu!(cache, cache.f(cache.u, p)) | ||
64 | end | ||
65 | |||
66 | reset!(cache.trace) | ||
67 | |||
68 | # Some algorithms store multiple termination caches | ||
69 | if hasfield(typeof(cache), :tc_cache) | ||
70 | # TODO: We need an efficient way to reset this upstream | ||
71 | tc = termination_condition === missing ? get_termination_mode(cache.tc_cache) : | ||
72 | termination_condition | ||
73 | abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, get_fu(cache), | ||
74 | get_u(cache), tc) | ||
75 | cache.tc_cache = tc_cache | ||
76 | end | ||
77 | |||
78 | if hasfield(typeof(cache), :ls_cache) | ||
79 | # TODO: A more efficient way to do this | ||
80 | cache.ls_cache = init_linesearch_cache(cache.alg.linesearch, cache.f, | ||
81 | get_u(cache), p, get_fu(cache), Val(iip)) | ||
82 | end | ||
83 | |||
84 | hasfield(typeof(cache), :uf) && cache.uf !== nothing && (cache.uf.p = p) | ||
85 | |||
86 | cache.abstol = abstol | ||
87 | cache.reltol = reltol | ||
88 | cache.maxiters = maxiters | ||
89 | cache.stats.nf = 1 | ||
90 | cache.stats.nsteps = 1 | ||
91 | cache.force_stop = false | ||
92 | cache.retcode = ReturnCode.Default | ||
93 | |||
94 | __reinit_internal!(cache; u0, p, abstol, reltol, maxiters, alias_u0, | ||
95 | termination_condition, kwargs...) | ||
96 | |||
97 | return cache | ||
98 | end | ||
99 | |||
100 | __reinit_internal!(::AbstractNonlinearSolveCache; kwargs...) = nothing | ||
101 | |||
102 | function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm) | ||
103 | str = "$(nameof(typeof(alg)))(" | ||
104 | modifiers = String[] | ||
105 | if __getproperty(alg, Val(:ad)) !== nothing | ||
106 | push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()") | ||
107 | end | ||
108 | if __getproperty(alg, Val(:linsolve)) !== nothing | ||
109 | push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()") | ||
110 | end | ||
111 | if __getproperty(alg, Val(:linesearch)) !== nothing | ||
112 | ls = alg.linesearch | ||
113 | if ls isa LineSearch | ||
114 | ls.method !== nothing && | ||
115 | push!(modifiers, "linesearch = $(nameof(typeof(ls.method)))()") | ||
116 | else | ||
117 | push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()") | ||
118 | end | ||
119 | end | ||
120 | append!(modifiers, __alg_print_modifiers(alg)) | ||
121 | if __getproperty(alg, Val(:radius_update_scheme)) !== nothing | ||
122 | push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)") | ||
123 | end | ||
124 | str = str * join(modifiers, ", ") | ||
125 | print(io, "$(str))") | ||
126 | return nothing | ||
127 | end | ||
128 | |||
129 | __alg_print_modifiers(_) = String[] | ||
130 | |||
131 | 284 (99 %) |
568 (197 %)
samples spent in __solve
284 (50 %) (incl.) when called from #solve_call#34 line 606 284 (50 %) (incl.) when called from __solve line 131
284 (100 %)
samples spent calling
#__solve#3
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
|
|
132 | alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...) | ||
133 | 119 (41 %) |
119 (100 %)
samples spent calling
init
cache = init(prob, alg, args...; kwargs...)
|
|
134 | 165 (57 %) |
165 (100 %)
samples spent calling
solve!
return solve!(cache)
|
|
135 | end | ||
136 | |||
137 | function not_terminated(cache::AbstractNonlinearSolveCache) | ||
138 | return !cache.force_stop && cache.stats.nsteps < cache.maxiters | ||
139 | end | ||
140 | |||
141 | get_fu(cache::AbstractNonlinearSolveCache) = cache.fu | ||
142 | set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu) | ||
143 | get_u(cache::AbstractNonlinearSolveCache) = cache.u | ||
144 | SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u) | ||
145 | |||
146 | function SciMLBase.solve!(cache::AbstractNonlinearSolveCache) | ||
147 | while not_terminated(cache) | ||
148 | 165 (57 %) |
165 (100 %)
samples spent calling
perform_step!
perform_step!(cache)
|
|
149 | cache.stats.nsteps += 1 | ||
150 | end | ||
151 | |||
152 | # The solver might have set a different `retcode` | ||
153 | if cache.retcode == ReturnCode.Default | ||
154 | if cache.stats.nsteps == cache.maxiters | ||
155 | cache.retcode = ReturnCode.MaxIters | ||
156 | else | ||
157 | cache.retcode = ReturnCode.Success | ||
158 | end | ||
159 | end | ||
160 | |||
161 | trace = __getproperty(cache, Val{:trace}()) | ||
162 | if trace !== nothing | ||
163 | update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing, | ||
164 | nothing, nothing; last = Val(true)) | ||
165 | end | ||
166 | |||
167 | return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache); | ||
168 | cache.retcode, cache.stats, trace) | ||
169 | end | ||
170 | |||
171 | include("utils.jl") | ||
172 | include("trace.jl") | ||
173 | include("extension_algs.jl") | ||
174 | include("linesearch.jl") | ||
175 | include("raphson.jl") | ||
176 | include("trustRegion.jl") | ||
177 | include("levenberg.jl") | ||
178 | include("gaussnewton.jl") | ||
179 | include("dfsane.jl") | ||
180 | include("pseudotransient.jl") | ||
181 | include("broyden.jl") | ||
182 | include("klement.jl") | ||
183 | include("lbroyden.jl") | ||
184 | include("jacobian.jl") | ||
185 | include("ad.jl") | ||
186 | include("default.jl") | ||
187 | |||
188 | @setup_workload begin | ||
189 | nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1), | ||
190 | (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]), | ||
191 | (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1])) | ||
192 | probs_nls = NonlinearProblem[] | ||
193 | for T in (Float32, Float64), (fn, u0) in nlfuncs | ||
194 | push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2))) | ||
195 | end | ||
196 | |||
197 | nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), | ||
198 | Broyden(), Klement(), DFSane(), nothing) | ||
199 | |||
200 | probs_nlls = NonlinearLeastSquaresProblem[] | ||
201 | nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]), | ||
202 | (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]), | ||
203 | (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, | ||
204 | resid_prototype = zeros(1)), [0.1, 0.0]), | ||
205 | (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), | ||
206 | resid_prototype = zeros(4)), [0.1, 0.1])) | ||
207 | for (fn, u0) in nlfuncs | ||
208 | push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0)) | ||
209 | end | ||
210 | nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]), | ||
211 | (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), | ||
212 | Float32[0.1, 0.1]), | ||
213 | (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p, | ||
214 | resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]), | ||
215 | (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), | ||
216 | resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1])) | ||
217 | for (fn, u0) in nlfuncs | ||
218 | push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0)) | ||
219 | end | ||
220 | |||
221 | nlls_algs = (LevenbergMarquardt(), GaussNewton(), | ||
222 | LevenbergMarquardt(; linsolve = LUFactorization()), | ||
223 | GaussNewton(; linsolve = LUFactorization())) | ||
224 | |||
225 | @compile_workload begin | ||
226 | for prob in probs_nls, alg in nls_algs | ||
227 | solve(prob, alg, abstol = 1e-2) | ||
228 | end | ||
229 | for prob in probs_nlls, alg in nlls_algs | ||
230 | solve(prob, alg, abstol = 1e-2) | ||
231 | end | ||
232 | end | ||
233 | end | ||
234 | |||
235 | export RadiusUpdateSchemes | ||
236 | |||
237 | export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient, | ||
238 | Broyden, Klement, LimitedMemoryBroyden | ||
239 | export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL | ||
240 | export NonlinearSolvePolyAlgorithm, | ||
241 | RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg | ||
242 | |||
243 | export LineSearch, LiFukushimaLineSearch | ||
244 | |||
245 | # Export the termination conditions from DiffEqBase | ||
246 | export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode, | ||
247 | NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode, | ||
248 | AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode, | ||
249 | RelSafeBestTerminationMode, AbsSafeBestTerminationMode | ||
250 | |||
251 | # Tracing Functionality | ||
252 | export TraceAll, TraceMinimal, TraceWithJacobianConditionNumber | ||
253 | |||
254 | end # module |