Skip to content

Commit

Permalink
Use Function Wrappers from SciMLBase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 17, 2023
1 parent cf8dd0c commit 874c09e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ SciMLBase = "2.4"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "1.9"
SparseDiffTools = "2.11"
SparseDiffTools = "2.12"
StaticArraysCore = "1.4"
UnPack = "1.0"
Zygote = "0.6"
Expand Down
15 changes: 2 additions & 13 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
@concrete struct JacobianWrapper{iip} <: Function
f
p
end

# Previous Implementation did not hold onto `iip`, but this causes problems in packages
# where we check for the presence of function signatures to check which dispatch to call
(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

sparsity_detection_alg(_, _) = NoSparsityDetection()
function sparsity_detection_alg(f, ad::AbstractSparseADType)
if f.sparsity === nothing
Expand Down Expand Up @@ -52,7 +41,7 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val{iip};
linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true),
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init, F}
uf = JacobianWrapper{iip}(f, p)
uf = SciMLBase.JacobianWrapper{iip}(f, p)

haslinsolve = hasfield(typeof(alg), :linsolve)

Expand Down Expand Up @@ -152,7 +141,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
kwargs...) where {needsJᵀJ, F}
# NOTE: Scalar `u` assumes scalar output from `f`
uf = JacobianWrapper{false}(f, p)
uf = SciMLBase.JacobianWrapper{false}(f, p)
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, nothing, u, nothing, nothing, u
end
2 changes: 1 addition & 1 deletion src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
end

function g!(u, fu)
op = VecJac(f, u, p; fu = fu1, autodiff)
op = VecJac(SciMLBase.JacobianWrapper(f, p), u; fu = fu1, autodiff)
if iip
mul!(g₀, op, fu)
return g₀
Expand Down

0 comments on commit 874c09e

Please sign in to comment.