From 2fdc12205eec56771c367ce33b7f6f9f250bdbcb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 Oct 2023 16:53:23 -0400 Subject: [PATCH] Add a function to check if square A is needed --- src/NonlinearSolve.jl | 6 ++---- src/gaussnewton.jl | 4 +--- src/levenberg.jl | 4 +--- src/utils.jl | 21 +++++++++++++++++++++ 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 6e3a6f804..2b26b3721 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -88,10 +88,8 @@ import PrecompileTools for T in (Float32, Float64) prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) - # precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), - # PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing) - # DON'T MERGE - precompile_algs = () + precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), + PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing) for alg in precompile_algs solve(prob, alg, abstol = T(1e-2)) diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 307119af3..ab18bf0ed 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -82,9 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob - # Use QR if the user did not specify a linear solver - if alg.linsolve === nothing || alg.linsolve isa QRFactorization || - alg.linsolve isa FastQRFactorization + if !needs_square_A(alg.linsolve) && !(u isa Number) linsolve_with_JᵀJ = Val(false) else linsolve_with_JᵀJ = Val(true) diff --git a/src/levenberg.jl b/src/levenberg.jl index 3480e6c63..38507ab2e 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -164,9 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, u = alias_u0 ? u0 : deepcopy(u0) fu1 = evaluate_f(prob, u) - # Use QR if the user did not specify a linear solver - if (alg.linsolve === nothing || alg.linsolve isa QRFactorization || - alg.linsolve isa FastQRFactorization) && !(u isa Number) + if !needs_square_A(alg.linsolve) && !(u isa Number) linsolve_with_JᵀJ = Val(false) else linsolve_with_JᵀJ = Val(true) diff --git a/src/utils.jl b/src/utils.jl index 688322329..56b5c1076 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -256,3 +256,24 @@ function _try_factorize_and_check_singular!(linsolve, X) return _issingular(X), false end _try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false + +# Needs Square Matrix +""" + needs_square_A(alg) + +Returns `true` if the algorithm requires a square matrix. +""" +needs_square_A(::Nothing) = false +function needs_square_A(alg) + try + A = ones(Float64, 3, 2) + b = ones(Float64, 3) + solve(LinearProblem(A, b), alg) + return false + catch err + return true + end +end +for alg in (:QRFactorization, :FastQRFactorization) + @eval needs_square_A(::$(alg)) = false +end