diff --git a/Project.toml b/Project.toml index 9afd3d558..99647f278 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.5.4" +version = "3.5.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/internal/linear_solve.jl b/src/internal/linear_solve.jl index 60ef18901..319b80005 100644 --- a/src/internal/linear_solve.jl +++ b/src/internal/linear_solve.jl @@ -49,15 +49,32 @@ function reinit_cache!(cache::LinearSolverCache, args...; kwargs...) cache.nfactors = 0 end +@inline __fix_strange_type_combination(A, b, u) = u +@inline function __fix_strange_type_combination(A, b, u::SArray) + A isa SArray && b isa SArray && return u + @warn "Solving Linear System A::$(typeof(A)) x::$(typeof(u)) = b::$(typeof(u)) is not \ + properly supported. Converting `x` to a mutable array. Check the return type \ + of the nonlinear function provided for optimal performance." maxlog=1 + return MArray(u) +end + +@inline __set_lincache_u!(cache, u) = (cache.lincache.u = u) +@inline function __set_lincache_u!(cache, u::SArray) + cache.lincache.u isa MArray && return __set_lincache_u!(cache, MArray(u)) + cache.lincache.u = u +end + function LinearSolverCache(alg, linsolve, A, b, u; kwargs...) + u_fixed = __fix_strange_type_combination(A, b, u) + if (A isa Number && b isa Number) || (linsolve === nothing && A isa SMatrix) || (A isa Diagonal) || (linsolve isa typeof(\)) return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0) end - @bb u_ = copy(u) + @bb u_ = copy(u_fixed) linprob = LinearProblem(A, b; u0 = u_, kwargs...) - weight = __init_ones(u) + weight = __init_ones(u_fixed) if __hasfield(alg, Val(:precs)) precs = alg.precs Pl_, Pr_ = precs(A, nothing, u, nothing, nothing, nothing, nothing, nothing, @@ -97,7 +114,7 @@ function (cache::LinearSolverCache)(; A = nothing, b = nothing, linu = nothing, __update_A!(cache, A, reuse_A_if_factorization) b !== nothing && (cache.lincache.b = b) - linu !== nothing && (cache.lincache.u = linu) + linu !== nothing && __set_lincache_u!(cache, linu) Plprev = cache.lincache.Pl isa ComposePreconditioner ? cache.lincache.Pl.outer : cache.lincache.Pl diff --git a/test/misc/polyalg_tests.jl b/test/misc/polyalg_tests.jl index 761270c9a..10df60429 100644 --- a/test/misc/polyalg_tests.jl +++ b/test/misc/polyalg_tests.jl @@ -178,7 +178,7 @@ end end @testitem "[OOP] Infeasible" setup=[InfeasibleFunction] begin - using StaticArrays + using LinearAlgebra, StaticArrays u0 = [0.0, 0.0, 0.0] prob = NonlinearProblem(f1_infeasible, u0) @@ -189,8 +189,12 @@ end u0 = @SVector [0.0, 0.0, 0.0] prob = NonlinearProblem(f1_infeasible, u0) - sol = solve(prob) - @test all(!isnan, sol.u) - @test !SciMLBase.successful_retcode(sol.retcode) + try + sol = solve(prob) + @test all(!isnan, sol.u) + @test !SciMLBase.successful_retcode(sol.retcode) + catch err + @test err isa LinearAlgebra.SingularException + end end