Skip to content

Commit

Permalink
Merge pull request #114 from avik-pal/ap/inplace_duals
Browse files Browse the repository at this point in the history
Add ForwardDiff Inplace Overloads
  • Loading branch information
avik-pal authored Dec 27, 2023
2 parents 3f63d6c + f66e913 commit ef11f39
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 156 deletions.
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.1.0"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
118 changes: 64 additions & 54 deletions lib/SimpleNonlinearSolve/src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,44 @@
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
f = prob.f
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
sol.original)
end

# Handle Ambiguities
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@eval begin
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::$(algType), args...; kwargs...) where {uType, T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
sol.stats, sol.original, left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
end
end
end

function __nlsolve_ad(prob, alg, args...; kwargs...)
p = value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = value.(prob.tspan)
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
else
u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)
end

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, p)
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, p)

z_arr = -inv(f_x) * f_p
z_arr = -f_x \ f_p

pp = prob.p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
Expand All @@ -30,60 +53,47 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
false, <:Dual{T, V, P}}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
false, <:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function scalar_nlsolve_∂f_∂p(f, u, p)
ff = p isa Number ? ForwardDiff.derivative :
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
return ff(Base.Fix1(f, u), p)
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
__f = p -> begin
du = similar(u, promote_type(eltype(u), eltype(p)))
f(du, u, p)
return du
end
else
__f = Base.Fix1(f, u)
end
if p isa Number
return __reshape(ForwardDiff.derivative(__f, p), :, 1)
elseif u isa Number
return __reshape(ForwardDiff.gradient(__f, p), 1, :)
else
return ForwardDiff.jacobian(__f, p)
end
end

function scalar_nlsolve_∂f_∂u(f, u, p)
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
return ff(Base.Fix2(f, p), u)
@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
if isinplace(prob)
du = similar(u)
__f = (du, u) -> f(du, u, p)
ForwardDiff.jacobian(__f, du, u)
else
__f = Base.Fix2(f, p)
if u isa Number
return ForwardDiff.derivative(__f, u)
else
return ForwardDiff.jacobian(__f, u)
end
end
end

function scalar_nlsolve_dual_soln(u::Number, partials,
@inline function __nlsolve_dual_soln(u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials)
end

function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
@inline function __nlsolve_dual_soln(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
end

# avoid ambiguities
for Alg in [Bisection]
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:Dual{T, V, P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
end
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:AbstractArray{<:Dual{T, V, P}}}, alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
end
_partials = _restructure(u, partials)
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
end
11 changes: 9 additions & 2 deletions lib/SimpleNonlinearSolve/src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
setindex_trait(x) === CannotSetindex() && (A = dfx)

# Factorize Once and Reuse
dfx_fact = factorize(dfx)
dfx_fact = if dfx isa Number
dfx
else
fact = lu(dfx; check = false)
!issuccess(fact) && return build_solution(prob, alg, x, fx;
retcode = ReturnCode.Unstable)
fact
end

aᵢ = dfx_fact \ _vec(fx)
A_ = _vec(A)
Expand All @@ -64,7 +71,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;

@bb Aaᵢ = A × aᵢ
@bb A .*= -1
bᵢ = dfx_fact \ Aaᵢ
bᵢ = dfx_fact \ _vec(Aaᵢ)

cᵢ_ = _vec(cᵢ)
@bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ))
Expand Down
3 changes: 3 additions & 0 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,6 @@ end
return AutoFiniteDiff()
end
end

@inline __reshape(x::Number, args...) = x
@inline __reshape(x::AbstractArray, args...) = reshape(x, args...)
98 changes: 0 additions & 98 deletions lib/SimpleNonlinearSolve/test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,6 @@ const TERMINATION_CONDITIONS = [
autodiff = AutoForwardDiff())) == 0
end

@testset "[OOP] Immutable AD" begin
for p in [1.0, 100.0]
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
end

@testset "[OOP] Scalar AD" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p) 1 / (2 * sqrt(p))
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ForwardDiff.jacobian(t, p)

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

Expand Down Expand Up @@ -124,36 +94,6 @@ end
autodiff = AutoForwardDiff())) == 0
end

@testset "[OOP] Immutable AD" begin
for p in [1.0, 100.0]
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
end

@testset "[OOP] Scalar AD" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p) 1 / (2 * sqrt(p))
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ForwardDiff.jacobian(t, p)

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

Expand Down Expand Up @@ -195,44 +135,6 @@ end
@test (@ballocated $(benchmark_nlsolve_oop)($quadratic_f, 1.0, 2.0)) == allocs
end

@testset "[OOP] Immutable AD" begin
for p in [1.0, 100.0]
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)

if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
@test_broken all(abs.(res) .≈ sqrt(p))
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
else
@test all(abs.(res) .≈ sqrt(p))
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p)))
end
end
end

