From 580c439b58999e6552a25e1236836e539eccf986 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 20 Oct 2023 08:57:38 -0400 Subject: [PATCH] don't restructure on number --- lib/SimpleNonlinearSolve/src/broyden.jl | 6 +++--- lib/SimpleNonlinearSolve/src/klement.jl | 6 +++--- lib/SimpleNonlinearSolve/src/raphson.jl | 2 +- lib/SimpleNonlinearSolve/src/utils.jl | 5 ++++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/SimpleNonlinearSolve/src/broyden.jl b/lib/SimpleNonlinearSolve/src/broyden.jl index 2f518df83..9f3c22505 100644 --- a/lib/SimpleNonlinearSolve/src/broyden.jl +++ b/lib/SimpleNonlinearSolve/src/broyden.jl @@ -58,12 +58,12 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; xₙ₋₁ = x fₙ₋₁ = fₙ for _ in 1:maxiters - xₙ = xₙ₋₁ - ArrayInterface.restructure(xₙ₋₁, J⁻¹ * _vec(fₙ₋₁)) + xₙ = xₙ₋₁ - _restructure(xₙ₋₁, J⁻¹ * _vec(fₙ₋₁)) fₙ = f(xₙ) Δxₙ = xₙ - xₙ₋₁ Δfₙ = fₙ - fₙ₋₁ - J⁻¹Δfₙ = ArrayInterface.restructure(Δfₙ, J⁻¹ * _vec(Δfₙ)) - J⁻¹ += ArrayInterface.restructure(J⁻¹, ((_vec(Δxₙ) .- _vec(J⁻¹Δfₙ)) ./ (_vec(Δxₙ)' * _vec(J⁻¹Δfₙ))) * (_vec(Δxₙ)' * J⁻¹)) + J⁻¹Δfₙ = _restructure(Δfₙ, J⁻¹ * _vec(Δfₙ)) + J⁻¹ += _restructure(J⁻¹, ((_vec(Δxₙ) .- _vec(J⁻¹Δfₙ)) ./ (_vec(Δxₙ)' * _vec(J⁻¹Δfₙ))) * (_vec(Δxₙ)' * J⁻¹)) if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success) diff --git a/lib/SimpleNonlinearSolve/src/klement.jl b/lib/SimpleNonlinearSolve/src/klement.jl index 7235487c5..799ba5987 100644 --- a/lib/SimpleNonlinearSolve/src/klement.jl +++ b/lib/SimpleNonlinearSolve/src/klement.jl @@ -75,7 +75,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, F = lu(J, check = false) end - tmp = ArrayInterface.restructure(fₙ₋₁, F \ _vec(fₙ₋₁)) + tmp = _restructure(fₙ₋₁, F \ _vec(fₙ₋₁)) xₙ = xₙ₋₁ - tmp fₙ = f(xₙ) @@ -92,9 +92,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, Δfₙ = fₙ - fₙ₋₁ # Prevent division by 0 - denominator = ArrayInterface.restructure(Δxₙ, max.(J' .^ 2 * _vec(Δxₙ) .^ 2, 1e-9)) + denominator = _restructure(Δxₙ, max.(J' .^ 2 * _vec(Δxₙ) .^ 2, 1e-9)) - k = (Δfₙ - ArrayInterface.restructure(Δxₙ, J * _vec(Δxₙ))) ./ denominator + k = (Δfₙ - _restructure(Δxₙ, J * _vec(Δxₙ))) ./ denominator J += (_vec(k) * _vec(Δxₙ)' .* J) * J xₙ₋₁ = xₙ diff --git a/lib/SimpleNonlinearSolve/src/raphson.jl b/lib/SimpleNonlinearSolve/src/raphson.jl index a7d9858bb..c36dc3504 100644 --- a/lib/SimpleNonlinearSolve/src/raphson.jl +++ b/lib/SimpleNonlinearSolve/src/raphson.jl @@ -100,7 +100,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, end iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) - Δx = ArrayInterface.restructure(fx, dfx \ _vec(fx)) + Δx = _restructure(fx, dfx \ _vec(fx)) x -= Δx if isapprox(x, xo, atol = atol, rtol = rtol) return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 4f5617786..c0245ece0 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -85,4 +85,7 @@ end @inline _vec(v) = vec(v) @inline _vec(v::Number) = v -@inline _vec(v::AbstractVector) = v \ No newline at end of file +@inline _vec(v::AbstractVector) = v + +@inline _restructure(y::Number, x::Number) = x +@inline _restructure(y, x) = ArrayInterface.restructure(y,x) \ No newline at end of file