Skip to content

Commit

Permalink
Add more solvers to GPU testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2023
1 parent 417701c commit 4f632dd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),

Check warning on line 68 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
zero(fu), p, J⁻¹, zero(_vec(fu)'), _mutable_zero(u), false, 0, alg.max_resets,
maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
zero(fu), p, J⁻¹, zero(reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
end
Expand Down
2 changes: 1 addition & 1 deletion src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob

if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray)
if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u0 isa StaticArray)
linsolve_with_JᵀJ = Val(false)

Check warning on line 86 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
else
linsolve_with_JᵀJ = Val(true)

Check warning on line 88 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L88

Added line #L88 was not covered by tests
Expand Down
19 changes: 18 additions & 1 deletion test/gpu.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using CUDA, NonlinearSolve

CUDA.allowscalar(false)

A = cu(rand(4, 4))
u0 = cu(rand(4))
b = cu(rand(4))
Expand All @@ -9,4 +11,19 @@ function f(du, u, p)
end

prob = NonlinearProblem(f, u0)
sol = solve(prob, NewtonRaphson())

# TrustRegion is broken
for alg in (NewtonRaphson(), LevenbergMarquardt(; linsolve = QRFactorization()),
PseudoTransient(; alpha_initial = 10.0f0), GeneralKlement(), GeneralBroyden())
@test_nowarn sol = solve(prob, alg; abstol = 1.0f-8, reltol = 1.0f-8)
end

f(u, p) = A * u .+ b

prob = NonlinearProblem{false}(f, u0)

# TrustRegion is broken
for alg in (NewtonRaphson(), LevenbergMarquardt(; linsolve = QRFactorization()),
PseudoTransient(; alpha_initial = 10.0f0), GeneralKlement(), GeneralBroyden())
@test_nowarn sol = solve(prob, alg; abstol = 1.0f-8, reltol = 1.0f-8)
end

0 comments on commit 4f632dd

Please sign in to comment.