From 93e2815b8aab711a529b5bd3d94425b8aaed0a2e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 Oct 2023 16:53:23 -0400 Subject: [PATCH] Add a function to check if square A is needed --- src/NonlinearSolve.jl | 6 ++---- src/gaussnewton.jl | 4 +--- src/levenberg.jl | 4 +--- src/utils.jl | 7 +++++++ 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 6e3a6f804..2b26b3721 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -88,10 +88,8 @@ import PrecompileTools for T in (Float32, Float64) prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) - # precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), - # PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing) - # DON'T MERGE - precompile_algs = () + precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), + PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing) for alg in precompile_algs solve(prob, alg, abstol = T(1e-2)) diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 307119af3..ab18bf0ed 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -82,9 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob - # Use QR if the user did not specify a linear solver - if alg.linsolve === nothing || alg.linsolve isa QRFactorization || - alg.linsolve isa FastQRFactorization + if !needs_square_A(alg.linsolve) && !(u isa Number) linsolve_with_JᵀJ = Val(false) else linsolve_with_JᵀJ = Val(true) diff --git a/src/levenberg.jl b/src/levenberg.jl index 3480e6c63..38507ab2e 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -164,9 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, u = alias_u0 ? u0 : deepcopy(u0) fu1 = evaluate_f(prob, u) - # Use QR if the user did not specify a linear solver - if (alg.linsolve === nothing || alg.linsolve isa QRFactorization || - alg.linsolve isa FastQRFactorization) && !(u isa Number) + if !needs_square_A(alg.linsolve) && !(u isa Number) linsolve_with_JᵀJ = Val(false) else linsolve_with_JᵀJ = Val(true) diff --git a/src/utils.jl b/src/utils.jl index 688322329..9c8bcac80 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -256,3 +256,10 @@ function _try_factorize_and_check_singular!(linsolve, X) return _issingular(X), false end _try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false + +# Needs Square Matrix +needs_square_A(::Nothing) = false +needs_square_A(_) = true +for alg in (:QRFactorization, :FastQRFactorization) + @eval needs_square_A(::$(alg)) = false +end