Skip to content

Commit

Permalink
Add more solvers to GPU testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2023
1 parent 417701c commit be854ec
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),

Check warning on line 68 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
zero(fu), p, J⁻¹, zero(_vec(fu)'), _mutable_zero(u), false, 0, alg.max_resets,
maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
zero(fu), p, J⁻¹, zero(reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
end
Expand Down
2 changes: 1 addition & 1 deletion src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob

if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray)
if !needs_square_A(alg.linsolve) && !(u0 isa Number) && !(u0 isa StaticArray)
linsolve_with_JᵀJ = Val(false)

Check warning on line 86 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
else
linsolve_with_JᵀJ = Val(true)

Check warning on line 88 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L88

Added line #L88 was not covered by tests
Expand Down
4 changes: 2 additions & 2 deletions src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,12 @@ end
return nothing

Check warning on line 249 in src/lbroyden.jl

View check run for this annotation

Codecov / codecov/patch

src/lbroyden.jl#L246-L249

Added lines #L246 - L249 were not covered by tests
end
mul!(xᵀVᵀ[:, 1:η], x', Vᵀ)
mul!(y', xᵀVᵀ[:, 1:η], U)
mul!(reshape(y, 1, :), xᵀVᵀ[:, 1:η], U)
return nothing

Check warning on line 253 in src/lbroyden.jl

View check run for this annotation

Codecov / codecov/patch

src/lbroyden.jl#L251-L253

Added lines #L251 - L253 were not covered by tests
end

@views function __lbroyden_rmatvec(U::AbstractMatrix, Vᵀ::AbstractMatrix, x::AbstractVector)

Check warning on line 256 in src/lbroyden.jl

View check run for this annotation

Codecov / codecov/patch

src/lbroyden.jl#L256

Added line #L256 was not covered by tests
# Computes xᵀ × Vᵀ × U
size(U, 1) == 0 && return x
return (x' * Vᵀ) * U
return (reshape(x, 1, :) * Vᵀ) * U

Check warning on line 259 in src/lbroyden.jl

View check run for this annotation

Codecov / codecov/patch

src/lbroyden.jl#L258-L259

Added lines #L258 - L259 were not covered by tests
end
19 changes: 18 additions & 1 deletion test/gpu.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using CUDA, NonlinearSolve

CUDA.allowscalar(false)

A = cu(rand(4, 4))
u0 = cu(rand(4))
b = cu(rand(4))
Expand All @@ -9,4 +11,19 @@ function f(du, u, p)
end

prob = NonlinearProblem(f, u0)
sol = solve(prob, NewtonRaphson())

# TrustRegion is broken
for alg in (NewtonRaphson(), LevenbergMarquardt(; linsolve = QRFactorization()),
PseudoTransient(; alpha_initial = 10.0f0), GeneralKlement(), GeneralBroyden())
@test_nowarn sol = solve(prob, alg; abstol = 1.0f-8, reltol = 1.0f-8)
end

f(u, p) = A * u .+ b

prob = NonlinearProblem{false}(f, u0)

# TrustRegion is broken
for alg in (NewtonRaphson(), LevenbergMarquardt(; linsolve = QRFactorization()),
PseudoTransient(; alpha_initial = 10.0f0), GeneralKlement(), GeneralBroyden())
@test_nowarn sol = solve(prob, alg; abstol = 1.0f-8, reltol = 1.0f-8)
end

0 comments on commit be854ec

Please sign in to comment.