Skip to content

Commit

Permalink
Merge pull request #24 from PALEOtoolkit/solver_robustness
Browse files Browse the repository at this point in the history
Solver robustness fixes
sjdaines authored Aug 25, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents 2d77d31 + 3a63963 commit c9f3936
Showing 9 changed files with 48 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PALEOmodel"
uuid = "bf7b4fbe-ccb1-42c5-83c2-e6e9378b660c"
authors = ["Stuart Daines <[email protected]>"]
version = "0.15.7"
version = "0.15.8"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
8 changes: 6 additions & 2 deletions src/JacobianAD.jl
Original file line number Diff line number Diff line change
@@ -53,7 +53,9 @@ function jac_config_ode(
jac_cellranges=modeldata.cellranges_all,
init_logger=Logging.NullLogger(),
)
@info "jac_config_ode: jac_ad=$jac_ad"
@info "jac_config_ode: jac_ad=$jac_ad"

PB.check_modeldata(model, modeldata)

iszero(PALEOmodel.num_total(modeldata.solver_view_all)) ||
throw(ArgumentError("model contains implicit variables, solve as a DAE"))
@@ -158,7 +160,9 @@ function jac_config_dae(
implicit_cellranges=modeldata.cellranges_all,
init_logger=Logging.NullLogger(),
)
@info "jac_config_dae: jac_ad=$jac_ad"
@info "jac_config_dae: jac_ad=$jac_ad"

PB.check_modeldata(model, modeldata)

# generate arrays with ODE layout for model Variables
state_sms_vars_data = similar(PALEOmodel.get_statevar_sms(modeldata.solver_view_all))
25 changes: 9 additions & 16 deletions src/Kinsol.jl
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@ module Kinsol

import Sundials

# import Infiltrator

###########################################################
# Internals: Julia <-> C wrapper functions
###########################################################
@@ -25,24 +27,17 @@ mutable struct UserFunctionAndData{F1, F2, F3, F4}
data::Any
end

UserFunctionAndData(func, data) = UserFunctionAndData(func, nothing, nothing, nothing, data)
UserFunctionAndData(func) = func
UserFunctionAndData(func, psetup::Nothing, psolve::Nothing, jv::Nothing, data::Nothing) = func

# Julia adaptor function with C types, passed in to Kinsol C code as a callback
# wraps C types and forwards to the Julia user function
function kinsolfun(y::Sundials.N_Vector, fy::Sundials.N_Vector, userfun::UserFunctionAndData)
# @Infiltrator.infiltrate
function kinsolfun(
y::Sundials.N_Vector,
fy::Sundials.N_Vector,
userfun::UserFunctionAndData
)
userfun.func(convert(Vector, fy), convert(Vector, y), userfun.data)
return Sundials.KIN_SUCCESS
end

function kinsolfun(y::Sundials.N_Vector, fy::Sundials.N_Vector, userfun)
# @Infiltrator.infiltrate
userfun(convert(Vector, fy), convert(Vector, y))
return Sundials.KIN_SUCCESS
end

