Skip to content

Commit

Permalink
Merge pull request #265 from avik-pal/ap/fixes
Browse files Browse the repository at this point in the history
Proper handling of complex numbers and failures
  • Loading branch information
ChrisRackauckas authored Nov 1, 2023
2 parents 1e4c3c0 + c489b23 commit 52b421e
Show file tree
Hide file tree
Showing 33 changed files with 831 additions and 847 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ on:
- master
paths-ignore:
- 'docs/**'
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
runs-on: ubuntu-latest
Expand Down
6 changes: 5 additions & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ on:
branches: [master]
tags: [v*]
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: ${{ matrix.package.repo }}/${{ matrix.package.group }}/${{ matrix.julia-version }}
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.130"
DiffEqBase = "6.136"
EnumX = "1"
Enzyme = "0.11"
FastBroadcast = "0.1.9, 0.2"
Expand All @@ -56,7 +56,7 @@ RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "2.4"
SimpleNonlinearSolve = "0.1.23"
SparseDiffTools = "2.6"
SparseDiffTools = "2.9"
StaticArraysCore = "1.4"
UnPack = "1.0"
Zygote = "0.6"
Expand Down
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
NonlinearSolveMINPACK = "c100e077-885d-495a-a2ea-599e143bf69d"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLNLSolve = "e9a6253c-8580-4d32-9898-8661bb511710"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -19,12 +21,14 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
AlgebraicMultigrid = "0.5, 0.6"
ArrayInterface = "6, 7"
BenchmarkTools = "1"
DiffEqBase = "6.136"
Documenter = "1"
IncompleteLU = "0.2"
LinearSolve = "2"
ModelingToolkit = "8"
NonlinearSolve = "1, 2"
NonlinearSolveMINPACK = "0.1"
SciMLBase = "2.4"
SciMLNLSolve = "0.1"
SimpleNonlinearSolve = "0.1.5"
StaticArrays = "1"
Expand Down
7 changes: 3 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Documenter, NonlinearSolve, SimpleNonlinearSolve, Sundials, SciMLNLSolve,
NonlinearSolveMINPACK, SteadyStateDiffEq
NonlinearSolveMINPACK, SteadyStateDiffEq, SciMLBase, DiffEqBase

cp("./docs/Manifest.toml", "./docs/src/assets/Manifest.toml", force = true)
cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
Expand All @@ -8,9 +8,8 @@ include("pages.jl")

