From 08d25cc6494bb044481ac534a9b9d2394afe099a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Nov 2023 15:50:14 -0400 Subject: [PATCH] Add a default for NLLS --- Project.toml | 2 +- src/default.jl | 21 +++++++++++++++------ test/nonlinear_least_squares.jl | 1 + 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 6411e91a5..076596ee4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "2.6.1" +version = "2.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/default.jl b/src/default.jl index 178d72f1c..8c11c6bbf 100644 --- a/src/default.jl +++ b/src/default.jl @@ -268,12 +268,21 @@ end ## Defaults -function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::Nothing, args...; - kwargs...) where {uType, iip} - SciMLBase.__init(prob, FastShortcutNonlinearPolyalg(), args...; kwargs...) +function SciMLBase.__init(prob::NonlinearProblem, ::Nothing, args...; kwargs...) + return SciMLBase.__init(prob, FastShortcutNonlinearPolyalg(), args...; kwargs...) end -function SciMLBase.__solve(prob::NonlinearProblem{uType, iip}, alg::Nothing, args...; - kwargs...) where {uType, iip} - SciMLBase.__solve(prob, FastShortcutNonlinearPolyalg(), args...; kwargs...) +function SciMLBase.__solve(prob::NonlinearProblem, ::Nothing, args...; kwargs...) + return SciMLBase.__solve(prob, FastShortcutNonlinearPolyalg(), args...; kwargs...) +end + +# FIXME: We default to using LM currently. But once we have line searches for GN implemented +# we should default to a polyalgorithm. +function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, ::Nothing, args...; kwargs...) + return SciMLBase.__init(prob, LevenbergMarquardt(), args...; kwargs...) +end + +function SciMLBase.__solve(prob::NonlinearLeastSquaresProblem, ::Nothing, args...; + kwargs...) + return SciMLBase.__solve(prob, LevenbergMarquardt(), args...; kwargs...) end diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index c7a02dc58..9cdbdd08a 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -34,6 +34,7 @@ solvers = [ LevenbergMarquardt(; linsolve = LUFactorization()), LeastSquaresOptimJL(:lm), LeastSquaresOptimJL(:dogleg), + nothing ] for prob in nlls_problems, solver in solvers