diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..453925c3f --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "sciml" \ No newline at end of file diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml new file mode 100644 index 000000000..2a3517a0f --- /dev/null +++ b/.github/workflows/FormatCheck.yml @@ -0,0 +1,42 @@ +name: format-check + +on: + push: + branches: + - 'master' + - 'release-' + tags: '*' + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: [1] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v1 + - name: Install JuliaFormatter and format + # This will use the latest version by default but you can set the version like so: + # + # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' diff --git a/docs/make.jl b/docs/make.jl index 6ec622d93..fb811e160 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,18 +2,14 @@ using Documenter, NonlinearSolve include("pages.jl") -makedocs( - sitename="NonlinearSolve.jl", - authors="Chris Rackauckas", - modules=[NonlinearSolve,NonlinearSolve.SciMLBase], - clean=true,doctest=false, - format = Documenter.HTML(analytics = "UA-90474609-3", - assets = ["assets/favicon.ico"], - canonical="https://nonlinearsolve.sciml.ai/stable/"), - pages=pages -) +makedocs(sitename = "NonlinearSolve.jl", + authors = "Chris Rackauckas", + modules = [NonlinearSolve, NonlinearSolve.SciMLBase], + clean = true, doctest = false, + format = Documenter.HTML(analytics = "UA-90474609-3", + assets = ["assets/favicon.ico"], + canonical = "https://nonlinearsolve.sciml.ai/stable/"), + pages = pages) -deploydocs( - repo = "github.com/SciML/NonlinearSolve.jl.git"; - push_preview = true -) +deploydocs(repo = "github.com/SciML/NonlinearSolve.jl.git"; + push_preview = true) diff --git a/docs/pages.jl b/docs/pages.jl index 5d4903ca2..8de820131 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,18 +1,12 @@ # Put in a separate page so it can be used by SciMLDocs.jl -pages=[ +pages = [ "Home" => "index.md", - "Tutorials" => Any[ - "tutorials/nonlinear.md", - "tutorials/iterator_interface.md" - ], - "Basics" => Any[ - "basics/NonlinearProblem.md", - "basics/NonlinearFunctions.md", - "basics/FAQ.md" - ], - "Solvers" => Any[ - "solvers/NonlinearSystemSolvers.md", - "solvers/BracketingSolvers.md" - ] -] \ No newline at end of file + "Tutorials" => Any["tutorials/nonlinear.md", + "tutorials/iterator_interface.md"], + "Basics" => Any["basics/NonlinearProblem.md", + "basics/NonlinearFunctions.md", + "basics/FAQ.md"], + "Solvers" => Any["solvers/NonlinearSystemSolvers.md", + "solvers/BracketingSolvers.md"], +] diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index fb1c822ae..b31f112bb 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -16,7 +16,7 @@ import RecursiveFactorization abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end -abstract type AbstractNewtonAlgorithm{CS,AD} <: AbstractNonlinearSolveAlgorithm end +abstract type AbstractNewtonAlgorithm{CS, AD} <: AbstractNonlinearSolveAlgorithm end abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolveAlgorithm end include("utils.jl") diff --git a/src/bisection.jl b/src/bisection.jl index c00ed19c2..8270ba9a3 100644 --- a/src/bisection.jl +++ b/src/bisection.jl @@ -1,114 +1,116 @@ struct Bisection <: AbstractBracketingAlgorithm - exact_left::Bool - exact_right::Bool + exact_left::Bool + exact_right::Bool end -function Bisection(;exact_left=false, exact_right=false) - Bisection(exact_left, exact_right) +function Bisection(; exact_left = false, exact_right = false) + Bisection(exact_left, exact_right) end struct BisectionCache{uType} - state::Int - left::uType - right::uType + state::Int + left::uType + right::uType end function alg_cache(alg::Bisection, left, right, p, ::Val{true}) - BisectionCache(0, left, right) + BisectionCache(0, left, right) end function alg_cache(alg::Bisection, left, right, p, ::Val{false}) - BisectionCache(0, left, right) + BisectionCache(0, left, right) end function perform_step(solver::BracketingImmutableSolver, alg::Bisection, cache) - @unpack f, p, left, right, fl, fr, cache = solver - - if cache.state == 0 - fzero = zero(fl) - fl * fr > fzero && error("Bracket became non-containing in between iterations. This could mean that " - * "input function crosses the x axis multiple times. Bisection is not the right method to solve this.") - - mid = (left + right) / 2 - - if left == mid || right == mid - @set! solver.force_stop = true - @set! solver.retcode = FLOATING_POINT_LIMIT - return solver - end - - fm = f(mid, p) - - if iszero(fm) - if alg.exact_left - @set! cache.state = 1 - @set! cache.right = mid - @set! cache.left = mid - @set! solver.cache = cache - elseif alg.exact_right - @set! solver.left = prevfloat_tdir(mid, left, right) - solver = sync_residuals!(solver) - @set! cache.state = 2 - @set! cache.left = mid - @set! solver.cache = cache - else - @set! solver.left = prevfloat_tdir(mid, left, right) - @set! solver.right = nextfloat_tdir(mid, left, right) - solver = sync_residuals!(solver) - @set! solver.force_stop = true - return solver - end - else - if sign(fm) == sign(fl) - @set! solver.left = mid - @set! solver.fl = fm - else - @set! solver.right = mid - @set! solver.fr = fm - end - end - elseif cache.state == 1 - mid = (left + cache.right) / 2 - - if cache.right == mid || left == mid - if alg.exact_right - @set! cache.state = 2 - @set! solver.cache = cache - return solver - else - @set! solver.right = nextfloat_tdir(mid, left, right) - solver = sync_residuals!(solver) - @set! solver.force_stop = true - return solver - end - end - - fm = f(mid, p) + @unpack f, p, left, right, fl, fr, cache = solver - if iszero(fm) - @set! cache.right = mid - @set! solver.cache = cache - else - @set! solver.left = mid - @set! solver.fl = fm - end - else - mid = (cache.left + right) / 2 - - if right == mid || cache.left == mid - @set! solver.force_stop = true - return solver - end - - fm = f(mid, p) + if cache.state == 0 + fzero = zero(fl) + fl * fr > fzero && + error("Bracket became non-containing in between iterations. This could mean that " + * + "input function crosses the x axis multiple times. Bisection is not the right method to solve this.") + + mid = (left + right) / 2 - if iszero(fm) - @set! cache.left = mid - @set! solver.cache = cache + if left == mid || right == mid + @set! solver.force_stop = true + @set! solver.retcode = FLOATING_POINT_LIMIT + return solver + end + + fm = f(mid, p) + + if iszero(fm) + if alg.exact_left + @set! cache.state = 1 + @set! cache.right = mid + @set! cache.left = mid + @set! solver.cache = cache + elseif alg.exact_right + @set! solver.left = prevfloat_tdir(mid, left, right) + solver = sync_residuals!(solver) + @set! cache.state = 2 + @set! cache.left = mid + @set! solver.cache = cache + else + @set! solver.left = prevfloat_tdir(mid, left, right) + @set! solver.right = nextfloat_tdir(mid, left, right) + solver = sync_residuals!(solver) + @set! solver.force_stop = true + return solver + end + else + if sign(fm) == sign(fl) + @set! solver.left = mid + @set! solver.fl = fm + else + @set! solver.right = mid + @set! solver.fr = fm + end + end + elseif cache.state == 1 + mid = (left + cache.right) / 2 + + if cache.right == mid || left == mid + if alg.exact_right + @set! cache.state = 2 + @set! solver.cache = cache + return solver + else + @set! solver.right = nextfloat_tdir(mid, left, right) + solver = sync_residuals!(solver) + @set! solver.force_stop = true + return solver + end + end + + fm = f(mid, p) + + if iszero(fm) + @set! cache.right = mid + @set! solver.cache = cache + else + @set! solver.left = mid + @set! solver.fl = fm + end else - @set! solver.right = mid - @set! solver.fr = fm + mid = (cache.left + right) / 2 + + if right == mid || cache.left == mid + @set! solver.force_stop = true + return solver + end + + fm = f(mid, p) + + if iszero(fm) + @set! cache.left = mid + @set! solver.cache = cache + else + @set! solver.right = mid + @set! solver.fr = fm + end end - end - solver + solver end diff --git a/src/falsi.jl b/src/falsi.jl index 2700bdbe9..f184088cb 100644 --- a/src/falsi.jl +++ b/src/falsi.jl @@ -1,45 +1,46 @@ -struct Falsi <: AbstractBracketingAlgorithm -end +struct Falsi <: AbstractBracketingAlgorithm end function alg_cache(alg::Falsi, left, right, p, ::Val{true}) - nothing + nothing end function alg_cache(alg::Falsi, left, right, p, ::Val{false}) - nothing + nothing end function perform_step(solver, alg::Falsi, cache) - @unpack f, p, left, right, fl, fr = solver + @unpack f, p, left, right, fl, fr = solver - fzero = zero(fl) - fl * fr > fzero && error("Bracket became non-containing in between iterations. This could mean that " - * "input function crosses the x axis multiple times. Bisection is not the right method to solve this.") + fzero = zero(fl) + fl * fr > fzero && + error("Bracket became non-containing in between iterations. This could mean that " + * + "input function crosses the x axis multiple times. Bisection is not the right method to solve this.") - mid = (fr * left - fl * right) / (fr - fl) - - if right == mid || right == mid - @set! solver.force_stop = true - @set! solver.retcode = FLOATING_POINT_LIMIT - return solver - end - - fm = f(mid, p) - - if iszero(fm) - # todo: phase 2 bisection similar to the raw method - @set! solver.force_stop = true - @set! solver.left = mid - @set! solver.fl = fm - @set! solver.retcode = EXACT_SOLUTION_LEFT - else - if sign(fm) == sign(fl) - @set! solver.left = mid - @set! solver.fl = fm + mid = (fr * left - fl * right) / (fr - fl) + + if right == mid || right == mid + @set! solver.force_stop = true + @set! solver.retcode = FLOATING_POINT_LIMIT + return solver + end + + fm = f(mid, p) + + if iszero(fm) + # todo: phase 2 bisection similar to the raw method + @set! solver.force_stop = true + @set! solver.left = mid + @set! solver.fl = fm + @set! solver.retcode = EXACT_SOLUTION_LEFT else - @set! solver.right = mid - @set! solver.fr = fm + if sign(fm) == sign(fl) + @set! solver.left = mid + @set! solver.fl = fm + else + @set! solver.right = mid + @set! solver.fr = fm + end end - end - return solver + return solver end diff --git a/src/raphson.jl b/src/raphson.jl index 2fe548f1b..d4c768f16 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -1,10 +1,12 @@ -struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS,AD} +struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS, AD} diff_type::DT linsolve::L end -function NewtonRaphson(;autodiff=true,chunk_size=12,diff_type=Val{:forward},linsolve=DEFAULT_LINSOLVE) - NewtonRaphson{chunk_size, autodiff, typeof(diff_type), typeof(linsolve)}(diff_type, linsolve) +function NewtonRaphson(; autodiff = true, chunk_size = 12, diff_type = Val{:forward}, + linsolve = DEFAULT_LINSOLVE) + NewtonRaphson{chunk_size, autodiff, typeof(diff_type), typeof(linsolve)}(diff_type, + linsolve) end mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC} @@ -16,7 +18,7 @@ mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC} end function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true}) - uf = JacobianWrapper(f,p) + uf = JacobianWrapper(f, p) linsolve = alg.linsolve(Val{:init}, f, u) J = false .* u .* u' du1 = zero(u) @@ -28,7 +30,9 @@ function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true}) du2 = zero(u) jac_config = FiniteDiff.JacobianCache(tmp, du1, du2, alg.diff_type) else - jac_config = FiniteDiff.JacobianCache(Complex{eltype(tmp)}.(tmp),Complex{eltype(du1)}.(du1),nothing,alg.diff_type,eltype(u)) + jac_config = FiniteDiff.JacobianCache(Complex{eltype(tmp)}.(tmp), + Complex{eltype(du1)}.(du1), nothing, + alg.diff_type, eltype(u)) end end NewtonRaphsonCache(uf, linsolve, J, du1, jac_config) diff --git a/src/scalar.jl b/src/scalar.jl index f97a9e312..b81d259bf 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -1,187 +1,226 @@ -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...) - f = Base.Fix2(prob.f, prob.p) - x = float(prob.u0) - fx = float(prob.u0) - T = typeof(x) - atol = xatol !== nothing ? xatol : oneunit(eltype(T)) * (eps(one(eltype(T))))^(4//5) - rtol = xrtol !== nothing ? xrtol : eps(one(eltype(T)))^(4//5) - - if typeof(x) <: Number - xo = oftype(one(eltype(x)), Inf) - else - xo = map(x->oftype(one(eltype(x)), Inf),x) - end - - for i in 1:maxiters - if alg_autodiff(alg) - fx, dfx = value_derivative(f, x) - elseif x isa AbstractArray - fx = f(x) - dfx = FiniteDiff.finite_difference_jacobian(f, x, alg.diff_type, eltype(x), fx) +function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}}, + alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, + maxiters = 1000, kwargs...) + f = Base.Fix2(prob.f, prob.p) + x = float(prob.u0) + fx = float(prob.u0) + T = typeof(x) + atol = xatol !== nothing ? xatol : oneunit(eltype(T)) * (eps(one(eltype(T))))^(4 // 5) + rtol = xrtol !== nothing ? xrtol : eps(one(eltype(T)))^(4 // 5) + + if typeof(x) <: Number + xo = oftype(one(eltype(x)), Inf) else - fx = f(x) - dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx) + xo = map(x -> oftype(one(eltype(x)), Inf), x) end - iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT)) - Δx = dfx \ fx - x -= Δx - if isapprox(x, xo, atol=atol, rtol=rtol) - return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT)) + + for i in 1:maxiters + if alg_autodiff(alg) + fx, dfx = value_derivative(f, x) + elseif x isa AbstractArray + fx = f(x) + dfx = FiniteDiff.finite_difference_jacobian(f, x, alg.diff_type, eltype(x), fx) + else + fx = f(x) + dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), + fx) + end + iszero(fx) && + return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(DEFAULT)) + Δx = dfx \ fx + x -= Δx + if isapprox(x, xo, atol = atol, rtol = rtol) + return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(DEFAULT)) + end + xo = x end - xo = x - end - return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(MAXITERS_EXCEED)) + return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(MAXITERS_EXCEED)) end function scalar_nlsolve_ad(prob, alg, args...; kwargs...) - f = prob.f - p = value(prob.p) - u0 = value(prob.u0) - - newprob = NonlinearProblem(f, u0, p; prob.kwargs...) - sol = solve(newprob, alg, args...; kwargs...) - - uu = sol.u - if p isa Number - f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p) - else - f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p) - end - - f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu) - pp = prob.p - sumfun = let f_x′ = -f_x - ((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p) - end - partials = sum(sumfun, zip(f_p, pp)) - return sol, partials -end + f = prob.f + p = value(prob.p) + u0 = value(prob.u0) + + newprob = NonlinearProblem(f, u0, p; prob.kwargs...) + sol = solve(newprob, alg, args...; kwargs...) + + uu = sol.u + if p isa Number + f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p) + else + f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p) + end -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} - sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode) + f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu) + pp = prob.p + sumfun = let f_x′ = -f_x + ((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p) + end + partials = sum(sumfun, zip(f_p, pp)) + return sol, partials +end +function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip, + <:Dual{T, V, P}}, alg::NewtonRaphson, + args...; kwargs...) where {iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid; + retcode = sol.retcode) end -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number,SVector}, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} - sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode) +function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip, + <:AbstractArray{<:Dual{T, V, P}}}, + alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid; + retcode = sol.retcode) end # avoid ambiguities for Alg in [Bisection] - @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} - sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials)) - #return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) - end - @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} - sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials)) - #return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) - end + @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T, V, P}}, + alg::$Alg, args...; + kwargs...) where {uType, iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), + sol.resid; retcode = sol.retcode, + left = Dual{T, V, P}(sol.left, partials), + right = Dual{T, V, P}(sol.right, partials)) + #return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) + end + @eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, + <:AbstractArray{<:Dual{T, V, P}}}, + alg::$Alg, args...; + kwargs...) where {uType, iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), + sol.resid; retcode = sol.retcode, + left = Dual{T, V, P}(sol.left, partials), + right = Dual{T, V, P}(sol.right, partials)) + #return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid) + end end -function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000, kwargs...) - f = Base.Fix2(prob.f, prob.p) - left, right = prob.u0 - fl, fr = f(left), f(right) +function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000, + kwargs...) + f = Base.Fix2(prob.f, prob.p) + left, right = prob.u0 + fl, fr = f(left), f(right) - if iszero(fl) - return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right) - end + if iszero(fl) + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = Symbol(EXACT_SOLUTION_LEFT), left = left, + right = right) + end - i = 1 - if !iszero(fr) - while i < maxiters - mid = (left + right) / 2 - (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) - fm = f(mid) - if iszero(fm) - right = mid - break - end - if sign(fl) == sign(fm) - fl = fm - left = mid - else - fr = fm - right = mid - end - i += 1 + i = 1 + if !iszero(fr) + while i < maxiters + mid = (left + right) / 2 + (mid == left || mid == right) && + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = Symbol(FLOATING_POINT_LIMIT), + left = left, right = right) + fm = f(mid) + if iszero(fm) + right = mid + break + end + if sign(fl) == sign(fm) + fl = fm + left = mid + else + fr = fm + right = mid + end + i += 1 + end end - end - - while i < maxiters - mid = (left + right) / 2 - (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) - fm = f(mid) - if iszero(fm) - right = mid - fr = fm - else - left = mid - fl = fm + + while i < maxiters + mid = (left + right) / 2 + (mid == left || mid == right) && + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = Symbol(FLOATING_POINT_LIMIT), + left = left, right = right) + fm = f(mid) + if iszero(fm) + right = mid + fr = fm + else + left = mid + fl = fm + end + i += 1 end - i += 1 - end - return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right) + return SciMLBase.build_solution(prob, alg, left, fl; retcode = Symbol(MAXITERS_EXCEED), + left = left, right = right) end -function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000, kwargs...) - f = Base.Fix2(prob.f, prob.p) - left, right = prob.u0 - fl, fr = f(left), f(right) +function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000, + kwargs...) + f = Base.Fix2(prob.f, prob.p) + left, right = prob.u0 + fl, fr = f(left), f(right) - if iszero(fl) - return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right) - end + if iszero(fl) + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = Symbol(EXACT_SOLUTION_LEFT), left = left, + right = right) + end - i = 1 - if !iszero(fr) - while i < maxiters - if nextfloat_tdir(left, prob.u0...) == right - return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) - end - mid = (fr * left - fl * right) / (fr - fl) - for i in 1:10 - mid = max_tdir(left, prevfloat_tdir(mid, prob.u0...), prob.u0...) - end - if mid == right || mid == left - break - end - fm = f(mid) - if iszero(fm) - right = mid - break - end - if sign(fl) == sign(fm) - fl = fm - left = mid - else - fr = fm - right = mid - end - i += 1 + i = 1 + if !iszero(fr) + while i < maxiters + if nextfloat_tdir(left, prob.u0...) == right + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = Symbol(FLOATING_POINT_LIMIT), + left = left, right = right) + end + mid = (fr * left - fl * right) / (fr - fl) + for i in 1:10 + mid = max_tdir(left, prevfloat_tdir(mid, prob.u0...), prob.u0...) + end + if mid == right || mid == left + break + end + fm = f(mid) + if iszero(fm) + right = mid + break + end + if sign(fl) == sign(fm) + fl = fm + left = mid + else + fr = fm + right = mid + end + i += 1 + end end - end - - while i < maxiters - mid = (left + right) / 2 - (mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right) - fm = f(mid) - if iszero(fm) - right = mid - fr = fm - elseif sign(fm) == sign(fl) - left = mid - fl = fm - else - right = mid - fr = fm + + while i < maxiters + mid = (left + right) / 2 + (mid == left || mid == right) && + return SciMLBase.build_solution(prob, alg, left, fl; + retcode = Symbol(FLOATING_POINT_LIMIT), + left = left, right = right) + fm = f(mid) + if iszero(fm) + right = mid + fr = fm + elseif sign(fm) == sign(fl) + left = mid + fl = fm + else + right = mid + fr = fm + end + i += 1 end - i += 1 - end - return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right) + return SciMLBase.build_solution(prob, alg, left, fl; retcode = Symbol(MAXITERS_EXCEED), + left = left, right = right) end diff --git a/src/solve.jl b/src/solve.jl index 5235056bf..efd73d65a 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,76 +1,79 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...) - solver = init(prob, alg, args...; kwargs...) - sol = solve!(solver) + solver = init(prob, alg, args...; kwargs...) + sol = solve!(solver) end -function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...; - alias_u0 = false, - maxiters = 1000, - kwargs... - ) where {uType, iip} +function SciMLBase.init(prob::NonlinearProblem{uType, iip}, + alg::AbstractBracketingAlgorithm, args...; + alias_u0 = false, + maxiters = 1000, + kwargs...) where {uType, iip} + if !(prob.u0 isa Tuple) + error("You need to pass a tuple of u0 in bracketing algorithms.") + end - if !(prob.u0 isa Tuple) - error("You need to pass a tuple of u0 in bracketing algorithms.") - end + if eltype(prob.u0) isa AbstractArray + error("Bracketing Algorithms work for scalar arguments only") + end - if eltype(prob.u0) isa AbstractArray - error("Bracketing Algorithms work for scalar arguments only") - end - - if alias_u0 - left, right = prob.u0 - else - left, right = deepcopy(prob.u0) - end - f = prob.f - p = prob.p - fl = f(left, p) - fr = f(right, p) - cache = alg_cache(alg, left, right,p, Val(iip)) - return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip,prob) + if alias_u0 + left, right = prob.u0 + else + left, right = deepcopy(prob.u0) + end + f = prob.f + p = prob.p + fl = f(left, p) + fr = f(right, p) + cache = alg_cache(alg, left, right, p, Val(iip)) + return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, + DEFAULT, cache, iip, prob) end -function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...; - alias_u0 = false, - maxiters = 1000, - tol = 1e-6, - internalnorm = DEFAULT_NORM, - kwargs... - ) where {uType, iip} - - if alias_u0 - u = prob.u0 - else - u = deepcopy(prob.u0) - end - f = prob.f - p = prob.p - if iip - fu = zero(u) - f(fu, u, p) - else - fu = f(u, p) - end - cache = alg_cache(alg, f, u, p, Val(iip)) - return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip, prob) +function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, + args...; + alias_u0 = false, + maxiters = 1000, + tol = 1e-6, + internalnorm = DEFAULT_NORM, + kwargs...) where {uType, iip} + if alias_u0 + u = prob.u0 + else + u = deepcopy(prob.u0) + end + f = prob.f + p = prob.p + if iip + fu = zero(u) + f(fu, u, p) + else + fu = f(u, p) + end + cache = alg_cache(alg, f, u, p, Val(iip)) + return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, + DEFAULT, tol, cache, iip, prob) end function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver) - solver = mic_check(solver) - while !solver.force_stop && solver.iter < solver.maxiters - solver = perform_step(solver, solver.alg, Val(solver.iip)) - @set! solver.iter += 1 - end - if solver.iter == solver.maxiters - @set! solver.retcode = MAXITERS_EXCEED - end - if typeof(solver) <: NewtonImmutableSolver - SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu;retcode=Symbol(solver.retcode)) - else - SciMLBase.build_solution(solver.prob, solver.alg, solver.left,solver.fl;retcode=Symbol(solver.retcode),left = solver.left,right = solver.right) - end + solver = mic_check(solver) + while !solver.force_stop && solver.iter < solver.maxiters + solver = perform_step(solver, solver.alg, Val(solver.iip)) + @set! solver.iter += 1 + end + if solver.iter == solver.maxiters + @set! solver.retcode = MAXITERS_EXCEED + end + if typeof(solver) <: NewtonImmutableSolver + SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu; + retcode = Symbol(solver.retcode)) + else + SciMLBase.build_solution(solver.prob, solver.alg, solver.left, solver.fl; + retcode = Symbol(solver.retcode), left = solver.left, + right = solver.right) + end end """ @@ -80,22 +83,22 @@ end Checks before running main solving iterations. """ function mic_check(solver::BracketingImmutableSolver) - @unpack f, fl, fr = solver - flr = fl * fr - fzero = zero(flr) - (flr > fzero) && error("Non bracketing interval passed in bracketing method.") - if fl == fzero - @set! solver.force_stop = true - @set! solver.retcode = EXACT_SOLUTION_LEFT - elseif fr == fzero - @set! solver.force_stop = true - @set! solver.retcode = EXACT_SOLUTION_RIGHT - end - solver + @unpack f, fl, fr = solver + flr = fl * fr + fzero = zero(flr) + (flr > fzero) && error("Non bracketing interval passed in bracketing method.") + if fl == fzero + @set! solver.force_stop = true + @set! solver.retcode = EXACT_SOLUTION_LEFT + elseif fr == fzero + @set! solver.force_stop = true + @set! solver.retcode = EXACT_SOLUTION_RIGHT + end + solver end function mic_check(solver::NewtonImmutableSolver) - solver + solver end """ @@ -103,16 +106,18 @@ end Reinitialize solver to the original starting conditions """ -function SciMLBase.reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, true}) where {uType} - @. solver.u = prob.u0 - @set! solver.iter = 1 - @set! solver.force_stop = false - return solver +function SciMLBase.reinit!(solver::NewtonImmutableSolver, + prob::NonlinearProblem{uType, true}) where {uType} + @. solver.u = prob.u0 + @set! solver.iter = 1 + @set! solver.force_stop = false + return solver end -function SciMLBase.reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, false}) where {uType} - @set! solver.u = prob.u0 - @set! solver.iter = 1 - @set! solver.force_stop = false - return solver +function SciMLBase.reinit!(solver::NewtonImmutableSolver, + prob::NonlinearProblem{uType, false}) where {uType} + @set! solver.u = prob.u0 + @set! solver.iter = 1 + @set! solver.force_stop = false + return solver end diff --git a/src/types.jl b/src/types.jl index 9e673188a..c272e09f7 100644 --- a/src/types.jl +++ b/src/types.jl @@ -6,7 +6,8 @@ FLOATING_POINT_LIMIT end -struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType} <: AbstractImmutableNonlinearSolver +struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType + } <: AbstractImmutableNonlinearSolver iter::Int f::fType alg::algType @@ -28,7 +29,8 @@ end # typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache) # end -struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType, probType} <: AbstractImmutableNonlinearSolver +struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, + cacheType, probType} <: AbstractImmutableNonlinearSolver iter::Int f::fType alg::algType @@ -50,9 +52,8 @@ end # typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) # end - function sync_residuals!(solver::BracketingImmutableSolver) @set! solver.fl = solver.f(solver.left, solver.p) @set! solver.fr = solver.f(solver.right, solver.p) solver -end \ No newline at end of file +end diff --git a/src/utils.jl b/src/utils.jl index c7aa8d3a2..29411bfa4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -17,88 +17,87 @@ expands to: end """ macro add_kwonly(ex) - esc(add_kwonly(ex)) + esc(add_kwonly(ex)) end add_kwonly(ex::Expr) = add_kwonly(Val{ex.head}, ex) -function add_kwonly(::Type{<: Val}, ex) - error("add_only does not work with expression $(ex.head)") +function add_kwonly(::Type{<:Val}, ex) + error("add_only does not work with expression $(ex.head)") end function add_kwonly(::Union{Type{Val{:function}}, Type{Val{:(=)}}}, ex::Expr) - body = ex.args[2:end] # function body - default_call = ex.args[1] # e.g., :(f(a, b=2; c=3)) - kwonly_call = add_kwonly(default_call) - if kwonly_call === nothing - return ex - end - - return quote - begin - $ex - $(Expr(ex.head, kwonly_call, body...)) + body = ex.args[2:end] # function body + default_call = ex.args[1] # e.g., :(f(a, b=2; c=3)) + kwonly_call = add_kwonly(default_call) + if kwonly_call === nothing + return ex + end + + return quote + begin + $ex + $(Expr(ex.head, kwonly_call, body...)) + end end - end end function add_kwonly(::Type{Val{:where}}, ex::Expr) - default_call = ex.args[1] - rest = ex.args[2:end] - kwonly_call = add_kwonly(default_call) - if kwonly_call === nothing - return nothing - end - return Expr(:where, kwonly_call, rest...) + default_call = ex.args[1] + rest = ex.args[2:end] + kwonly_call = add_kwonly(default_call) + if kwonly_call === nothing + return nothing + end + return Expr(:where, kwonly_call, rest...) end function add_kwonly(::Type{Val{:call}}, default_call::Expr) - # default_call is, e.g., :(f(a, b=2; c=3)) - funcname = default_call.args[1] # e.g., :f - required = [] # required positional arguments; e.g., [:a] - optional = [] # optional positional arguments; e.g., [:(b=2)] - default_kwargs = [] - for arg in default_call.args[2:end] - if isa(arg, Symbol) - push!(required, arg) - elseif arg.head == :(::) - push!(required, arg) - elseif arg.head == :kw - push!(optional, arg) - elseif arg.head == :parameters - @assert default_kwargs == [] # can I have :parameters twice? - default_kwargs = arg.args - else - error("Not expecting to see: $arg") + # default_call is, e.g., :(f(a, b=2; c=3)) + funcname = default_call.args[1] # e.g., :f + required = [] # required positional arguments; e.g., [:a] + optional = [] # optional positional arguments; e.g., [:(b=2)] + default_kwargs = [] + for arg in default_call.args[2:end] + if isa(arg, Symbol) + push!(required, arg) + elseif arg.head == :(::) + push!(required, arg) + elseif arg.head == :kw + push!(optional, arg) + elseif arg.head == :parameters + @assert default_kwargs == [] # can I have :parameters twice? + default_kwargs = arg.args + else + error("Not expecting to see: $arg") + end + end + if isempty(required) && isempty(optional) + # If the function is already keyword-only, do nothing: + return nothing + end + if isempty(required) + # It's not clear what should be done. Let's not support it at + # the moment: + error("At least one positional mandatory argument is required.") end - end - if isempty(required) && isempty(optional) - # If the function is already keyword-only, do nothing: - return nothing - end - if isempty(required) - # It's not clear what should be done. Let's not support it at - # the moment: - error("At least one positional mandatory argument is required.") - end - kwonly_kwargs = Expr(:parameters, [ - Expr(:kw, pa, :(error($("No argument $pa")))) - for pa in required - ]..., optional..., default_kwargs...) - kwonly_call = Expr(:call, funcname, kwonly_kwargs) - # e.g., :(f(; a=error(...), b=error(...), c=1, d=2)) + kwonly_kwargs = Expr(:parameters, + [Expr(:kw, pa, :(error($("No argument $pa")))) + for pa in required]..., optional..., default_kwargs...) + kwonly_call = Expr(:call, funcname, kwonly_kwargs) + # e.g., :(f(; a=error(...), b=error(...), c=1, d=2)) - return kwonly_call + return kwonly_call end function num_types_in_tuple(sig) - length(sig.parameters) + length(sig.parameters) end function num_types_in_tuple(sig::UnionAll) - length(Base.unwrap_unionall(sig).parameters) + length(Base.unwrap_unionall(sig).parameters) end ### Default Linsolve @@ -109,88 +108,95 @@ end # gmres if operator mutable struct DefaultLinSolve - A - iterable + A::Any + iterable::Any end DefaultLinSolve() = DefaultLinSolve(nothing, nothing) -function (p::DefaultLinSolve)(x,A,b,update_matrix=false;tol=nothing, kwargs...) - if p.iterable isa Vector && eltype(p.iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector - F = LU{eltype(A)}(A, p.iterable, zero(LinearAlgebra.BlasInt)) - ldiv!(x, F, b) - return nothing - end - if update_matrix - if typeof(A) <: Matrix - blasvendor = BLAS.vendor() - # if the user doesn't use OpenBLAS, we assume that is a better BLAS - # implementation like MKL - # - # RecursiveFactorization seems to be consistantly winning below 100 - # https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213 - if ArrayInterfaceCore.can_setindex(x) && (size(A,1) <= 100 || ((blasvendor === :openblas || blasvendor === :openblas64) && size(A,1) <= 500)) - p.A = RecursiveFactorization.lu!(A) - else - p.A = lu!(A) - end - elseif typeof(A) <: Tridiagonal - p.A = lu!(A) - elseif typeof(A) <: Union{SymTridiagonal} - p.A = ldlt!(A) - elseif typeof(A) <: Union{Symmetric,Hermitian} - p.A = bunchkaufman!(A) - elseif typeof(A) <: SparseMatrixCSC - p.A = lu(A) - elseif ArrayInterfaceCore.isstructured(A) - p.A = factorize(A) - elseif !(typeof(A) <: AbstractDiffEqOperator) - # Most likely QR is the one that is overloaded - # Works on things like CuArrays - p.A = qr(A) +function (p::DefaultLinSolve)(x, A, b, update_matrix = false; tol = nothing, kwargs...) + if p.iterable isa Vector && eltype(p.iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector + F = LU{eltype(A)}(A, p.iterable, zero(LinearAlgebra.BlasInt)) + ldiv!(x, F, b) + return nothing end - end - - if typeof(A) <: Union{Matrix,SymTridiagonal,Tridiagonal,Symmetric,Hermitian} # No 2-arg form for SparseArrays! - x .= b - ldiv!(p.A,x) - # Missing a little bit of efficiency in a rare case - #elseif typeof(A) <: DiffEqArrayOperator - # ldiv!(x,p.A,b) - elseif ArrayInterfaceCore.isstructured(A) || A isa SparseMatrixCSC - ldiv!(x,p.A,b) - elseif typeof(A) <: AbstractDiffEqOperator - # No good starting guess, so guess zero - if p.iterable === nothing - p.iterable = IterativeSolvers.gmres_iterable!(x,A,b;initially_zero=true,restart=5,maxiter=5,tol=1e-16,kwargs...) - p.iterable.reltol = tol + if update_matrix + if typeof(A) <: Matrix + blasvendor = BLAS.vendor() + # if the user doesn't use OpenBLAS, we assume that is a better BLAS + # implementation like MKL + # + # RecursiveFactorization seems to be consistantly winning below 100 + # https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213 + if ArrayInterfaceCore.can_setindex(x) && (size(A, 1) <= 100 || + ((blasvendor === :openblas || blasvendor === :openblas64) && + size(A, 1) <= 500)) + p.A = RecursiveFactorization.lu!(A) + else + p.A = lu!(A) + end + elseif typeof(A) <: Tridiagonal + p.A = lu!(A) + elseif typeof(A) <: Union{SymTridiagonal} + p.A = ldlt!(A) + elseif typeof(A) <: Union{Symmetric, Hermitian} + p.A = bunchkaufman!(A) + elseif typeof(A) <: SparseMatrixCSC + p.A = lu(A) + elseif ArrayInterfaceCore.isstructured(A) + p.A = factorize(A) + elseif !(typeof(A) <: AbstractDiffEqOperator) + # Most likely QR is the one that is overloaded + # Works on things like CuArrays + p.A = qr(A) + end end - x .= false - iter = p.iterable - purge_history!(iter, x, b) - for residual in iter + if typeof(A) <: Union{Matrix, SymTridiagonal, Tridiagonal, Symmetric, Hermitian} # No 2-arg form for SparseArrays! + x .= b + ldiv!(p.A, x) + # Missing a little bit of efficiency in a rare case + #elseif typeof(A) <: DiffEqArrayOperator + # ldiv!(x,p.A,b) + elseif ArrayInterfaceCore.isstructured(A) || A isa SparseMatrixCSC + ldiv!(x, p.A, b) + elseif typeof(A) <: AbstractDiffEqOperator + # No good starting guess, so guess zero + if p.iterable === nothing + p.iterable = IterativeSolvers.gmres_iterable!(x, A, b; initially_zero = true, + restart = 5, maxiter = 5, + tol = 1e-16, kwargs...) + p.iterable.reltol = tol + end + x .= false + iter = p.iterable + purge_history!(iter, x, b) + + for residual in iter + end + else + ldiv!(x, p.A, b) end - else - ldiv!(x,p.A,b) - end - return nothing + return nothing end -function (p::DefaultLinSolve)(::Type{Val{:init}},f,u0_prototype) - DefaultLinSolve() +function (p::DefaultLinSolve)(::Type{Val{:init}}, f, u0_prototype) + DefaultLinSolve() end const DEFAULT_LINSOLVE = DefaultLinSolve() @inline UNITLESS_ABS2(x) = real(abs2(x)) -@inline DEFAULT_NORM(u::Union{AbstractFloat,Complex}) = @fastmath abs(u) -@inline DEFAULT_NORM(u::Array{T}) where T<:Union{AbstractFloat,Complex} = - sqrt(real(sum(abs2,u)) / length(u)) -@inline DEFAULT_NORM(u::StaticArray{T}) where T<:Union{AbstractFloat,Complex} = - sqrt(real(sum(abs2,u)) / length(u)) -@inline DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray) = - sum(sqrt(real(sum(UNITLESS_ABS2,_u)) / length(_u)) for _u in u.u) -@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2,u)) / length(u)) +@inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u) +@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}} + sqrt(real(sum(abs2, u)) / length(u)) +end +@inline function DEFAULT_NORM(u::StaticArray{T}) where {T <: Union{AbstractFloat, Complex}} + sqrt(real(sum(abs2, u)) / length(u)) +end +@inline function DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray) + sum(sqrt(real(sum(UNITLESS_ABS2, _u)) / length(_u)) for _u in u.u) +end +@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2, u)) / length(u)) @inline DEFAULT_NORM(u) = norm(u) """ @@ -199,18 +205,18 @@ const DEFAULT_LINSOLVE = DefaultLinSolve() Move `x` one floating point towards x0. """ function prevfloat_tdir(x, x0, x1) - x1 > x0 ? prevfloat(x) : nextfloat(x) + x1 > x0 ? prevfloat(x) : nextfloat(x) end function nextfloat_tdir(x, x0, x1) - x1 > x0 ? nextfloat(x) : prevfloat(x) + x1 > x0 ? nextfloat(x) : prevfloat(x) end function max_tdir(a, b, x0, x1) - x1 > x0 ? max(a, b) : min(a, b) + x1 > x0 ? max(a, b) : min(a, b) end -alg_autodiff(alg::AbstractNewtonAlgorithm{CS,AD}) where {CS,AD} = AD +alg_autodiff(alg::AbstractNewtonAlgorithm{CS, AD}) where {CS, AD} = AD alg_autodiff(alg) = false """ @@ -218,14 +224,14 @@ alg_autodiff(alg) = false Compute `f(x), d/dx f(x)` in the most efficient way. """ -function value_derivative(f::F, x::R) where {F,R} +function value_derivative(f::F, x::R) where {F, R} T = typeof(ForwardDiff.Tag(f, R)) out = f(ForwardDiff.Dual{T}(x, one(x))) ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out) end # Todo: improve this dispatch -value_derivative(f::F, x::SVector) where F = f(x),ForwardDiff.jacobian(f, x) +value_derivative(f::F, x::SVector) where {F} = f(x), ForwardDiff.jacobian(f, x) value(x) = x value(x::Dual) = ForwardDiff.value(x) diff --git a/test/basictests.jl b/test/basictests.jl index bdcb4b457..e37293b62 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,181 +1,181 @@ -using NonlinearSolve -using StaticArrays -using BenchmarkTools -using Test - -function benchmark_immutable(f, u0) - probN = NonlinearProblem{false}(f, u0) - solver = init(probN, NewtonRaphson(), tol = 1e-9) - sol = solve!(solver) -end - -function benchmark_mutable(f, u0) - probN = NonlinearProblem{false}(f, u0) - solver = init(probN, NewtonRaphson(), tol = 1e-9) - sol = (reinit!(solver, probN); solve!(solver)) -end - -function benchmark_scalar(f, u0) - probN = NonlinearProblem{false}(f, u0) - sol = (solve(probN, NewtonRaphson())) -end - -function ff(u,p) - u .* u .- 2 -end -const cu0 = @SVector[1.0, 1.0] -function sf(u,p) - u * u - 2 -end -const csu0 = 1.0 - -sol = benchmark_immutable(ff, cu0) -@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) -@test all(sol.u .* sol.u .- 2 .< 1e-9) -sol = benchmark_mutable(ff, cu0) -@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) -@test all(sol.u .* sol.u .- 2 .< 1e-9) -sol = benchmark_scalar(sf, csu0) -@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) -@test sol.u * sol.u - 2 < 1e-9 - -@test (@ballocated benchmark_immutable(ff, cu0)) == 0 -@test (@ballocated benchmark_mutable(ff, cu0)) < 200 -@test (@ballocated benchmark_scalar(sf, csu0)) == 0 - -# AD Tests -using ForwardDiff - -# Immutable -f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0] - -g = function (p) - probN = NonlinearProblem{false}(f, csu0, p) - sol = solve(probN, NewtonRaphson(), tol = 1e-9) - return sol.u[end] -end - -for p in 1.0:0.1:100.0 - @test g(p) ≈ sqrt(p) - @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) -end - -# Scalar -f, u0 = (u, p) -> u * u - p, 1.0 - -# NewtonRaphson -g = function (p) - probN = NonlinearProblem{false}(f, oftype(p, u0), p) - sol = solve(probN, NewtonRaphson()) - return sol.u -end - -@test ForwardDiff.derivative(g, 1.0) ≈ 0.5 - -for p in 1.1:0.1:100.0 - @test g(p) ≈ sqrt(p) - @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) -end - -u0 = (1.0, 20.0) -# Falsi -g = function (p) - probN = NonlinearProblem{false}(f, typeof(p).(u0), p) - sol = solve(probN, Falsi()) - return sol.left -end - -for p in 1.1:0.1:100.0 - @test g(p) ≈ sqrt(p) - @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) -end - -f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0) -t = (p) -> [sqrt(p[2] / p[1])] -p = [0.9, 50.0] -for alg in [Bisection(), Falsi()] - global g, p - g = function (p) - probN = NonlinearProblem{false}(f, u0, p) - sol = solve(probN, Bisection()) - return [sol.left] - end - - @test g(p) ≈ [sqrt(p[2] / p[1])] - @test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p) -end - -gnewton = function (p) - probN = NonlinearProblem{false}(f, 0.5, p) - sol = solve(probN, NewtonRaphson()) - return [sol.u] -end -@test gnewton(p) ≈ [sqrt(p[2] / p[1])] -@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) - -# Error Checks - -f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0] -probN = NonlinearProblem(f, u0) - -@test solve(probN, NewtonRaphson()).u[end] ≈ sqrt(2.0) -@test solve(probN, NewtonRaphson(); immutable = false).u[end] ≈ sqrt(2.0) -@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] ≈ sqrt(2.0) -@test solve(probN, NewtonRaphson(;autodiff=false)).u[end] ≈ sqrt(2.0) - -for u0 in [1.0, [1, 1.0]] - local f, probN, sol - f = (u, p) -> u .* u .- 2.0 - probN = NonlinearProblem(f, u0) - sol = sqrt(2) * u0 - - @test solve(probN, NewtonRaphson()).u ≈ sol - @test solve(probN, NewtonRaphson()).u ≈ sol - @test solve(probN, NewtonRaphson(;autodiff=false)).u ≈ sol -end - -# Bisection Tests -f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0) -probB = NonlinearProblem(f, u0) - -# Falsi -solver = init(probB, Falsi()) -sol = solve!(solver) -@test sol.left ≈ sqrt(2.0) - -# this should call the fast scalar overload -@test solve(probB, Bisection()).left ≈ sqrt(2.0) - -# these should call the iterator version -solver = init(probB, Bisection()) -@test solver isa NonlinearSolve.BracketingImmutableSolver -@test solve!(solver).left ≈ sqrt(2.0) - -# Garuntee Tests for Bisection -f = function (u, p) - if u < 2.0 - return u - 2.0 - elseif u > 3.0 - return u - 3.0 - else - return 0.0 - end -end -probB = NonlinearProblem(f, (0.0, 4.0)) - -solver = init(probB, Bisection(;exact_left = true)) -sol = solve!(solver) -@test f(sol.left, nothing) < 0.0 -@test f(nextfloat(sol.left), nothing) >= 0.0 - -solver = init(probB, Bisection(;exact_right = true)) -sol = solve!(solver) -@test f(sol.right, nothing) > 0.0 -@test f(prevfloat(sol.right), nothing) <= 0.0 - -solver = init(probB, Bisection(;exact_left = true, exact_right = true); immutable = false) -sol = solve!(solver) -@test f(sol.left, nothing) < 0.0 -@test f(nextfloat(sol.left), nothing) >= 0.0 -@test f(sol.right, nothing) > 0.0 -@test f(prevfloat(sol.right), nothing) <= 0.0 +using NonlinearSolve +using StaticArrays +using BenchmarkTools +using Test + +function benchmark_immutable(f, u0) + probN = NonlinearProblem{false}(f, u0) + solver = init(probN, NewtonRaphson(), tol = 1e-9) + sol = solve!(solver) +end + +function benchmark_mutable(f, u0) + probN = NonlinearProblem{false}(f, u0) + solver = init(probN, NewtonRaphson(), tol = 1e-9) + sol = (reinit!(solver, probN); solve!(solver)) +end + +function benchmark_scalar(f, u0) + probN = NonlinearProblem{false}(f, u0) + sol = (solve(probN, NewtonRaphson())) +end + +function ff(u, p) + u .* u .- 2 +end +const cu0 = @SVector[1.0, 1.0] +function sf(u, p) + u * u - 2 +end +const csu0 = 1.0 + +sol = benchmark_immutable(ff, cu0) +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) +@test all(sol.u .* sol.u .- 2 .< 1e-9) +sol = benchmark_mutable(ff, cu0) +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) +@test all(sol.u .* sol.u .- 2 .< 1e-9) +sol = benchmark_scalar(sf, csu0) +@test sol.retcode === Symbol(NonlinearSolve.DEFAULT) +@test sol.u * sol.u - 2 < 1e-9 + +@test (@ballocated benchmark_immutable(ff, cu0)) == 0 +@test (@ballocated benchmark_mutable(ff, cu0)) < 200 +@test (@ballocated benchmark_scalar(sf, csu0)) == 0 + +# AD Tests +using ForwardDiff + +# Immutable +f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0] + +g = function (p) + probN = NonlinearProblem{false}(f, csu0, p) + sol = solve(probN, NewtonRaphson(), tol = 1e-9) + return sol.u[end] +end + +for p in 1.0:0.1:100.0 + @test g(p) ≈ sqrt(p) + @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) +end + +# Scalar +f, u0 = (u, p) -> u * u - p, 1.0 + +# NewtonRaphson +g = function (p) + probN = NonlinearProblem{false}(f, oftype(p, u0), p) + sol = solve(probN, NewtonRaphson()) + return sol.u +end + +@test ForwardDiff.derivative(g, 1.0) ≈ 0.5 + +for p in 1.1:0.1:100.0 + @test g(p) ≈ sqrt(p) + @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) +end + +u0 = (1.0, 20.0) +# Falsi +g = function (p) + probN = NonlinearProblem{false}(f, typeof(p).(u0), p) + sol = solve(probN, Falsi()) + return sol.left +end + +for p in 1.1:0.1:100.0 + @test g(p) ≈ sqrt(p) + @test ForwardDiff.derivative(g, p) ≈ 1 / (2 * sqrt(p)) +end + +f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0) +t = (p) -> [sqrt(p[2] / p[1])] +p = [0.9, 50.0] +for alg in [Bisection(), Falsi()] + global g, p + g = function (p) + probN = NonlinearProblem{false}(f, u0, p) + sol = solve(probN, Bisection()) + return [sol.left] + end + + @test g(p) ≈ [sqrt(p[2] / p[1])] + @test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p) +end + +gnewton = function (p) + probN = NonlinearProblem{false}(f, 0.5, p) + sol = solve(probN, NewtonRaphson()) + return [sol.u] +end +@test gnewton(p) ≈ [sqrt(p[2] / p[1])] +@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) + +# Error Checks + +f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0] +probN = NonlinearProblem(f, u0) + +@test solve(probN, NewtonRaphson()).u[end] ≈ sqrt(2.0) +@test solve(probN, NewtonRaphson(); immutable = false).u[end] ≈ sqrt(2.0) +@test solve(probN, NewtonRaphson(; autodiff = false)).u[end] ≈ sqrt(2.0) +@test solve(probN, NewtonRaphson(; autodiff = false)).u[end] ≈ sqrt(2.0) + +for u0 in [1.0, [1, 1.0]] + local f, probN, sol + f = (u, p) -> u .* u .- 2.0 + probN = NonlinearProblem(f, u0) + sol = sqrt(2) * u0 + + @test solve(probN, NewtonRaphson()).u ≈ sol + @test solve(probN, NewtonRaphson()).u ≈ sol + @test solve(probN, NewtonRaphson(; autodiff = false)).u ≈ sol +end + +# Bisection Tests +f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0) +probB = NonlinearProblem(f, u0) + +# Falsi +solver = init(probB, Falsi()) +sol = solve!(solver) +@test sol.left ≈ sqrt(2.0) + +# this should call the fast scalar overload +@test solve(probB, Bisection()).left ≈ sqrt(2.0) + +# these should call the iterator version +solver = init(probB, Bisection()) +@test solver isa NonlinearSolve.BracketingImmutableSolver +@test solve!(solver).left ≈ sqrt(2.0) + +# Garuntee Tests for Bisection +f = function (u, p) + if u < 2.0 + return u - 2.0 + elseif u > 3.0 + return u - 3.0 + else + return 0.0 + end +end +probB = NonlinearProblem(f, (0.0, 4.0)) + +solver = init(probB, Bisection(; exact_left = true)) +sol = solve!(solver) +@test f(sol.left, nothing) < 0.0 +@test f(nextfloat(sol.left), nothing) >= 0.0 + +solver = init(probB, Bisection(; exact_right = true)) +sol = solve!(solver) +@test f(sol.right, nothing) > 0.0 +@test f(prevfloat(sol.right), nothing) <= 0.0 + +solver = init(probB, Bisection(; exact_left = true, exact_right = true); immutable = false) +sol = solve!(solver) +@test f(sol.left, nothing) < 0.0 +@test f(nextfloat(sol.left), nothing) >= 0.0 +@test f(sol.right, nothing) > 0.0 +@test f(prevfloat(sol.right), nothing) <= 0.0 diff --git a/test/runtests.jl b/test/runtests.jl index 58afd0c94..0cf5ed83b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,13 +3,11 @@ using SafeTestsets const LONGER_TESTS = false const GROUP = get(ENV, "GROUP", "All") -const is_APPVEYOR = Sys.iswindows() && haskey(ENV,"APPVEYOR") +const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR") @time begin if GROUP == "All" || GROUP == "Interface" - #@time @safetestset "Linear Solver Tests" begin include("interface/linear_solver_test.jl") end - @time @safetestset "Basic Tests + Some AD" begin include("basictests.jl") end -end - -end + #@time @safetestset "Linear Solver Tests" begin include("interface/linear_solver_test.jl") end + @time @safetestset "Basic Tests + Some AD" begin include("basictests.jl") end +end end