Skip to content

Commit

Permalink
Fix solvers failed on OOP problem
Browse files Browse the repository at this point in the history
Signed-off-by: ErikQQY <[email protected]>
  • Loading branch information
ErikQQY committed Sep 17, 2023
1 parent 639ae0f commit 11fc315
Show file tree
Hide file tree
Showing 21 changed files with 446 additions and 435 deletions.
44 changes: 22 additions & 22 deletions src/euler/euler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ DiffEqBase.isinplace(::SEI{IIP}) where {IIP} = IIP
################################################################################

function DiffEqBase.__init(prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm"))
dt = error("dt is required for this algorithm"))
simpleeuler_init(prob.f,
DiffEqBase.isinplace(prob),
prob.u0,
prob.tspan[1],
dt,
prob.p)
DiffEqBase.isinplace(prob),
prob.u0,
prob.tspan[1],
dt,
prob.p)
end

function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm"))
dt = error("dt is required for this algorithm"))
u0 = prob.u0
tspan = prob.tspan
ts = Array(tspan[1]:dt:tspan[2])
Expand All @@ -53,7 +53,7 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler;
@inbounds us[1] = _copy(u0)

integ = simpleeuler_init(prob.f, DiffEqBase.isinplace(prob), prob.u0,
prob.tspan[1], dt, prob.p)
prob.tspan[1], dt, prob.p)

for i in 1:(n - 1)
step!(integ)
Expand All @@ -64,26 +64,26 @@ function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler;

DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol;
timeseries_errors = true,
dense_errors = false)
timeseries_errors = true,
dense_errors = false)

return sol
end

@inline function simpleeuler_init(f::F, IIP::Bool, u0::S, t0::T, dt::T,
p::P) where
{F, P, T, S <: AbstractArray{T}}
p::P) where
{F, P, T, S}
integ = SEI{IIP, S, T, P, F}(f,
_copy(u0),
_copy(u0),
_copy(u0),
t0,
t0,
t0,
dt,
sign(dt),
p,
true)
_copy(u0),
_copy(u0),
_copy(u0),
t0,
t0,
t0,
dt,
sign(dt),
p,
true)

return integ
end
Expand Down
10 changes: 5 additions & 5 deletions src/euler/gpueuler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ struct GPUSimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end
export GPUSimpleEuler

@muladd function DiffEqBase.solve(prob::ODEProblem,
alg::GPUSimpleEuler;
dt = error("dt is required for this algorithm"))
alg::GPUSimpleEuler;
dt = error("dt is required for this algorithm"))
@assert !isinplace(prob)
u0 = prob.u0
tspan = prob.tspan
Expand All @@ -30,10 +30,10 @@ export GPUSimpleEuler
end

sol = DiffEqBase.build_solution(prob, alg, ts, SArray(us),
k = nothing, stats = nothing,
calculate_error = false)
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
dense_errors = false)
sol
end
44 changes: 22 additions & 22 deletions src/euler/loopeuler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ export LoopEuler
# Out-of-place
# No caching, good for static arrays, bad for arrays
@muladd function DiffEqBase.__solve(prob::ODEProblem{uType, tType, false},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
@assert !adaptive
@assert !dense
u0 = prob.u0
Expand Down Expand Up @@ -51,26 +51,26 @@ export LoopEuler
!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false)
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
dense_errors = false)
sol
end

# In-place
# Good for mutable objects like arrays
# Use DiffEqBase.@.. for simd ivdep
@muladd function DiffEqBase.solve(prob::ODEProblem{uType, tType, true},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
@assert !adaptive
@assert !dense
u0 = prob.u0
Expand Down Expand Up @@ -106,10 +106,10 @@ end
!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false)
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
dense_errors = false)
sol
end
16 changes: 8 additions & 8 deletions src/euler_maruyama.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ struct SimpleEM <: DiffEqBase.AbstractSDEAlgorithm end
export SimpleEM

@muladd function DiffEqBase.solve(prob::SDEProblem{uType, tType, false}, alg::SimpleEM,
args...;
dt = error("dt required for SimpleEM")) where {uType,
tType}
args...;
dt = error("dt required for SimpleEM")) where {uType,
tType}
f = prob.f
g = prob.g
u0 = prob.u0
Expand Down Expand Up @@ -39,13 +39,13 @@ export SimpleEM
end

sol = DiffEqBase.build_solution(prob, alg, t, u,
calculate_error = false)
calculate_error = false)
end

@muladd function DiffEqBase.solve(prob::SDEProblem{uType, tType, true}, alg::SimpleEM,
args...;
dt = error("dt required for SimpleEM")) where {uType,
tType}
args...;
dt = error("dt required for SimpleEM")) where {uType,
tType}
f = prob.f
g = prob.g
u0 = prob.u0
Expand Down Expand Up @@ -80,5 +80,5 @@ end
end

