diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index 61fe5f16a..28a6b0ed7 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "0.1.22" +version = "0.1.23" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/SimpleNonlinearSolve/src/broyden.jl b/lib/SimpleNonlinearSolve/src/broyden.jl index 6c5c3ce7f..9f3c22505 100644 --- a/lib/SimpleNonlinearSolve/src/broyden.jl +++ b/lib/SimpleNonlinearSolve/src/broyden.jl @@ -58,12 +58,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; xₙ₋₁ = x fₙ₋₁ = fₙ for _ in 1:maxiters - xₙ = xₙ₋₁ - J⁻¹ * fₙ₋₁ + xₙ = xₙ₋₁ - _restructure(xₙ₋₁, J⁻¹ * _vec(fₙ₋₁)) fₙ = f(xₙ) Δxₙ = xₙ - xₙ₋₁ Δfₙ = fₙ - fₙ₋₁ - J⁻¹Δfₙ = J⁻¹ * Δfₙ - J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹) + J⁻¹Δfₙ = _restructure(Δfₙ, J⁻¹ * _vec(Δfₙ)) + J⁻¹ += _restructure(J⁻¹, ((_vec(Δxₙ) .- _vec(J⁻¹Δfₙ)) ./ (_vec(Δxₙ)' * _vec(J⁻¹Δfₙ))) * (_vec(Δxₙ)' * J⁻¹)) if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) diff --git a/lib/SimpleNonlinearSolve/src/klement.jl b/lib/SimpleNonlinearSolve/src/klement.jl index 00264d32f..799ba5987 100644 --- a/lib/SimpleNonlinearSolve/src/klement.jl +++ b/lib/SimpleNonlinearSolve/src/klement.jl @@ -75,7 +75,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, F = lu(J, check = false) end - tmp = F \ fₙ₋₁ + tmp = _restructure(fₙ₋₁, F \ _vec(fₙ₋₁)) xₙ = xₙ₋₁ - tmp fₙ = f(xₙ) @@ -92,10 +92,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, Δfₙ = fₙ - fₙ₋₁ # Prevent division by 0 - denominator = max.(J' .^ 2 * Δxₙ .^ 2, 1e-9) + denominator = _restructure(Δxₙ, max.(J' .^ 2 * _vec(Δxₙ) .^ 2, 1e-9)) - k = (Δfₙ - J * Δxₙ) ./ denominator - J += (k * Δxₙ' .* J) * J + k = (Δfₙ - _restructure(Δxₙ, J * _vec(Δxₙ))) ./ denominator + J += (_vec(k) * _vec(Δxₙ)' .* J) * J xₙ₋₁ = xₙ fₙ₋₁ = fₙ diff --git a/lib/SimpleNonlinearSolve/src/raphson.jl b/lib/SimpleNonlinearSolve/src/raphson.jl index 48b8f7591..c36dc3504 100644 --- a/lib/SimpleNonlinearSolve/src/raphson.jl +++ b/lib/SimpleNonlinearSolve/src/raphson.jl @@ -100,7 +100,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, end iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) - Δx = dfx \ fx + Δx = _restructure(fx, dfx \ _vec(fx)) x -= Δx if isapprox(x, xo, atol = atol, rtol = rtol) return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 12462a05c..c0245ece0 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -82,3 +82,10 @@ function dogleg_method(H, g, Δ) tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd return δsd + tau * δN_δsd end + +@inline _vec(v) = vec(v) +@inline _vec(v::Number) = v +@inline _vec(v::AbstractVector) = v + +@inline _restructure(y::Number, x::Number) = x +@inline _restructure(y, x) = ArrayInterface.restructure(y,x) \ No newline at end of file diff --git a/lib/SimpleNonlinearSolve/test/matrix_resizing_tests.jl b/lib/SimpleNonlinearSolve/test/matrix_resizing_tests.jl new file mode 100644 index 000000000..9612cbb68 --- /dev/null +++ b/lib/SimpleNonlinearSolve/test/matrix_resizing_tests.jl @@ -0,0 +1,11 @@ +using SimpleNonlinearSolve + +ff(u, p) = u .* u .- p +u0 = rand(2,2) +p = 2.0 +vecprob = NonlinearProblem(ff, vec(u0), p) +prob = NonlinearProblem(ff, u0, p) + +for alg in (Klement(), Broyden(), SimpleNewtonRaphson()) + @test vec(solve(prob, alg).u) == solve(vecprob, alg).u +end diff --git a/lib/SimpleNonlinearSolve/test/runtests.jl b/lib/SimpleNonlinearSolve/test/runtests.jl index bea57ea0a..98a01bdba 100644 --- a/lib/SimpleNonlinearSolve/test/runtests.jl +++ b/lib/SimpleNonlinearSolve/test/runtests.jl @@ -4,12 +4,8 @@ const GROUP = get(ENV, "GROUP", "All") @time begin if GROUP == "All" || GROUP == "Core" - @time @safetestset "Basic Tests + Some AD" begin - include("basictests.jl") - end - - @time @safetestset "Inplace Tests" begin - include("inplace.jl") - end + @time @safetestset "Basic Tests + Some AD" include("basictests.jl") + @time @safetestset "Inplace Tests" include("inplace.jl") + @time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl") end end