Skip to content

Commit

Permalink
Merge pull request #50 from SciML/staticarrays
Browse files Browse the repository at this point in the history
specialize Newton on static arrays
  • Loading branch information
ChrisRackauckas authored Jan 18, 2022
2 parents eb6b2e6 + 47bcc55 commit 51c443e
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 189 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ julia = "1.6"
[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "Test", "ForwardDiff"]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff"]
22 changes: 15 additions & 7 deletions src/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fx = float(prob.u0)
T = typeof(x)
atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5)
rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5)
atol = xatol !== nothing ? xatol : oneunit(eltype(T)) * (eps(one(eltype(T))))^(4//5)
rtol = xrtol !== nothing ? xrtol : eps(one(eltype(T)))^(4//5)

if typeof(x) <: Number
xo = oftype(one(eltype(x)), Inf)
else
xo = map(x->oftype(one(eltype(x)), Inf),x)
end

xo = oftype(x, Inf)
for i in 1:maxiters
if alg_autodiff(alg)
fx, dfx = value_derivative(f, x)
elseif x isa AbstractArray
fx = f(x)
dfx = FiniteDiff.finite_difference_jacobian(f, x, alg.diff_type, eltype(x), fx)
else
fx = f(x)
dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx)
Expand Down Expand Up @@ -49,12 +57,12 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)

end
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
end
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ function value_derivative(f::F, x::R) where {F,R}
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
end

# Todo: improve this dispatch
value_derivative(f::F, x::SVector) where F = f(x),ForwardDiff.jacobian(f, x)

value(x) = x
value(x::Dual) = ForwardDiff.value(x)
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
181 changes: 181 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
using NonlinearSolve
using StaticArrays
using BenchmarkTools
using Test

function benchmark_immutable(f, u0)
probN = NonlinearProblem{false}(f, u0)
solver = init(probN, NewtonRaphson(), tol = 1e-9)
sol = solve!(solver)
end

function benchmark_mutable(f, u0)
probN = NonlinearProblem{false}(f, u0)
solver = init(probN, NewtonRaphson(), tol = 1e-9)
sol = (reinit!(solver, probN); solve!(solver))
end

function benchmark_scalar(f, u0)
probN = NonlinearProblem{false}(f, u0)
sol = (solve(probN, NewtonRaphson()))
end

function ff(u,p)
u .* u .- 2
end
const cu0 = @SVector[1.0, 1.0]
function sf(u,p)
u * u - 2
end
const csu0 = 1.0

sol = benchmark_immutable(ff, cu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_mutable(ff, cu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
@test sol.u * sol.u - 2 < 1e-9

@test (@ballocated benchmark_immutable(ff, cu0)) == 0
@test (@ballocated benchmark_mutable(ff, cu0)) < 200
@test (@ballocated benchmark_scalar(sf, csu0)) == 0

# AD Tests
using ForwardDiff

# Immutable
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]

g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, NewtonRaphson(), tol = 1e-9)
return sol.u[end]
end

for p in 1.0:0.1:100.0
@test g(p) sqrt(p)
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
end

# Scalar
f, u0 = (u, p) -> u * u - p, 1.0

# NewtonRaphson
g = function (p)
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
sol = solve(probN, NewtonRaphson())
return sol.u
end

@test ForwardDiff.derivative(g, 1.0) 0.5

for p in 1.1:0.1:100.0
@test g(p) sqrt(p)
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
end

u0 = (1.0, 20.0)
# Falsi
g = function (p)
probN = NonlinearProblem{false}(f, typeof(p).(u0), p)
sol = solve(probN, Falsi())
return sol.left
end

for p in 1.1:0.1:100.0
@test g(p) sqrt(p)
@test ForwardDiff.derivative(g, p) 1/(2*sqrt(p))
end

f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
for alg in [Bisection(), Falsi()]
global g, p
g = function (p)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, Bisection())
return [sol.left]
end

@test g(p) [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
end

gnewton = function (p)
probN = NonlinearProblem{false}(f, 0.5, p)
sol = solve(probN, NewtonRaphson())
return [sol.u]
end
@test gnewton(p) [sqrt(p[2] / p[1])]
@test ForwardDiff.jacobian(gnewton, p) ForwardDiff.jacobian(t, p)

# Error Checks

f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
probN = NonlinearProblem(f, u0)

@test solve(probN, NewtonRaphson()).u[end] sqrt(2.0)
@test solve(probN, NewtonRaphson(); immutable = false).u[end] sqrt(2.0)
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] sqrt(2.0)
@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
f = (u, p) -> u .* u .- 2.0
probN = NonlinearProblem(f, u0)
sol = sqrt(2) * u0

@test solve(probN, NewtonRaphson()).u sol
@test solve(probN, NewtonRaphson()).u sol
@test solve(probN, NewtonRaphson(;autodiff=false)).u sol
end

# Bisection Tests
f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0)
probB = NonlinearProblem(f, u0)

# Falsi
solver = init(probB, Falsi())
sol = solve!(solver)
@test sol.left sqrt(2.0)

# this should call the fast scalar overload
@test solve(probB, Bisection()).left sqrt(2.0)

# these should call the iterator version
solver = init(probB, Bisection())
@test solver isa NonlinearSolve.BracketingImmutableSolver
@test solve!(solver).left sqrt(2.0)

# Garuntee Tests for Bisection
f = function (u, p)
if u < 2.0
return u - 2.0
elseif u > 3.0
return u - 3.0
else
return 0.0
end
end
probB = NonlinearProblem(f, (0.0, 4.0))

solver = init(probB, Bisection(;exact_left = true))
sol = solve!(solver)
@test f(sol.left, nothing) < 0.0
@test f(nextfloat(sol.left), nothing) >= 0.0

solver = init(probB, Bisection(;exact_right = true))
sol = solve!(solver)
@test f(sol.right, nothing) > 0.0
@test f(prevfloat(sol.right), nothing) <= 0.0

solver = init(probB, Bisection(;exact_left = true, exact_right = true); immutable = false)
sol = solve!(solver)
@test f(sol.left, nothing) < 0.0
@test f(nextfloat(sol.left), nothing) >= 0.0
@test f(sol.right, nothing) > 0.0
@test f(prevfloat(sol.right), nothing) <= 0.0
Loading

0 comments on commit 51c443e

Please sign in to comment.