StatProfilerHTML.jl report
Generated on Thu, 21 Dec 2023 13:06:16
File source code
Line Exclusive Inclusive Code
1 @concrete struct KrylovJᵀJ
2 JᵀJ
3 Jᵀ
4 end
5
6 __maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ
7
8 isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)
9
10 # Select if we are going to use sparse differentiation or not
11 sparsity_detection_alg(_, _) = NoSparsityDetection()
12 function sparsity_detection_alg(f, ad::AbstractSparseADType)
13 if f.sparsity === nothing
14 if f.jac_prototype === nothing
15 if is_extension_loaded(Val(:Symbolics))
16 return SymbolicsSparsityDetection()
17 else
18 return ApproximateJacobianSparsity()
19 end
20 else
21 jac_prototype = f.jac_prototype
22 end
23 elseif f.sparsity isa SparseDiffTools.AbstractSparsityDetection
24 if f.jac_prototype === nothing
25 return f.sparsity
26 else
27 jac_prototype = f.jac_prototype
28 end
29 elseif f.sparsity isa AbstractMatrix
30 jac_prototype = f.sparsity
31 elseif f.jac_prototype isa AbstractMatrix
32 jac_prototype = f.jac_prototype
33 else
34 error("`sparsity::typeof($(typeof(f.sparsity)))` & \
35 `jac_prototype::typeof($(typeof(f.jac_prototype)))` is not supported. \
36 Use `sparsity::AbstractMatrix` or `sparsity::AbstractSparsityDetection` or \
37 set to `nothing`. `jac_prototype` can be set to `nothing` or an \
38 `AbstractMatrix`.")
39 end
40
41 if SciMLBase.has_colorvec(f)
42 return PrecomputedJacobianColorvec(; jac_prototype, f.colorvec,
43 partition_by_rows = ad isa ADTypes.AbstractSparseReverseMode)
44 else
45 return JacPrototypeSparsityDetection(; jac_prototype)
46 end
47 end
48
49 # NoOp for Jacobian if it is not a Abstract Array -- For eg, JacVec Operator
50 jacobian!!(J, _) = J
51 # `!!` notation is from BangBang.jl since J might be jacobian in case of oop `f.jac`
52 # and we don't want wasteful `copyto!`
53 function jacobian!!(J::Union{AbstractMatrix{<:Number}, Nothing}, cache)
54 @unpack f, uf, u, p, jac_cache, alg, fu_cache = cache
55 cache.stats.njacs += 1
56 iip = isinplace(cache)
57 if iip
58 if has_jac(f)
59 f.jac(J, u, p)
60 else
61 113 (42 %)
113 (42 %) samples spent in jacobian!!
113 (100 %) (incl.) when called from perform_step! line 129
113 (100 %) samples spent calling sparse_jacobian!
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu_cache, u)
62 end
63 return J
64 else
65 if has_jac(f)
66 return f.jac(u, p)
67 elseif can_setindex(typeof(J))
68 return sparse_jacobian!(J, alg.ad, jac_cache, uf, u)
69 else
70 return sparse_jacobian(alg.ad, jac_cache, uf, u)
71 end
72 end
73 end
74 # Scalar case
75 function jacobian!!(::Number, cache)
76 cache.stats.njacs += 1
77 return last(value_derivative(cache.uf, cache.u))
78 end
79
80 # Build Jacobian Caches
81 78 (29 %)
78 (29 %) samples spent in jacobian_caches
78 (100 %) (incl.) when called from #__init#44 line 104
71 (91 %) samples spent calling #jacobian_caches#82
7 (9 %) samples spent calling #jacobian_caches#82
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val{iip};
82 linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true),
83 linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init, F}
84 uf = SciMLBase.JacobianWrapper{iip}(f, p)
85
86 haslinsolve = hasfield(typeof(alg), :linsolve)
87
88 has_analytic_jac = has_jac(f)
89 linsolve_needs_jac = (concrete_jac(alg) === nothing &&
90 (!haslinsolve || (haslinsolve && (alg.linsolve === nothing ||
91 needs_concrete_A(alg.linsolve)))))
92 alg_wants_jac = (concrete_jac(alg) !== nothing && concrete_jac(alg))
93
94 # NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
95 fu = f.resid_prototype === nothing ? (iip ? zero(u) : f(u, p)) :
96 (iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
97 if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
98 sd = sparsity_detection_alg(f, alg.ad)
99 ad = alg.ad
100 jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, u) :
101 sparse_jacobian_cache(ad, sd, uf, __maybe_mutable(u, ad); fx = fu)
102 else
103 jac_cache = nothing
104 end
105
106 J = if !(linsolve_needs_jac || alg_wants_jac)
107 if f.jvp === nothing
108 # We don't need to construct the Jacobian
109 JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
110 else
111 if iip
112 jvp = (_, u, v) -> (du_ = similar(fu); f.jvp(du_, v, u, p); du_)
113 jvp! = (du_, _, u, v) -> f.jvp(du_, v, u, p)
114 else
115 jvp = (_, u, v) -> f.jvp(v, u, p)
116 jvp! = (du_, _, u, v) -> (du_ .= f.jvp(v, u, p))
117 end
118 op = SparseDiffTools.FwdModeAutoDiffVecProd(f, u, (), jvp, jvp!)
119 FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(false),
120 p, islinear = true)
121 end
122 else
123 if has_analytic_jac
124 f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
125 elseif f.jac_prototype === nothing
126 7 (3 %)
7 (3 %) samples spent in #jacobian_caches#82
7 (100 %) (incl.) when called from jacobian_caches line 81
7 (100 %) samples spent calling init_jacobian
init_jacobian(jac_cache; preserve_immutable = Val(true))
127 else
128 f.jac_prototype
129 end
130 end
131
132 du = copy(u)
133
134 if needsJᵀJ
135 JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f,
136 vjp_autodiff = __get_nonsparse_ad(__getproperty(alg, Val(:vjp_autodiff))),
137 jvp_autodiff = __get_nonsparse_ad(alg.ad))
138 else
139 JᵀJ, Jᵀfu = nothing, nothing
140 end
141
142 if linsolve_init
143 if alg isa PseudoTransient && J isa SciMLOperators.AbstractSciMLOperator
144 linprob_A = J - inv(convert(eltype(u), alg.alpha_initial)) * I
145 else
146 linprob_A = needsJᵀJ ? __maybe_symmetric(JᵀJ) : J
147 end
148 71 (26 %)
71 (26 %) samples spent in #jacobian_caches#82
71 (100 %) (incl.) when called from jacobian_caches line 81
71 (100 %) samples spent calling linsolve_caches
linsolve = linsolve_caches(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg;
149 linsolve_kwargs)
150 else
151 linsolve = nothing
152 end
153
154 return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu
155 end
156
157 ## Special Handling for Scalars
158 function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
159 ::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
160 kwargs...) where {needsJᵀJ, F}
161 # NOTE: Scalar `u` assumes scalar output from `f`
162 uf = SciMLBase.JacobianWrapper{false}(f, p)
163 return uf, FakeLinearSolveJLCache(u, u), u, zero(u), nothing, u, u, u
164 end
165
166 # Linear Solve Cache
167 71 (26 %)
142 (52 %) samples spent in linsolve_caches
71 (50 %) (incl.) when called from linsolve_caches line 167
71 (50 %) (incl.) when called from #jacobian_caches#82 line 148
71 (100 %) samples spent calling #linsolve_caches#92
function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
168 if A isa Number ||
169 (alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;))
170 # Default handling for SArrays in LinearSolve is not great. Some parts are patched
171 # but there are quite a few unnecessary allocations
172 return FakeLinearSolveJLCache(A, b)
173 end
174
175 linprob = LinearProblem(A, _vec(b); u0 = _vec(u), linsolve_kwargs...)
176
177 weight = __init_ones(u)
178
179 Pl, Pr = wrapprecs(alg.precs(A, nothing, u, p, nothing, nothing, nothing, nothing,
180 nothing)..., weight)
181 71 (26 %)
71 (100 %) samples spent calling init
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
182 end
183 linsolve_caches(A::KrylovJᵀJ, b, u, p, alg) = linsolve_caches(A.JᵀJ, b, u, p, alg)
184
185 __init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)
186 function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
187 JᵀJ = J' * J
188 Jᵀfu = J' * fu
189 return JᵀJ, Jᵀfu
190 end
191 function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
192 JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
193 return JᵀJ, J' * fu
194 end
195 function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; f = nothing,
196 vjp_autodiff = nothing, jvp_autodiff = nothing, kwargs...)
197 # FIXME: Proper fix to this requires the FunctionOperator patch
198 if f !== nothing && f.vjp !== nothing
199 @warn "Currently we don't make use of user provided `jvp`. This is planned to be \
200 fixed in the near future."
201 end
202 autodiff = __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
203 Jᵀ = VecJac(uf, u; fu, autodiff)
204 JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u)
205 JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ)
206 Jᵀfu = Jᵀ * fu
207 return JᵀJ, Jᵀfu
208 end
209
210 function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
211 if vjp_autodiff === nothing
212 if isinplace(uf)
213 # VecJac can be only FiniteDiff
214 return AutoFiniteDiff()
215 else
216 # Short circuit if we see that FiniteDiff was used for J computation
217 jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
218 # Check if Zygote is loaded then use Zygote else use FiniteDiff
219 is_extension_loaded(Val{:Zygote}()) && return AutoZygote()
220 return AutoFiniteDiff()
221 end
222 else
223 ad = __get_nonsparse_ad(vjp_autodiff)
224 if isinplace(uf) && ad isa AutoZygote
225 @warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \
226 Falling back to finite differencing."
227 return AutoFiniteDiff()
228 end
229 return ad
230 end
231 end
232
233 # jvp fallback scalar
234 function __gradient_operator(uf, u; autodiff, kwargs...)
235 if !(autodiff isa AutoFiniteDiff || autodiff isa AutoZygote)
236 _ad = autodiff
237 number_ad = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),
238 AutoFiniteDiff())
239 if u isa Number
240 autodiff = number_ad
241 else
242 if isinplace(uf)
243 autodiff = AutoFiniteDiff()
244 else
245 autodiff = ifelse(is_extension_loaded(Val{:Zygote}()), AutoZygote(),
246 AutoFiniteDiff())
247 end
248 end
249 if _ad !== nothing && _ad !== autodiff
250 @warn "$(_ad) not supported for VecJac. Using $(autodiff) instead."
251 end
252 end
253 return u isa Number ? GradientScalar(uf, u, autodiff) :
254 VecJac(uf, u; autodiff, kwargs...)
255 end
256
257 @concrete mutable struct GradientScalar
258 uf
259 u
260 autodiff
261 end
262
263 function Base.:*(jvp::GradientScalar, v::Number)
264 if jvp.autodiff isa AutoForwardDiff
265 T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
266 out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, one(v)))
267 return ForwardDiff.extract_derivative(T, out)
268 elseif jvp.autodiff isa AutoFiniteDiff
269 J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype)
270 return J
271 else
272 error("Only ForwardDiff & FiniteDiff is currently supported.")
273 end
274 end
275
276 # Generic Handling of Krylov Methods for Normal Form Linear Solves
277 function __update_JᵀJ!(cache::AbstractNonlinearSolveCache, J = nothing)
278 if !(cache.JᵀJ isa KrylovJᵀJ)
279 J_ = ifelse(J === nothing, cache.J, J)
280 @bb cache.JᵀJ = transpose(J_) × J_
281 end
282 end
283
284 function __update_Jᵀf!(cache::AbstractNonlinearSolveCache, J = nothing)
285 if cache.JᵀJ isa KrylovJᵀJ
286 @bb cache.Jᵀf = cache.JᵀJ.Jᵀ × cache.fu
287 else
288 J_ = ifelse(J === nothing, cache.J, J)
289 @bb cache.Jᵀf = transpose(J_) × vec(cache.fu)
290 end
291 end
292
293 # Left-Right Multiplication
294 __lr_mul(cache::AbstractNonlinearSolveCache) = __lr_mul(cache, cache.JᵀJ, cache.Jᵀf)
295 function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ::KrylovJᵀJ, Jᵀf)
296 @bb cache.lr_mul_cache = JᵀJ.JᵀJ × vec(Jᵀf)
297 return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache))
298 end
299 function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ, Jᵀf)
300 @bb cache.lr_mul_cache = JᵀJ × vec(Jᵀf)
301 return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache))
302 end