From a688a85446b5669b615c78dabe9499e1c3e8983c Mon Sep 17 00:00:00 2001
From: Charles Kawczynski <kawczynski.charles@gmail.com>
Date: Tue, 3 Oct 2023 17:05:00 -0700
Subject: [PATCH 1/2] callback fixes

---
 src/nl_solvers/newtons_method.jl | 53 ++++++++++++++++++++++----------
 src/solvers/imex_ark.jl          | 40 +++++++++++++++---------
 src/solvers/imex_ssprk.jl        | 35 ++++++++++++++-------
 3 files changed, 86 insertions(+), 42 deletions(-)

diff --git a/src/nl_solvers/newtons_method.jl b/src/nl_solvers/newtons_method.jl
index d60cc443..89f354eb 100644
--- a/src/nl_solvers/newtons_method.jl
+++ b/src/nl_solvers/newtons_method.jl
@@ -130,10 +130,10 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end
 Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov
 method without directly using the Jacobian `j(x[n])`, and instead only using
 `x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by
-calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f)`. The `jΔx` passed to
-a Jacobian-free JVP is modified in-place. The `cache` can be obtained with
-`allocate_cache(::JacobianFreeJVP, x_prototype)`, where `x_prototype` is
-`similar` to `x` (and also to `Δx` and `f`).
+calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`.
+The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can
+be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where
+`x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
 """
 abstract type JacobianFreeJVP end
 
@@ -151,12 +151,13 @@ end
 
 allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = similar(x_prototype), f2 = similar(x_prototype))
 
-function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f)
+function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)
     (; default_step, step_adjustment) = alg
     (; x2, f2) = cache
     FT = eltype(x)
     ε = FT(step_adjustment) * default_step(Δx, x)
     @. x2 = x + ε * Δx
+    isnothing(post_implicit!) || post_implicit!(x2)
     f!(f2, x2)
     @. jΔx = (f2 - f) / ε
 end
@@ -342,10 +343,10 @@ end
 Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such
 that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the
 value of the forcing term on iteration `n`. This is done by calling
-`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)`, where `f` is
-`f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an approximation
-of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place. The
-`cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
+`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`,
+where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an
+approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place.
+The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
 where `x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
 
 This is primarily a wrapper for a `Krylov.KrylovSolver` from `Krylov.jl`. In
@@ -427,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype)
     )
 end
 
-function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)
+function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)
     (; jacobian_free_jvp, forcing_term, solve_kwargs) = alg
     (; disable_preconditioner, debugger) = alg
     type = solver_type(alg)
     (; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache
     jΔx!(jΔx, Δx) =
         isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) :
-        jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f)
+        jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!)
     opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!)
     M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j
     print_debug!(debugger, debugger_cache, opj, M)
@@ -566,9 +567,25 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing)
     )
 end
 
-solve_newton!(alg::NewtonsMethod, cache::Nothing, x, f!, j! = nothing, post_implicit! = nothing) = nothing
-
-function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, post_implicit! = nothing)
+solve_newton!(
+    alg::NewtonsMethod,
+    cache::Nothing,
+    x,
+    f!,
+    j! = nothing,
+    post_implicit! = nothing,
+    post_implicit_last! = nothing,
+) = nothing
+
+function solve_newton!(
+    alg::NewtonsMethod,
+    cache,
+    x,
+    f!,
+    j! = nothing,
+    post_implicit! = nothing,
+    post_implicit_last! = nothing,
+)
     (; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg
     (; krylov_method_cache, convergence_checker_cache) = cache
     (; Δx, f, j) = cache
@@ -588,16 +605,20 @@ function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, post_impl
                 ldiv!(Δx, j, f)
             end
         else
-            solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, j)
+            solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j)
         end
         is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
 
         x .-= Δx
-        isnothing(post_implicit!) || post_implicit!(x)
         # Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed.
         # Check for convergence if necessary.
         if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n)
+            isnothing(post_implicit_last!) || post_implicit_last!(x)
             break
+        elseif n == max_iters
+            isnothing(post_implicit_last!) || post_implicit_last!(x)
+        else
+            isnothing(post_implicit!) || post_implicit!(x)
         end
         if is_verbose(verbose) && n == max_iters
             @warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl
index d0da9264..5250f2cb 100644
--- a/src/solvers/imex_ark.jl
+++ b/src/solvers/imex_ark.jl
@@ -47,10 +47,11 @@ end
 
 step_u!(integrator, cache::IMEXARKCache) = step_u!(integrator, cache, integrator.sol.prob.f, integrator.alg.name)
 
-include("hard_coded_ars343.jl")
+# include("hard_coded_ars343.jl")
 # generic fallback
 function step_u!(integrator, cache::IMEXARKCache, f, name)
     (; u, p, t, dt, alg) = integrator
+    (; post_explicit!, post_implicit!) = f
     (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
     (; tableau, newtons_method) = alg
     (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
@@ -114,11 +115,14 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
                 dss!(U, p, t_exp)
             end
 
-            if !isnothing(T_imp!) && !iszero(a_imp[i, i]) # Implicit solve
+            if !(!isnothing(T_imp!) && !iszero(a_imp[i, i])) # Implicit solve
+                post_explicit!(U, p, t_imp)
+            else
                 @assert !isnothing(newtons_method)
                 NVTX.@range "temp = U" color = colorant"yellow" begin
                     @. temp = U
                 end
+                post_explicit!(U, p, t_imp)
                 # TODO: can/should we remove these closures?
                 implicit_equation_residual! =
                     (residual, Ui) -> begin
@@ -130,6 +134,18 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
                         end
                     end
                 implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
+                call_post_implicit! = Ui -> begin
+                    post_implicit!(Ui, p, t_imp)
+                end
+                call_post_implicit_last! =
+                    Ui -> begin
+                        if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !iszero(a_imp[i, i])
+                            # If T_imp[i] is being treated implicitly, ensure that it
+                            # exactly satisfies the implicit equation.
+                            @. T_imp[i] = (Ui - temp) / (dt * a_imp[i, i])
+                        end
+                        post_implicit!(Ui, p, t_imp)
+                    end
 
                 NVTX.@range "solve_newton!" color = colorant"yellow" begin
                     solve_newton!(
@@ -138,6 +154,8 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
                         U,
                         implicit_equation_residual!,
                         implicit_equation_jacobian!,
+                        call_post_implicit!,
+                        call_post_implicit_last!,
                     )
                 end
             end
@@ -147,19 +165,11 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
             # tendency only acts in the vertical direction).
 
             if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
-                if !isnothing(T_imp!)
-                    if iszero(a_imp[i, i])
-                        # If its coefficient is 0, T_imp[i] is effectively being
-                        # treated explicitly.
-                        NVTX.@range "T_imp!" color = colorant"yellow" begin
-                            T_imp!(T_imp[i], U, p, t_imp)
-                        end
-                    else
-                        # If T_imp[i] is being treated implicitly, ensure that it
-                        # exactly satisfies the implicit equation.
-                        NVTX.@range "T_imp=(U-temp)/(dt*a_imp)" color = colorant"yellow" begin
-                            @. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
-                        end
+                if iszero(a_imp[i, i]) && !isnothing(T_imp!)
+                    # If its coefficient is 0, T_imp[i] is effectively being
+                    # treated explicitly.
+                    NVTX.@range "T_imp!" color = colorant"yellow" begin
+                        T_imp!(T_imp[i], U, p, t_imp)
                     end
                 end
             end
diff --git a/src/solvers/imex_ssprk.jl b/src/solvers/imex_ssprk.jl
index 3a473416..427ac1bc 100644
--- a/src/solvers/imex_ssprk.jl
+++ b/src/solvers/imex_ssprk.jl
@@ -56,6 +56,7 @@ step_u!(integrator, cache::IMEXSSPRKCache) = step_u!(integrator, cache, integrat
 
 function step_u!(integrator, cache::IMEXSSPRKCache, f, name)
     (; u, p, t, dt, alg) = integrator
+    (; post_explicit!, post_implicit!) = f
     (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
     (; tableau, newtons_method) = alg
     (; a_imp, b_imp, c_exp, c_imp) = tableau
@@ -104,21 +105,39 @@ function step_u!(integrator, cache::IMEXSSPRKCache, f, name)
             end
         end
 
-        if !isnothing(T_imp!) && !iszero(a_imp[i, i]) # Implicit solve
+        if !(!isnothing(T_imp!) && !iszero(a_imp[i, i])) # Implicit solve
+            post_explicit!(U, p, t_imp)
+        else
             @assert !isnothing(newtons_method)
             @. temp = U
+            post_explicit!(U, p, t_imp)
             # TODO: can/should we remove these closures?
             implicit_equation_residual! = (residual, Ui) -> begin
                 T_imp!(residual, Ui, p, t_imp)
                 @. residual = temp + dt * a_imp[i, i] * residual - Ui
             end
             implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
+            call_post_implicit! = Ui -> begin
+                post_implicit!(Ui, p, t_imp)
+            end
+            call_post_implicit_last! =
+                Ui -> begin
+                    if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !iszero(a_imp[i, i])
+                        # If T_imp[i] is being treated implicitly, ensure that it
+                        # exactly satisfies the implicit equation.
+                        @. T_imp[i] = (Ui - temp) / (dt * a_imp[i, i])
+                    end
+                    post_implicit!(Ui, p, t_imp)
+                end
+
             solve_newton!(
                 newtons_method,
                 newtons_method_cache,
                 U,
                 implicit_equation_residual!,
                 implicit_equation_jacobian!,
+                call_post_implicit!,
+                call_post_implicit_last!,
             )
         end
 
@@ -127,16 +146,10 @@ function step_u!(integrator, cache::IMEXSSPRKCache, f, name)
         # tendency only acts in the vertical direction).
 
         if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
-            if !isnothing(T_imp!)
-                if iszero(a_imp[i, i])
-                    # If its coefficient is 0, T_imp[i] is effectively being
-                    # treated explicitly.
-                    T_imp!(T_imp[i], U, p, t_imp)
-                else
-                    # If T_imp[i] is being treated implicitly, ensure that it
-                    # exactly satisfies the implicit equation.
-                    @. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
-                end
+            if iszero(a_imp[i, i]) && !isnothing(T_imp!)
+                # If its coefficient is 0, T_imp[i] is effectively being
+                # treated explicitly.
+                T_imp!(T_imp[i], U, p, t_imp)
             end
         end
 

From f497ddee29aeb7af86a49d71f9e628fbbc707200 Mon Sep 17 00:00:00 2001
From: Charles Kawczynski <kawczynski.charles@gmail.com>
Date: Tue, 3 Oct 2023 21:00:47 -0700
Subject: [PATCH 2/2] Bump patch version

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index e0414979..07396b8c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "ClimaTimeSteppers"
 uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
 authors = ["Climate Modeling Alliance"]
-version = "0.7.10"
+version = "0.7.11"
 
 [deps]
 CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"