Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change post to pre hooks #323

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions ext/ClimaTimeSteppersBenchmarkToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
"T_exp_T_lim!" => 4,
"lim!" => 4,
"dss!" => 4,
"post_explicit!" => 3,
"post_implicit!" => 4,
"pre_explicit!" => 3,
"pre_implicit!" => 4,
"step!" => 1,
)
function n_calls_per_step(alg::CTS.RosenbrockAlgorithm)
Expand All @@ -47,8 +47,8 @@ function n_calls_per_step(alg::CTS.RosenbrockAlgorithm)
"T_exp_T_lim!" => CTS.n_stages(alg.tableau),
"lim!" => 0,
"dss!" => CTS.n_stages(alg.tableau),
"post_explicit!" => 0,
"post_implicit!" => CTS.n_stages(alg.tableau),
"pre_explicit!" => 0,
"pre_implicit!" => CTS.n_stages(alg.tableau),
"step!" => 1,
)
end
Expand All @@ -60,7 +60,7 @@ function maybe_push!(trials₀, name, f!, args, kwargs, only)
end

const allowed_names =
["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "post_explicit!", "post_implicit!", "step!"]
["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "pre_explicit!", "pre_implicit!", "step!"]

"""
benchmark_step(
Expand Down Expand Up @@ -89,8 +89,8 @@ Benchmark a DistributedODEIntegrator given:
- "T_exp_T_lim!"
- "lim!"
- "dss!"
- "post_explicit!"
- "post_implicit!"
- "pre_explicit!"
- "pre_implicit!"
- "step!"
"""
function CTS.benchmark_step(
Expand Down Expand Up @@ -123,8 +123,8 @@ function CTS.benchmark_step(
maybe_push!(trials₀, "T_exp_T_lim!", remaining_fun(integrator), remaining_args(integrator), kwargs, only)
maybe_push!(trials₀, "lim!", f.lim!, (Xlim, p, t, u), kwargs, only)
maybe_push!(trials₀, "dss!", f.dss!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "post_explicit!", f.post_explicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "post_implicit!", f.post_implicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "pre_explicit!", f.pre_explicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "pre_implicit!", f.pre_implicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "step!", SciMLBase.step!, (integrator, ), kwargs, only)
#! format: on

Expand Down
10 changes: 5 additions & 5 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFuncti
T_imp!::TI
lim!::L
dss!::D
post_explicit!::PE
post_implicit!::PI
pre_explicit!::PE
pre_implicit!::PI
function ClimaODEFunction(;
T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ...
T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ...
lim! = (u, p, t, u_ref) -> nothing,
dss! = (u, p, t) -> nothing,
post_explicit! = (u, p, t) -> nothing,
post_implicit! = (u, p, t) -> nothing,
pre_explicit! = (u, p, t) -> nothing,
pre_implicit! = (u, p, t) -> nothing,
)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, pre_explicit!, pre_implicit!)

if !isnothing(T_exp_T_lim!)
@assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`"
Expand Down
4 changes: 2 additions & 2 deletions src/integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ function DiffEqBase.__init(
tdir,
)
if prob.f isa ClimaODEFunction
(; post_explicit!) = prob.f
isnothing(post_explicit!) || post_explicit!(u0, p, t0)
(; pre_explicit!) = prob.f
isnothing(pre_explicit!) || pre_explicit!(u0, p, t0)
end
DiffEqBase.initialize!(callback, u0, t0, integrator)
return integrator
Expand Down
51 changes: 18 additions & 33 deletions src/nl_solvers/newtons_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end
Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov
method without directly using the Jacobian `j(x[n])`, and instead only using
`x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`.
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, pre_implicit!)`.
The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can
be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where
`x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
Expand All @@ -151,13 +151,13 @@ end

allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = similar(x_prototype), f2 = similar(x_prototype))

function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)
function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, pre_implicit!)
(; default_step, step_adjustment) = alg
(; x2, f2) = cache
FT = eltype(x)
ε = FT(step_adjustment) * default_step(Δx, x)
@. x2 = x + ε * Δx
isnothing(post_implicit!) || post_implicit!(x2)
isnothing(pre_implicit!) || pre_implicit!(x2)
f!(f2, x2)
@. jΔx = (f2 - f) / ε
end
Expand Down Expand Up @@ -343,7 +343,7 @@ end
Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such
that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the
value of the forcing term on iteration `n`. This is done by calling
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`,
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, pre_implicit!, j = nothing)`,
where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an
approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place.
The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
Expand Down Expand Up @@ -428,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype)
)
end

NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)
NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, pre_implicit!, j = nothing)
(; jacobian_free_jvp, forcing_term, solve_kwargs) = alg
(; disable_preconditioner, debugger) = alg
type = solver_type(alg)
(; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache
jΔx!(jΔx, Δx) =
isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) :
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!)
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, pre_implicit!)
opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!)
M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j
print_debug!(debugger, debugger_cache, opj, M)
Expand Down Expand Up @@ -567,32 +567,22 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing)
)
end

solve_newton!(
alg::NewtonsMethod,
cache::Nothing,
x,
f!,
j! = nothing,
post_implicit! = nothing,
post_implicit_last! = nothing,
) = nothing

NVTX.@annotate function solve_newton!(
alg::NewtonsMethod,
cache,
x,
f!,
j! = nothing,
post_implicit! = nothing,
post_implicit_last! = nothing,
)
solve_newton!(alg::NewtonsMethod, cache::Nothing, x, f!, j! = nothing, pre_implicit! = nothing) = nothing

NVTX.@annotate function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, pre_implicit! = nothing)
(; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg
(; krylov_method_cache, convergence_checker_cache) = cache
(; Δx, f, j) = cache
if (!isnothing(j)) && needs_update!(update_j, NewNewtonSolve())
j!(j, x)
if !isnothing(pre_implicit!) && !isempty(1:max_iters)
pre_implicit!(x)
if (!isnothing(j)) && needs_update!(update_j, NewNewtonSolve())
j!(j, x)
end
end
for n in 1:max_iters
if !isnothing(pre_implicit!)
n 1 && pre_implicit!(x)
end
# Compute Δx[n].
if (!isnothing(j)) && needs_update!(update_j, NewNewtonIteration())
j!(j, x)
Expand All @@ -605,20 +595,15 @@ NVTX.@annotate function solve_newton!(
ldiv!(Δx, j, f)
end
else
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j)
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, pre_implicit!, j)
end
is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"

x .-= Δx
# Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed.
# Check for convergence if necessary.
if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n)
isnothing(post_implicit_last!) || post_implicit_last!(x)
break
elseif n == max_iters
isnothing(post_implicit_last!) || post_implicit_last!(x)
else
isnothing(post_implicit!) || post_implicit!(x)
end
if is_verbose(verbose) && n == max_iters
@warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
Expand Down
28 changes: 14 additions & 14 deletions src/solvers/hard_coded_ars343.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
(; u, p, t, dt, sol, alg) = integrator
(; f) = sol.prob
(; T_imp!, lim!, dss!) = f
(; post_explicit!, post_implicit!) = f
(; pre_explicit!, pre_implicit!) = f
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
(; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache
Expand Down Expand Up @@ -34,7 +34,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
lim!(U, p, t_exp, u)
@. U += dt * a_exp[i, 1] * T_exp[1]
dss!(U, p, t_exp)
post_explicit!(U, p, t_exp)

@. temp = U # used in closures
let i = i
Expand All @@ -46,21 +45,22 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> begin
post_implicit!(Ui, p, t_imp)
call_pre_implicit! = Ui -> begin
pre_implicit!(Ui, p, t_imp)
end
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_pre_implicit!,
)
end

@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])

pre_explicit!(U, p, t_exp)
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)

Expand All @@ -70,7 +70,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
lim!(U, p, t_exp, u)
@. U += dt * a_exp[i, 1] * T_exp[1] + dt * a_exp[i, 2] * T_exp[2] + dt * a_imp[i, 2] * T_imp[2]
dss!(U, p, t_exp)
post_explicit!(U, p, t_exp)

@. temp = U # used in closures
let i = i
Expand All @@ -82,21 +81,22 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> begin
post_implicit!(Ui, p, t_imp)
call_pre_implicit! = Ui -> begin
pre_implicit!(Ui, p, t_imp)
end
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_pre_implicit!,
)
end

@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])

pre_explicit!(U, p, t_exp)
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)
i = 4
Expand All @@ -110,7 +110,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
dt * a_imp[i, 2] * T_imp[2] +
dt * a_imp[i, 3] * T_imp[3]
dss!(U, p, t_exp)
post_explicit!(U, p, t_exp)

@. temp = U # used in closures
let i = i
Expand All @@ -122,21 +121,22 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> begin
post_implicit!(Ui, p, t_imp)
call_pre_implicit! = Ui -> begin
pre_implicit!(Ui, p, t_imp)
end
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_pre_implicit!,
)
end

@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])

pre_explicit!(U, p, t_exp)
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)

Expand All @@ -155,6 +155,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
dt * b_imp[3] * T_imp[3] +
dt * b_imp[4] * T_imp[4]
dss!(u, p, t_final)
post_explicit!(u, p, t_final)
pre_explicit!(U, p, t_final)
return u
end
Loading
Loading