Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more general fixes #99

Merged
merged 2 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
Expand All @@ -27,9 +27,9 @@ LinearSolve = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "1.73"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1"
StaticArrays = "0.12,1.0"
SnoopPrecompile = "1"
StaticArraysCore = "1.4"
UnPack = "1.0"
julia = "1.6"

Expand All @@ -38,7 +38,8 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff"]
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays"]
21 changes: 18 additions & 3 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@ using Reexport
using UnPack: @unpack
using FiniteDiff, ForwardDiff
using ForwardDiff: Dual
using Setfield
using StaticArrays
using RecursiveArrayTools
using LinearAlgebra
using StaticArraysCore
using RecursiveArrayTools
import ArrayInterfaceCore
import LinearSolve
using DiffEqBase

@reexport using SciMLBase
@reexport using SimpleNonlinearSolve

import SciMLBase: _unwrap_val

abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
AbstractNonlinearSolveAlgorithm end
Expand All @@ -31,6 +32,20 @@ include("jacobian.jl")
include("raphson.jl")
include("ad.jl")

import SnoopPrecompile

SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
for alg in (NewtonRaphson,)
solve(prob, alg(), abstol = T(1e-2))
end

prob = NonlinearProblem{true}((du, u, p) -> du[1] = u[1] * u[1] - p[1], T[0.1], T[2])
for alg in (NewtonRaphson,)
solve(prob, alg(), abstol = T(1e-2))
end
end end

export NewtonRaphson

end # module
6 changes: 4 additions & 2 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:Dual{T, V, P}}, alg::NewtonRaphson,
args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
Expand Down
11 changes: 7 additions & 4 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,14 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson
end

function perform_step!(cache::NewtonRaphsonCache{true})
@unpack u, fu, f, p, cache = cache
@unpack u, fu, f, p, alg = cache
@unpack J, linsolve, du1 = cache
calc_J!(J, cache, cache)

# u = u - J \ fu
linsolve = dolinsolve(alg.precs, linsolve, A = J, b = fu, u = du1,
p = p, reltol = cache.abstol)
cache.linsolve = linsolve
linres = dolinsolve(alg.precs, linsolve, A = J, b = fu, linu = du1,
p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du1
f(fu, u, p)

Expand Down Expand Up @@ -150,6 +151,8 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache)

if cache.iter == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu;
Expand Down
130 changes: 7 additions & 123 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,111 +1,13 @@
"""
@add_kwonly function_definition
Define keyword-only version of the `function_definition`.
@add_kwonly function f(x; y=1)
...
end
expands to:
function f(x; y=1)
...
end
function f(; x = error("No argument x"), y=1)
...
end
"""
macro add_kwonly(ex)
esc(add_kwonly(ex))
end

add_kwonly(ex::Expr) = add_kwonly(Val{ex.head}, ex)

function add_kwonly(::Type{<:Val}, ex)
error("add_only does not work with expression $(ex.head)")
end

function add_kwonly(::Union{Type{Val{:function}},
Type{Val{:(=)}}}, ex::Expr)
body = ex.args[2:end] # function body
default_call = ex.args[1] # e.g., :(f(a, b=2; c=3))
kwonly_call = add_kwonly(default_call)
if kwonly_call === nothing
return ex
end

return quote
begin
$ex
$(Expr(ex.head, kwonly_call, body...))
end
end
end

function add_kwonly(::Type{Val{:where}}, ex::Expr)
default_call = ex.args[1]
rest = ex.args[2:end]
kwonly_call = add_kwonly(default_call)
if kwonly_call === nothing
return nothing
end
return Expr(:where, kwonly_call, rest...)
end

