Skip to content

Commit

Permalink
more general fixes
Browse files Browse the repository at this point in the history
Retcode conformity
More tests
Decrease dependencies
SnoopPrecompile
  • Loading branch information
ChrisRackauckas committed Nov 24, 2022
1 parent 37ef686 commit 4bd17e1
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 139 deletions.
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
Expand All @@ -27,9 +27,9 @@ LinearSolve = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "1.73"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1"
StaticArrays = "0.12,1.0"
SnoopPrecompile = "1"
StaticArraysCore = "1.4"
UnPack = "1.0"
julia = "1.6"

Expand All @@ -38,7 +38,8 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff"]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays"]
22 changes: 19 additions & 3 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@ using Reexport
using UnPack: @unpack
using FiniteDiff, ForwardDiff
using ForwardDiff: Dual
using Setfield
using StaticArrays
using RecursiveArrayTools
using LinearAlgebra
using StaticArraysCore
using RecursiveArrayTools
import ArrayInterfaceCore
import LinearSolve
using DiffEqBase

@reexport using SciMLBase
@reexport using SimpleNonlinearSolve

import SciMLBase: _unwrap_val

abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
AbstractNonlinearSolveAlgorithm end
Expand All @@ -31,6 +32,21 @@ include("jacobian.jl")
include("raphson.jl")
include("ad.jl")

import SnoopPrecompile

SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
for alg in (NewtonRaphson,)
solve(prob, alg(), abstol = T(1e-2))
end

prob = NonlinearProblem{true}((du, u, p) -> du[1] = u[1] * u[1] - p[1], T[0.1], T[2])
for alg in (NewtonRaphson,)
solve(prob, alg(), abstol = T(1e-2))
end

end end

export NewtonRaphson

end # module
4 changes: 2 additions & 2 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, iip,
<:Dual{T, V, P}}, alg::NewtonRaphson,
args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
Expand Down
9 changes: 6 additions & 3 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,14 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson
end

function perform_step!(cache::NewtonRaphsonCache{true})
@unpack u, fu, f, p, cache = cache
@unpack u, fu, f, p, alg = cache
@unpack J, linsolve, du1 = cache
calc_J!(J, cache, cache)

# u = u - J \ fu
linsolve = dolinsolve(alg.precs, linsolve, A = J, b = fu, u = du1,
linres = dolinsolve(alg.precs, linsolve, A = J, b = fu, linu = du1,
p = p, reltol = cache.abstol)
cache.linsolve = linsolve
cache.linsolve = linres.cache
@. u = u - du1
f(fu, u, p)

Expand Down Expand Up @@ -150,6 +151,8 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache)

if cache.iter == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu;
Expand Down
125 changes: 2 additions & 123 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,111 +1,10 @@
"""
@add_kwonly function_definition
Define keyword-only version of the `function_definition`.
@add_kwonly function f(x; y=1)
...
end
expands to:
function f(x; y=1)
...
end
function f(; x = error("No argument x"), y=1)
...
end
"""
macro add_kwonly(ex)
esc(add_kwonly(ex))
end

add_kwonly(ex::Expr) = add_kwonly(Val{ex.head}, ex)

function add_kwonly(::Type{<:Val}, ex)
error("add_only does not work with expression $(ex.head)")
end

function add_kwonly(::Union{Type{Val{:function}},
Type{Val{:(=)}}}, ex::Expr)
body = ex.args[2:end] # function body
default_call = ex.args[1] # e.g., :(f(a, b=2; c=3))
kwonly_call = add_kwonly(default_call)
if kwonly_call === nothing
return ex
end

return quote
begin
$ex
$(Expr(ex.head, kwonly_call, body...))
end
end
end

function add_kwonly(::Type{Val{:where}}, ex::Expr)
default_call = ex.args[1]
rest = ex.args[2:end]
kwonly_call = add_kwonly(default_call)
if kwonly_call === nothing
return nothing
end
return Expr(:where, kwonly_call, rest...)
end

