Skip to content

Commit

Permalink
feat: add override_init_get_nlsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 11, 2024
1 parent c61b13d commit 924e92b
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 14 deletions.
12 changes: 10 additions & 2 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve!
import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: AbstractADType
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset
using Expronicon.ADT: @match

Expand Down Expand Up @@ -351,7 +351,15 @@ struct CheckInit <: DAEInitializationAlgorithm end
"""
$(TYPEDEF)
"""
struct OverrideInit <: DAEInitializationAlgorithm end
struct OverrideInit{T, F} <: DAEInitializationAlgorithm
abstol::T
nlsolve::F
end

function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
OverrideInit(abstol, nlsolve)
end
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)

# PDE Discretizations

Expand Down
53 changes: 49 additions & 4 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,51 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
end

struct NoNonlinearSolverError <: Exception
end

function Base.showerror(io::IO, err::NoNonlinearSolverError)
println(io, """
This problem requires initialization and thus a nonlinear solve, but no nonlinear \
solve has been loaded. If you are using OrdinaryDiffEq, import the \
`OrdinaryDiffEqNonlinearSolve` package or pass a custom `nlsolve` into the \
`initializealg`. If you are not using `OrdinaryDiffEq`, please open an issue in \
the appropriate library with an MWE.
""")
end

"""
$(TYPEDSIGNATURES)
Given a user-provided nonlinear solve algorithm `alg`, `iip::Union{Val{true}, Val{false}}`
indicating whether the initialization problem is in-place or not, the initial state
vector of the initialization problem, the initialization problem (either a
`NonlinearProblem` or `NonlinearLeastSquaresProblem`) and a boolean `autodiff`
indicating whether to use `AutoForwardDiff` or `AutoFiniteDiff`, return a nonlinear
solve algorithm to use for solving the initialization. If `alg` is not nothing, it will
be returned as-is. If the initialization problem is trivial (`u === nothing`) the trivial
`nothing` algorithm will be used. Otherwise, requires `NonlinearSolve.jl` to
automatically find an appropriate solver.
"""
override_init_get_nlsolve(alg, iip, u, prob, autodiff = false) = alg

for iip in (Val{true}, Val{false}), prob in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval function override_init_get_nlsolve(
::Nothing, ::$(iip), u::Nothing, ::$(prob), autodiff = false)
nothing
end
end

function override_init_get_nlsolve(
::Nothing, isinplace, u, initializeprob::NonlinearProblem, autodiff = false)
throw(NoNonlinearSolverError())
end

function override_init_get_nlsolve(
::Nothing, isinplace, u, initializeprob::NonlinearLeastSquaresProblem, autodiff = false)
throw(NoNonlinearSolverError())
end

"""
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
Expand Down Expand Up @@ -160,7 +205,7 @@ argument, failing which this function will throw an error. The success value ret
depends on the success of the nonlinear solve.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, autodiff = false, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

Expand All @@ -171,9 +216,9 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

if nlsolve_alg === nothing
throw(OverrideInitMissingAlgorithm())
end
nlsolve_alg = override_init_get_nlsolve(
something(nlsolve_alg, alg.nlsolve, Some(nothing)),
Val{isinplace(initprob)}(), state_values(initprob), initprob, autodiff)

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
Expand Down
44 changes: 36 additions & 8 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,48 @@ end
integ = init(prob; initializealg = NoInit())

@testset "Errors without `nlsolve_alg`" begin
@test_throws SciMLBase.OverrideInitMissingAlgorithm SciMLBase.get_initial_values(
@test_throws SciMLBase.NoNonlinearSolverError SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
end

@testset "Solves" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())
@testset "with explicit alg" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())

@test u0 [2.0, 2.0]
@test p 1.0
@test success
@test u0 [2.0, 2.0]
@test p 1.0
@test success

initprob.p[1] = 1.0
initprob.p[1] = 1.0
end
@testset "with alg in `OverrideInit`" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()),
Val(false))

@test u0 [2.0, 2.0]
@test p 1.0
@test success

initprob.p[1] = 1.0
end
@testset "with trivial problem and no alg" begin
iprob = NonlinearProblem((u, p) -> 0.0, nothing, 1.0)
iprobmap = (_) -> [1.0, 1.0]
initdata = SciMLBase.OverrideInitData(iprob, nothing, iprobmap, nothing)
_fn = ODEFunction(rhs2; initialization_data = initdata)
_prob = ODEProblem(_fn, [2.0, 0.0], (0.0, 1.0), 1.0)
_integ = init(_prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false))

@test u0 [1.0, 1.0]
@test p 1.0
@test success
end
end

@testset "Solves with non-integrator value provider" begin
Expand Down

0 comments on commit 924e92b

Please sign in to comment.