Skip to content

Commit

Permalink
Merge pull request #87 from SciML/matrix_resizing
Browse files Browse the repository at this point in the history
Add matrix resizing and fix cases with u0 as a matrix
  • Loading branch information
ChrisRackauckas authored Oct 20, 2023
2 parents d7d9c45 + 580c439 commit aa27e6a
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 16 deletions.
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "0.1.22"
version = "0.1.23"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
6 changes: 3 additions & 3 deletions lib/SimpleNonlinearSolve/src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;
xₙ₋₁ = x
fₙ₋₁ = fₙ
for _ in 1:maxiters
xₙ = xₙ₋₁ - J⁻¹ * fₙ₋₁
xₙ = xₙ₋₁ - _restructure(xₙ₋₁, J⁻¹ * _vec(fₙ₋₁))
fₙ = f(xₙ)
Δxₙ = xₙ - xₙ₋₁
Δfₙ = fₙ - fₙ₋₁
J⁻¹Δfₙ = J⁻¹ * Δfₙ
J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹)
J⁻¹Δfₙ = _restructure(Δfₙ, J⁻¹ * _vec(Δfₙ))
J⁻¹ += _restructure(J⁻¹, ((_vec(Δxₙ) .- _vec(J⁻¹Δfₙ)) ./ (_vec(Δxₙ)' * _vec(J⁻¹Δfₙ))) * (_vec(Δxₙ)' * J⁻¹))

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
Expand Down
8 changes: 4 additions & 4 deletions lib/SimpleNonlinearSolve/src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
F = lu(J, check = false)
end

tmp = F \ fₙ₋₁
tmp = _restructure(fₙ₋₁, F \ _vec(fₙ₋₁))
xₙ = xₙ₋₁ - tmp
fₙ = f(xₙ)

Expand All @@ -92,10 +92,10 @@ function SciMLBase.__solve(prob::NonlinearProblem,
Δfₙ = fₙ - fₙ₋₁

# Prevent division by 0
denominator = max.(J' .^ 2 * Δxₙ .^ 2, 1e-9)
denominator = _restructure(Δxₙ, max.(J' .^ 2 * _vec(Δxₙ) .^ 2, 1e-9))

k = (Δfₙ - J * Δxₙ) ./ denominator
J += (k * Δxₙ' .* J) * J
k = (Δfₙ - _restructure(Δxₙ, J * _vec(Δxₙ))) ./ denominator
J += (_vec(k) * _vec(Δxₙ)' .* J) * J

xₙ₋₁ = xₙ
fₙ₋₁ = fₙ
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
end
iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
Δx = dfx \ fx
Δx = _restructure(fx, dfx \ _vec(fx))
x -= Δx
if isapprox(x, xo, atol = atol, rtol = rtol)
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
Expand Down
7 changes: 7 additions & 0 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,10 @@ function dogleg_method(H, g, Δ)
tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd
return δsd + tau * δN_δsd
end

@inline _vec(v) = vec(v)
@inline _vec(v::Number) = v
@inline _vec(v::AbstractVector) = v

@inline _restructure(y::Number, x::Number) = x
@inline _restructure(y, x) = ArrayInterface.restructure(y,x)
11 changes: 11 additions & 0 deletions lib/SimpleNonlinearSolve/test/matrix_resizing_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using SimpleNonlinearSolve

ff(u, p) = u .* u .- p
u0 = rand(2,2)
p = 2.0
vecprob = NonlinearProblem(ff, vec(u0), p)
prob = NonlinearProblem(ff, u0, p)

for alg in (Klement(), Broyden(), SimpleNewtonRaphson())
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
end
10 changes: 3 additions & 7 deletions lib/SimpleNonlinearSolve/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@ const GROUP = get(ENV, "GROUP", "All")

@time begin
if GROUP == "All" || GROUP == "Core"
@time @safetestset "Basic Tests + Some AD" begin
include("basictests.jl")
end

@time @safetestset "Inplace Tests" begin
include("inplace.jl")
end
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
@time @safetestset "Inplace Tests" include("inplace.jl")
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
end
end

0 comments on commit aa27e6a

Please sign in to comment.