diff --git a/Project.toml b/Project.toml index fcc210bef..0eb72ea4e 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ SciMLBase = "2.4" SciMLOperators = "0.3" SimpleNonlinearSolve = "0.1.23" SparseArrays = "1.9" -SparseDiffTools = "2.11" +SparseDiffTools = "2.12" StaticArraysCore = "1.4" UnPack = "1.0" Zygote = "0.6" diff --git a/src/jacobian.jl b/src/jacobian.jl index 368e0bb70..ac824559b 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -1,14 +1,3 @@ -@concrete struct JacobianWrapper{iip} <: Function - f - p -end - -# Previous Implementation did not hold onto `iip`, but this causes problems in packages -# where we check for the presence of function signatures to check which dispatch to call -(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p) -(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p))) -(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p) - sparsity_detection_alg(_, _) = NoSparsityDetection() function sparsity_detection_alg(f, ad::AbstractSparseADType) if f.sparsity === nothing @@ -52,7 +41,7 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u)) function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val{iip}; linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true), linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init, F} - uf = JacobianWrapper{iip}(f, p) + uf = SciMLBase.JacobianWrapper{iip}(f, p) haslinsolve = hasfield(typeof(alg), :linsolve) @@ -152,7 +141,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, ::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false), kwargs...) where {needsJᵀJ, F} # NOTE: Scalar `u` assumes scalar output from `f` - uf = JacobianWrapper{false}(f, p) + uf = SciMLBase.JacobianWrapper{false}(f, p) needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u return uf, nothing, u, nothing, nothing, u end diff --git a/src/linesearch.jl b/src/linesearch.jl index a2b396b06..598934d03 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -122,7 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe end function g!(u, fu) - op = VecJac(f, u, p; fu = fu1, autodiff) + op = VecJac(SciMLBase.JacobianWrapper(f, p), u; fu = fu1, autodiff) if iip mul!(g₀, op, fu) return g₀