Skip to content

Commit

Permalink
Add more solvers to GPU testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2023
1 parent 417701c commit 22b7f25
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),

Check warning on line 68 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
zero(fu), p, J⁻¹, zero(_vec(fu)'), _mutable_zero(u), false, 0, alg.max_resets,
maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
zero(fu), p, J⁻¹, zero(reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
end
Expand Down
19 changes: 18 additions & 1 deletion test/gpu.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using CUDA, NonlinearSolve

CUDA.allowscalar(false)

A = cu(rand(4, 4))
u0 = cu(rand(4))
b = cu(rand(4))
Expand All @@ -9,4 +11,19 @@ function f(du, u, p)
end

prob = NonlinearProblem(f, u0)
sol = solve(prob, NewtonRaphson())

# TrustRegion is broken
for alg in (NewtonRaphson(), LevenbergMarquardt(; linsolve = QRFactorization()),
PseudoTransient(; alpha_initial = 10.0f0), GeneralKlement(), GeneralBroyden())
@test_nowarn sol = solve(prob, alg; abstol = 1f-8, reltol = 1f-8)
end

f(u, p) = A * u .+ b

prob = NonlinearProblem{false}(f, u0)

# TrustRegion is broken
for alg in (NewtonRaphson(), LevenbergMarquardt(; linsolve = QRFactorization()),
PseudoTransient(; alpha_initial = 10.0f0), GeneralKlement(), GeneralBroyden())
@test_nowarn sol = solve(prob, alg; abstol = 1f-8, reltol = 1f-8)
end

0 comments on commit 22b7f25

Please sign in to comment.