Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

specialize Newton on static arrays #50

Merged
merged 4 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ julia = "1.6"
[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "Test", "ForwardDiff"]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff"]
22 changes: 15 additions & 7 deletions src/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fx = float(prob.u0)
T = typeof(x)
atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5)
rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5)
atol = xatol !== nothing ? xatol : oneunit(eltype(T)) * (eps(one(eltype(T))))^(4//5)
rtol = xrtol !== nothing ? xrtol : eps(one(eltype(T)))^(4//5)

if typeof(x) <: Number
xo = oftype(one(eltype(x)), Inf)
else
xo = map(x->oftype(one(eltype(x)), Inf),x)
end

xo = oftype(x, Inf)
for i in 1:maxiters
if alg_autodiff(alg)
fx, dfx = value_derivative(f, x)
elseif x isa AbstractArray
fx = f(x)
dfx = FiniteDiff.finite_difference_jacobian(f, x, alg.diff_type, eltype(x), fx)
else
fx = f(x)
dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx)
Expand Down Expand Up @@ -49,12 +57,12 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)

end
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
end
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ function value_derivative(f::F, x::R) where {F,R}
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
end

# Todo: improve this dispatch
value_derivative(f::F, x::SVector) where F = f(x),ForwardDiff.jacobian(f, x)

value(x) = x
value(x::Dual) = ForwardDiff.value(x)
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
181 changes: 181 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
using NonlinearSolve
using StaticArrays
using BenchmarkTools
using Test

function benchmark_immutable(f, u0)
probN = NonlinearProblem{false}(f, u0)
solver = init(probN, NewtonRaphson(), tol = 1e-9)
sol = solve!(solver)
end

function benchmark_mutable(f, u0)
probN = NonlinearProblem{false}(f, u0)
solver = init(probN, NewtonRaphson(), tol = 1e-9)
sol = (reinit!(solver, probN); solve!(solver))
end

function benchmark_scalar(f, u0)
probN = NonlinearProblem{false}(f, u0)
sol = (solve(probN, NewtonRaphson()))
end

function ff(u,p)
u .* u .- 2
end
const cu0 = @SVector[1.0, 1.0]
function sf(u,p)
u * u - 2
end
const csu0 = 1.0

sol = benchmark_immutable(ff, cu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_mutable(ff, cu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test sol.u * sol.u - 2 < 1e-9

@test (@ballocated benchmark_immutable(ff, cu0)) == 0
@test (@ballocated benchmark_mutable(ff, cu0)) < 200
@test (@ballocated benchmark_scalar(sf, csu0)) == 0

# AD Tests
using ForwardDiff

# Immutable
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]

g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, NewtonRaphson(), tol = 1e-9)
return sol.u[end]
end

for p in 1.0:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p))
end

# Scalar
f, u0 = (u, p) -> u * u - p, 1.0

# NewtonRaphson
g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, NewtonRaphson())
return sol.u
end

@test ForwardDiff.derivative(g, 1.0) ≈ 0.5

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p))
end

u0 = (1.0, 20.0)
# Falsi
g = function (p)
probN = NonlinearProblem{false}(f, typeof(p).(u0), p)
sol = solve(probN, Falsi())
return sol.left
end

for p in 1.1:0.1:100.0
@test g(p) ≈ sqrt(p)
@test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p))
end

f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
for alg in [Bisection(), Falsi()]
global g, p
g = function (p)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, Bisection())
return [sol.left]
end

@test g(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p)
end

gnewton = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
sol = solve(probN, NewtonRaphson())
return [sol.u]
end
@test gnewton(p) ≈ [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p)

# Error Checks

f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
probN = NonlinearProblem(f, u0)

@test solve(probN, NewtonRaphson()).u[end] ≈ sqrt(2.0)
@test solve(probN, NewtonRaphson(); immutable = false).u[end] ≈ sqrt(2.0)
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] ≈ sqrt(2.0)
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] ≈ sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
f = (u, p) -> u .* u .- 2.0
probN = NonlinearProblem(f, u0)
sol = sqrt(2) * u0

@test solve(probN, NewtonRaphson()).u ≈ sol
@test solve(probN, NewtonRaphson()).u ≈ sol
@test solve(probN, NewtonRaphson(;autodiff=false)).u ≈ sol
end

# Bisection Tests
f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0)
probB = NonlinearProblem(f, u0)

# Falsi
solver = init(probB, Falsi())
sol = solve!(solver)
@test sol.left ≈ sqrt(2.0)

# this should call the fast scalar overload
@test solve(probB, Bisection()).left ≈ sqrt(2.0)

# these should call the iterator version
solver = init(probB, Bisection())
@test solver isa NonlinearSolve.BracketingImmutableSolver
@test solve!(solver).left ≈ sqrt(2.0)

# Garuntee Tests for Bisection
f = function (u, p)
if u < 2.0
return u - 2.0
elseif u > 3.0
return u - 3.0
else
return 0.0
end
end
probB = NonlinearProblem(f, (0.0, 4.0))

solver = init(probB, Bisection(;exact_left = true))
sol = solve!(solver)
@test f(sol.left, nothing) < 0.0
@test f(nextfloat(sol.left), nothing) >= 0.0

solver = init(probB, Bisection(;exact_right = true))
sol = solve!(solver)
@test f(sol.right, nothing) > 0.0
@test f(prevfloat(sol.right), nothing) <= 0.0

solver = init(probB, Bisection(;exact_left = true, exact_right = true); immutable = false)
sol = solve!(solver)
@test f(sol.left, nothing) < 0.0
@test f(nextfloat(sol.left), nothing) >= 0.0
@test f(sol.right, nothing) > 0.0
@test f(prevfloat(sol.right), nothing) <= 0.0
Loading