sol = DiffEqBase.build_solution(prob, alg, t, u,
calculate_error = false)
calculate_error = false)
end
22 changes: 11 additions & 11 deletions src/functionmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ SciMLBase.isdiscrete(alg::SimpleFunctionMap) = true

# ConstantCache version
function DiffEqBase.__solve(prob::DiffEqBase.DiscreteProblem{uType, tupType, false},
alg::SimpleFunctionMap;
calculate_values = true) where {uType, tupType}
alg::SimpleFunctionMap;
calculate_values = true) where {uType, tupType}
tType = eltype(tupType)
tspan = prob.tspan
f = prob.f
Expand All @@ -22,14 +22,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.DiscreteProblem{uType, tupType, fal
end
end
sol = DiffEqBase.build_solution(prob, alg, t, u, dense = false,
interp = DiffEqBase.ConstantInterpolation(t, u),
calculate_error = false)
interp = DiffEqBase.ConstantInterpolation(t, u),
calculate_error = false)
end

# Cache version
function DiffEqBase.__solve(prob::DiscreteProblem{uType, tupType, true},
alg::SimpleFunctionMap;
calculate_values = true) where {uType, tupType}
alg::SimpleFunctionMap;
calculate_values = true) where {uType, tupType}
tType = eltype(tupType)
tspan = prob.tspan
f = prob.f
Expand All @@ -47,8 +47,8 @@ function DiffEqBase.__solve(prob::DiscreteProblem{uType, tupType, true},
end
end
sol = DiffEqBase.build_solution(prob, alg, t, u, dense = false,
interp = DiffEqBase.ConstantInterpolation(t, u),
calculate_error = false)
interp = DiffEqBase.ConstantInterpolation(t, u),
calculate_error = false)
end

##################################################
Expand All @@ -67,7 +67,7 @@ mutable struct DiscreteIntegrator{F, IIP, uType, tType, P, S} <:
end

function DiffEqBase.__init(prob::DiscreteProblem,
alg::SimpleFunctionMap)
alg::SimpleFunctionMap)
sol = solve(prob, alg; calculate_values = false)
F = typeof(prob.f)
IIP = isinplace(prob)
Expand All @@ -76,8 +76,8 @@ function DiffEqBase.__init(prob::DiscreteProblem,
P = typeof(prob.p)
S = typeof(sol)
DiscreteIntegrator{F, IIP, uType, tType, P, S}(prob.f, prob.u0, prob.tspan[1],
copy(prob.u0), prob.p, sol, 1,
one(tType))
copy(prob.u0), prob.p, sol, 1,
one(tType))
end

function DiffEqBase.step!(integrator::DiscreteIntegrator)
Expand Down
10 changes: 5 additions & 5 deletions src/rk4/gpurk4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ struct GPUSimpleRK4 <: AbstractSimpleDiffEqODEAlgorithm end
export GPUSimpleRK4

@muladd function DiffEqBase.solve(prob::ODEProblem,
alg::GPUSimpleRK4;
dt = error("dt is required for this algorithm"))
alg::GPUSimpleRK4;
dt = error("dt is required for this algorithm"))
@assert !isinplace(prob)
u0 = prob.u0
tspan = prob.tspan
Expand Down Expand Up @@ -38,10 +38,10 @@ export GPUSimpleRK4
end

sol = DiffEqBase.build_solution(prob, alg, ts, SArray(us),
k = nothing, stats = nothing,
calculate_error = false)
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
dense_errors = false)
sol
end
44 changes: 22 additions & 22 deletions src/rk4/looprk4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ export LoopRK4
# Out-of-place
# No caching, good for static arrays, bad for arrays
@muladd function DiffEqBase.__solve(prob::ODEProblem{uType, tType, false},
alg::LoopRK4;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
alg::LoopRK4;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
@assert !adaptive
@assert !dense
u0 = prob.u0
Expand Down Expand Up @@ -59,26 +59,26 @@ export LoopRK4
!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false)
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
dense_errors = false)
sol
end

# In-place
# Good for mutable objects like arrays
# Use DiffEqBase.@.. for simd ivdep
@muladd function DiffEqBase.solve(prob::ODEProblem{uType, tType, true},
alg::LoopRK4;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
alg::LoopRK4;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
@assert !adaptive
@assert !dense
u0 = prob.u0
Expand Down Expand Up @@ -127,10 +127,10 @@ end
!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false)
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
dense_errors = false)
sol
end
Loading

0 comments on commit 11fc315

Please sign in to comment.