function add_kwonly(::Type{Val{:call}}, default_call::Expr)
# default_call is, e.g., :(f(a, b=2; c=3))
funcname = default_call.args[1] # e.g., :f
required = [] # required positional arguments; e.g., [:a]
optional = [] # optional positional arguments; e.g., [:(b=2)]
default_kwargs = []
for arg in default_call.args[2:end]
if isa(arg, Symbol)
push!(required, arg)
elseif arg.head == :(::)
push!(required, arg)
elseif arg.head == :kw
push!(optional, arg)
elseif arg.head == :parameters
@assert default_kwargs == [] # can I have :parameters twice?
default_kwargs = arg.args
else
error("Not expecting to see: $arg")
end
end
if isempty(required) && isempty(optional)
# If the function is already keyword-only, do nothing:
return nothing
end
if isempty(required)
# It's not clear what should be done. Let's not support it at
# the moment:
error("At least one positional mandatory argument is required.")
end

kwonly_kwargs = Expr(:parameters,
[Expr(:kw, pa, :(error($("No argument $pa"))))
for pa in required]..., optional..., default_kwargs...)
kwonly_call = Expr(:call, funcname, kwonly_kwargs)
# e.g., :(f(; a=error(...), b=error(...), c=1, d=2))

return kwonly_call
end

function num_types_in_tuple(sig)
length(sig.parameters)
end

function num_types_in_tuple(sig::UnionAll)
length(Base.unwrap_unionall(sig).parameters)
end

@inline UNITLESS_ABS2(x) = real(abs2(x))
@inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u)
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}
sqrt(real(sum(abs2, u)) / length(u))
end
@inline function DEFAULT_NORM(u::StaticArray{T}) where {T <: Union{AbstractFloat, Complex}}
@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where {
T <: Union{
AbstractFloat,
Complex}}
sqrt(real(sum(abs2, u)) / length(u))
end
@inline function DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray)
Expand All @@ -114,23 +16,6 @@ end
@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2, u)) / length(u))
@inline DEFAULT_NORM(u) = norm(u)

"""
prevfloat_tdir(x, x0, x1)
Move `x` one floating point towards x0.
"""
function prevfloat_tdir(x, x0, x1)
x1 > x0 ? prevfloat(x) : nextfloat(x)
end

function nextfloat_tdir(x, x0, x1)
x1 > x0 ? nextfloat(x) : prevfloat(x)
end

function max_tdir(a, b, x0, x1)
x1 > x0 ? max(a, b) : min(a, b)
end

alg_autodiff(alg::AbstractNewtonAlgorithm{CS, AD}) where {CS, AD} = AD
alg_autodiff(alg) = false

Expand All @@ -146,15 +31,14 @@ function value_derivative(f::F, x::R) where {F, R}
end

# Todo: improve this dispatch
value_derivative(f::F, x::SVector) where {F} = f(x), ForwardDiff.jacobian(f, x)
function value_derivative(f::F, x::StaticArraysCore.SVector) where {F}
f(x), ForwardDiff.jacobian(f, x)
end

value(x) = x
value(x::Dual) = ForwardDiff.value(x)
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

_unwrap_val(::Val{B}) where {B} = B
_unwrap_val(B) = B

_vec(v) = vec(v)
_vec(v::Number) = v
_vec(v::AbstractVector) = v
Expand Down
21 changes: 18 additions & 3 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,34 @@ end
const csu0 = 1.0

sol = benchmark_immutable(ff, cu0)
@test sol.retcode === ReturnCode.Default
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_mutable(ff, cu0)
@test sol.retcode === ReturnCode.Default
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Default
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9

@test (@ballocated benchmark_immutable(ff, cu0)) < 200
@test (@ballocated benchmark_mutable(ff, cu0)) < 200
@test (@ballocated benchmark_scalar(sf, csu0)) < 400

function benchmark_inplace(f, u0)
probN = NonlinearProblem{true}(f, u0)
solver = init(probN, NewtonRaphson(), abstol = 1e-9)
sol = solve!(solver)
end

function ffiip(du, u, p)
du .= u .* u .- 2
end
u0 = [1.0, 1.0]

sol = benchmark_inplace(ffiip, u0)
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)

# AD Tests
using ForwardDiff

Expand Down