Line | Exclusive | Inclusive | Code |
---|---|---|---|
1 | @concrete struct KrylovJᵀJ | ||
2 | JᵀJ | ||
3 | Jᵀ | ||
4 | end | ||
5 | |||
6 | __maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ | ||
7 | |||
8 | isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) | ||
9 | |||
10 | # Select if we are going to use sparse differentiation or not | ||
11 | sparsity_detection_alg(_, _) = NoSparsityDetection() | ||
12 | function sparsity_detection_alg(f, ad::AbstractSparseADType) | ||
13 | if f.sparsity === nothing | ||
14 | if f.jac_prototype === nothing | ||
15 | if is_extension_loaded(Val(:Symbolics)) | ||
16 | return SymbolicsSparsityDetection() | ||
17 | else | ||
18 | return ApproximateJacobianSparsity() | ||
19 | end | ||
20 | else | ||
21 | jac_prototype = f.jac_prototype | ||
22 | end | ||
23 | elseif f.sparsity isa SparseDiffTools.AbstractSparsityDetection | ||
24 | if f.jac_prototype === nothing | ||
25 | return f.sparsity | ||
26 | else | ||
27 | jac_prototype = f.jac_prototype | ||
28 | end | ||
29 | elseif f.sparsity isa AbstractMatrix | ||
30 | jac_prototype = f.sparsity | ||
31 | elseif f.jac_prototype isa AbstractMatrix | ||
32 | jac_prototype = f.jac_prototype | ||
33 | else | ||
34 | error("`sparsity::typeof($(typeof(f.sparsity)))` & \ | ||
35 | `jac_prototype::typeof($(typeof(f.jac_prototype)))` is not supported. \ | ||
36 | Use `sparsity::AbstractMatrix` or `sparsity::AbstractSparsityDetection` or \ | ||
37 | set to `nothing`. `jac_prototype` can be set to `nothing` or an \ | ||
38 | `AbstractMatrix`.") | ||
39 | end | ||
40 | |||
41 | if SciMLBase.has_colorvec(f) | ||
42 | return PrecomputedJacobianColorvec(; jac_prototype, f.colorvec, | ||
43 | partition_by_rows = ad isa ADTypes.AbstractSparseReverseMode) | ||
44 | else | ||
45 | return JacPrototypeSparsityDetection(; jac_prototype) | ||
46 | end | ||
47 | end | ||
48 | |||
49 | # NoOp for Jacobian if it is not a Abstract Array -- For eg, JacVec Operator | ||
50 | jacobian!!(J, _) = J | ||
51 | # `!!` notation is from BangBang.jl since J might be jacobian in case of oop `f.jac` | ||
52 | # and we don't want wasteful `copyto!` | ||
53 | function jacobian!!(J::Union{AbstractMatrix{<:Number}, Nothing}, cache) | ||
54 | @unpack f, uf, u, p, jac_cache, alg, fu_cache = cache | ||
55 | cache.stats.njacs += 1 | ||
56 | iip = isinplace(cache) | ||
57 | if iip | ||
58 | if has_jac(f) | ||
59 | f.jac(J, u, p) | ||
60 | else | ||
61 | 96 (33 %) |
96 (100 %)
samples spent calling
sparse_jacobian!
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu_cache, u)
|
|
62 | end | ||
63 | return J | ||
64 | else | ||
65 | if has_jac(f) | ||
66 | return f.jac(u, p) | ||
67 | elseif can_setindex(typeof(J)) | ||
68 | return sparse_jacobian!(J, alg.ad, jac_cache, uf, u) | ||
69 | else | ||
70 | return sparse_jacobian(alg.ad, jac_cache, uf, u) | ||
71 | end | ||
72 | end | ||
73 | end | ||
74 | # Scalar case | ||
75 | function jacobian!!(::Number, cache) | ||
76 | cache.stats.njacs += 1 | ||
77 | return last(value_derivative(cache.uf, cache.u)) | ||
78 | end | ||
79 | |||
80 | # Build Jacobian Caches | ||
81 | 119 (41 %) |
119 (41 %)
samples spent in jacobian_caches
119 (100 %) (incl.) when called from #__init#44 line 104
119 (100 %)
samples spent calling
#jacobian_caches#82
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val{iip};
|
|
82 | linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true), | ||
83 | linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init, F} | ||
84 | uf = SciMLBase.JacobianWrapper{iip}(f, p) | ||
85 | |||
86 | haslinsolve = hasfield(typeof(alg), :linsolve) | ||
87 | |||
88 | has_analytic_jac = has_jac(f) | ||
89 | linsolve_needs_jac = (concrete_jac(alg) === nothing && | ||
90 | (!haslinsolve || (haslinsolve && (alg.linsolve === nothing || | ||
91 | needs_concrete_A(alg.linsolve))))) | ||
92 | alg_wants_jac = (concrete_jac(alg) !== nothing && concrete_jac(alg)) | ||
93 | |||
94 | # NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere | ||
95 | fu = f.resid_prototype === nothing ? (iip ? zero(u) : f(u, p)) : | ||
96 | (iip ? deepcopy(f.resid_prototype) : f.resid_prototype) | ||
97 | if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) | ||
98 | sd = sparsity_detection_alg(f, alg.ad) | ||
99 | ad = alg.ad | ||
100 | jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, u) : | ||
101 | sparse_jacobian_cache(ad, sd, uf, __maybe_mutable(u, ad); fx = fu) | ||
102 | else | ||
103 | jac_cache = nothing | ||
104 | end | ||
105 | |||
106 | J = if !(linsolve_needs_jac || alg_wants_jac) | ||
107 | if f.jvp === nothing | ||
108 | # We don't need to construct the Jacobian | ||
109 | JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad)) | ||
110 | else | ||
111 | if iip | ||
112 | jvp = (_, u, v) -> (du_ = similar(fu); f.jvp(du_, v, u, p); du_) | ||
113 | jvp! = (du_, _, u, v) -> f.jvp(du_, v, u, p) | ||
114 | else | ||
115 | jvp = (_, u, v) -> f.jvp(v, u, p) | ||
116 | jvp! = (du_, _, u, v) -> (du_ .= f.jvp(v, u, p)) | ||
117 | end | ||
118 | op = SparseDiffTools.FwdModeAutoDiffVecProd(f, u, (), jvp, jvp!) | ||
119 | FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(false), | ||
120 | p, islinear = true) | ||
121 | end | ||
122 | else | ||
123 | if has_analytic_jac | ||
124 | f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype | ||
125 | elseif f.jac_prototype === nothing | ||
126 | init_jacobian(jac_cache; preserve_immutable = Val(true)) | ||
127 | else | ||
128 | f.jac_prototype | ||
129 | end | ||
130 | end | ||
131 | |||
132 | du = copy(u) | ||
133 | |||
134 | if needsJᵀJ | ||
135 | JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f, | ||
136 | vjp_autodiff = __get_nonsparse_ad(__getproperty(alg, Val(:vjp_autodiff))), | ||
137 | jvp_autodiff = __get_nonsparse_ad(alg.ad)) | ||
138 | else | ||
139 | JᵀJ, Jᵀfu = nothing, nothing | ||
140 | end | ||
141 | |||
142 | if linsolve_init | ||
143 | if alg isa PseudoTransient && J isa SciMLOperators.AbstractSciMLOperator | ||
144 | linprob_A = J - inv(convert(eltype(u), alg.alpha_initial)) * I | ||
145 | else | ||
146 | linprob_A = needsJᵀJ ? __maybe_symmetric(JᵀJ) : J | ||
147 | end | ||
148 | 119 (41 %) |
119 (41 %)
samples spent in #jacobian_caches#82
119 (100 %) (incl.) when called from jacobian_caches line 81
119 (100 %)
samples spent calling
linsolve_caches
linsolve = linsolve_caches(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg;
|
|
149 | linsolve_kwargs) | ||
150 | else | ||
151 | linsolve = nothing | ||
152 | end | ||
153 | |||
154 | return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu | ||
155 | end | ||
156 | |||
157 | ## Special Handling for Scalars | ||
158 | function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p, | ||
159 | ::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false), | ||
160 | kwargs...) where {needsJᵀJ, F} | ||
161 | # NOTE: Scalar `u` assumes scalar output from `f` | ||
162 | uf = SciMLBase.JacobianWrapper{false}(f, p) | ||
163 | return uf, FakeLinearSolveJLCache(u, u), u, zero(u), nothing, u, u, u | ||
164 | end | ||
165 | |||
166 | # Linear Solve Cache | ||
167 | 119 (41 %) |
238 (83 %)
samples spent in linsolve_caches
119 (50 %) (incl.) when called from linsolve_caches line 167 119 (50 %) (incl.) when called from #jacobian_caches#82 line 148
119 (100 %)
samples spent calling
#linsolve_caches#92
function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
|
|
168 | if A isa Number || | ||
169 | (alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;)) | ||
170 | # Default handling for SArrays in LinearSolve is not great. Some parts are patched | ||
171 | # but there are quite a few unnecessary allocations | ||
172 | return FakeLinearSolveJLCache(A, b) | ||
173 | end | ||
174 | |||
175 | linprob = LinearProblem(A, _vec(b); u0 = _vec(u), linsolve_kwargs...) | ||
176 | |||
177 | weight = __init_ones(u) | ||
178 | |||
179 | Pl, Pr = wrapprecs(alg.precs(A, nothing, u, p, nothing, nothing, nothing, nothing, | ||
180 | nothing)..., weight) | ||
181 | 119 (41 %) |
119 (100 %)
samples spent calling
init
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
|
|
182 | end | ||
183 | linsolve_caches(A::KrylovJᵀJ, b, u, p, alg) = linsolve_caches(A.JᵀJ, b, u, p, alg) | ||
184 | |||
185 | __init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J) | ||
186 | function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...) | ||
187 | JᵀJ = J' * J | ||
188 | Jᵀfu = J' * fu | ||
189 | return JᵀJ, Jᵀfu | ||
190 | end | ||
191 | function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...) | ||
192 | JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef) | ||
193 | return JᵀJ, J' * fu | ||
194 | end | ||
195 | function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; f = nothing, | ||
196 | vjp_autodiff = nothing, jvp_autodiff = nothing, kwargs...) | ||
197 | # FIXME: Proper fix to this requires the FunctionOperator patch | ||
198 | if f !== nothing && f.vjp !== nothing | ||
199 | @warn "Currently we don't make use of user provided `jvp`. This is planned to be \ | ||
200 | fixed in the near future." | ||
201 | end | ||
202 | autodiff = __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) | ||
203 | Jᵀ = VecJac(uf, u; fu, autodiff) | ||
204 | JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u) | ||
205 | JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ) | ||
206 | Jᵀfu = Jᵀ * fu | ||
207 | return JᵀJ, Jᵀfu | ||
208 | end | ||
209 | |||
210 | function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) | ||
211 | if vjp_autodiff === nothing | ||
212 | if isinplace(uf) | ||
213 | # VecJac can be only FiniteDiff | ||
214 | return AutoFiniteDiff() | ||
215 | else | ||
216 | # Short circuit if we see that FiniteDiff was used for J computation | ||
217 | jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff | ||
218 | # Check if Zygote is loaded then use Zygote else use FiniteDiff | ||
219 | is_extension_loaded(Val{:Zygote}()) && return AutoZygote() | ||
220 | return AutoFiniteDiff() | ||
221 | end | ||
222 | else | ||
223 | ad = __get_nonsparse_ad(vjp_autodiff) | ||
224 | if isinplace(uf) && ad isa AutoZygote | ||
225 | @warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \ | ||
226 | Falling back to finite differencing." | ||
227 | return AutoFiniteDiff() | ||
228 | end | ||
229 | return ad | ||
230 | end | ||
231 | end | ||
232 | |||
233 | # jvp fallback scalar | ||
234 | function __gradient_operator(uf, u; autodiff, kwargs...) | ||
235 | if !(autodiff isa AutoFiniteDiff || autodiff isa AutoZygote) | ||
236 | _ad = autodiff | ||
237 | number_ad = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(), | ||
238 | AutoFiniteDiff()) | ||
239 | if u isa Number | ||
240 | autodiff = number_ad | ||
241 | else | ||
242 | if isinplace(uf) | ||
243 | autodiff = AutoFiniteDiff() | ||
244 | else | ||
245 | autodiff = ifelse(is_extension_loaded(Val{:Zygote}()), AutoZygote(), | ||
246 | AutoFiniteDiff()) | ||
247 | end | ||
248 | end | ||
249 | if _ad !== nothing && _ad !== autodiff | ||
250 | @warn "$(_ad) not supported for VecJac. Using $(autodiff) instead." | ||
251 | end | ||
252 | end | ||
253 | return u isa Number ? GradientScalar(uf, u, autodiff) : | ||
254 | VecJac(uf, u; autodiff, kwargs...) | ||
255 | end | ||
256 | |||
257 | @concrete mutable struct GradientScalar | ||
258 | uf | ||
259 | u | ||
260 | autodiff | ||
261 | end | ||
262 | |||
263 | function Base.:*(jvp::GradientScalar, v::Number) | ||
264 | if jvp.autodiff isa AutoForwardDiff | ||
265 | T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u))) | ||
266 | out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, one(v))) | ||
267 | return ForwardDiff.extract_derivative(T, out) | ||
268 | elseif jvp.autodiff isa AutoFiniteDiff | ||
269 | J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype) | ||
270 | return J | ||
271 | else | ||
272 | error("Only ForwardDiff & FiniteDiff is currently supported.") | ||
273 | end | ||
274 | end | ||
275 | |||
276 | # Generic Handling of Krylov Methods for Normal Form Linear Solves | ||
277 | function __update_JᵀJ!(cache::AbstractNonlinearSolveCache, J = nothing) | ||
278 | if !(cache.JᵀJ isa KrylovJᵀJ) | ||
279 | J_ = ifelse(J === nothing, cache.J, J) | ||
280 | @bb cache.JᵀJ = transpose(J_) × J_ | ||
281 | end | ||
282 | end | ||
283 | |||
284 | function __update_Jᵀf!(cache::AbstractNonlinearSolveCache, J = nothing) | ||
285 | if cache.JᵀJ isa KrylovJᵀJ | ||
286 | @bb cache.Jᵀf = cache.JᵀJ.Jᵀ × cache.fu | ||
287 | else | ||
288 | J_ = ifelse(J === nothing, cache.J, J) | ||
289 | @bb cache.Jᵀf = transpose(J_) × vec(cache.fu) | ||
290 | end | ||
291 | end | ||
292 | |||
293 | # Left-Right Multiplication | ||
294 | __lr_mul(cache::AbstractNonlinearSolveCache) = __lr_mul(cache, cache.JᵀJ, cache.Jᵀf) | ||
295 | function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ::KrylovJᵀJ, Jᵀf) | ||
296 | @bb cache.lr_mul_cache = JᵀJ.JᵀJ × vec(Jᵀf) | ||
297 | return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache)) | ||
298 | end | ||
299 | function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ, Jᵀf) | ||
300 | @bb cache.lr_mul_cache = JᵀJ × vec(Jᵀf) | ||
301 | return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache)) | ||
302 | end |