function add_kwonly(::Type{Val{:call}}, default_call::Expr)
# default_call is, e.g., :(f(a, b=2; c=3))
funcname = default_call.args[1] # e.g., :f
required = [] # required positional arguments; e.g., [:a]
optional = [] # optional positional arguments; e.g., [:(b=2)]
default_kwargs = []
for arg in default_call.args[2:end]
if isa(arg, Symbol)
push!(required, arg)
elseif arg.head == :(::)
push!(required, arg)
elseif arg.head == :kw
push!(optional, arg)
elseif arg.head == :parameters
@assert default_kwargs == [] # can I have :parameters twice?
default_kwargs = arg.args
else
error("Not expecting to see: $arg")
end
end
if isempty(required) && isempty(optional)
# If the function is already keyword-only, do nothing:
return nothing
end
if isempty(required)
# It's not clear what should be done. Let's not support it at
# the moment:
error("At least one positional mandatory argument is required.")
end

kwonly_kwargs = Expr(:parameters,
[Expr(:kw, pa, :(error($("No argument $pa"))))
for pa in required]..., optional..., default_kwargs...)
kwonly_call = Expr(:call, funcname, kwonly_kwargs)
# e.g., :(f(; a=error(...), b=error(...), c=1, d=2))

return kwonly_call
end

function num_types_in_tuple(sig)
length(sig.parameters)
end

function num_types_in_tuple(sig::UnionAll)
length(Base.unwrap_unionall(sig).parameters)
end

@inline UNITLESS_ABS2(x) = real(abs2(x))
@inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u)
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}
sqrt(real(sum(abs2, u)) / length(u))
end
@inline function DEFAULT_NORM(u::StaticArray{T}) where {T <: Union{AbstractFloat, Complex}}
@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where {T <: Union{AbstractFloat, Complex}}
sqrt(real(sum(abs2, u)) / length(u))
end
@inline function DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray)
Expand All @@ -114,23 +13,6 @@ end
@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2, u)) / length(u))
@inline DEFAULT_NORM(u) = norm(u)

"""
prevfloat_tdir(x, x0, x1)
Move `x` one floating point towards x0.
"""
function prevfloat_tdir(x, x0, x1)
x1 > x0 ? prevfloat(x) : nextfloat(x)
end

function nextfloat_tdir(x, x0, x1)
x1 > x0 ? nextfloat(x) : prevfloat(x)
end

function max_tdir(a, b, x0, x1)
x1 > x0 ? max(a, b) : min(a, b)
end

alg_autodiff(alg::AbstractNewtonAlgorithm{CS, AD}) where {CS, AD} = AD
alg_autodiff(alg) = false

Expand All @@ -146,15 +28,12 @@ function value_derivative(f::F, x::R) where {F, R}
end

# Todo: improve this dispatch
value_derivative(f::F, x::SVector) where {F} = f(x), ForwardDiff.jacobian(f, x)
value_derivative(f::F, x::StaticArraysCore.SVector) where {F} = f(x), ForwardDiff.jacobian(f, x)

value(x) = x
value(x::Dual) = ForwardDiff.value(x)
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

_unwrap_val(::Val{B}) where {B} = B
_unwrap_val(B) = B

_vec(v) = vec(v)
_vec(v::Number) = v
_vec(v::AbstractVector) = v
Expand Down
21 changes: 18 additions & 3 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,34 @@ end
const csu0 = 1.0

sol = benchmark_immutable(ff, cu0)
@test sol.retcode === ReturnCode.Default
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_mutable(ff, cu0)
@test sol.retcode === ReturnCode.Default
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Default
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9

@test (@ballocated benchmark_immutable(ff, cu0)) < 200
@test (@ballocated benchmark_mutable(ff, cu0)) < 200
@test (@ballocated benchmark_scalar(sf, csu0)) < 400

function benchmark_inplace(f, u0)
probN = NonlinearProblem{true}(f, u0)
solver = init(probN, NewtonRaphson(), abstol = 1e-9)
sol = solve!(solver)
end

function ffiip(du, u, p)
du .= u .* u .- 2
end
u0 = [1.0, 1.0]

sol = benchmark_inplace(ffiip, u0)
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)

# AD Tests
using ForwardDiff

Expand Down

0 comments on commit 4bd17e1

Please sign in to comment.