From c2791b9aa51726744601f1884737ae0bed26547d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Jan 2024 06:18:21 -0500 Subject: [PATCH] Allow linsolve to be \ --- Project.toml | 2 +- src/internal/linear_solve.jl | 15 +++++---------- src/utils.jl | 1 + test/core/rootfind.jl | 6 +++--- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 75f45bce7..81ac8a571 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.5.0" +version = "3.5.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/internal/linear_solve.jl b/src/internal/linear_solve.jl index 184edf660..72e05caa0 100644 --- a/src/internal/linear_solve.jl +++ b/src/internal/linear_solve.jl @@ -49,17 +49,11 @@ function reinit_cache!(cache::LinearSolverCache, args...; kwargs...) cache.nfactors = 0 end -@inline function LinearSolverCache(alg, linsolve, A::Number, b::Number, u; kwargs...) - return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0) -end -@inline function LinearSolverCache(alg, ::Nothing, A::SMatrix, b, u; kwargs...) - # Default handling for SArrays caching in LinearSolve is not the best. Override it here - return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0) -end -@inline function LinearSolverCache(alg, linsolve, A::Diagonal, b, u; kwargs...) - return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0) -end function LinearSolverCache(alg, linsolve, A, b, u; kwargs...) + if (A isa Number && b isa Number) || (linsolve === nothing && A isa SMatrix) || + (A isa Diagonal) || (linsolve isa typeof(\)) + return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0) + end @bb b_ = copy(b) @bb u_ = copy(u) linprob = LinearProblem(A, b_; u0 = u_, kwargs...) @@ -193,3 +187,4 @@ end @inline __needs_square_A(::Nothing, ::Number) = false @inline __needs_square_A(::Nothing, _) = false @inline __needs_square_A(linsolve, _) = LinearSolve.needs_square_A(linsolve) +@inline __needs_square_A(::typeof(\), _) = false diff --git a/src/utils.jl b/src/utils.jl index ef8fdf713..7f4c2c439 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,6 +18,7 @@ end end @inline __needs_concrete_A(::Nothing) = false +@inline __needs_concrete_A(::typeof(\)) = true @inline __needs_concrete_A(linsolve) = needs_concrete_A(linsolve) @inline __maybe_mutable(x, ::AutoSparseEnzyme) = __mutable(x) diff --git a/test/core/rootfind.jl b/test/core/rootfind.jl index ff26c3a08..f64dcda88 100644 --- a/test/core/rootfind.jl +++ b/test/core/rootfind.jl @@ -72,7 +72,7 @@ const TERMINATION_CONDITIONS = [ ] @testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([ - 1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES()) + 1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES(), \) ad isa AutoZygote && continue if prec === :Random prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing) @@ -139,7 +139,7 @@ end RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan, RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin] u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) - linear_solvers = [nothing, LUFactorization(), KrylovJL_GMRES()] + linear_solvers = [nothing, LUFactorization(), KrylovJL_GMRES(), \] @testset "[OOP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme) linear_solver: $(linsolve)" for u0 in u0s, radius_update_scheme in radius_update_schemes, linsolve in linear_solvers @@ -471,7 +471,7 @@ end precs = [NonlinearSolve.DEFAULT_PRECS, :Random] @testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([ - 1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES()) + 1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES(), \) ad isa AutoZygote && continue if prec === :Random prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)