Skip to content

Commit

Permalink
Format and add a format CI
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 8, 2023
1 parent b1d8bea commit 0ca6ab0
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 223 deletions.
29 changes: 29 additions & 0 deletions lib/SimpleNonlinearSolve/.github/workflows/FormatPR.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: format-pr
on:
schedule:
- cron: '0 0 * * *'
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install JuliaFormatter and format
run: |
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
julia -e 'using JuliaFormatter; format(".")'
# https://github.com/marketplace/actions/create-pull-request
# https://github.com/peter-evans/create-pull-request#reference-example
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v5
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Format .jl files
title: 'Automatic JuliaFormatter.jl run'
branch: auto-juliaformatter-pr
delete-branch: true
labels: formatting, automated pr, no changelog
- name: Check outputs
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function _init_J_batched(x::AbstractMatrix{T}) where {T}
end

function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)
f = Base.Fix2(prob.f, prob.p)
Expand Down Expand Up @@ -74,7 +74,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ)
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ) ./
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
Expand Down
46 changes: 26 additions & 20 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ end

function __init__()
@static if !isdefined(Base, :get_extension)
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin include("../ext/SimpleBatchedNonlinearSolveExt.jl") end
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin
include("../ext/SimpleBatchedNonlinearSolveExt.jl")
end
end
end

Expand All @@ -42,31 +44,35 @@ include("alefeld.jl")

import PrecompileTools

PrecompileTools.@compile_workload begin for T in (Float32, Float64)
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
for alg in (SimpleNewtonRaphson, Halley, Broyden, Klement, SimpleTrustRegion,
SimpleDFSane)
solve(prob_no_brack, alg(), abstol = T(1e-2))
end
PrecompileTools.@compile_workload begin
for T in (Float32, Float64)
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
for alg in (SimpleNewtonRaphson, Halley, Broyden, Klement, SimpleTrustRegion,
SimpleDFSane)
solve(prob_no_brack, alg(), abstol = T(1e-2))
end

#=
for alg in (SimpleNewtonRaphson,)
for u0 in ([1., 1.], StaticArraysCore.SA[1.0, 1.0])
u0 = T.(.1)
probN = NonlinearProblem{false}((u,p) -> u .* u .- p, u0, T(2))
solve(probN, alg(), tol = T(1e-2))
#=
for alg in (SimpleNewtonRaphson,)
for u0 in ([1., 1.], StaticArraysCore.SA[1.0, 1.0])
u0 = T.(.1)
probN = NonlinearProblem{false}((u,p) -> u .* u .- p, u0, T(2))
solve(probN, alg(), tol = T(1e-2))
end
end
end
=#
=#

prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p, T.((0.0, 2.0)), T(2))
for alg in (Bisection, Falsi, Ridder, Brent, Alefeld)
solve(prob_brack, alg(), abstol = T(1e-2))
prob_brack = IntervalNonlinearProblem{false}((u, p) -> u * u - p,
T.((0.0, 2.0)),
T(2))
for alg in (Bisection, Falsi, Ridder, Brent, Alefeld)
solve(prob_brack, alg(), abstol = T(1e-2))
end
end
end end
end

# DiffEq styled algorithms
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld

end # module
52 changes: 26 additions & 26 deletions lib/SimpleNonlinearSolve/src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,50 +29,50 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:Dual{T, V, P}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...) where {iip, T, V, P}
iip,
<:Dual{T, V, P}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
retcode = sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {iip, T, V, P}
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
retcode = sol.retcode)
end

# avoid ambiguities
for Alg in [Bisection]
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:Dual{T, V, P}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
<:Dual{T, V, P}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
<:AbstractArray{
<:Dual{T,
V,
P}
}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
<:AbstractArray{
<:Dual{T,
V,
P},
}},
alg::$Alg, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
sol.resid; retcode = sol.retcode,
left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
end
end
68 changes: 34 additions & 34 deletions lib/SimpleNonlinearSolve/src/alefeld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ algorithm 4.1 because, in certain sense, the second algorithm(4.2) is an optimal
struct Alefeld <: AbstractBracketingAlgorithm end

function SciMLBase.solve(prob::IntervalNonlinearProblem,
alg::Alefeld, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
alg::Alefeld, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
a, b = prob.tspan
c = a - (b - a) / (f(b) - f(a)) * f(a)

fc = f(c)
(a == c || b == c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = a,
right = b)
retcode = ReturnCode.FloatingPointLimit,
left = a,
right = b)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = a,
right = b)
retcode = ReturnCode.Success,
left = a,
right = b)
a, b, d = _bracket(f, a, b, c)
e = zero(a) # Set e as 0 before iteration to avoid a non-value f(e)

Expand All @@ -45,14 +45,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
ē, fc = d, f(c)
(a == c || b == c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = a,
right = b)
retcode = ReturnCode.FloatingPointLimit,
left = a,
right = b)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = a,
right = b)
retcode = ReturnCode.Success,
left = a,
right = b)
ā, b̄, d̄ = _bracket(f, a, b, c)

# The second bracketing block
Expand All @@ -68,14 +68,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
fc = f(c)
(ā == c ||== c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = ā,
right = b̄)
retcode = ReturnCode.Success,
left = ā,
right = b̄)
ā, b̄, d̄ = _bracket(f, ā, b̄, c)

# The third bracketing block
Expand All @@ -91,14 +91,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
fc = f(c)
(ā == c ||== c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = ā,
right = b̄)
retcode = ReturnCode.Success,
left = ā,
right = b̄)
ā, b̄, d = _bracket(f, ā, b̄, c)

# The last bracketing block
Expand All @@ -110,14 +110,14 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,
fc = f(c)
(ā == c ||== c) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
retcode = ReturnCode.FloatingPointLimit,
left = ā,
right = b̄)
iszero(fc) &&
return SciMLBase.build_solution(prob, alg, c, fc;
retcode = ReturnCode.Success,
left = ā,
right = b̄)
retcode = ReturnCode.Success,
left = ā,
right = b̄)
a, b, d = _bracket(f, ā, b̄, c)
end
end
Expand All @@ -132,7 +132,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem,

# Reuturn solution when run out of max interation
return SciMLBase.build_solution(prob, alg, c, fc; retcode = ReturnCode.MaxIters,
left = a, right = b)
left = a, right = b)
end

# Define subrotine function bracket, check fc before bracket to return solution
Expand Down
18 changes: 9 additions & 9 deletions lib/SimpleNonlinearSolve/src/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ function Bisection(; exact_left = false, exact_right = false)
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...;
maxiters = 1000,
kwargs...)
maxiters = 1000,
kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.tspan
fl, fr = f(left), f(right)

if iszero(fl)
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
retcode = ReturnCode.ExactSolutionLeft, left = left,
right = right)
end

i = 1
Expand All @@ -38,8 +38,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -60,8 +60,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
mid = (left + right) / 2
(mid == left || mid == right) &&
return SciMLBase.build_solution(prob, alg, left, fl;
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
retcode = ReturnCode.FloatingPointLimit,
left = left, right = right)
fm = f(mid)
if iszero(fm)
right = mid
Expand All @@ -74,5 +74,5 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
end

return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
left = left, right = right)
left = left, right = right)
end
Loading

0 comments on commit 0ca6ab0

Please sign in to comment.