Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

Commit

Permalink
Merge pull request #27 from SciML/base
Browse files Browse the repository at this point in the history
Use DiffEqBase high level handling
  • Loading branch information
ChrisRackauckas authored Jan 17, 2023
2 parents 9aab971 + 0d593de commit 2893a8e
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 32 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.5"

[deps]
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -15,6 +16,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[compat]
ArrayInterfaceCore = "0.1.1"
DiffEqBase = "6.114"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
Reexport = "0.2, 1"
Expand Down
7 changes: 4 additions & 3 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ForwardDiff: Dual
using StaticArraysCore
using LinearAlgebra
import ArrayInterfaceCore
using DiffEqBase

@reexport using SciMLBase

Expand All @@ -28,11 +29,11 @@ import SnoopPrecompile
SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
for alg in (SimpleNewtonRaphson, Broyden, Klement)
solve(prob_no_brack, alg(), tol = T(1e-2))
solve(prob_no_brack, alg(), abstol = T(1e-2))
end

for alg in (TrustRegion(10.0),)
solve(prob_no_brack, alg, tol = T(1e-2))
solve(prob_no_brack, alg, abstol = T(1e-2))
end

#=
Expand All @@ -47,7 +48,7 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)

prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
for alg in (Bisection, Falsi)
solve(prob_brack, alg(), tol = T(1e-2))
solve(prob_brack, alg(), abstol = T(1e-2))
end
end end

Expand Down
8 changes: 4 additions & 4 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ and static array problems.
"""
struct Broyden <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.solve(prob::NonlinearProblem,
alg::Broyden, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
function SciMLBase.__solve(prob::NonlinearProblem,
alg::Broyden, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fₙ = f(x)
Expand Down
8 changes: 4 additions & 4 deletions src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ This method is non-allocating on scalar problems.
"""
struct Klement <: AbstractSimpleNonlinearSolveAlgorithm end

function SciMLBase.solve(prob::NonlinearProblem,
alg::Klement, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
function SciMLBase.__solve(prob::NonlinearProblem,
alg::Klement, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fₙ = f(x)
Expand Down
8 changes: 4 additions & 4 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
end
end

function SciMLBase.solve(prob::NonlinearProblem,
alg::SimpleNewtonRaphson, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
function SciMLBase.__solve(prob::NonlinearProblem,
alg::SimpleNewtonRaphson, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
fx = float(prob.u0)
Expand Down
8 changes: 4 additions & 4 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ struct TrustRegion{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
end
end

function SciMLBase.solve(prob::NonlinearProblem,
alg::TrustRegion, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
function SciMLBase.__solve(prob::NonlinearProblem,
alg::TrustRegion, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
T = typeof(x)
Expand Down
23 changes: 10 additions & 13 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9

@test (@ballocated benchmark_scalar(sf, csu0)) == 0
if VERSION >= v"1.7"
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
end

# Broyden
function benchmark_scalar(f, u0)
Expand All @@ -33,7 +35,9 @@ end
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
if VERSION >= v"1.7"
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
end

# Klement
function benchmark_scalar(f, u0)
Expand All @@ -44,7 +48,9 @@ end
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
if VERSION >= v"1.7"
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
end

# TrustRegion
function benchmark_scalar(f, u0)
Expand All @@ -66,7 +72,7 @@ for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
TrustRegion(10.0)]
g = function (p)
probN = NonlinearProblem{false}(f, csu0, p)
sol = solve(probN, alg, tol = 1e-9)
sol = solve(probN, alg, abstol = 1e-9)
return sol.u[end]
end

Expand Down Expand Up @@ -137,20 +143,11 @@ f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
probN = NonlinearProblem(f, u0)

@test solve(probN, SimpleNewtonRaphson()).u[end] sqrt(2.0)
@test solve(probN, SimpleNewtonRaphson(); immutable = false).u[end] sqrt(2.0)
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)

@test solve(probN, TrustRegion(10.0)).u[end] sqrt(2.0)
@test solve(probN, TrustRegion(10.0); immutable = false).u[end] sqrt(2.0)
@test solve(probN, TrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)
@test solve(probN, TrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)

@test solve(probN, Broyden()).u[end] sqrt(2.0)
@test solve(probN, Broyden(); immutable = false).u[end] sqrt(2.0)

@test solve(probN, Klement()).u[end] sqrt(2.0)
@test solve(probN, Klement(); immutable = false).u[end] sqrt(2.0)

for u0 in [1.0, [1, 1.0]]
local f, probN, sol
Expand Down

0 comments on commit 2893a8e

Please sign in to comment.