@testset "[OOP] Scalar AD" begin
for p in 1.0:0.1:100.0
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)

if any(x -> isnan(x), res)
@test_broken abs(res.u) sqrt(p)
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0, p).u, p)) 1 / (2 * sqrt(p))
else
@test abs(res.u) sqrt(p)
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0, p).u, p)), 1 / (2 * sqrt(p)))
end
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ForwardDiff.jacobian(t, p)

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

Expand Down
82 changes: 82 additions & 0 deletions lib/SimpleNonlinearSolve/test/forward_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
using ForwardDiff, SimpleNonlinearSolve, StaticArrays, Test, LinearAlgebra
import SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm

test_f!(du, u, p) = (@. du = u^2 - p)
test_f(u, p) = (@. u^2 - p)

jacobian_f(::Number, p) = 1 / (2 * p)
jacobian_f(::Number, p::Number) = 1 / (2 * p)
jacobian_f(u, p::Number) = one.(u) .* (1 / (2 * p))
jacobian_f(u, p::AbstractArray) = diagm(vec(@. 1 / (2 * p)))

function solve_with(::Val{mode}, u, alg) where {mode}
f = if mode === :iip
solve_iip(p) = solve(NonlinearProblem(test_f!, u, p), alg).u
elseif mode === :oop
solve_oop(p) = solve(NonlinearProblem(test_f, u, p), alg).u
end
return f
end

__compatible(::Any, ::Val{:oop}) = true
__compatible(::Number, ::Val{:iip}) = false
__compatible(::AbstractArray, ::Val{:iip}) = true
__compatible(::StaticArray, ::Val{:iip}) = false

__compatible(::Any, ::Number) = true
__compatible(::Number, ::AbstractArray) = false
__compatible(u::AbstractArray, p::AbstractArray) = size(u) == size(p)

__compatible(u::Number, ::AbstractSimpleNonlinearSolveAlgorithm) = true
__compatible(u::AbstractArray, ::AbstractSimpleNonlinearSolveAlgorithm) = true
__compatible(u::StaticArray, ::AbstractSimpleNonlinearSolveAlgorithm) = true

__compatible(::AbstractSimpleNonlinearSolveAlgorithm, ::Val{:iip}) = true
__compatible(::AbstractSimpleNonlinearSolveAlgorithm, ::Val{:oop}) = true
__compatible(::SimpleHalley, ::Val{:iip}) = false

@testset "ForwardDiff.jl Integration: $(alg)" for alg in (SimpleNewtonRaphson(),
SimpleTrustRegion(), SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())
us = (2.0, @SVector[1.0, 1.0], [1.0, 1.0], ones(2, 2), @SArray ones(2, 2))

@testset "Scalar AD" begin
for p in 1.0:0.1:100.0, u0 in us, mode in (:iip, :oop)
__compatible(u0, alg) || continue
__compatible(u0, Val(mode)) || continue
__compatible(alg, Val(mode)) || continue

sol = solve(NonlinearProblem(test_f, u0, p), alg)
if SciMLBase.successful_retcode(sol)
gs = abs.(ForwardDiff.derivative(solve_with(Val{mode}(), u0, alg), p))
gs_true = abs.(jacobian_f(u0, p))
if !(isapprox(gs, gs_true, atol = 1e-5))
@show sol.retcode, sol.u
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true
else
@test abs.(gs)abs.(gs_true) atol=1e-5
end
end
end
end

@testset "Jacobian" begin
for u0 in us, p in ([2.0, 1.0], [2.0 1.0; 3.0 4.0]), mode in (:iip, :oop)
__compatible(u0, p) || continue
__compatible(u0, alg) || continue
__compatible(u0, Val(mode)) || continue
__compatible(alg, Val(mode)) || continue

sol = solve(NonlinearProblem(test_f, u0, p), alg)
if SciMLBase.successful_retcode(sol)
gs = abs.(ForwardDiff.jacobian(solve_with(Val{mode}(), u0, alg), p))
gs_true = abs.(jacobian_f(u0, p))
if !(isapprox(gs, gs_true, atol = 1e-5))
@show sol.retcode, sol.u
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_jacobian=gs true_jacobian=gs_true
else
@test abs.(gs)abs.(gs_true) atol=1e-5
end
end
end
end
end
3 changes: 2 additions & 1 deletion lib/SimpleNonlinearSolve/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ const GROUP = get(ENV, "GROUP", "All")

@time @testset "SimpleNonlinearSolve.jl" begin
if GROUP == "All" || GROUP == "Core"
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
@time @safetestset "Basic Tests" include("basictests.jl")
@time @safetestset "Forward AD" include("forward_ad.jl")
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
@time @safetestset "Least Squares Tests" include("least_squares.jl")
@time @safetestset "23 Test Problems" include("23_test_problems.jl")
Expand Down

0 comments on commit ef11f39

Please sign in to comment.