diff --git a/Project.toml b/Project.toml index 667922e4..41e81c76 100644 --- a/Project.toml +++ b/Project.toml @@ -52,6 +52,7 @@ Logging = "1.10" NonlinearSolve = "3.8.1" ODEInterface = "0.5" OrdinaryDiffEq = "6.89.0" +Pkg = "1.10.0" PreallocationTools = "0.4.24" PrecompileTools = "1.2" Preferences = "1.4" @@ -74,6 +75,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -81,4 +83,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "DiffEqDevTools", "JET", "LinearSolve", "ODEInterface", "OrdinaryDiffEq", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"] +test = ["Aqua", "DiffEqDevTools", "JET", "LinearSolve", "ODEInterface", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"] diff --git a/lib/BoundaryValueDiffEqMIRK/LICENSE.md b/lib/BoundaryValueDiffEqMIRK/LICENSE.md new file mode 100644 index 00000000..b9c908a8 --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/LICENSE.md @@ -0,0 +1,21 @@ +The BoundaryValueDiffEq.jl package is licensed under the MIT "Expat" License: + +> Copyright (c) 2017: ChrisRackauckas. +> +> Permission is hereby granted, free of charge, to any person obtaining a copy +> of this software and associated documentation files (the "Software"), to deal +> in the Software without restriction, including without limitation the rights +> to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +> copies of the Software, and to permit persons to whom the Software is +> furnished to do so, subject to the following conditions: +> +> The above copyright notice and this permission notice shall be included in all +> copies or substantial portions of the Software. +> +> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +> IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +> FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +> AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +> LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +> OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +> SOFTWARE. diff --git a/lib/BoundaryValueDiffEqMIRK/Project.toml b/lib/BoundaryValueDiffEqMIRK/Project.toml new file mode 100644 index 00000000..49c30736 --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/Project.toml @@ -0,0 +1,77 @@ +name = "BoundaryValueDiffEqMIRK" +uuid = "1a22d4ce-7765-49ea-b6f2-13c8438986a6" +version = "0.1.0" + +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +BandedMatrices = "aae01518-5342-5314-be14-df237901396f" +BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" +FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" + +[compat] +ADTypes = "1.2" +Adapt = "4" +Aqua = "0.8.7" +ArrayInterface = "7.7" +BandedMatrices = "1.4" +BoundaryValueDiffEq = "5.10" +ConcreteStructs = "0.2.3" +DiffEqBase = "6.146" +DiffEqDevTools = "2.44" +FastAlmostBandedMatrices = "0.1.1" +FastClosures = "0.3" +ForwardDiff = "0.10.36" +JET = "0.8" +LinearAlgebra = "1.10" +LinearSolve = "2.21" +Logging = "1.10" +NonlinearSolve = "3.8.1" +OrdinaryDiffEq = "6.89.0" +PreallocationTools = "0.4.24" +PrecompileTools = "1.2" +Preferences = "1.4" +Random = "1.10" +ReTestItems = "1.23.1" +RecursiveArrayTools = "3.27.0" +Reexport = "1.2" +SciMLBase = "2.40" +Setfield = "1" +SparseArrays = "1.10" +SparseDiffTools = "2.14" +StaticArrays = "1.8.1" +Test = "1.10" +julia = "1.10" + +[extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Aqua", "DiffEqDevTools", "JET", "LinearSolve", "OrdinaryDiffEq", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"] diff --git a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl new file mode 100644 index 00000000..eae9772b --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl @@ -0,0 +1,160 @@ +module BoundaryValueDiffEqMIRK + +import PrecompileTools: @compile_workload, @setup_workload + +using ADTypes, Adapt, ArrayInterface, DiffEqBase, ForwardDiff, LinearAlgebra, + NonlinearSolve, Preferences, RecursiveArrayTools, Reexport, SciMLBase, Setfield, + SparseDiffTools + +using PreallocationTools: PreallocationTools, DiffCache + +# Special Matrix Types +using BandedMatrices, FastAlmostBandedMatrices, SparseArrays + +import BoundaryValueDiffEq: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorithm, + recursive_flatten, recursive_flatten!, recursive_unflatten!, + __concrete_nonlinearsolve_algorithm, diff!, + __FastShortcutBVPCompatibleNonlinearPolyalg, + __FastShortcutBVPCompatibleNLLSPolyalg, + concrete_jacobian_algorithm, eval_bc_residual, + eval_bc_residual!, get_tmp, __maybe_matmul!, __append_similar!, + __extract_problem_details, __initial_guess, + __maybe_allocate_diffcache, __get_bcresid_prototype, __similar, + __vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!, + recursive_flatten_twopoint!, __unsafe_nonlinearfunction, + __internal_nlsolve_problem, + __generate_sparse_jacobian_prototype, __extract_mesh, + __extract_u0, __has_initial_guess, __initial_guess_length, + __initial_guess_on_mesh, __flatten_initial_guess, + __build_solution, __Fix3, __sparse_jacobian_cache, + __sparsity_detection_alg, _sparse_like, ColoredMatrix + +import ADTypes: AbstractADType +import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing +import ConcreteStructs: @concrete +import DiffEqBase: solve +import FastClosures: @closure +import ForwardDiff: ForwardDiff, pickchunksize +import Logging +import RecursiveArrayTools: ArrayPartition, DiffEqArray +import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val + +@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase + +include("types.jl") +include("algorithms.jl") +include("mirk.jl") +include("adaptivity.jl") +include("alg_utils.jl") +include("collocation.jl") +include("interpolation.jl") +include("mirk_tableaus.jl") +include("sparse_jacobians.jl") + +@setup_workload begin + function f1!(du, u, p, t) + du[1] = u[2] + du[2] = 0 + end + f1 = (u, p, t) -> [u[2], 0] + + function bc1!(residual, u, p, t) + residual[1] = u[:, 1][1] - 5 + residual[2] = u[:, end][1] + end + + bc1 = (u, p, t) -> [u[:, 1][1] - 5, u[:, end][1]] + + bc1_a! = (residual, ua, p) -> (residual[1] = ua[1] - 5) + bc1_b! = (residual, ub, p) -> (residual[1] = ub[1]) + + bc1_a = (ua, p) -> [ua[1] - 5] + bc1_b = (ub, p) -> [ub[1]] + + tspan = (0.0, 5.0) + u0 = [5.0, -3.5] + bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1)) + + probs = [BVProblem(f1!, bc1!, u0, tspan; nlls = Val(false)), + BVProblem(f1, bc1, u0, tspan; nlls = Val(false)), + TwoPointBVProblem( + f1!, (bc1_a!, bc1_b!), u0, tspan; bcresid_prototype, nlls = Val(false)), + TwoPointBVProblem( + f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype, nlls = Val(false))] + + algs = [] + + jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)) + + if Preferences.@load_preference("PrecompileMIRK", true) + append!(algs, [MIRK2(; jac_alg), MIRK4(; jac_alg), MIRK6(; jac_alg)]) + end + + @compile_workload begin + @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2) + end + end + + f1_nlls! = (du, u, p, t) -> begin + du[1] = u[2] + du[2] = -u[1] + end + + f1_nlls = (u, p, t) -> [u[2], -u[1]] + + bc1_nlls! = (resid, sol, p, t) -> begin + solₜ₁ = sol[:, 1] + solₜ₂ = sol[:, end] + resid[1] = solₜ₁[1] + resid[2] = solₜ₂[1] - 1 + resid[3] = solₜ₂[2] + 1.729109 + return nothing + end + bc1_nlls = (sol, p, t) -> [sol[:, 1][1], sol[:, end][1] - 1, sol[:, end][2] + 1.729109] + + bc1_nlls_a! = (resid, ua, p) -> (resid[1] = ua[1]) + bc1_nlls_b! = (resid, ub, p) -> (resid[1] = ub[1] - 1; + resid[2] = ub[2] + 1.729109) + + bc1_nlls_a = (ua, p) -> [ua[1]] + bc1_nlls_b = (ub, p) -> [ub[1] - 1, ub[2] + 1.729109] + + tspan = (0.0, 100.0) + u0 = [0.0, 1.0] + bcresid_prototype1 = Array{Float64}(undef, 3) + bcresid_prototype2 = (Array{Float64}(undef, 1), Array{Float64}(undef, 2)) + + probs = [ + BVProblem(BVPFunction(f1_nlls!, bc1_nlls!; bcresid_prototype = bcresid_prototype1), + u0, tspan, nlls = Val(true)), + BVProblem(BVPFunction(f1_nlls, bc1_nlls; bcresid_prototype = bcresid_prototype1), + u0, tspan, nlls = Val(true)), + TwoPointBVProblem(f1_nlls!, (bc1_nlls_a!, bc1_nlls_b!), u0, tspan; + bcresid_prototype = bcresid_prototype2, nlls = Val(true)), + TwoPointBVProblem(f1_nlls, (bc1_nlls_a, bc1_nlls_b), u0, tspan; + bcresid_prototype = bcresid_prototype2, nlls = Val(true))] + + jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)) + + nlsolvers = [LevenbergMarquardt(; disable_geodesic = Val(true)), GaussNewton()] + + algs = [] + + if Preferences.@load_preference("PrecompileMIRKNLLS", false) + for nlsolve in nlsolvers + append!(algs, [MIRK2(; jac_alg, nlsolve), MIRK6(; jac_alg, nlsolve)]) + end + end + + @compile_workload begin + @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2) + end + end +end + +export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6 +export BVPJacobianAlgorithm + +end diff --git a/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl b/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl new file mode 100644 index 00000000..39a266c2 --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl @@ -0,0 +1,379 @@ +""" + interp_eval!(y::AbstractArray, cache::MIRKCache, t) + +After we construct an interpolant, we use interp_eval to evaluate it. +""" +@views function interp_eval!(y::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt) + i = interval(mesh, t) + dt = mesh_dt[i] + τ = (t - mesh[i]) / dt + w, w′ = interp_weights(τ, cache.alg) + sum_stages!(y, cache, w, i) + return y +end + +""" + interval(mesh, t) + +Find the interval that `t` belongs to in `mesh`. Assumes that `mesh` is sorted. +""" +function interval(mesh, t) + return clamp(searchsortedfirst(mesh, t) - 1, 1, length(mesh) - 1) +end + +""" + mesh_selector!(cache::MIRKCache) + +Generate new mesh based on the defect. +""" +@views function mesh_selector!(cache::MIRKCache{iip, T}) where {iip, T} + (; order, defect, mesh, mesh_dt) = cache + (abstol, _, _), kwargs = __split_mirk_kwargs(; cache.kwargs...) + N = length(mesh) + + safety_factor = T(1.3) + ρ = T(1.0) # Set rho=1 means mesh distribution will take place everytime. + Nsub_star = 0 + Nsub_star_ub = 4 * (N - 1) + Nsub_star_lb = N ÷ 2 + + info = ReturnCode.Success + + ŝ = [maximum(abs, d) for d in defect] # Broadcasting breaks GPU Compilation + ŝ .= (ŝ ./ abstol) .^ (T(1) / (order + 1)) + r₁ = maximum(ŝ) + r₂ = sum(ŝ) + r₃ = r₂ / (N - 1) + + n_predict = round(Int, (safety_factor * r₂) + 1) + n = N - 1 + n_ = T(0.1) * n + n_predict = ifelse(abs((n_predict - n)) < n_, round(Int, n + n_), n_predict) + + if r₁ ≤ ρ * r₂ + Nsub_star = 2 * (N - 1) + if Nsub_star > cache.alg.max_num_subintervals # Need to determine the too large threshold + info = ReturnCode.Failure + meshₒ = mesh + mesh_dt₀ = mesh_dt + else + meshₒ = copy(mesh) + mesh_dt₀ = copy(mesh_dt) + half_mesh!(cache) + end + else + Nsub_star = clamp(n_predict, Nsub_star_lb, Nsub_star_ub) + if Nsub_star > cache.alg.max_num_subintervals + # Mesh redistribution fails + info = ReturnCode.Failure + meshₒ = mesh + mesh_dt₀ = mesh_dt + else + ŝ ./= mesh_dt + meshₒ = copy(mesh) + mesh_dt₀ = copy(mesh_dt) + redistribute!(cache, Nsub_star, ŝ, meshₒ, mesh_dt₀) + end + end + return meshₒ, mesh_dt₀, Nsub_star, info +end + +""" + redistribute!(cache::MIRKCache, Nsub_star, ŝ, mesh, mesh_dt) + +Generate a new mesh based on the `ŝ`. +""" +function redistribute!( + cache::MIRKCache{iip, T}, Nsub_star, ŝ, mesh, mesh_dt) where {iip, T} + N = length(mesh) + ζ = sum(ŝ .* mesh_dt) / Nsub_star + k, i = 1, 0 + append!(cache.mesh, Nsub_star + 1 - N) + cache.mesh[1] = mesh[1] + t = mesh[1] + integral = T(0) + while k ≤ N - 1 + next_piece = ŝ[k] * (mesh[k + 1] - t) + _int_next = integral + next_piece + if _int_next > ζ + cache.mesh[i + 2] = (ζ - integral) / ŝ[k] + t + t = cache.mesh[i + 2] + i += 1 + integral = T(0) + else + integral = _int_next + t = mesh[k + 1] + k += 1 + end + end + cache.mesh[end] = mesh[end] + append!(cache.mesh_dt, Nsub_star - N) + diff!(cache.mesh_dt, cache.mesh) + return cache +end + +""" + half_mesh!(mesh, mesh_dt) + half_mesh!(cache::MIRKCache) + +The input mesh has length of `n + 1`. Divide the original subinterval into two equal length +subinterval. The `mesh` and `mesh_dt` are modified in place. +""" +function half_mesh!(mesh::Vector{T}, mesh_dt::Vector{T}) where {T} + n = length(mesh) - 1 + resize!(mesh, 2n + 1) + resize!(mesh_dt, 2n) + mesh[2n + 1] = mesh[n + 1] + for i in (2n - 1):-2:1 + mesh[i] = mesh[(i + 1) ÷ 2] + mesh_dt[i + 1] = mesh_dt[(i + 1) ÷ 2] / T(2) + end + @simd for i in (2n):-2:2 + mesh[i] = (mesh[i + 1] + mesh[i - 1]) / T(2) + mesh_dt[i - 1] = mesh_dt[i] + end + return mesh, mesh_dt +end +function half_mesh!(cache::MIRKCache) + half_mesh!(cache.mesh, cache.mesh_dt) +end + +""" + defect_estimate!(cache::MIRKCache) + +defect_estimate use the discrete solution approximation Y, plus stages of +the RK method in 'k_discrete', plus some new stages in 'k_interp' to construct +an interpolant +""" +@views function defect_estimate!(cache::MIRKCache{iip, T}) where {iip, T} + (; f, alg, mesh, mesh_dt, defect) = cache + (; τ_star) = cache.ITU + + # Evaluate at the first sample point + w₁, w₁′ = interp_weights(τ_star, alg) + # Evaluate at the second sample point + w₂, w₂′ = interp_weights(T(1) - τ_star, alg) + + interp_setup!(cache) + + for i in 1:(length(mesh) - 1) + dt = mesh_dt[i] + + z, z′ = sum_stages!(cache, w₁, w₁′, i) + if iip + yᵢ₁ = cache.y[i].du + f(yᵢ₁, z, cache.p, mesh[i] + τ_star * dt) + else + yᵢ₁ = f(z, cache.p, mesh[i] + τ_star * dt) + end + yᵢ₁ .= (z′ .- yᵢ₁) ./ (abs.(yᵢ₁) .+ T(1)) + est₁ = maximum(abs, yᵢ₁) + + z, z′ = sum_stages!(cache, w₂, w₂′, i) + if iip + yᵢ₂ = cache.y[i + 1].du + f(yᵢ₂, z, cache.p, mesh[i] + (T(1) - τ_star) * dt) + else + yᵢ₂ = f(z, cache.p, mesh[i] + (T(1) - τ_star) * dt) + end + yᵢ₂ .= (z′ .- yᵢ₂) ./ (abs.(yᵢ₂) .+ T(1)) + est₂ = maximum(abs, yᵢ₂) + + defect.u[i] .= est₁ > est₂ ? yᵢ₁ : yᵢ₂ + end + + return maximum(Base.Fix1(maximum, abs), defect.u) +end + +""" + interp_setup!(cache::MIRKCache) + +`interp_setup!` prepare the extra stages in ki_interp for interpolant construction. +Here, the ki_interp is the stages in one subinterval. +""" +@views function interp_setup!(cache::MIRKCache{iip, T}) where {iip, T} + (; x_star, s_star, c_star, v_star) = cache.ITU + (; k_interp, k_discrete, f, stage, new_stages, y, p, mesh, mesh_dt) = cache + + for r in 1:(s_star - stage) + idx₁ = ((1:stage) .- 1) .* (s_star - stage) .+ r + idx₂ = ((1:(r - 1)) .+ stage .- 1) .* (s_star - stage) .+ r + for j in eachindex(k_discrete) + __maybe_matmul!(new_stages.u[j], k_discrete[j].du[:, 1:stage], x_star[idx₁]) + end + if r > 1 + for j in eachindex(k_interp) + __maybe_matmul!( + new_stages.u[j], k_interp.u[j][:, 1:(r - 1)], x_star[idx₂], T(1), T(1)) + end + end + for i in eachindex(new_stages) + new_stages.u[i] .= new_stages.u[i] .* mesh_dt[i] .+ + (1 - v_star[r]) .* vec(y[i].du) .+ + v_star[r] .* vec(y[i + 1].du) + if iip + f(k_interp.u[i][:, r], new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) + else + k_interp.u[i][:, r] .= f( + new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) + end + end + end + + return k_interp +end + +""" + sum_stages!(cache::MIRKCache, w, w′, i::Int) + +sum_stages add the discrete solution, RK method stages and extra stages to construct interpolant. +""" +function sum_stages!(cache::MIRKCache, w, w′, i::Int, dt = cache.mesh_dt[i]) + sum_stages!(cache.fᵢ_cache.du, cache.fᵢ₂_cache, cache, w, w′, i, dt) +end + +function sum_stages!(z::AbstractArray, cache::MIRKCache, w, i::Int, dt = cache.mesh_dt[i]) + (; stage, k_discrete, k_interp) = cache + (; s_star) = cache.ITU + + z .= zero(z) + __maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage]) + __maybe_matmul!( + z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true) + z .= z .* dt .+ cache.y₀.u[i] + + return z +end + +@views function sum_stages!(z, z′, cache::MIRKCache, w, w′, i::Int, dt = cache.mesh_dt[i]) + (; stage, k_discrete, k_interp) = cache + (; s_star) = cache.ITU + + z .= zero(z) + __maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage]) + __maybe_matmul!( + z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true) + z′ .= zero(z′) + __maybe_matmul!(z′, k_discrete[i].du[:, 1:stage], w′[1:stage]) + __maybe_matmul!( + z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true) + z .= z .* dt[1] .+ cache.y₀.u[i] + + return z, z′ +end + +""" + interp_weights(τ, alg) + +interp_weights: solver-specified interpolation weights and its first derivative +""" +function interp_weights end + +for order in (2, 3, 4, 5, 6) + alg = Symbol("MIRK$(order)") + @eval begin + function interp_weights(τ::T, ::$(alg)) where {T} + if $(order == 2) + w = [0, τ * (1 - τ / 2), τ^2 / 2] + + # Derivative polynomials. + + wp = [0, 1 - τ, τ] + elseif $(order == 3) + w = [τ / 4.0 * (2.0 * τ^2 - 5.0 * τ + 4.0), + -3.0 / 4.0 * τ^2 * (2.0 * τ - 3.0), τ^2 * (τ - 1.0)] + + # Derivative polynomials. + + wp = [3.0 / 2.0 * (τ - 2.0 / 3.0) * (τ - 1.0), + -9.0 / 2.0 * τ * (τ - 1.0), 3.0 * τ * (τ - 2.0 / 3.0)] + elseif $(order == 4) + t2 = τ * τ + tm1 = τ - 1.0 + t4m3 = τ * 4.0 - 3.0 + t2m1 = τ * 2.0 - 1.0 + + w = [-τ * (2.0 * τ - 3.0) * (2.0 * t2 - 3.0 * τ + 2.0) / 6.0, + t2 * (12.0 * t2 - 20.0 * τ + 9.0) / 6.0, + 2.0 * t2 * (6.0 * t2 - 14.0 * τ + 9.0) / 3.0, + -16.0 * t2 * tm1 * tm1 / 3.0] + + # Derivative polynomials + + wp = [-tm1 * t4m3 * t2m1 / 3.0, τ * t2m1 * t4m3, + 4.0 * τ * t4m3 * tm1, -32.0 * τ * t2m1 * tm1 / 3.0] + elseif $(order == 5) + w = [ + τ * (22464.0 - 83910.0 * τ + 143041.0 * τ^2 - 113808.0 * τ^3 + + 33256.0 * τ^4) / 22464.0, + τ^2 * (-2418.0 + 12303.0 * τ - 19512.0 * τ^2 + 10904.0 * τ^3) / 3360.0, + -8 / 81 * τ^2 * (-78.0 + 209.0 * τ - 204.0 * τ^2 + 8.0 * τ^3), + -25 / 1134 * τ^2 * (-390.0 + 1045.0 * τ - 1020.0 * τ^2 + 328.0 * τ^3), + -25 / 5184 * τ^2 * (390.0 + 255.0 * τ - 1680.0 * τ^2 + 2072.0 * τ^3), + 279841 / 168480 * τ^2 * (-6.0 + 21.0 * τ - 24.0 * τ^2 + 8.0 * τ^3)] + + # Derivative polynomials + + wp = [ + 1.0 - 13985 // 1872 * τ + 143041 // 7488 * τ^2 - 2371 // 117 * τ^3 + + 20785 // 2808 * τ^4, + -403 // 280 * τ + 12303 // 1120 * τ^2 - 813 // 35 * τ^3 + + 1363 // 84 * τ^4, + 416 // 27 * τ - 1672 // 27 * τ^2 + 2176 // 27 * τ^3 - 320 // 81 * τ^4, + 3250 // 189 * τ - 26125 // 378 * τ^2 + 17000 // 189 * τ^3 - + 20500 // 567 * τ^4, + -1625 // 432 * τ - 2125 // 576 * τ^2 + 875 // 27 * τ^3 - + 32375 // 648 * τ^4, + -279841 // 14040 * τ + 1958887 // 18720 * τ^2 - 279841 // 1755 * τ^3 + + 279841 // 4212 * τ^4] + elseif $(order == 6) + w = [ + τ - 28607 // 7434 * τ^2 - 166210 // 33453 * τ^3 + + 334780 // 11151 * τ^4 - 1911296 // 55755 * τ^5 + 406528 // 33453 * τ^6, + 777 // 590 * τ^2 - 2534158 // 234171 * τ^3 + 2088580 // 78057 * τ^4 - + 10479104 // 390285 * τ^5 + 11328512 // 1170855 * τ^6, + -1008 // 59 * τ^2 + 222176 // 1593 * τ^3 - 180032 // 531 * τ^4 + + 876544 // 2655 * τ^5 - 180224 // 1593 * τ^6, + -1008 // 59 * τ^2 + 222176 // 1593 * τ^3 - 180032 // 531 * τ^4 + + 876544 // 2655 * τ^5 - 180224 // 1593 * τ^6, + -378 // 59 * τ^2 + 27772 // 531 * τ^3 - 22504 // 177 * τ^4 + + 109568 // 885 * τ^5 - 22528 // 531 * τ^6, + -95232 // 413 * τ^2 + 62384128 // 33453 * τ^3 - + 49429504 // 11151 * τ^4 + 46759936 // 11151 * τ^5 - + 46661632 // 33453 * τ^6, + 896 // 5 * τ^2 - 4352 // 3 * τ^3 + 3456 * τ^4 - 16384 // 5 * τ^5 + + 16384 // 15 * τ^6, + 50176 // 531 * τ^2 - 179554304 // 234171 * τ^3 + + 143363072 // 78057 * τ^4 - 136675328 // 78057 * τ^5 + + 137363456 // 234171 * τ^6, + 16384 // 441 * τ^3 - 16384 // 147 * τ^4 + 16384 // 147 * τ^5 - + 16384 // 441 * τ^6] + + # Derivative polynomials. + + wp = [ + 1 - 28607 // 3717 * τ - 166210 // 11151 * τ^2 + 1339120 // 11151 * τ^3 - + 1911296 // 11151 * τ^4 + 813056 // 11151 * τ^5, + 777 // 295 * τ - 2534158 // 78057 * τ^2 + 8354320 // 78057 * τ^3 - + 10479104 // 78057 * τ^4 + 22657024 // 390285 * τ^5, + -2016 // 59 * τ + 222176 // 531 * τ^2 - 720128 // 531 * τ^3 + + 876544 // 531 * τ^4 - 360448 // 531 * τ^5, + -2016 // 59 * τ + 222176 // 531 * τ^2 - 720128 // 531 * τ^3 + + 876544 // 531 * τ^4 - 360448 // 531 * τ^5, + -756 // 59 * τ + 27772 // 177 * τ^2 - 90016 // 177 * τ^3 + + 109568 // 177 * τ^4 - 45056 // 177 * τ^5, + -190464 // 413 * τ + 62384128 // 11151 * τ^2 - + 197718016 // 11151 * τ^3 + 233799680 // 11151 * τ^4 - + 93323264 // 11151 * τ^5, + 1792 // 5 * τ - 4352 * τ^2 + 13824 * τ^3 - 16384 * τ^4 + + 32768 // 5 * τ^5, + 100352 // 531 * τ - 179554304 // 78057 * τ^2 + + 573452288 // 78057 * τ^3 - 683376640 // 78057 * τ^4 + + 274726912 // 78057 * τ^5, + 16384 // 147 * τ^2 - 65536 // 147 * τ^3 + 81920 // 147 * τ^4 - + 32768 // 147 * τ^5] + end + return T.(w), T.(wp) + end + end +end diff --git a/lib/BoundaryValueDiffEqMIRK/src/alg_utils.jl b/lib/BoundaryValueDiffEqMIRK/src/alg_utils.jl new file mode 100644 index 00000000..707bb00a --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/src/alg_utils.jl @@ -0,0 +1,7 @@ +for order in (2, 3, 4, 5, 6) + alg = Symbol("MIRK$(order)") + @eval alg_order(::$(alg)) = $order + @eval alg_stage(::$(alg)) = $(order - 1) +end + +SciMLBase.isadaptive(alg::AbstractMIRK) = true diff --git a/lib/BoundaryValueDiffEqMIRK/src/algorithms.jl b/lib/BoundaryValueDiffEqMIRK/src/algorithms.jl new file mode 100644 index 00000000..4588b7c1 --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/src/algorithms.jl @@ -0,0 +1,55 @@ +# Algorithms +abstract type AbstractMIRK <: BoundaryValueDiffEqAlgorithm end + +for order in (2, 3, 4, 5, 6) + alg = Symbol("MIRK$(order)") + + @eval begin + """ + $($alg)(; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm(), + defect_threshold = 0.1, max_num_subintervals = 3000) + + $($order)th order Monotonic Implicit Runge Kutta method. + + ## Keyword Arguments + + - `nlsolve`: Internal Nonlinear solver. Any solver which conforms to the SciML + `NonlinearProblem` interface can be used. Note that any autodiff argument for + the solver will be ignored and a custom jacobian algorithm will be used. + - `jac_alg`: Jacobian Algorithm used for the nonlinear solver. Defaults to + `BVPJacobianAlgorithm()`, which automatically decides the best algorithm to + use based on the input types and problem type. + - For `TwoPointBVProblem`, only `diffmode` is used (defaults to + `AutoSparse(AutoForwardDiff())` if possible else `AutoSparse(AutoFiniteDiff())`). + - For `BVProblem`, `bc_diffmode` and `nonbc_diffmode` are used. For + `nonbc_diffmode` defaults to `AutoSparse(AutoForwardDiff())` if possible else + `AutoSparse(AutoFiniteDiff())`. For `bc_diffmode`, defaults to `AutoForwardDiff` if + possible else `AutoFiniteDiff`. + - `defect_threshold`: Threshold for defect control. + - `max_num_subintervals`: Number of maximal subintervals, default as 3000. + + !!! note + For type-stability, the chunksizes for ForwardDiff ADTypes in + `BVPJacobianAlgorithm` must be provided. + + ## References + + ```bibtex + @article{Enright1996RungeKuttaSW, + title={Runge-Kutta Software with Defect Control for Boundary Value ODEs}, + author={Wayne H. Enright and Paul H. Muir}, + journal={SIAM J. Sci. Comput.}, + year={1996}, + volume={17}, + pages={479-497} + } + ``` + """ + @kwdef struct $(alg){N, J <: BVPJacobianAlgorithm, T} <: AbstractMIRK + nlsolve::N = nothing + jac_alg::J = BVPJacobianAlgorithm() + defect_threshold::T = 0.1 + max_num_subintervals::Int = 3000 + end + end +end diff --git a/lib/BoundaryValueDiffEqMIRK/src/collocation.jl b/lib/BoundaryValueDiffEqMIRK/src/collocation.jl new file mode 100644 index 00000000..9c14fc3f --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/src/collocation.jl @@ -0,0 +1,63 @@ +function Φ!(residual, cache::MIRKCache, y, u, p = cache.p) + return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, + cache.TU, y, u, p, cache.mesh, cache.mesh_dt, cache.stage) +end + +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, + y, u, p, mesh, mesh_dt, stage::Int) + (; c, v, x, b) = TU + + tmp = get_tmp(fᵢ_cache, u) + T = eltype(u) + for i in eachindex(k_discrete) + K = get_tmp(k_discrete[i], u) + residᵢ = residual[i] + h = mesh_dt[i] + + yᵢ = get_tmp(y[i], u) + yᵢ₊₁ = get_tmp(y[i + 1], u) + + for r in 1:stage + @. tmp = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ + __maybe_matmul!(tmp, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) + f!(K[:, r], tmp, p, mesh[i] + c[r] * h) + end + + # Update residual + @. residᵢ = yᵢ₊₁ - yᵢ + __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) + end +end + +function Φ(cache::MIRKCache, y, u, p = cache.p) + return Φ(cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, + y, u, p, cache.mesh, cache.mesh_dt, cache.stage) +end + +@views function Φ( + fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int) + (; c, v, x, b) = TU + residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]] + tmp = get_tmp(fᵢ_cache, u) + T = eltype(u) + for i in eachindex(k_discrete) + K = get_tmp(k_discrete[i], u) + residᵢ = residuals[i] + h = mesh_dt[i] + + yᵢ = get_tmp(y[i], u) + yᵢ₊₁ = get_tmp(y[i + 1], u) + + for r in 1:stage + @. tmp = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ + __maybe_matmul!(tmp, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) + K[:, r] .= f(tmp, p, mesh[i] + c[r] * h) + end + + # Update residual + @. residᵢ = yᵢ₊₁ - yᵢ + __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) + end + + return residuals +end diff --git a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl new file mode 100644 index 00000000..d65b20e0 --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl @@ -0,0 +1,83 @@ +# MIRK Interpolation +@concrete struct MIRKInterpolation <: AbstractDiffEqInterpolation + t + u + cache +end + +function DiffEqBase.interp_summary(interp::MIRKInterpolation) + return "MIRK Order $(interp.cache.order) Interpolation" +end + +function (id::MIRKInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left) + return interpolation(tvals, id, idxs, deriv, p, continuity) +end + +function (id::MIRKInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left) + interpolation!(val, tvals, id, idxs, deriv, p, continuity) + return +end + +@inline function interpolation(tvals, id::MIRKInterpolation, idxs, deriv::D, + p, continuity::Symbol = :left) where {D} + (; t, u, cache) = id + tdir = sign(t[end] - t[1]) + idx = sortperm(tvals, rev = tdir < 0) + + if idxs isa Number + vals = Vector{eltype(first(u))}(undef, length(tvals)) + elseif idxs isa AbstractVector + vals = Vector{Vector{eltype(first(u))}}(undef, length(tvals)) + else + vals = Vector{eltype(u)}(undef, length(tvals)) + end + + for j in idx + z = similar(cache.fᵢ₂_cache) + interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv) + vals[j] = idxs !== nothing ? z[idxs] : z + end + return DiffEqArray(vals, tvals) +end + +@inline function interpolation!(vals, tvals, id::MIRKInterpolation, idxs, + deriv::D, p, continuity::Symbol = :left) where {D} + (; t, cache) = id + tdir = sign(t[end] - t[1]) + idx = sortperm(tvals, rev = tdir < 0) + + for j in idx + z = similar(cache.fᵢ₂_cache) + interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv) + vals[j] = z + end +end + +@inline function interpolation(tval::Number, id::MIRKInterpolation, idxs, + deriv::D, p, continuity::Symbol = :left) where {D} + z = similar(id.cache.fᵢ₂_cache) + interpolant!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv) + return idxs !== nothing ? z[idxs] : z +end + +@inline function interpolant!( + z::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{0}}) + i = interval(mesh, t) + dt = mesh_dt[i] + τ = (t - mesh[i]) / dt + w, w′ = interp_weights(τ, cache.alg) + sum_stages!(z, cache, w, i) +end + +@inline function interpolant!( + dz::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{1}}) + i = interval(mesh, t) + dt = mesh_dt[i] + τ = (t - mesh[i]) / dt + w, w′ = interp_weights(τ, cache.alg) + z = similar(dz) + sum_stages!(z, dz, cache, w, w′, i) +end + +@inline __build_interpolation(cache::MIRKCache, u::AbstractVector) = MIRKInterpolation( + cache.mesh, u, cache) diff --git a/src/solve/mirk.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl similarity index 92% rename from src/solve/mirk.jl rename to lib/BoundaryValueDiffEqMIRK/src/mirk.jl index 15291dc4..30ca870e 100644 --- a/src/solve/mirk.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl @@ -111,8 +111,6 @@ end """ __expand_cache!(cache::MIRKCache) - __expand_cache!(cache::FIRKCacheNested) - __expand_cache!(cache::MIRKCacheExpand) After redistributing or halving the mesh, this function expands the required vectors to match the length of the new mesh. @@ -129,21 +127,11 @@ function __expand_cache!(cache::MIRKCache) return cache end -function __expand_cache!(cache::FIRKCacheNested) - Nₙ = length(cache.mesh) - __append_similar!(cache.k_discrete, Nₙ - 1, cache.M) - __append_similar!(cache.y, Nₙ, cache.M) - __append_similar!(cache.y₀, Nₙ, cache.M) - __append_similar!(cache.residual, Nₙ, cache.M) - __append_similar!(cache.defect, Nₙ - 1, cache.M) - return cache -end - function __split_mirk_kwargs(; abstol, dt, adaptive = true, kwargs...) return ((abstol, adaptive, dt), (; abstol, adaptive, kwargs...)) end -function SciMLBase.solve!(cache::Union{MIRKCache, FIRKCacheNested}) +function SciMLBase.solve!(cache::MIRKCache) (abstol, adaptive, _), kwargs = __split_mirk_kwargs(; cache.kwargs...) info::ReturnCode.T = ReturnCode.Success @@ -168,8 +156,8 @@ function SciMLBase.solve!(cache::Union{MIRKCache, FIRKCacheNested}) return __build_solution(cache.prob, odesol, sol_nlprob) end -function __perform_mirk_iteration(cache::Union{MIRKCache, FIRKCacheNested}, abstol, - adaptive::Bool; nlsolve_kwargs = (;), kwargs...) +function __perform_mirk_iteration( + cache::MIRKCache, abstol, adaptive::Bool; nlsolve_kwargs = (;), kwargs...) nlprob = __construct_nlproblem(cache, vec(cache.y₀)) nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve) sol_nlprob = __solve( @@ -218,9 +206,7 @@ function __perform_mirk_iteration(cache::Union{MIRKCache, FIRKCacheNested}, abst end # Constructing the Nonlinear Problem -function __construct_nlproblem( - cache::Union{MIRKCache{iip}, FIRKCacheNested{iip}, FIRKCacheExpand{iip}}, - y::AbstractVector) where {iip} +function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {iip} pt = cache.problem_type loss_bc = if iip @@ -289,16 +275,15 @@ end return vcat(resid_bca, mapreduce(vec, vcat, resid_co), resid_bcb) end -@views function __mirk_loss_bc!(resid, u, p, pt, bc!::BC, y, mesh, - cache::Union{MIRKCache, FIRKCacheNested, FIRKCacheExpand}) where {BC} +@views function __mirk_loss_bc!( + resid, u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC} y_ = recursive_unflatten!(y, u) soly_ = VectorOfArray(y_) eval_bc_residual!(resid, pt, bc!, soly_, p, mesh) return nothing end -@views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, - cache::Union{MIRKCache, FIRKCacheNested, FIRKCacheExpand}) where {BC} +@views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC} y_ = recursive_unflatten!(y, u) soly_ = VectorOfArray(y_) return eval_bc_residual(pt, bc!, soly_, p, mesh) @@ -318,9 +303,8 @@ end return mapreduce(vec, vcat, resids) end -function __construct_nlproblem( - cache::Union{MIRKCache{iip}, FIRKCacheNested{iip}}, y, loss_bc::BC, - loss_collocation::C, loss::LF, ::StandardBVProblem) where {iip, BC, C, LF} +function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, + loss::LF, ::StandardBVProblem) where {iip, BC, C, LF} (; nlsolve, jac_alg) = cache.alg N = length(cache.mesh) @@ -424,9 +408,8 @@ function __mirk_mpoint_jacobian( return J end -function __construct_nlproblem( - cache::Union{MIRKCache{iip}, FIRKCacheNested{iip}}, y, loss_bc::BC, - loss_collocation::C, loss::LF, ::TwoPointBVProblem) where {iip, BC, C, LF} +function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, + loss::LF, ::TwoPointBVProblem) where {iip, BC, C, LF} (; nlsolve, jac_alg) = cache.alg N = length(cache.mesh) diff --git a/src/mirk_tableaus.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk_tableaus.jl similarity index 100% rename from src/mirk_tableaus.jl rename to lib/BoundaryValueDiffEqMIRK/src/mirk_tableaus.jl diff --git a/lib/BoundaryValueDiffEqMIRK/src/sparse_jacobians.jl b/lib/BoundaryValueDiffEqMIRK/src/sparse_jacobians.jl new file mode 100644 index 00000000..4b28ebcb --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/src/sparse_jacobians.jl @@ -0,0 +1,35 @@ +# For MIRK Methods +""" + __generate_sparse_jacobian_prototype(::MIRKCache, ya, yb, M, N) + __generate_sparse_jacobian_prototype(::MIRKCache, _, ya, yb, M, N) + __generate_sparse_jacobian_prototype(::MIRKCache, ::TwoPointBVProblem, ya, yb, M, N) + +Generate a prototype of the sparse Jacobian matrix for the BVP problem with row and column +coloring. + +If the problem is a TwoPointBVProblem, then this is the complete Jacobian, else it only +computes the sparse part excluding the contributions from the boundary conditions. +""" +function __generate_sparse_jacobian_prototype(cache::MIRKCache, ya, yb, M, N) + return __generate_sparse_jacobian_prototype(cache, cache.problem_type, ya, yb, M, N) +end + +function __generate_sparse_jacobian_prototype( + ::MIRKCache, ::StandardBVProblem, ya, yb, M, N) + fast_scalar_indexing(ya) || + error("Sparse Jacobians are only supported for Fast Scalar Index-able Arrays") + J_c = BandedMatrix(Ones{eltype(ya)}(M * (N - 1), M * N), (1, 2M - 1)) + return ColoredMatrix(J_c, matrix_colors(J_c'), matrix_colors(J_c)) +end + +function __generate_sparse_jacobian_prototype( + ::MIRKCache, ::TwoPointBVProblem, ya, yb, M, N) + fast_scalar_indexing(ya) || + error("Sparse Jacobians are only supported for Fast Scalar Index-able Arrays") + J₁ = length(ya) + length(yb) + M * (N - 1) + J₂ = M * N + J = BandedMatrix(Ones{eltype(ya)}(J₁, J₂), (M + 1, M + 1)) + # for underdetermined systems we don't have banded qr implemented. use sparse + J₁ < J₂ && return ColoredMatrix(sparse(J), matrix_colors(J'), matrix_colors(J)) + return ColoredMatrix(J, matrix_colors(J'), matrix_colors(J)) +end diff --git a/lib/BoundaryValueDiffEqMIRK/src/types.jl b/lib/BoundaryValueDiffEqMIRK/src/types.jl new file mode 100644 index 00000000..42337ea4 --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/src/types.jl @@ -0,0 +1,29 @@ +# MIRK Method Tableaus +struct MIRKTableau{sType, cType, vType, bType, xType} + """Discrete stages of MIRK formula""" + s::sType + c::cType + v::vType + b::bType + x::xType + + function MIRKTableau(s, c, v, b, x) + @assert eltype(c) == eltype(v) == eltype(b) == eltype(x) + return new{typeof(s), typeof(c), typeof(v), typeof(b), typeof(x)}(s, c, v, b, x) + end +end + +struct MIRKInterpTableau{s, c, v, x, τ} + s_star::s + c_star::c + v_star::v + x_star::x + τ_star::τ + + function MIRKInterpTableau(s_star, c_star, v_star, x_star, τ_star) + @assert eltype(c_star) == eltype(v_star) == eltype(x_star) + return new{ + typeof(s_star), typeof(c_star), typeof(v_star), typeof(x_star), typeof(τ_star)}( + s_star, c_star, v_star, x_star, τ_star) + end +end diff --git a/test/mirk/ensemble_tests.jl b/lib/BoundaryValueDiffEqMIRK/test/ensemble_tests.jl similarity index 96% rename from test/mirk/ensemble_tests.jl rename to lib/BoundaryValueDiffEqMIRK/test/ensemble_tests.jl index f979c837..56bb6147 100644 --- a/test/mirk/ensemble_tests.jl +++ b/lib/BoundaryValueDiffEqMIRK/test/ensemble_tests.jl @@ -1,4 +1,6 @@ + @testitem "EnsembleProblem" begin + using BoundaryValueDiffEqMIRK using Random function ode!(du, u, p, t) diff --git a/test/mirk/mirk_basic_tests.jl b/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl similarity index 97% rename from test/mirk/mirk_basic_tests.jl rename to lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl index 90c9e152..cdde769b 100644 --- a/test/mirk/mirk_basic_tests.jl +++ b/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl @@ -1,6 +1,6 @@ @testsetup module MIRKConvergenceTests -using BoundaryValueDiffEq +using BoundaryValueDiffEqMIRK for order in (2, 3, 4, 5, 6) s = Symbol("MIRK$(order)") @@ -95,8 +95,10 @@ end @testset "MIRK$order" for order in (2, 3, 4, 5, 6) solver = mirk_solver(Val(order); nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))) - @test_opt target_modules=(BoundaryValueDiffEq,) solve(prob, solver; dt = 0.2) - @test_call target_modules=(BoundaryValueDiffEq,) solve(prob, solver; dt = 0.2) + @test_opt target_modules=(BoundaryValueDiffEqMIRK,) solve( + prob, solver; dt = 0.2) + @test_call target_modules=(BoundaryValueDiffEqMIRK,) solve( + prob, solver; dt = 0.2) end end end diff --git a/test/mirk/nlls_tests.jl b/lib/BoundaryValueDiffEqMIRK/test/nlls_tests.jl similarity index 99% rename from test/mirk/nlls_tests.jl rename to lib/BoundaryValueDiffEqMIRK/test/nlls_tests.jl index 57926773..b7a3238d 100644 --- a/test/mirk/nlls_tests.jl +++ b/lib/BoundaryValueDiffEqMIRK/test/nlls_tests.jl @@ -1,5 +1,5 @@ @testitem "Overconstrained BVP" begin - using LinearAlgebra + using BoundaryValueDiffEqMIRK, LinearAlgebra SOLVERS = [mirk(; nlsolve) for mirk in (MIRK4, MIRK5, MIRK6), diff --git a/lib/BoundaryValueDiffEqMIRK/test/runtests.jl b/lib/BoundaryValueDiffEqMIRK/test/runtests.jl new file mode 100644 index 00000000..72d02637 --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/test/runtests.jl @@ -0,0 +1,12 @@ +using ReTestItems + +@time begin + if GROUP == "All" || GROUP == "MIRK" + @time "MIRK solvers" begin + ReTestItems.runtests("ensemble_tests.jl") + ReTestItems.runtests("mirk_basic_tests.jl") + ReTestItems.runtests("nlls_tests.jl") + ReTestItems.runtests("vectorofvector_initials_tests.jl") + end + end +end diff --git a/test/mirk/vectorofvector_initials_tests.jl b/lib/BoundaryValueDiffEqMIRK/test/vectorofvector_initials_tests.jl similarity index 97% rename from test/mirk/vectorofvector_initials_tests.jl rename to lib/BoundaryValueDiffEqMIRK/test/vectorofvector_initials_tests.jl index 301de3ee..8567140b 100644 --- a/test/mirk/vectorofvector_initials_tests.jl +++ b/lib/BoundaryValueDiffEqMIRK/test/vectorofvector_initials_tests.jl @@ -1,4 +1,5 @@ @testitem "VectorOfVector Initial Condition" begin + using BoundaryValueDiffEqMIRK, OrdinaryDiffEq #System Constants ss = 1 #excitatory parameter sj = 0 #inhibitory parameter diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index 6cbf6df9..f309db0f 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -29,14 +29,12 @@ include("utils.jl") include("algorithms.jl") include("alg_utils.jl") -include("mirk_tableaus.jl") include("lobatto_tableaus.jl") include("radau_tableaus.jl") include("solve/single_shooting.jl") include("solve/multiple_shooting.jl") include("solve/firk.jl") -include("solve/mirk.jl") include("collocation.jl") include("sparse_jacobians.jl") @@ -51,54 +49,12 @@ function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kw return solve!(cache) end -@setup_workload begin - function f1!(du, u, p, t) - du[1] = u[2] - du[2] = 0 - end - f1 = (u, p, t) -> [u[2], 0] +include("../lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl") +using ..BoundaryValueDiffEqMIRK - function bc1!(residual, u, p, t) - residual[1] = u[:, 1][1] - 5 - residual[2] = u[:, end][1] - end - - bc1 = (u, p, t) -> [u[:, 1][1] - 5, u[:, end][1]] - - bc1_a! = (residual, ua, p) -> (residual[1] = ua[1] - 5) - bc1_b! = (residual, ub, p) -> (residual[1] = ub[1]) - - bc1_a = (ua, p) -> [ua[1] - 5] - bc1_b = (ub, p) -> [ub[1]] - - tspan = (0.0, 5.0) - u0 = [5.0, -3.5] - bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1)) - - probs = [BVProblem(f1!, bc1!, u0, tspan; nlls = Val(false)), - BVProblem(f1, bc1, u0, tspan; nlls = Val(false)), - TwoPointBVProblem( - f1!, (bc1_a!, bc1_b!), u0, tspan; bcresid_prototype, nlls = Val(false)), - TwoPointBVProblem( - f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype, nlls = Val(false))] - - algs = [] - - jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)) - - if Preferences.@load_preference("PrecompileMIRK", true) - append!(algs, [MIRK2(; jac_alg), MIRK4(; jac_alg), MIRK6(; jac_alg)]) - end - - @compile_workload begin - @sync for prob in probs, alg in algs - Threads.@spawn solve(prob, alg; dt = 0.2) - end - end -end +export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6 export Shooting, MultipleShooting -export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6 export BVPM2, BVPSOL, COLNEW # From ODEInterface.jl export RadauIIa1, RadauIIa2, RadauIIa3, RadauIIa5, RadauIIa7 diff --git a/src/adaptivity.jl b/src/adaptivity.jl index aabfe5f4..033e942e 100644 --- a/src/adaptivity.jl +++ b/src/adaptivity.jl @@ -1,19 +1,9 @@ """ - interp_eval!(y::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt) interp_eval!(y::AbstractArray, cache::FIRKCacheExpand, t, mesh, mesh_dt) interp_eval!(y::AbstractArray, cache::FIRKCacheNested, t, mesh, mesh_dt) After we construct an interpolant, we use interp_eval to evaluate it. """ -@views function interp_eval!(y::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt) - i = interval(mesh, t) - dt = mesh_dt[i] - τ = (t - mesh[i]) / dt - w, w′ = interp_weights(τ, cache.alg) - sum_stages!(y, cache, w, i) - return y -end - @views function interp_eval!( y::AbstractArray, cache::FIRKCacheExpand{iip}, t, mesh, mesh_dt) where {iip} j = interval(mesh, t) @@ -162,14 +152,13 @@ function interval(mesh, t) end """ - mesh_selector!(cache::MIRKCache) mesh_selector!(cache::FIRKCacheExpand) mesh_selector!(cache::FIRKCacheNested) Generate new mesh based on the defect. """ @views function mesh_selector!(cache::Union{ - MIRKCache{iip, T}, FIRKCacheExpand{iip, T}, FIRKCacheNested{iip, T}}) where {iip, T} + FIRKCacheExpand{iip, T}, FIRKCacheNested{iip, T}}) where {iip, T} (; order, defect, mesh, mesh_dt) = cache (abstol, _, _), kwargs = __split_mirk_kwargs(; cache.kwargs...) N = length(mesh) @@ -222,14 +211,12 @@ Generate new mesh based on the defect. end """ - redistribute!(cache::MIRKCache, Nsub_star, ŝ, mesh, mesh_dt) redistribute!(cache::FIRKCacheExpand, Nsub_star, ŝ, mesh, mesh_dt) redistribute!(cache::FIRKCacheNested, Nsub_star, ŝ, mesh, mesh_dt) Generate a new mesh based on the `ŝ`. """ -function redistribute!( - cache::Union{MIRKCache{iip, T}, FIRKCacheExpand{iip, T}, FIRKCacheNested{iip, T}}, +function redistribute!(cache::Union{FIRKCacheExpand{iip, T}, FIRKCacheNested{iip, T}}, Nsub_star, ŝ, mesh, mesh_dt) where {iip, T} N = length(mesh) ζ = sum(ŝ .* mesh_dt) / Nsub_star @@ -260,7 +247,6 @@ end """ half_mesh!(mesh, mesh_dt) - half_mesh!(cache::MIRKCache) half_mesh!(cache::FIRKCacheExpand) half_mesh!(cache::FIRKCacheNested) @@ -282,12 +268,11 @@ function half_mesh!(mesh::Vector{T}, mesh_dt::Vector{T}) where {T} end return mesh, mesh_dt end -function half_mesh!(cache::Union{MIRKCache, FIRKCacheNested, FIRKCacheExpand}) +function half_mesh!(cache::Union{FIRKCacheNested, FIRKCacheExpand}) half_mesh!(cache.mesh, cache.mesh_dt) end """ - defect_estimate!(cache::MIRKCache) defect_estimate!(cache::FIRKCacheExpand) defect_estimate!(cache::FIRKCacheNested) @@ -295,46 +280,6 @@ defect_estimate use the discrete solution approximation Y, plus stages of the RK method in 'k_discrete', plus some new stages in 'k_interp' to construct an interpolant """ -@views function defect_estimate!(cache::MIRKCache{iip, T}) where {iip, T} - (; f, alg, mesh, mesh_dt, defect) = cache - (; τ_star) = cache.ITU - - # Evaluate at the first sample point - w₁, w₁′ = interp_weights(τ_star, alg) - # Evaluate at the second sample point - w₂, w₂′ = interp_weights(T(1) - τ_star, alg) - - interp_setup!(cache) - - for i in 1:(length(mesh) - 1) - dt = mesh_dt[i] - - z, z′ = sum_stages!(cache, w₁, w₁′, i) - if iip - yᵢ₁ = cache.y[i].du - f(yᵢ₁, z, cache.p, mesh[i] + τ_star * dt) - else - yᵢ₁ = f(z, cache.p, mesh[i] + τ_star * dt) - end - yᵢ₁ .= (z′ .- yᵢ₁) ./ (abs.(yᵢ₁) .+ T(1)) - est₁ = maximum(abs, yᵢ₁) - - z, z′ = sum_stages!(cache, w₂, w₂′, i) - if iip - yᵢ₂ = cache.y[i + 1].du - f(yᵢ₂, z, cache.p, mesh[i] + (T(1) - τ_star) * dt) - else - yᵢ₂ = f(z, cache.p, mesh[i] + (T(1) - τ_star) * dt) - end - yᵢ₂ .= (z′ .- yᵢ₂) ./ (abs.(yᵢ₂) .+ T(1)) - est₂ = maximum(abs, yᵢ₂) - - defect.u[i] .= est₁ > est₂ ? yᵢ₁ : yᵢ₂ - end - - return maximum(Base.Fix1(maximum, abs), defect.u) -end - @views function defect_estimate!(cache::FIRKCacheExpand{iip, T}) where {iip, T} (; f, M, stage, mesh, mesh_dt, defect, ITU) = cache (; q_coeff, τ_star) = ITU @@ -457,210 +402,3 @@ function eval_q(y_i, τ, h, A, K) end return q, q′ end - -""" - interp_setup!(cache::MIRKCache) - -`interp_setup!` prepare the extra stages in ki_interp for interpolant construction. -Here, the ki_interp is the stages in one subinterval. -""" -@views function interp_setup!(cache::MIRKCache{iip, T}) where {iip, T} - (; x_star, s_star, c_star, v_star) = cache.ITU - (; k_interp, k_discrete, f, stage, new_stages, y, p, mesh, mesh_dt) = cache - - for r in 1:(s_star - stage) - idx₁ = ((1:stage) .- 1) .* (s_star - stage) .+ r - idx₂ = ((1:(r - 1)) .+ stage .- 1) .* (s_star - stage) .+ r - for j in eachindex(k_discrete) - __maybe_matmul!(new_stages.u[j], k_discrete[j].du[:, 1:stage], x_star[idx₁]) - end - if r > 1 - for j in eachindex(k_interp) - __maybe_matmul!( - new_stages.u[j], k_interp.u[j][:, 1:(r - 1)], x_star[idx₂], T(1), T(1)) - end - end - for i in eachindex(new_stages) - new_stages.u[i] .= new_stages.u[i] .* mesh_dt[i] .+ - (1 - v_star[r]) .* vec(y[i].du) .+ - v_star[r] .* vec(y[i + 1].du) - if iip - f(k_interp.u[i][:, r], new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) - else - k_interp.u[i][:, r] .= f( - new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) - end - end - end - - return k_interp -end - -""" - sum_stages!(cache::MIRKCache, w, w′, i::Int) - -sum_stages add the discrete solution, RK method stages and extra stages to construct interpolant. -""" -function sum_stages!(cache::MIRKCache, w, w′, i::Int, dt = cache.mesh_dt[i]) - sum_stages!(cache.fᵢ_cache.du, cache.fᵢ₂_cache, cache, w, w′, i, dt) -end - -function sum_stages!(z::AbstractArray, cache::MIRKCache, w, i::Int, dt = cache.mesh_dt[i]) - (; stage, k_discrete, k_interp) = cache - (; s_star) = cache.ITU - - z .= zero(z) - __maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage]) - __maybe_matmul!( - z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true) - z .= z .* dt .+ cache.y₀.u[i] - - return z -end - -@views function sum_stages!(z, z′, cache::MIRKCache, w, w′, i::Int, dt = cache.mesh_dt[i]) - (; stage, k_discrete, k_interp) = cache - (; s_star) = cache.ITU - - z .= zero(z) - __maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage]) - __maybe_matmul!( - z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true) - z′ .= zero(z′) - __maybe_matmul!(z′, k_discrete[i].du[:, 1:stage], w′[1:stage]) - __maybe_matmul!( - z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true) - z .= z .* dt[1] .+ cache.y₀.u[i] - - return z, z′ -end - -""" - interp_weights(τ, alg) - -interp_weights: solver-specified interpolation weights and its first derivative -""" -function interp_weights end - -for order in (2, 3, 4, 5, 6) - alg = Symbol("MIRK$(order)") - @eval begin - function interp_weights(τ::T, ::$(alg)) where {T} - if $(order == 2) - w = [0, τ * (1 - τ / 2), τ^2 / 2] - - # Derivative polynomials. - - wp = [0, 1 - τ, τ] - elseif $(order == 3) - w = [τ / 4.0 * (2.0 * τ^2 - 5.0 * τ + 4.0), - -3.0 / 4.0 * τ^2 * (2.0 * τ - 3.0), τ^2 * (τ - 1.0)] - - # Derivative polynomials. - - wp = [3.0 / 2.0 * (τ - 2.0 / 3.0) * (τ - 1.0), - -9.0 / 2.0 * τ * (τ - 1.0), 3.0 * τ * (τ - 2.0 / 3.0)] - elseif $(order == 4) - t2 = τ * τ - tm1 = τ - 1.0 - t4m3 = τ * 4.0 - 3.0 - t2m1 = τ * 2.0 - 1.0 - - w = [-τ * (2.0 * τ - 3.0) * (2.0 * t2 - 3.0 * τ + 2.0) / 6.0, - t2 * (12.0 * t2 - 20.0 * τ + 9.0) / 6.0, - 2.0 * t2 * (6.0 * t2 - 14.0 * τ + 9.0) / 3.0, - -16.0 * t2 * tm1 * tm1 / 3.0] - - # Derivative polynomials - - wp = [-tm1 * t4m3 * t2m1 / 3.0, τ * t2m1 * t4m3, - 4.0 * τ * t4m3 * tm1, -32.0 * τ * t2m1 * tm1 / 3.0] - elseif $(order == 5) - w = [ - τ * (22464.0 - 83910.0 * τ + 143041.0 * τ^2 - 113808.0 * τ^3 + - 33256.0 * τ^4) / 22464.0, - τ^2 * (-2418.0 + 12303.0 * τ - 19512.0 * τ^2 + 10904.0 * τ^3) / 3360.0, - -8 / 81 * τ^2 * (-78.0 + 209.0 * τ - 204.0 * τ^2 + 8.0 * τ^3), - -25 / 1134 * τ^2 * (-390.0 + 1045.0 * τ - 1020.0 * τ^2 + 328.0 * τ^3), - -25 / 5184 * τ^2 * (390.0 + 255.0 * τ - 1680.0 * τ^2 + 2072.0 * τ^3), - 279841 / 168480 * τ^2 * (-6.0 + 21.0 * τ - 24.0 * τ^2 + 8.0 * τ^3)] - - # Derivative polynomials - - wp = [ - 1.0 - 13985 // 1872 * τ + 143041 // 7488 * τ^2 - 2371 // 117 * τ^3 + - 20785 // 2808 * τ^4, - -403 // 280 * τ + 12303 // 1120 * τ^2 - 813 // 35 * τ^3 + - 1363 // 84 * τ^4, - 416 // 27 * τ - 1672 // 27 * τ^2 + 2176 // 27 * τ^3 - 320 // 81 * τ^4, - 3250 // 189 * τ - 26125 // 378 * τ^2 + 17000 // 189 * τ^3 - - 20500 // 567 * τ^4, - -1625 // 432 * τ - 2125 // 576 * τ^2 + 875 // 27 * τ^3 - - 32375 // 648 * τ^4, - -279841 // 14040 * τ + 1958887 // 18720 * τ^2 - 279841 // 1755 * τ^3 + - 279841 // 4212 * τ^4] - elseif $(order == 6) - w = [ - τ - 28607 // 7434 * τ^2 - 166210 // 33453 * τ^3 + - 334780 // 11151 * τ^4 - 1911296 // 55755 * τ^5 + 406528 // 33453 * τ^6, - 777 // 590 * τ^2 - 2534158 // 234171 * τ^3 + 2088580 // 78057 * τ^4 - - 10479104 // 390285 * τ^5 + 11328512 // 1170855 * τ^6, - -1008 // 59 * τ^2 + 222176 // 1593 * τ^3 - 180032 // 531 * τ^4 + - 876544 // 2655 * τ^5 - 180224 // 1593 * τ^6, - -1008 // 59 * τ^2 + 222176 // 1593 * τ^3 - 180032 // 531 * τ^4 + - 876544 // 2655 * τ^5 - 180224 // 1593 * τ^6, - -378 // 59 * τ^2 + 27772 // 531 * τ^3 - 22504 // 177 * τ^4 + - 109568 // 885 * τ^5 - 22528 // 531 * τ^6, - -95232 // 413 * τ^2 + 62384128 // 33453 * τ^3 - - 49429504 // 11151 * τ^4 + 46759936 // 11151 * τ^5 - - 46661632 // 33453 * τ^6, - 896 // 5 * τ^2 - 4352 // 3 * τ^3 + 3456 * τ^4 - 16384 // 5 * τ^5 + - 16384 // 15 * τ^6, - 50176 // 531 * τ^2 - 179554304 // 234171 * τ^3 + - 143363072 // 78057 * τ^4 - 136675328 // 78057 * τ^5 + - 137363456 // 234171 * τ^6, - 16384 // 441 * τ^3 - 16384 // 147 * τ^4 + 16384 // 147 * τ^5 - - 16384 // 441 * τ^6] - - # Derivative polynomials. - - wp = [ - 1 - 28607 // 3717 * τ - 166210 // 11151 * τ^2 + 1339120 // 11151 * τ^3 - - 1911296 // 11151 * τ^4 + 813056 // 11151 * τ^5, - 777 // 295 * τ - 2534158 // 78057 * τ^2 + 8354320 // 78057 * τ^3 - - 10479104 // 78057 * τ^4 + 22657024 // 390285 * τ^5, - -2016 // 59 * τ + 222176 // 531 * τ^2 - 720128 // 531 * τ^3 + - 876544 // 531 * τ^4 - 360448 // 531 * τ^5, - -2016 // 59 * τ + 222176 // 531 * τ^2 - 720128 // 531 * τ^3 + - 876544 // 531 * τ^4 - 360448 // 531 * τ^5, - -756 // 59 * τ + 27772 // 177 * τ^2 - 90016 // 177 * τ^3 + - 109568 // 177 * τ^4 - 45056 // 177 * τ^5, - -190464 // 413 * τ + 62384128 // 11151 * τ^2 - - 197718016 // 11151 * τ^3 + 233799680 // 11151 * τ^4 - - 93323264 // 11151 * τ^5, - 1792 // 5 * τ - 4352 * τ^2 + 13824 * τ^3 - 16384 * τ^4 + - 32768 // 5 * τ^5, - 100352 // 531 * τ - 179554304 // 78057 * τ^2 + - 573452288 // 78057 * τ^3 - 683376640 // 78057 * τ^4 + - 274726912 // 78057 * τ^5, - 16384 // 147 * τ^2 - 65536 // 147 * τ^3 + 81920 // 147 * τ^4 - - 32768 // 147 * τ^5] - end - return T.(w), T.(wp) - end - end -end - -function sol_eval(cache::MIRKCache{T}, t::T) where {T} - (; M, mesh, mesh_dt, alg) = cache - - @assert mesh[1] ≤ t ≤ mesh[end] - i = interval(mesh, t) - dt = mesh_dt[i] - τ = (t - mesh[i]) / dt - weights, weights_prime = interp_weights(τ, alg) - z = zeros(M) - z_prime = zeros(M) - sum_stages!(z, z_prime, cache, weights, weights_prime, i, mesh_dt) - return z -end diff --git a/src/alg_utils.jl b/src/alg_utils.jl index c6b4dc79..7772544a 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -1,9 +1,3 @@ -for order in (2, 3, 4, 5, 6) - alg = Symbol("MIRK$(order)") - @eval alg_order(::$(alg)) = $order - @eval alg_stage(::$(alg)) = $(order - 1) -end - for stage in (1, 2, 3, 5, 7) alg = Symbol("RadauIIa$(stage)") @eval alg_order(::$(alg)) = $(2 * stage - 1) @@ -32,5 +26,4 @@ SciMLBase.isautodifferentiable(::BoundaryValueDiffEqAlgorithm) = true SciMLBase.allows_arbitrary_number_types(::BoundaryValueDiffEqAlgorithm) = true SciMLBase.allowscomplex(alg::BoundaryValueDiffEqAlgorithm) = true -SciMLBase.isadaptive(alg::AbstractMIRK) = true SciMLBase.isadaptive(alg::AbstractFIRK) = true diff --git a/src/algorithms.jl b/src/algorithms.jl index 4b953ccb..ab5bad85 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -1,7 +1,6 @@ # Algorithms abstract type BoundaryValueDiffEqAlgorithm <: SciMLBase.AbstractBVPAlgorithm end abstract type AbstractShooting <: BoundaryValueDiffEqAlgorithm end -abstract type AbstractMIRK <: BoundaryValueDiffEqAlgorithm end abstract type AbstractFIRK <: BoundaryValueDiffEqAlgorithm end ## Disable the ugly verbose printing by default @@ -143,59 +142,6 @@ end @inline MultipleShooting(nshoots::Int, ode_alg, nlsolve; kwargs...) = MultipleShooting(; nshoots, ode_alg, nlsolve, kwargs...) -for order in (2, 3, 4, 5, 6) - alg = Symbol("MIRK$(order)") - - @eval begin - """ - $($alg)(; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm(), - defect_threshold = 0.1, max_num_subintervals = 3000) - - $($order)th order Monotonic Implicit Runge Kutta method. - - ## Keyword Arguments - - - `nlsolve`: Internal Nonlinear solver. Any solver which conforms to the SciML - `NonlinearProblem` interface can be used. Note that any autodiff argument for - the solver will be ignored and a custom jacobian algorithm will be used. - - `jac_alg`: Jacobian Algorithm used for the nonlinear solver. Defaults to - `BVPJacobianAlgorithm()`, which automatically decides the best algorithm to - use based on the input types and problem type. - - For `TwoPointBVProblem`, only `diffmode` is used (defaults to - `AutoSparse(AutoForwardDiff())` if possible else `AutoSparse(AutoFiniteDiff())`). - - For `BVProblem`, `bc_diffmode` and `nonbc_diffmode` are used. For - `nonbc_diffmode` defaults to `AutoSparse(AutoForwardDiff())` if possible else - `AutoSparse(AutoFiniteDiff())`. For `bc_diffmode`, defaults to `AutoForwardDiff` if - possible else `AutoFiniteDiff`. - - `defect_threshold`: Threshold for defect control. - - `max_num_subintervals`: Number of maximal subintervals, default as 3000. - - !!! note - For type-stability, the chunksizes for ForwardDiff ADTypes in - `BVPJacobianAlgorithm` must be provided. - - ## References - - ```bibtex - @article{Enright1996RungeKuttaSW, - title={Runge-Kutta Software with Defect Control for Boundary Value ODEs}, - author={Wayne H. Enright and Paul H. Muir}, - journal={SIAM J. Sci. Comput.}, - year={1996}, - volume={17}, - pages={479-497} - } - ``` - """ - @kwdef struct $(alg){N, J <: BVPJacobianAlgorithm, T} <: AbstractMIRK - nlsolve::N = nothing - jac_alg::J = BVPJacobianAlgorithm() - defect_threshold::T = 0.1 - max_num_subintervals::Int = 3000 - end - end -end - for stage in (1, 2, 3, 5, 7) alg = Symbol("RadauIIa$(stage)") diff --git a/src/collocation.jl b/src/collocation.jl index 589db587..ead08220 100644 --- a/src/collocation.jl +++ b/src/collocation.jl @@ -1,4 +1,4 @@ -function Φ!(residual, cache::Union{MIRKCache, FIRKCacheExpand}, y, u, p = cache.p) +function Φ!(residual, cache::FIRKCacheExpand, y, u, p = cache.p) return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, y, u, p, cache.mesh, cache.mesh_dt, cache.stage) end @@ -8,32 +8,6 @@ function Φ!(residual, cache::FIRKCacheNested, y, u, p = cache.p) y, u, p, cache.mesh, cache.mesh_dt, cache.stage, cache) end -@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, - y, u, p, mesh, mesh_dt, stage::Int) - (; c, v, x, b) = TU - - tmp = get_tmp(fᵢ_cache, u) - T = eltype(u) - for i in eachindex(k_discrete) - K = get_tmp(k_discrete[i], u) - residᵢ = residual[i] - h = mesh_dt[i] - - yᵢ = get_tmp(y[i], u) - yᵢ₊₁ = get_tmp(y[i + 1], u) - - for r in 1:stage - @. tmp = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ - __maybe_matmul!(tmp, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) - f!(K[:, r], tmp, p, mesh[i] + c[r] * h) - end - - # Update residual - @. residᵢ = yᵢ₊₁ - yᵢ - __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) - end -end - @views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{false}, y, u, p, mesh, mesh_dt, stage::Int) (; c, a, b) = TU @@ -136,7 +110,7 @@ end end end -function Φ(cache::Union{MIRKCache, FIRKCacheExpand}, y, u, p = cache.p) +function Φ(cache::FIRKCacheExpand, y, u, p = cache.p) return Φ(cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, y, u, p, cache.mesh, cache.mesh_dt, cache.stage) end @@ -146,34 +120,6 @@ function Φ(cache::FIRKCacheNested, y, u, p = cache.p) u, p, cache.mesh, cache.mesh_dt, cache.stage, cache) end -@views function Φ( - fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt, stage::Int) - (; c, v, x, b) = TU - residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]] - tmp = get_tmp(fᵢ_cache, u) - T = eltype(u) - for i in eachindex(k_discrete) - K = get_tmp(k_discrete[i], u) - residᵢ = residuals[i] - h = mesh_dt[i] - - yᵢ = get_tmp(y[i], u) - yᵢ₊₁ = get_tmp(y[i + 1], u) - - for r in 1:stage - @. tmp = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ - __maybe_matmul!(tmp, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) - K[:, r] .= f(tmp, p, mesh[i] + c[r] * h) - end - - # Update residual - @. residᵢ = yᵢ₊₁ - yᵢ - __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) - end - - return residuals -end - @views function Φ( fᵢ_cache, k_discrete, f, TU::FIRKTableau{false}, y, u, p, mesh, mesh_dt, stage::Int) (; c, a, b) = TU diff --git a/src/interpolation.jl b/src/interpolation.jl index 876f9981..fac00915 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -1,23 +1,3 @@ -# MIRK Interpolation -@concrete struct MIRKInterpolation <: AbstractDiffEqInterpolation - t - u - cache -end - -function DiffEqBase.interp_summary(interp::MIRKInterpolation) - return "MIRK Order $(interp.cache.order) Interpolation" -end - -function (id::MIRKInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left) - return interpolation(tvals, id, idxs, deriv, p, continuity) -end - -function (id::MIRKInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left) - interpolation!(val, tvals, id, idxs, deriv, p, continuity) - return -end - # FIRK Expand Interpolation struct FIRKExpandInterpolation{T1, T2} <: AbstractDiffEqInterpolation t::T1 @@ -58,67 +38,6 @@ function (id::FIRKNestedInterpolation)( interpolation!(val, tvals, id, idxs, deriv, p, continuity) end -@inline function interpolation(tvals, id::MIRKInterpolation, idxs, deriv::D, - p, continuity::Symbol = :left) where {D} - (; t, u, cache) = id - tdir = sign(t[end] - t[1]) - idx = sortperm(tvals, rev = tdir < 0) - - if idxs isa Number - vals = Vector{eltype(first(u))}(undef, length(tvals)) - elseif idxs isa AbstractVector - vals = Vector{Vector{eltype(first(u))}}(undef, length(tvals)) - else - vals = Vector{eltype(u)}(undef, length(tvals)) - end - - for j in idx - z = similar(cache.fᵢ₂_cache) - interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv) - vals[j] = idxs !== nothing ? z[idxs] : z - end - return DiffEqArray(vals, tvals) -end - -@inline function interpolation!(vals, tvals, id::MIRKInterpolation, idxs, - deriv::D, p, continuity::Symbol = :left) where {D} - (; t, cache) = id - tdir = sign(t[end] - t[1]) - idx = sortperm(tvals, rev = tdir < 0) - - for j in idx - z = similar(cache.fᵢ₂_cache) - interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv) - vals[j] = z - end -end - -@inline function interpolation(tval::Number, id::MIRKInterpolation, idxs, - deriv::D, p, continuity::Symbol = :left) where {D} - z = similar(id.cache.fᵢ₂_cache) - interpolant!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv) - return idxs !== nothing ? z[idxs] : z -end - -@inline function interpolant!( - z::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{0}}) - i = interval(mesh, t) - dt = mesh_dt[i] - τ = (t - mesh[i]) / dt - w, w′ = interp_weights(τ, cache.alg) - sum_stages!(z, cache, w, i) -end - -@inline function interpolant!( - dz::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{1}}) - i = interval(mesh, t) - dt = mesh_dt[i] - τ = (t - mesh[i]) / dt - w, w′ = interp_weights(τ, cache.alg) - z = similar(dz) - sum_stages!(z, dz, cache, w, w′, i) -end - @inline function interpolation(tvals, id::FIRKNestedInterpolation, idxs, deriv::D, p, continuity::Symbol = :left) where {D} (; t, u, cache) = id @@ -204,8 +123,6 @@ end return idxs !== nothing ? z[idxs] : z end -@inline __build_interpolation(cache::MIRKCache, u::AbstractVector) = MIRKInterpolation( - cache.mesh, u, cache) @inline __build_interpolation(cache::FIRKCacheExpand, u::AbstractVector) = FIRKExpandInterpolation( cache.mesh, u, cache) @inline __build_interpolation(cache::FIRKCacheNested, u::AbstractVector) = FIRKNestedInterpolation( diff --git a/src/solve/firk.jl b/src/solve/firk.jl index 47e6d1a7..51c36f78 100644 --- a/src/solve/firk.jl +++ b/src/solve/firk.jl @@ -286,6 +286,20 @@ function __expand_cache!(cache::FIRKCacheExpand) return cache end +function __expand_cache!(cache::FIRKCacheNested) + Nₙ = length(cache.mesh) + __append_similar!(cache.k_discrete, Nₙ - 1, cache.M) + __append_similar!(cache.y, Nₙ, cache.M) + __append_similar!(cache.y₀, Nₙ, cache.M) + __append_similar!(cache.residual, Nₙ, cache.M) + __append_similar!(cache.defect, Nₙ - 1, cache.M) + return cache +end + +function __split_mirk_kwargs(; abstol, dt, adaptive = true, kwargs...) + return ((abstol, adaptive, dt), (; abstol, adaptive, kwargs...)) +end + function SciMLBase.solve!(cache::FIRKCacheExpand) (abstol, adaptive, _), kwargs = __split_mirk_kwargs(; cache.kwargs...) info::ReturnCode.T = ReturnCode.Success @@ -312,8 +326,33 @@ function SciMLBase.solve!(cache::FIRKCacheExpand) return __build_solution(cache.prob, odesol, sol_nlprob) end -function __perform_firk_iteration( - cache::FIRKCacheExpand, abstol, adaptive; nlsolve_kwargs = (;), kwargs...) +function SciMLBase.solve!(cache::FIRKCacheNested) + (abstol, adaptive, _), kwargs = __split_mirk_kwargs(; cache.kwargs...) + info::ReturnCode.T = ReturnCode.Success + + # We do the first iteration outside the loop to preserve type-stability of the + # `original` field of the solution + sol_nlprob, info, defect_norm = __perform_firk_iteration( + cache, abstol, adaptive; kwargs...) + + if adaptive + while SciMLBase.successful_retcode(info) && defect_norm > abstol + sol_nlprob, info, defect_norm = __perform_firk_iteration( + cache, abstol, adaptive; kwargs...) + end + end + + u = recursivecopy(cache.y₀) + + interpolation = __build_interpolation(cache, u.u) + + odesol = DiffEqBase.build_solution( + cache.prob, cache.alg, cache.mesh, u.u; interp = interpolation, retcode = info) + return __build_solution(cache.prob, odesol, sol_nlprob) +end + +function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested}, abstol, + adaptive::Bool; nlsolve_kwargs = (;), kwargs...) nlprob = __construct_nlproblem(cache, vec(cache.y₀)) nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve) sol_nlprob = __solve( @@ -325,7 +364,7 @@ function __perform_firk_iteration( # Early terminate if non-adaptive adaptive || return sol_nlprob, sol_nlprob.retcode, defect_norm - info = sol_nlprob.retcode + info::ReturnCode.T = sol_nlprob.retcode if info == ReturnCode.Success # Nonlinear Solve was successful defect_norm = defect_estimate!(cache) @@ -361,10 +400,40 @@ function __perform_firk_iteration( return sol_nlprob, info, defect_norm end +# Constructing the Nonlinear Problem +function __construct_nlproblem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{iip}}, + y::AbstractVector) where {iip} + pt = cache.problem_type + + loss_bc = if iip + @closure (du, u, p) -> __firk_loss_bc!( + du, u, p, pt, cache.bc, cache.y, cache.mesh, cache) + else + @closure (u, p) -> __firk_loss_bc(u, p, pt, cache.bc, cache.y, cache.mesh, cache) + end + + loss_collocation = if iip + @closure (du, u, p) -> __firk_loss_collocation!( + du, u, p, cache.y, cache.mesh, cache.residual, cache) + else + @closure (u, p) -> __firk_loss_collocation( + u, p, cache.y, cache.mesh, cache.residual, cache) + end + + loss = if iip + @closure (du, u, p) -> __firk_loss!( + du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache) + else + @closure (u, p) -> __firk_loss(u, p, cache.y, pt, cache.bc, cache.mesh, cache) + end + + return __construct_nlproblem(cache, y, loss_bc, loss_collocation, loss, pt) +end + function __construct_nlproblem( cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF, ::StandardBVProblem) where {iip, BC, C, LF} - (; nlsolve, jac_alg) = cache.alg + (; jac_alg) = cache.alg (; stage) = cache N = length(cache.mesh) @@ -417,11 +486,11 @@ function __construct_nlproblem( end jac = if iip - @closure (J, u, p) -> __mirk_mpoint_jacobian!( + @closure (J, u, p) -> __firk_mpoint_jacobian!( J, J_c, u, jac_alg.bc_diffmode, jac_alg.nonbc_diffmode, cache_bc, cache_collocation, loss_bcₚ, loss_collocationₚ, resid_bc, resid_collocation, L) else - @closure (u, p) -> __mirk_mpoint_jacobian( + @closure (u, p) -> __firk_mpoint_jacobian( jac_prototype, J_c, u, jac_alg.bc_diffmode, jac_alg.nonbc_diffmode, cache_bc, cache_collocation, loss_bcₚ, loss_collocationₚ, L) end @@ -437,7 +506,7 @@ end function __construct_nlproblem( cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF, ::TwoPointBVProblem) where {iip, BC, C, LF} - (; nlsolve, jac_alg) = cache.alg + (; jac_alg) = cache.alg (; stage) = cache N = length(cache.mesh) @@ -470,10 +539,10 @@ function __construct_nlproblem( jac_prototype = zero(init_jacobian(diffcache)) jac = if iip - @closure (J, u, p) -> __mirk_2point_jacobian!( + @closure (J, u, p) -> __firk_2point_jacobian!( J, u, jac_alg.diffmode, diffcache, lossₚ, resid) else - @closure (u, p) -> __mirk_2point_jacobian( + @closure (u, p) -> __firk_2point_jacobian( u, jac_prototype, jac_alg.diffmode, diffcache, lossₚ) end @@ -481,3 +550,227 @@ function __construct_nlproblem( nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) end + +function __construct_nlproblem( + cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C, + loss::LF, ::StandardBVProblem) where {iip, BC, C, LF} + (; jac_alg) = cache.alg + N = length(cache.mesh) + + resid_bc = cache.bcresid_prototype + L = length(resid_bc) + resid_collocation = __similar(y, cache.M * (N - 1)) + + loss_bcₚ = (iip ? __Fix3 : Base.Fix2)(loss_bc, cache.p) + loss_collocationₚ = (iip ? __Fix3 : Base.Fix2)(loss_collocation, cache.p) + + sd_bc = jac_alg.bc_diffmode isa AutoSparse ? SymbolicsSparsityDetection() : + NoSparsityDetection() + cache_bc = __sparse_jacobian_cache( + Val(iip), jac_alg.bc_diffmode, sd_bc, loss_bcₚ, resid_bc, y) + + sd_collocation = if jac_alg.nonbc_diffmode isa AutoSparse + if L < cache.M + # For underdetermined problems we use sparse since we don't have banded qr + colored_matrix = __generate_sparse_jacobian_prototype( + cache, cache.problem_type, y, y, cache.M, N) + J_full_band = nothing + __sparsity_detection_alg(ColoredMatrix( + sparse(colored_matrix.M), colored_matrix.row_colorvec, + colored_matrix.col_colorvec)) + else + J_full_band = BandedMatrix(Ones{eltype(y)}(L + cache.M * (N - 1), cache.M * N), + (L + 1, cache.M + max(cache.M - L, 0))) + __sparsity_detection_alg(__generate_sparse_jacobian_prototype( + cache, cache.problem_type, y, y, cache.M, N)) + end + else + J_full_band = nothing + NoSparsityDetection() + end + cache_collocation = __sparse_jacobian_cache( + Val(iip), jac_alg.nonbc_diffmode, sd_collocation, + loss_collocationₚ, resid_collocation, y) + + J_bc = zero(init_jacobian(cache_bc)) + J_c = zero(init_jacobian(cache_collocation)) + if J_full_band === nothing + jac_prototype = vcat(J_bc, J_c) + else + jac_prototype = AlmostBandedMatrix{eltype(cache)}(J_full_band, J_bc) + end + + jac = if iip + @closure (J, u, p) -> __firk_mpoint_jacobian!( + J, J_c, u, jac_alg.bc_diffmode, jac_alg.nonbc_diffmode, cache_bc, + cache_collocation, loss_bcₚ, loss_collocationₚ, resid_bc, resid_collocation, L) + else + @closure (u, p) -> __firk_mpoint_jacobian( + jac_prototype, J_c, u, jac_alg.bc_diffmode, jac_alg.nonbc_diffmode, + cache_bc, cache_collocation, loss_bcₚ, loss_collocationₚ, L) + end + + resid_prototype = vcat(resid_bc, resid_collocation) + nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + + return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) +end + +function __construct_nlproblem( + cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C, + loss::LF, ::TwoPointBVProblem) where {iip, BC, C, LF} + (; nlsolve, jac_alg) = cache.alg + N = length(cache.mesh) + + lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p)) + + resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), + __similar(y, cache.M * (N - 1)), + @view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])) + L = length(cache.bcresid_prototype) + + sd = if jac_alg.diffmode isa AutoSparse + __sparsity_detection_alg(__generate_sparse_jacobian_prototype( + cache, cache.problem_type, + @view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), + @view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]), + cache.M, N)) + else + NoSparsityDetection() + end + diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, lossₚ, resid, y) + jac_prototype = zero(init_jacobian(diffcache)) + + jac = if iip + @closure (J, u, p) -> __firk_2point_jacobian!( + J, u, jac_alg.diffmode, diffcache, lossₚ, resid) + else + @closure (u, p) -> __firk_2point_jacobian( + u, jac_prototype, jac_alg.diffmode, diffcache, lossₚ) + end + + resid_prototype = copy(resid) + nlf = __unsafe_nonlinearfunction{iip}(loss; resid_prototype, jac, jac_prototype) + return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p) +end + +@views function __firk_loss!( + resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, cache) where {BC} + y_ = recursive_unflatten!(y, u) + resids = [get_tmp(r, u) for r in residual] + soly_ = VectorOfArray(y_) + eval_bc_residual!(resids[1], pt, bc!, soly_, p, mesh) + Φ!(resids[2:end], cache, y_, u, p) + recursive_flatten!(resid, resids) + return nothing +end + +@views function __firk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, + residual, mesh, cache) where {BC1, BC2} + y_ = recursive_unflatten!(y, u) + soly_ = VectorOfArray(y_) + resids = [get_tmp(r, u) for r in residual] + resida = resids[1][1:prod(cache.resid_size[1])] + residb = resids[1][(prod(cache.resid_size[1]) + 1):end] + eval_bc_residual!((resida, residb), pt, bc!, soly_, p, mesh) + Φ!(resids[2:end], cache, y_, u, p) + recursive_flatten_twopoint!(resid, resids, cache.resid_size) + return nothing +end + +@views function __firk_loss(u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache) where {BC} + y_ = recursive_unflatten!(y, u) + soly_ = VectorOfArray(y_) + resid_bc = eval_bc_residual(pt, bc, soly_, p, mesh) + resid_co = Φ(cache, y_, u, p) + return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) +end + +@views function __firk_loss( + u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2}, mesh, cache) where {BC1, BC2} + y_ = recursive_unflatten!(y, u) + soly_ = VectorOfArray(y_) + resid_bca, resid_bcb = eval_bc_residual(pt, bc, soly_, p, mesh) + resid_co = Φ(cache, y_, u, p) + return vcat(resid_bca, mapreduce(vec, vcat, resid_co), resid_bcb) +end + +@views function __firk_loss_bc!(resid, u, p, pt, bc!::BC, y, mesh, + cache::Union{FIRKCacheNested, FIRKCacheExpand}) where {BC} + y_ = recursive_unflatten!(y, u) + soly_ = VectorOfArray(y_) + eval_bc_residual!(resid, pt, bc!, soly_, p, mesh) + return nothing +end + +@views function __firk_loss_bc(u, p, pt, bc!::BC, y, mesh, + cache::Union{FIRKCacheNested, FIRKCacheExpand}) where {BC} + y_ = recursive_unflatten!(y, u) + soly_ = VectorOfArray(y_) + return eval_bc_residual(pt, bc!, soly_, p, mesh) +end + +@views function __firk_loss_collocation!(resid, u, p, y, mesh, residual, cache) + y_ = recursive_unflatten!(y, u) + resids = [get_tmp(r, u) for r in residual[2:end]] + Φ!(resids, cache, y_, u, p) + recursive_flatten!(resid, resids) + return nothing +end + +@views function __firk_loss_collocation(u, p, y, mesh, residual, cache) + y_ = recursive_unflatten!(y, u) + resids = Φ(cache, y_, u, p) + return mapreduce(vec, vcat, resids) +end + +function __firk_mpoint_jacobian!( + J, _, x, bc_diffmode, nonbc_diffmode, bc_diffcache, nonbc_diffcache, loss_bc::BC, + loss_collocation::C, resid_bc, resid_collocation, L::Int) where {BC, C} + sparse_jacobian!(@view(J[1:L, :]), bc_diffmode, bc_diffcache, loss_bc, resid_bc, x) + sparse_jacobian!(@view(J[(L + 1):end, :]), nonbc_diffmode, + nonbc_diffcache, loss_collocation, resid_collocation, x) + return nothing +end + +function __firk_mpoint_jacobian!(J::AlmostBandedMatrix, J_c, x, bc_diffmode, nonbc_diffmode, + bc_diffcache, nonbc_diffcache, loss_bc::BC, loss_collocation::C, + resid_bc, resid_collocation, L::Int) where {BC, C} + J_bc = fillpart(J) + sparse_jacobian!(J_bc, bc_diffmode, bc_diffcache, loss_bc, resid_bc, x) + sparse_jacobian!( + J_c, nonbc_diffmode, nonbc_diffcache, loss_collocation, resid_collocation, x) + exclusive_bandpart(J) .= J_c + finish_part_setindex!(J) + return nothing +end + +function __firk_mpoint_jacobian( + J, _, x, bc_diffmode, nonbc_diffmode, bc_diffcache, nonbc_diffcache, + loss_bc::BC, loss_collocation::C, L::Int) where {BC, C} + sparse_jacobian!(@view(J[1:L, :]), bc_diffmode, bc_diffcache, loss_bc, x) + sparse_jacobian!( + @view(J[(L + 1):end, :]), nonbc_diffmode, nonbc_diffcache, loss_collocation, x) + return J +end + +function __firk_mpoint_jacobian( + J::AlmostBandedMatrix, J_c, x, bc_diffmode, nonbc_diffmode, bc_diffcache, + nonbc_diffcache, loss_bc::BC, loss_collocation::C, L::Int) where {BC, C} + J_bc = fillpart(J) + sparse_jacobian!(J_bc, bc_diffmode, bc_diffcache, loss_bc, x) + sparse_jacobian!(J_c, nonbc_diffmode, nonbc_diffcache, loss_collocation, x) + exclusive_bandpart(J) .= J_c + finish_part_setindex!(J) + return J +end + +function __firk_2point_jacobian!(J, x, diffmode, diffcache, loss_fn::L, resid) where {L} + sparse_jacobian!(J, diffmode, diffcache, loss_fn, resid, x) + return J +end + +function __firk_2point_jacobian(x, J, diffmode, diffcache, loss_fn::L) where {L} + sparse_jacobian!(J, diffmode, diffcache, loss_fn, x) + return J +end diff --git a/src/sparse_jacobians.jl b/src/sparse_jacobians.jl index 7f8b0d88..17497bb7 100644 --- a/src/sparse_jacobians.jl +++ b/src/sparse_jacobians.jl @@ -37,11 +37,12 @@ function __sparsity_detection_alg(M::ColoredMatrix) end __sparsity_detection_alg(::ColoredMatrix{Nothing}) = NoSparsityDetection() -# For MIRK Methods +# For FIRK Methods """ - __generate_sparse_jacobian_prototype(::MIRKCache, ya, yb, M, N) - __generate_sparse_jacobian_prototype(::MIRKCache, _, ya, yb, M, N) - __generate_sparse_jacobian_prototype(::MIRKCache, ::TwoPointBVProblem, ya, yb, M, N) + __generate_sparse_jacobian_prototype(::FIRKCacheNested, ::StandardBVProblem, ya, yb, M, N) + __generate_sparse_jacobian_prototype(::FIRKCacheNested, ::TwoPointBVProblem, ya, yb, M, N) + __generate_sparse_jacobian_prototype(::FIRKCacheExpand, ::StandardBVProblem, ya, yb, M, N) + __generate_sparse_jacobian_prototype(::FIRKCacheExpand, ::TwoPointBVProblem, ya, yb, M, N) Generate a prototype of the sparse Jacobian matrix for the BVP problem with row and column coloring. @@ -49,12 +50,8 @@ coloring. If the problem is a TwoPointBVProblem, then this is the complete Jacobian, else it only computes the sparse part excluding the contributions from the boundary conditions. """ -function __generate_sparse_jacobian_prototype(cache::MIRKCache, ya, yb, M, N) - return __generate_sparse_jacobian_prototype(cache, cache.problem_type, ya, yb, M, N) -end - function __generate_sparse_jacobian_prototype( - ::Union{MIRKCache, FIRKCacheNested}, ::StandardBVProblem, ya, yb, M, N) + ::FIRKCacheNested, ::StandardBVProblem, ya, yb, M, N) fast_scalar_indexing(ya) || error("Sparse Jacobians are only supported for Fast Scalar Index-able Arrays") J_c = BandedMatrix(Ones{eltype(ya)}(M * (N - 1), M * N), (1, 2M - 1)) @@ -62,7 +59,7 @@ function __generate_sparse_jacobian_prototype( end function __generate_sparse_jacobian_prototype( - ::Union{MIRKCache, FIRKCacheNested}, ::TwoPointBVProblem, ya, yb, M, N) + ::FIRKCacheNested, ::TwoPointBVProblem, ya, yb, M, N) fast_scalar_indexing(ya) || error("Sparse Jacobians are only supported for Fast Scalar Index-able Arrays") J₁ = length(ya) + length(yb) + M * (N - 1) diff --git a/src/types.jl b/src/types.jl index 2000ea9b..f3506807 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,33 +1,3 @@ -# MIRK Method Tableaus -struct MIRKTableau{sType, cType, vType, bType, xType} - """Discrete stages of MIRK formula""" - s::sType - c::cType - v::vType - b::bType - x::xType - - function MIRKTableau(s, c, v, b, x) - @assert eltype(c) == eltype(v) == eltype(b) == eltype(x) - return new{typeof(s), typeof(c), typeof(v), typeof(b), typeof(x)}(s, c, v, b, x) - end -end - -struct MIRKInterpTableau{s, c, v, x, τ} - s_star::s - c_star::c - v_star::v - x_star::x - τ_star::τ - - function MIRKInterpTableau(s_star, c_star, v_star, x_star, τ_star) - @assert eltype(c_star) == eltype(v_star) == eltype(x_star) - return new{ - typeof(s_star), typeof(c_star), typeof(v_star), typeof(x_star), typeof(τ_star)}( - s_star, c_star, v_star, x_star, τ_star) - end -end - # FIRK Method Tableaus struct FIRKTableau{nested, sType, aType, cType, bType} """Discrete stages of RK formula""" diff --git a/test/firk/expanded/nlls_tests.jl b/test/firk/expanded/nlls_tests.jl index 5206d933..99040ac1 100644 --- a/test/firk/expanded/nlls_tests.jl +++ b/test/firk/expanded/nlls_tests.jl @@ -8,7 +8,7 @@ nlsolve in (LevenbergMarquardt(), GaussNewton())]#, TrustRegion())] SOLVERS_NAMES = ["$solver with $nlsolve" for solver in ["RadauIIa5", "LobattoIIIa4", "LobattoIIIb4", "LobattoIIIc4"], -nlsolve in ["LevenbergMarquardt", "GaussNewton", "TrustRegion"]] +nlsolve in ["LevenbergMarquardt", "GaussNewton"]] ### Overconstrained BVP ### diff --git a/test/runtests.jl b/test/runtests.jl index fb57f8c1..979da3ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,22 @@ -using ReTestItems +using ReTestItems, Pkg const GROUP = get(ENV, "GROUP", "All") const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR") +function activate_mirk() + Pkg.activate("../lib/BoundaryValueDiffEqMIRK") + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + Pkg.instantiate() +end + @time begin if GROUP == "All" || GROUP == "MIRK" @time "MIRK solvers" begin - ReTestItems.runtests(joinpath(@__DIR__, "mirk/")) + activate_mirk() + ReTestItems.runtests("../lib/BoundaryValueDiffEqMIRK/test/ensemble_tests.jl") + ReTestItems.runtests("../lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl") + ReTestItems.runtests("../lib/BoundaryValueDiffEqMIRK/test/nlls_tests.jl") + ReTestItems.runtests("../lib/BoundaryValueDiffEqMIRK/test/vectorofvector_initials_tests.jl") end end