Skip to content

Commit

Permalink
Merge pull request #285 from avik-pal/ap/fix_jacvec
Browse files Browse the repository at this point in the history
Use Function Wrappers from SciMLBase
  • Loading branch information
ChrisRackauckas authored Nov 17, 2023
2 parents cf8dd0c + af40e41 commit 0026bc1
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 18 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.8.1"
version = "2.8.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -56,11 +56,11 @@ NonlinearProblemLibrary = "0.1"
PrecompileTools = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "2.4"
SciMLBase = "2.8.2"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "1.9"
SparseDiffTools = "2.11"
SparseDiffTools = "2.12"
StaticArraysCore = "1.4"
UnPack = "1.0"
Zygote = "0.6"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ There are multiple return codes which can mean the solve was successful, and thu
general command `SciMLBase.successful_retcode` to check whether the solution process exited as
intended:

```@example
```@example 1
SciMLBase.successful_retcode(sol)
```

Expand Down
15 changes: 2 additions & 13 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
@concrete struct JacobianWrapper{iip} <: Function
f
p
end

# Previous Implementation did not hold onto `iip`, but this causes problems in packages
# where we check for the presence of function signatures to check which dispatch to call
(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

sparsity_detection_alg(_, _) = NoSparsityDetection()
function sparsity_detection_alg(f, ad::AbstractSparseADType)
if f.sparsity === nothing
Expand Down Expand Up @@ -52,7 +41,7 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val{iip};
linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true),
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init, F}
uf = JacobianWrapper{iip}(f, p)
uf = SciMLBase.JacobianWrapper{iip}(f, p)

haslinsolve = hasfield(typeof(alg), :linsolve)

Expand Down Expand Up @@ -152,7 +141,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
kwargs...) where {needsJᵀJ, F}
# NOTE: Scalar `u` assumes scalar output from `f`
uf = JacobianWrapper{false}(f, p)
uf = SciMLBase.JacobianWrapper{false}(f, p)
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, nothing, u, nothing, nothing, u
end
2 changes: 1 addition & 1 deletion src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
end

function g!(u, fu)
op = VecJac(f, u, p; fu = fu1, autodiff)
op = VecJac(SciMLBase.JacobianWrapper(f, p), u; fu = fu1, autodiff)
if iip
mul!(g₀, op, fu)
return g₀
Expand Down

0 comments on commit 0026bc1

Please sign in to comment.