makedocs(sitename = "NonlinearSolve.jl",
authors = "Chris Rackauckas",
modules = [NonlinearSolve, NonlinearSolve.SciMLBase, NonlinearSolve.DiffEqBase,
SimpleNonlinearSolve, Sundials, SciMLNLSolve, NonlinearSolveMINPACK,
SteadyStateDiffEq],
modules = [NonlinearSolve, SciMLBase, DiffEqBase, SimpleNonlinearSolve, Sundials,
SciMLNLSolve, NonlinearSolveMINPACK, SteadyStateDiffEq],
clean = true, doctest = false, linkcheck = true,
linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615"],
warnonly = [:missing_docs, :cross_references],
Expand Down
1 change: 0 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pages = ["index.md",
"Handling Large Ill-Conditioned and Sparse Systems" => "tutorials/large_systems.md",
"Symbolic System Definition and Acceleration via ModelingToolkit" => "tutorials/modelingtoolkit.md",
"tutorials/small_compile.md",
"tutorials/termination_conditions.md",
"tutorials/iterator_interface.md"],
"Basics" => Any["basics/NonlinearProblem.md",
"basics/NonlinearFunctions.md",
Expand Down
71 changes: 68 additions & 3 deletions docs/src/basics/TerminationCondition.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,75 @@
# [Termination Conditions](@id termination_condition)

Provides a API to specify termination conditions for [`NonlinearProblem`](@ref) and
[`SteadyStateProblem`](@ref). For details on the various termination modes, i.e.,
NLSolveTerminationMode, see the documentation for [`NLSolveTerminationCondition`](@ref).
[`SteadyStateProblem`](@ref). For details on the various termination modes:

## Termination Condition API
## Termination Conditions

The termination condition is constructed as:

```julia
cache = init(du, u, AbsNormTerminationMode(); abstol = 1e-9, reltol = 1e-9)
```

If `abstol` and `reltol` are not supplied, then we choose a default based on the element
types of `du` and `u`.

We can query the `cache` using `DiffEqBase.get_termination_mode`, `DiffEqBase.get_abstol`
and `DiffEqBase.get_reltol`.

To test for termination simply call the `cache`:

```julia
terminated = cache(du, u, uprev)
```

!!! note

The default for NonlinearSolve.jl is `AbsSafeBestTerminationMode`!

### Absolute Tolerance

```@docs
AbsTerminationMode
AbsNormTerminationMode
AbsSafeTerminationMode
AbsSafeBestTerminationMode
```

### Relative Tolerance

```@docs
RelTerminationMode
RelNormTerminationMode
RelSafeTerminationMode
RelSafeBestTerminationMode
```

### Both Absolute and Relative Tolerance

```@docs
NormTerminationMode
SteadyStateDiffEqTerminationMode
SimpleNonlinearSolveTerminationMode
```

### Return Codes

```@docs
DiffEqBase.NonlinearSafeTerminationReturnCode
DiffEqBase.NonlinearSafeTerminationReturnCode.Success
DiffEqBase.NonlinearSafeTerminationReturnCode.Default
DiffEqBase.NonlinearSafeTerminationReturnCode.Failure
DiffEqBase.NonlinearSafeTerminationReturnCode.PatienceTermination
DiffEqBase.NonlinearSafeTerminationReturnCode.ProtectiveTermination
```

## [Deprecated] Termination Condition API

!!! warning

This is deprecated. Currently only parts of `SimpleNonlinearSolve` uses this API. That
will also be phased out soon!

```@docs
NLSolveTerminationCondition
Expand Down
3 changes: 0 additions & 3 deletions docs/src/tutorials/termination_conditions.md

This file was deleted.

4 changes: 2 additions & 2 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ end
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtJL, args...; abstol = 1e-8, reltol = 1e-8,
verbose = false, maxiters = 1000, kwargs...)
alg::FastLevenbergMarquardtJL, args...; abstol = 1e-8, reltol = 1e-8,
verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"
Expand Down
2 changes: 1 addition & 1 deletion ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresOptimJL,
args...; abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
args...; abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

f! = FunctionWrapper{iip}(prob.f, prob.p)
Expand Down
20 changes: 16 additions & 4 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ PrecompileTools.@recompile_invalidations begin

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix,
matrix_colors, parameterless_type, ismutable, issingular,fast_scalar_indexing
matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing
import ConcreteStructs: @concrete
import EnumX: @enumx
import ForwardDiff
Expand All @@ -30,6 +30,9 @@ PrecompileTools.@recompile_invalidations begin
end

@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
import DiffEqBase: AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, get_termination_mode

const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
Expand All @@ -44,7 +47,7 @@ abstract type AbstractNonlinearSolveCache{iip} end
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end
Expand All @@ -53,6 +56,9 @@ function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
end
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)
get_u(cache::AbstractNonlinearSolveCache) = cache.u
SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
while not_terminated(cache)
Expand All @@ -69,7 +75,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
end
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache);
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, cache.stats)
end

Expand All @@ -96,7 +102,7 @@ PrecompileTools.@compile_workload begin
NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, T[0.1], T[2]))

precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
PseudoTransient(), GeneralBroyden(), GeneralKlement(), DFSane(), nothing)

for prob in probs, alg in precompile_algs
solve(prob, alg, abstol = T(1e-2))
Expand All @@ -113,4 +119,10 @@ export RobustMultiNewton, FastShortcutNonlinearPolyalg

export LineSearch, LiFukushimaLineSearch

# Export the termination conditions from DiffEqBase
export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode,
NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode,
AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode,
RelSafeBestTerminationMode, AbsSafeBestTerminationMode

end # module
12 changes: 6 additions & 6 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
false, <:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {T, V, P}
false, <:Dual{T, V, P}}, alg::AbstractNonlinearSolveAlgorithm, args...;
kwargs...) where {T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
false, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {T, V, P}
false, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNonlinearSolveAlgorithm,
args...; kwargs...) where {T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
Expand All @@ -53,11 +53,11 @@ function scalar_nlsolve_∂f_∂u(f, u, p)
end

function scalar_nlsolve_dual_soln(u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials)
end

function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
end
Loading

0 comments on commit 52b421e

Please sign in to comment.