diff --git a/Project.toml b/Project.toml index 9fb70f62b..89d584fb5 100644 --- a/Project.toml +++ b/Project.toml @@ -65,6 +65,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -76,4 +77,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"] +test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath"] diff --git a/src/levenberg.jl b/src/levenberg.jl index 047ca16c2..16867033a 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -177,7 +177,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, else uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs, linsolve_with_JᵀJ) - JᵀJ = similar(u) + JᵀJ = similar(_vec(u)) J² = similar(J) v = similar(du) end @@ -214,7 +214,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, # Preserve Types mat_tmp = vcat(J, DᵀD) fill!(mat_tmp, zero(eltype(u))) - rhs_tmp = vcat(fu1, u) + rhs_tmp = vcat(_vec(fu1), _vec(u)) fill!(rhs_tmp, zero(eltype(u))) linsolve = __setup_linsolve(mat_tmp, rhs_tmp, u, p, alg) end diff --git a/test/polyalgs.jl b/test/polyalgs.jl index 4497eae97..4f861c20b 100644 --- a/test/polyalgs.jl +++ b/test/polyalgs.jl @@ -1,4 +1,4 @@ -using NonlinearSolve, Test +using NonlinearSolve, Test, NaNMath f(u, p) = u .* u .- 2 u0 = [1.0, 1.0] @@ -38,7 +38,8 @@ sol = solve(prob) @test SciMLBase.successful_retcode(sol) # https://github.com/SciML/NonlinearSolve.jl/issues/187 -ff(u, p) = 0.5 / 1.5 * log.(u ./ (1.0 .- u)) .- 2.0 * u .+ 1.0 +# If we use a General Nonlinear Solver the solution might go out of the domain! +ff(u, p) = 0.5 / 1.5 * NaNMath.log.(u ./ (1.0 .- u)) .- 2.0 * u .+ 1.0 uspan = (0.02, 0.1) prob = IntervalNonlinearProblem(ff, uspan) @@ -48,5 +49,5 @@ sol = solve(prob) u0 = 0.06 p = 2.0 prob = NonlinearProblem(ff, u0, p) -solver = solve(prob) +sol = solve(prob) @test SciMLBase.successful_retcode(sol)