function kinprecsetup(
u::Sundials.N_Vector,
uscale::Sundials.N_Vector,
@@ -69,7 +64,6 @@ function kinprecsolve(
v::Sundials.N_Vector,
userfun::UserFunctionAndData
)

retval = userfun.psolve(
convert(Vector, u),
convert(Vector, uscale),
@@ -87,7 +81,7 @@ function kinjactimesvec(
u::Sundials.N_Vector,
new_u::Ptr{Cint},
userfun::UserFunctionAndData
)
)
retval = userfun.jv(
convert(Vector, v),
convert(Vector, Jv),
@@ -136,6 +130,7 @@ function kin_create(
# use the user_data field to pass a function
# see: https://github.com/JuliaLang/julia/issues/2554
userfun = UserFunctionAndData(f, psetupfun, psolvefun, jvfun, userdata)
# push!(handles, userfun) # TODO prevent userfun from being garbage collected ?
function getkinsolfun(userfun::T) where {T}
@cfunction(kinsolfun, Cint, (Sundials.N_Vector, Sundials.N_Vector, Ref{T}))
end
@@ -220,10 +215,8 @@ function kin_solve(
flag = Sundials.@checkflag Sundials.KINSetNoInitSetup(kmem, noInitSetup) true

## Solve problem
# @Infiltrator.infiltrate
returnflag = Sundials.KINSol(kmem, y, strategy, y_scale, f_scale)


## Get stats
nfevals = [0]
flag = Sundials.@checkflag Sundials.KINGetNumFuncEvals(kmem, nfevals)
4 changes: 3 additions & 1 deletion src/ODE.jl
Original file line number Diff line number Diff line change
@@ -52,6 +52,7 @@ function ODEfunction(
jac_ad_t_sparsity=nothing,
init_logger=Logging.NullLogger(),
)
PB.check_modeldata(model, modeldata)

# check for implicit total variables
PALEOmodel.num_total(modeldata.solver_view_all) == 0 ||
@@ -101,8 +102,9 @@ function DAEfunction(
jac_ad_t_sparsity=nothing,
init_logger=Logging.NullLogger(),
)

@info "DAEfunction: using Jacobian $jac_ad"

PB.check_modeldata(model, modeldata)

jac, jac_prototype, odeimplicit = PALEOmodel.JacobianAD.jac_config_dae(
jac_ad, model, initial_state, modeldata, jac_ad_t_sparsity,
6 changes: 6 additions & 0 deletions src/ODELocalIMEX.jl
Original file line number Diff line number Diff line change
@@ -59,6 +59,8 @@ function integrateLocalIMEXEuler(

@info "integrateLocalIMEXEuler: Δt_outer=$Δt_outer (yr)"

PB.check_modeldata(run.model, modeldata)

solver_view_outer = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_outer)
@info "solver_view_outer: $(solver_view_outer)"

@@ -106,6 +108,8 @@ function timestep_LocalImplicit(
deriv_only=false,
integrator_barrier=nothing,
)
PB.check_modeldata(model, modeldata)

length(cellranges) == 1 || error("timestep_LocalImplicit only single cellrange supported")
cellrange = cellranges[1]

@@ -180,6 +184,7 @@ function create_timestep_LocalImplicit_ctxt(
niter_max,
Lnorm_inf_max
)
PB.check_modeldata(model, modeldata)

lictxt = PALEOmodel.ODELocalIMEX.getLocalImplicitContext(
model, modeldata, cellrange, exclude_var_nameroots,
@@ -196,6 +201,7 @@ function getLocalImplicitContext(
request_adchunksize=ForwardDiff.DEFAULT_CHUNK_THRESHOLD,
init_logger=Logging.NullLogger(),
)
PB.check_modeldata(model, modeldata)

# create SolverViews for first cell, to work out how many dof we need
cellrange_cell = PB.CellRange(cellrange.domain, cellrange.operatorID, first(cellrange.indices) )
12 changes: 10 additions & 2 deletions src/ODEfixed.jl
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ function integrateEuler(
outputwriter=run.output,
report_interval=1000
)
PB.check_modeldata(run.model, modeldata)

timesteppers = [
[(
@@ -96,6 +97,8 @@ function integrateSplitEuler(

@info "integrateSplitEuler: Δt_outer=$Δt_outer (yr) n_inner=$n_inner"

PB.check_modeldata(run.model, modeldata)

solver_view_outer = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_outer)
@info "solver_view_outer: $(solver_view_outer)"
solver_view_inner = PALEOmodel.create_solver_view(run.model, modeldata, cellranges_inner)
@@ -159,11 +162,12 @@ function integrateEulerthreads(
outputwriter=run.output,
report_interval=1000,
)
PB.check_modeldata(run.model, modeldata)

nt = Threads.nthreads()
nt == 1 || modeldata.threadsafe ||
error("integrateEulerthreads: Threads.nthreads() = $nt but modeldata is not thread safe "*
"(check initialize!(run::Run, ...))")
"(check initialize!(run.model, ...))")

lc = length(cellranges)
lc == nt ||
@@ -235,7 +239,8 @@ function integrateSplitEulerthreads(
outputwriter=run.output,
report_interval=1000,
)

PB.check_modeldata(run.model, modeldata)

nt = Threads.nthreads()
nt == 1 || modeldata.threadsafe ||
error("integrateEulerthreads: Threads.nthreads() = $nt but modeldata is not thread safe (check initialize!(run::Run, ...))")
@@ -338,6 +343,7 @@ function create_timestep_Euler_ctxt(
n_substep=1,
verbose=false,
)
PB.check_modeldata(model, modeldata)

num_constraints = PALEOmodel.num_algebraic_constraints(solver_view)
iszero(num_constraints) || error("DAE problem with $num_constraints algebraic constraints")
@@ -368,6 +374,7 @@ function integrateFixed(
outputwriter=run.output,
report_interval=1000
)
PB.check_modeldata(run.model, modeldata)

nevals = 0

@@ -447,6 +454,7 @@ function integrateFixedthreads(
outputwriter=run.output,
report_interval=1000
)
PB.check_modeldata(run.model, modeldata)

nevals = 0

8 changes: 6 additions & 2 deletions src/Run.jl
Original file line number Diff line number Diff line change
@@ -26,11 +26,15 @@ function Base.show(io::IO, ::MIME"text/plain", run::Run)
end


initialize!(run::Run; kwargs...) = initialize!(run.model; kwargs...)
function initialize!(run::Run; kwargs...)
Base.depwarn("call to deprecated initialize!(run::Run; ...), please update your code to use initialize!(run.model; ...)", :initialize!, force=true)

return initialize!(run.model; kwargs...)
end

"""
initialize!(model::PB.Model; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData)
initialize!(run::Run; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData)
[deprecated] initialize!(run::Run; kwargs...) -> (initial_state::Vector, modeldata::PB.ModelData)
Initialize `model` or `run.model` and return:
- an `initial_state` Vector
5 changes: 5 additions & 0 deletions src/SteadyState.jl
Original file line number Diff line number Diff line change
@@ -53,6 +53,7 @@ function steadystate(
use_norm::Bool=false,
BLAS_num_threads=1,
)
PB.check_modeldata(run.model, modeldata)

LinearAlgebra.BLAS.set_num_threads(BLAS_num_threads)
@info "steadystate: using BLAS with $(LinearAlgebra.BLAS.get_num_threads()) threads"
@@ -189,6 +190,8 @@ function steadystate_ptc(
verbose=false,
BLAS_num_threads=1
)
PB.check_modeldata(run.model, modeldata)

!use_norm || ArgumentError("use_norm=true not supported")

nlsolveF = nlsolveF_PTC(
@@ -272,6 +275,8 @@ function nlsolveF_PTC(
request_adchunksize=10,
jac_cellranges=modeldata.cellranges_all,
)
PB.check_modeldata(model, modeldata)

sv = modeldata.solver_view_all

# We only support explicit ODE-like configurations (no DAE constraints or implicit variables)
2 changes: 2 additions & 0 deletions src/SteadyStateKinsol.jl
Original file line number Diff line number Diff line change
@@ -62,6 +62,8 @@ function steadystate_ptc(
verbose=false,
BLAS_num_threads=1
)
PB.check_modeldata(run.model, modeldata)

# start, end times
tss, tss_max = tspan

2 comments on commit c9f3936

@sjdaines
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67060

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.15.8 -m "<description of version>" c9f3936ea5d0e74fd3c74d572a4625cc45d99e54
git push origin v0.15.8

Please sign in to comment.