Skip to content

Commit

Permalink
Generic _axpy!
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 25, 2023
1 parent 026a600 commit 06186c0
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 11 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
group:
- All
Expand Down
2 changes: 1 addition & 1 deletion src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function perform_step!(cache::GeneralBroydenCache{true})

mul!(_vec(du), J⁻¹, -_vec(fu))
α = perform_linesearch!(cache.lscache, u, du)
axpy!(α, du, u)
_axpy!(α, du, u)
f(fu2, u, p)

cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
Expand Down
2 changes: 1 addition & 1 deletion src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ function perform_step!(cache::GeneralKlementCache{true})

# Line Search
α = perform_linesearch!(cache.lscache, u, du)
axpy!(α, du, u)
_axpy!(α, du, u)
f(cache.fu2, u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
Expand Down
2 changes: 1 addition & 1 deletion src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
T = eltype(u)

α = perform_linesearch!(cache.lscache, u, du)
axpy!(α, du, u)
_axpy!(α, du, u)
f(cache.fu2, u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
Expand Down
2 changes: 1 addition & 1 deletion src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function perform_step!(cache::NewtonRaphsonCache{true})

# Line Search
α = perform_linesearch!(cache.lscache, u, du)
axpy!(-α, du, u)
_axpy!(-α, du, u)
f(cache.fu1, u, p)

cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,13 @@ _try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false
_reshape(x, args...) = reshape(x, args...)
_reshape(x::Number, args...) = x

@generated function _axpy!(α, x, y)
hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y))
return :(@. y += α * x)
end

# Needs Square Matrix
# FIXME: Remove once https://github.com/SciML/LinearSolve.jl/pull/400 is merged and tagged
"""
needs_square_A(alg)
Expand Down
8 changes: 4 additions & 4 deletions test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ end
GeneralBroyden(; linesearch = BackTracking()))

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 3, 4, 5, 6, 11, 12, 13, 14, 21]
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 9, 11, 13, 22]
broken_tests[alg_ops[1]] = [1, 3, 4, 5, 6, 11, 12, 13, 14]
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 9, 11, 13, 15, 16, 21, 22]
broken_tests[alg_ops[3]] = [1, 2, 4, 5, 6, 11, 12, 13, 14, 21]

test_on_library(problems, dicts, alg_ops, broken_tests)
Expand All @@ -100,8 +100,8 @@ end

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 2, 4, 5, 6, 7, 11, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 12, 13, 22]
broken_tests[alg_ops[3]] = [1, 2, 4, 5, 6, 8, 11, 12, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 13, 22]
broken_tests[alg_ops[3]] = [1, 2, 5, 6, 11, 12, 13, 22]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
Expand Down
8 changes: 5 additions & 3 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
nlls_problems = [prob_oop, prob_iip]
solvers = [
GaussNewton(),
GaussNewton(; linsolve = CholeskyFactorization()),
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = CholeskyFactorization()),
GaussNewton(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
]

# Compile time on v"1.9" is too high!
VERSION v"1.10-" && append!(solvers,
[LevenbergMarquardt(), LevenbergMarquardt(; linsolve = LUFactorization())])

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
Expand Down

0 comments on commit 06186c0

Please sign in to comment.