Skip to content

Commit

Permalink
Patches for DiffEqCallbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 22, 2024
1 parent 156e65b commit 8767492
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Interface}
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Regression}
- {user: SciML, repo: BoundaryValueDiffEq.jl, group: All}
- {user: SciML, repo: DiffEqCallbacks.jl, group: All}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.7.1"
version = "3.7.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
10 changes: 6 additions & 4 deletions src/core/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
update_trace!(cache.trace, get_nsteps(cache), get_u(cache),
get_fu(cache), nothing, nothing, nothing; last = True)

stats = ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
get_nsolve(cache), get_nsteps(cache))
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, stats = __compile_stats(cache), cache.trace)
end

return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache),
get_fu(cache); cache.retcode, stats, cache.trace)
function __compile_stats(cache::AbstractNonlinearSolveCache)
return ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
get_nsolve(cache), get_nsteps(cache))
end

"""
Expand Down
38 changes: 22 additions & 16 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ end
retcode::ReturnCode.T
force_stop::Bool
maxiters::Int
internalnorm
end

function Base.show(
Expand All @@ -80,10 +81,13 @@ end
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
algType = NonlinearSolvePolyAlgorithm{pType}
@eval begin
function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...;
maxtime = nothing, maxiters = 1000, kwargs...) where {N}
function SciMLBase.__init(
prob::$probType, alg::$algType{N}, args...; maxtime = nothing,
maxiters = 1000, internalnorm = DEFAULT_NORM, kwargs...) where {N}
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
map(
solver -> SciMLBase.__init(
prob, solver, args...; maxtime, internalnorm, kwargs...),
alg.algs),
alg,
-1,
Expand All @@ -93,7 +97,8 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
maxtime,
ReturnCode.Default,
false,
maxiters)
maxiters,
internalnorm)
end
end
end
Expand Down Expand Up @@ -134,8 +139,8 @@ end
push!(calls,
quote
fus = tuple($(Tuple(resids)...))
minfu, idx = __findmin(cache.caches[1].internalnorm, fus)
stats = cache.caches[idx].stats
minfu, idx = __findmin(cache.internalnorm, fus)
stats = __compile_stats(cache.caches[idx])

Check warning on line 143 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L142-L143

Added lines #L142 - L143 were not covered by tests
u = get_u(cache.caches[idx])
retcode = cache.caches[idx].retcode

Expand Down Expand Up @@ -171,16 +176,15 @@ end
end)
end

push!(calls,
quote
if !(1 cache.current length(cache.caches))
minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches)
cache.best = idx
cache.retcode = cache.caches[cache.best].retcode
cache.force_stop = true
return
end
end)
push!(calls, quote
if !(1 cache.current length(cache.caches))
minfu, idx = __findmin(cache.internalnorm, cache.caches)
cache.best = idx
cache.retcode = cache.caches[cache.best].retcode
cache.force_stop = true
return

Check warning on line 185 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L180-L185

Added lines #L180 - L185 were not covered by tests
end
end)

return Expr(:block, calls...)
end
Expand Down Expand Up @@ -353,9 +357,11 @@ function FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing,
linsolve = nothing, precs = DEFAULT_PRECS, kwargs...) where {T}
if __is_complex(T)
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...),
LevenbergMarquardt(; linsolve, precs, kwargs...))
else
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...),
TrustRegion(; concrete_jac, linsolve, precs, kwargs...),
GaussNewton(; concrete_jac, linsolve, precs,
linesearch = LineSearchesJL(; method = BackTracking()), kwargs...),
Expand Down

0 comments on commit 8767492

Please sign in to comment.