Skip to content

Commit

Permalink
Allow specifying custom jvp
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 17, 2023
1 parent 1c416ba commit ad257fd
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 4 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.8.0"
version = "2.8.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -19,6 +19,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Expand Down Expand Up @@ -48,14 +49,15 @@ FastLevenbergMarquardt = "0.1"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
LinearAlgebra = "1.9"
LineSearches = "7"
LinearAlgebra = "1.9"
LinearSolve = "2.12"
NonlinearProblemLibrary = "0.1"
PrecompileTools = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "2.4"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "1.9"
SparseDiffTools = "2.11"
Expand Down
1 change: 1 addition & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
import RecursiveArrayTools: ArrayPartition,
AbstractVectorOfArray, recursivecopy!, recursivefill!
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
import SciMLOperators: FunctionOperator
import StaticArraysCore: StaticArray, SVector, SArray, MArray
import UnPack: @unpack

Expand Down
17 changes: 15 additions & 2 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,21 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
if f.jvp === nothing
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
else
if iip
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
jvp! = (du, _, u, v) -> f.jvp(du, v, u, p)
else
jvp = (_, u, v) -> f.jvp(v, u, p)
jvp! = (du, _, u, v) -> (du .= f.jvp(v, u, p))
end
op = SparseDiffTools.FwdModeAutoDiffVecProd(f, u, (), jvp, jvp!)
FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(false),
p, islinear = true)
end
else
if has_analytic_jac
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
Expand Down
38 changes: 38 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -973,3 +973,41 @@ end
termination_condition).u .≈ sqrt(2.0))
end
end

# Miscelleneous Tests
@testset "Custom JVP" begin
function F(u::Vector{Float64}, p::Vector{Float64})
Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99))
return u + 0.1 * u .* Δ * u - p
end

function F!(du::Vector{Float64}, u::Vector{Float64}, p::Vector{Float64})
Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99))
du .= u + 0.1 * u .* Δ * u - p
return nothing
end

function JVP(v::Vector{Float64}, u::Vector{Float64}, p::Vector{Float64})
Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99))
return v + 0.1 * (u .* Δ * v + v .* Δ * u)
end

function JVP!(du::Vector{Float64}, v::Vector{Float64}, u::Vector{Float64},
p::Vector{Float64})
Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99))
du .= v + 0.1 * (u .* Δ * v + v .* Δ * u)
return nothing
end

u0 = rand(100)

prob = NonlinearProblem(NonlinearFunction{false}(F; jvp = JVP), u0, u0)
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()))

@test norm(F(sol.u, u0)) 1e-8

prob = NonlinearProblem(NonlinearFunction{true}(F!; jvp = JVP!), u0, u0)
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()))

@test norm(F(sol.u, u0)) 1e-8
end

0 comments on commit ad257fd

Please sign in to comment.