From 5e43b5661dd7b80480bfe3fca73b7514d5329041 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 22 Feb 2024 15:45:27 -0500 Subject: [PATCH] Patches for DiffEqCallbacks --- .github/workflows/Downstream.yml | 1 + Project.toml | 2 +- src/core/generic.jl | 10 +++++---- src/default.jl | 38 ++++++++++++++++++-------------- 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index e5b27bddf..b0c39af7a 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -25,6 +25,7 @@ jobs: - {user: SciML, repo: OrdinaryDiffEq.jl, group: Interface} - {user: SciML, repo: OrdinaryDiffEq.jl, group: Regression} - {user: SciML, repo: BoundaryValueDiffEq.jl, group: All} + - {user: SciML, repo: DiffEqCallbacks.jl, group: All} steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 09059a431..a06689370 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.7.1" +version = "3.7.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/core/generic.jl b/src/core/generic.jl index 9aafba49b..b318ae581 100644 --- a/src/core/generic.jl +++ b/src/core/generic.jl @@ -24,11 +24,13 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache) update_trace!(cache.trace, get_nsteps(cache), get_u(cache), get_fu(cache), nothing, nothing, nothing; last = True) - stats = ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache), - get_nsolve(cache), get_nsteps(cache)) - return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), - get_fu(cache); cache.retcode, stats, cache.trace) + get_fu(cache); cache.retcode, stats = __compile_stats(cache), cache.trace) +end + +function __compile_stats(cache::AbstractNonlinearSolveCache) + return ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache), + get_nsolve(cache), get_nsteps(cache)) end """ diff --git a/src/default.jl b/src/default.jl index da2420656..805001c0e 100644 --- a/src/default.jl +++ b/src/default.jl @@ -56,6 +56,7 @@ end retcode::ReturnCode.T force_stop::Bool maxiters::Int + internalnorm end function Base.show( @@ -80,10 +81,13 @@ end for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS)) algType = NonlinearSolvePolyAlgorithm{pType} @eval begin - function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...; - maxtime = nothing, maxiters = 1000, kwargs...) where {N} + function SciMLBase.__init( + prob::$probType, alg::$algType{N}, args...; maxtime = nothing, + maxiters = 1000, internalnorm = DEFAULT_NORM, kwargs...) where {N} return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}( - map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...), + map( + solver -> SciMLBase.__init( + prob, solver, args...; maxtime, internalnorm, kwargs...), alg.algs), alg, -1, @@ -93,7 +97,8 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb maxtime, ReturnCode.Default, false, - maxiters) + maxiters, + internalnorm) end end end @@ -134,8 +139,8 @@ end push!(calls, quote fus = tuple($(Tuple(resids)...)) - minfu, idx = __findmin(cache.caches[1].internalnorm, fus) - stats = cache.caches[idx].stats + minfu, idx = __findmin(cache.internalnorm, fus) + stats = __compile_stats(cache.caches[idx]) u = get_u(cache.caches[idx]) retcode = cache.caches[idx].retcode @@ -171,16 +176,15 @@ end end) end - push!(calls, - quote - if !(1 ≤ cache.current ≤ length(cache.caches)) - minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches) - cache.best = idx - cache.retcode = cache.caches[cache.best].retcode - cache.force_stop = true - return - end - end) + push!(calls, quote + if !(1 ≤ cache.current ≤ length(cache.caches)) + minfu, idx = __findmin(cache.internalnorm, cache.caches) + cache.best = idx + cache.retcode = cache.caches[cache.best].retcode + cache.force_stop = true + return + end + end) return Expr(:block, calls...) end @@ -353,9 +357,11 @@ function FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS, kwargs...) where {T} if __is_complex(T) algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...), + LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...), LevenbergMarquardt(; linsolve, precs, kwargs...)) else algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...), + LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...), TrustRegion(; concrete_jac, linsolve, precs, kwargs...), GaussNewton(; concrete_jac, linsolve, precs, linesearch = LineSearchesJL(; method = BackTracking()), kwargs...),