Skip to content

Commit

Permalink
Merge pull request #76 from avik-pal/ap/inplace_raphson
Browse files Browse the repository at this point in the history
Add support for inplace BatchedSimpleNewtonRaphson
  • Loading branch information
ChrisRackauckas authored Aug 1, 2023
2 parents c91ce26 + dfb620f commit d4196c2
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "0.1.18"
version = "0.1.19"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
27 changes: 21 additions & 6 deletions lib/SimpleNonlinearSolve/src/batched/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ end
function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphson;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)
@assert !iip "BatchedSimpleNewtonRaphson currently only supports out-of-place nonlinear problems."
iip &&
@assert alg_autodiff(alg) "Inplace BatchedSimpleNewtonRaphson currently only supports autodiff."
u, f, reconstruct = _construct_batched_problem_structure(prob)

tc = alg.termination_condition
Expand All @@ -35,12 +36,26 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphs
rtol = _get_tolerance(reltol, tc.reltol, T)
termination_condition = tc(storage)

if iip
𝓙 = similar(xₙ, length(xₙ), length(xₙ))
fₙ = similar(xₙ)
jac_cfg = ForwardDiff.JacobianConfig(f, fₙ, xₙ)
end

for i in 1:maxiters
if alg_autodiff(alg)
fₙ, 𝓙 = value_derivative(f, xₙ)
if iip
value_derivative!(𝓙, fₙ, f, xₙ, jac_cfg)
else
fₙ = f(xₙ)
𝓙 = FiniteDiff.finite_difference_jacobian(f, xₙ, diff_type(alg), eltype(xₙ), fₙ)
if alg_autodiff(alg)
fₙ, 𝓙 = value_derivative(f, xₙ)
else
fₙ = f(xₙ)
𝓙 = FiniteDiff.finite_difference_jacobian(f,
xₙ,
diff_type(alg),
eltype(xₙ),
fₙ)
end
end

iszero(fₙ) && return DiffEqBase.build_solution(prob,
Expand All @@ -66,7 +81,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphs

if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
xₙ = storage.u
fₙ = f(xₙ)
@maybeinplace iip fₙ=f(xₙ)
end

return DiffEqBase.build_solution(prob,
Expand Down
14 changes: 14 additions & 0 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ function value_derivative(f::F, x::R) where {F, R}
end
value_derivative(f::F, x::AbstractArray) where {F} = f(x), ForwardDiff.jacobian(f, x)

"""
value_derivative!(J, y, f!, x, cfg = JacobianConfig(f!, y, x))
Inplace version of [`SimpleNonlinearSolve.value_derivative`](@ref).
"""
function value_derivative!(J::AbstractMatrix,
y::AbstractArray,
f!::F,
x::AbstractArray,
cfg::ForwardDiff.JacobianConfig = ForwardDiff.JacobianConfig(f!, y, x)) where {F}
ForwardDiff.jacobian!(J, f!, y, x, cfg)
return y, J
end

value(x) = x
value(x::Dual) = ForwardDiff.value(x)
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
Expand Down
15 changes: 7 additions & 8 deletions lib/SimpleNonlinearSolve/test/inplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@ using SimpleNonlinearSolve,
StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test,
NNlib

# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane
# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane, BatchedSimpleNewtonRaphson
function f!(du::AbstractArray{<:Number, N},
u::AbstractArray{<:Number, N},
p::AbstractVector) where {N}
u_ = reshape(u, :, size(u, N))
du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :),
ntuple(_ -> 1, N - 1)...,
size(u, N))
du .= reshape(sum(abs2, u_; dims = 1) .- u_ .- reshape(p, 1, :), size(u))
return du
end

function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector)
du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :)
du .= sum(abs2, u; dims = 1) .- u .- reshape(p, 1, :)
return du
end

function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector)
du .= sum(abs2, u) .- p
du .= sum(abs2, u) .- u .- p
return du
end

@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true),
SimpleDFSane(batched = true))
@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(; batched = true),
SimpleDFSane(; batched = true),
SimpleNewtonRaphson(; batched = true))
@testset "T: $T" for T in (Float32, Float64)
p = rand(T, 5)
@testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5))
Expand Down

0 comments on commit d4196c2

Please sign in to comment.