Skip to content

Commit

Permalink
Merge pull request SciML#347 from oscardssmith/os/robustmultinewton-a…
Browse files Browse the repository at this point in the history
…utodiff

make RobustMultiNewton always respect autodiff choice
  • Loading branch information
ChrisRackauckas authored Jan 6, 2024
2 parents d47b131 + 8614ebe commit 138de9b
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/src/solvers/FixedPointSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ robust.

### SIAMFANLEquations.jl

- `SIAMFANLEquationsJL(; method = :anderson)`: Anderson acceleration for fixed point problems.
- `SIAMFANLEquationsJL(; method = :anderson)`: Anderson acceleration for fixed point problems.
7 changes: 4 additions & 3 deletions ext/NonlinearSolveSIAMFANLEquationsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
rtol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, T)

if prob.u0 isa Number
f = method == :anderson ? (du, u) -> (du = prob.f(u, prob.p)) : ((u) -> prob.f(u, prob.p))
f = method == :anderson ? (du, u) -> (du = prob.f(u, prob.p)) :
((u) -> prob.f(u, prob.p))

if method == :newton
sol = nsolsc(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT)
Expand All @@ -55,7 +56,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
elseif method == :anderson
f, u = NonlinearSolve.__construct_f(prob; alias_u0,
make_fixed_point = Val(true), can_handle_arbitrary_dims = Val(true))
sol = aasol(f, [prob.u0], m, __zeros_like(u, 1, 2*m+4); maxit = maxiters,
sol = aasol(f, [prob.u0], m, __zeros_like(u, 1, 2 * m + 4); maxit = maxiters,
atol, rtol, beta = beta)
end

Expand Down Expand Up @@ -110,7 +111,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
elseif method == :anderson
f!, u = NonlinearSolve.__construct_f(prob; alias_u0,
can_handle_arbitrary_dims = Val(true), make_fixed_point = Val(true))
sol = aasol(f!, u, m, zeros(T, N, 2*m+4), atol = atol, rtol = rtol,
sol = aasol(f!, u, m, zeros(T, N, 2 * m + 4), atol = atol, rtol = rtol,
maxit = maxiters, beta = beta)
end
else
Expand Down
2 changes: 1 addition & 1 deletion src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ function RobustMultiNewton(::Type{T} = Float64; concrete_jac = nothing, linsolve
# Let's atleast have something here for complex numbers
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),)
else
algs = (TrustRegion(; concrete_jac, linsolve, precs),
algs = (TrustRegion(; concrete_jac, linsolve, precs, autodiff),
TrustRegion(; concrete_jac, linsolve, precs, autodiff,
radius_update_scheme = RadiusUpdateSchemes.Bastin),
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
Expand Down
23 changes: 23 additions & 0 deletions test/misc/no_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using LinearAlgebra, NonlinearSolve, Test

@testset "[IIP] no AD" begin
f_iip = Base.Experimental.@opaque (du, u, p) -> du .= u .* u .- p
u0 = [0.0]
prob = NonlinearProblem(f_iip, u0, 1.0)
for alg in [RobustMultiNewton(autodiff = AutoFiniteDiff()())]
sol = solve(prob, alg)
@test isapprox(only(sol.u), 1.0)
@test SciMLBase.successful_retcode(sol.retcode)
end
end

@testset "[OOP] no AD" begin
f_oop = Base.Experimental.@opaque (u, p) -> u .* u .- p
u0 = [0.0]
prob = NonlinearProblem{false}(f_oop, u0, 1.0)
for alg in [RobustMultiNewton(autodiff = AutoFiniteDiff())]
sol = solve(prob, alg)
@test isapprox(only(sol.u), 1.0)
@test SciMLBase.successful_retcode(sol.retcode)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ end
@time @safetestset "Matrix Resizing" include("misc/matrix_resizing.jl")
@time @safetestset "Infeasible Problems" include("misc/infeasible.jl")
@time @safetestset "Banded Matrices" include("misc/banded_matrices.jl")
@time @safetestset "No AD" include("misc/no_ad.jl")
end

if GROUP == "GPU"
Expand Down
6 changes: 4 additions & 2 deletions test/wrappers/fixedpoint.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NonlinearSolve, FixedPointAcceleration, SpeedMapping, NLsolve, SIAMFANLEquations, LinearAlgebra, Test
using NonlinearSolve,
FixedPointAcceleration, SpeedMapping, NLsolve, SIAMFANLEquations, LinearAlgebra, Test

# Simple Scalar Problem
@testset "Simple Scalar Problem" begin
Expand Down Expand Up @@ -29,7 +30,8 @@ end
@test maximum(abs.(solve(prob, SpeedMappingJL()).resid)) 1e-10
@test maximum(abs.(solve(prob, SpeedMappingJL(; orders = [3, 2])).resid)) 1e-10
@test maximum(abs.(solve(prob, SpeedMappingJL(; stabilize = true)).resid)) 1e-10
@test maximum(abs.(solve(prob, SIAMFANLEquationsJL(; method = :anderson)).resid)) 1e-10
@test maximum(abs.(solve(prob, SIAMFANLEquationsJL(; method = :anderson)).resid))
1e-10

@test_broken maximum(abs.(solve(prob, NLsolveJL(; method = :anderson)).resid)) 1e-10
end
Expand Down

0 comments on commit 138de9b

Please sign in to comment.