Skip to content

Commit

Permalink
Merge pull request #119 from ErikQQY/qqy/fix_interp
Browse files Browse the repository at this point in the history
Fix interpolant evaluation error
  • Loading branch information
ChrisRackauckas authored Oct 12, 2023
2 parents 14f47d3 + 43c946c commit bdc87f1
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 48 deletions.
69 changes: 21 additions & 48 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ struct MIRKInterpolation{T1, T2} <: AbstractDiffEqInterpolation
cache
end

function DiffEqBase.interp_summary(interp::MIRKInterpolation)
return "MIRK Order $(interp.cache.order) Interpolation"
end

function (id::MIRKInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation(tvals, id, idxs, deriv, p, continuity)
end
Expand All @@ -12,15 +16,11 @@ function (id::MIRKInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol
interpolation!(val, tvals, id, idxs, deriv, p, continuity)
end

@inline function interpolation(tvals,
id::I,
idxs,
deriv::D,
p,
# FIXME: Fix the interpolation outside the tspan

@inline function interpolation(tvals, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
t = id.t
u = id.u
cache = id.cache
@unpack t, u, cache = id
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)

Expand All @@ -33,56 +33,29 @@ end
end

for j in idx
tval = tvals[j]
i = interval(t, tval)
dt = t[i + 1] - t[i]
θ = (tval - t[i]) / dt
weights, _ = interp_weights(θ, cache.alg)
z = zeros(cache.M)
sum_stages!(z, cache, weights, i)
vals[j] = copy(z)
z = similar(cache.fᵢ₂_cache)
interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt)
vals[j] = z
end
DiffEqArray(vals, tvals)
return DiffEqArray(vals, tvals)
end

@inline function interpolation!(vals,
tvals,
id::I,
idxs,
deriv::D,
p,
@inline function interpolation!(vals, tvals, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
t = id.t
cache = id.cache
@unpack t, cache = id
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)

for j in idx
tval = tvals[j]
i = interval(t, tval)
dt = t[i] - t[i - 1]
θ = (tval - t[i]) / dt
weights, _ = interp_weights(θ, cache.alg)
z = zeros(cache.M)
sum_stages!(z, cache, weights, i)
vals[j] = copy(z)
z = similar(cache.fᵢ₂_cache)
interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt)
vals[j] = z
end
end

@inline function interpolation(tval::Number,
id::I,
idxs,
deriv::D,
p,
@inline function interpolation(tval::Number, id::I, idxs, deriv::D, p,
continuity::Symbol = :left) where {I, D}
t = id.t
cache = id.cache
i = interval(t, tval)
dt = t[i] - t[i - 1]
θ = (tval - t[i]) / dt
weights, _ = interp_weights(θ, cache.alg)
z = zeros(cache.M)
sum_stages!(z, cache, weights, i)
val = copy(z)
val
z = similar(id.cache.fᵢ₂_cache)
interp_eval!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt)
return z
end
33 changes: 33 additions & 0 deletions test/interpolation_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using BoundaryValueDiffEq, DiffEqBase, DiffEqDevTools, LinearAlgebra, Test

λ = 1
function prob_bvp_linear_analytic(u, λ, t)
a = 1 / sqrt(λ)
[(exp(-a * t) - exp((t - 2) * a)) / (1 - exp(-2 * a)),
(-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a))]
end
function prob_bvp_linear_f!(du, u, p, t)
du[1] = u[2]
du[2] = 1 / p * u[1]
end
function prob_bvp_linear_bc!(res, u, p, t)
res[1] = u[1][1] - 1
res[2] = u[end][1]
end
prob_bvp_linear_function = ODEFunction(prob_bvp_linear_f!, analytic = prob_bvp_linear_analytic)
prob_bvp_linear_tspan = (0.0, 1.0)
prob_bvp_linear = BVProblem(prob_bvp_linear_function, prob_bvp_linear_bc!,
[1.0, 0.0], prob_bvp_linear_tspan, λ)
testTol = 1e-6

for order in (2, 3, 4, 5, 6)
s = Symbol("MIRK$(order)")
@eval mirk_solver(::Val{$order}) = $(s)()
end

@testset "Interpolation" begin
@testset "MIRK$order" for order in (2, 3, 4, 5, 6)
@time sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001)
@test sol(0.001) [0.998687464, -1.312035941] atol=testTol
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,10 @@ using Test, SafeTestsets
include("non_vector_inputs.jl")
end
end

@time @testset "Interpolation Tests" begin
@time @safetestset "MIRK Interpolation Test" begin
include("interpolation_test.jl")
end
end
end

0 comments on commit bdc87f1

Please sign in to comment.