Skip to content

Commit

Permalink
Merge pull request #70 from SciML/euler
Browse files Browse the repository at this point in the history
add a SimpleEuler method
  • Loading branch information
ChrisRackauckas authored Aug 14, 2023
2 parents 07645cc + f91225e commit 0c45459
Show file tree
Hide file tree
Showing 7 changed files with 465 additions and 9 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ julia = "1.6"

[extras]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["OrdinaryDiffEq", "Test"]
test = ["OrdinaryDiffEq", "SafeTestsets", "Test"]
3 changes: 3 additions & 0 deletions src/SimpleDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ include("euler_maruyama.jl")
include("rk4/rk4.jl")
include("rk4/gpurk4.jl")
include("rk4/looprk4.jl")
include("euler/euler.jl")
include("euler/gpueuler.jl")
include("euler/loopeuler.jl")
include("tsit5/atsit5_cache.jl")
include("tsit5/tsit5.jl")
include("tsit5/atsit5.jl")
Expand Down
151 changes: 151 additions & 0 deletions src/euler/euler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
#
# Euler solver
#
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

struct SimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end
export SimpleEuler

mutable struct SimpleEulerIntegrator{IIP, S, T, P, F} <:
DiffEqBase.AbstractODEIntegrator{SimpleEuler, IIP, S, T}
f::F # ..................................... Equations of motion
uprev::S # .......................................... Previous state
u::S # ........................................... Current state
tmp::S # Auxiliary variable similar to state to avoid allocations
tprev::T # ...................................... Previous time step
t::T # ....................................... Current time step
t0::T # ........... Initial time step, only for re-initialization
dt::T # ............................................... Step size
tdir::T # ...................................... Not used for Euler
p::P # .................................... Parameters container
u_modified::Bool # ..... If `true`, then the input of last step was modified
end

const SEI = SimpleEulerIntegrator

# If `true`, then the equation of motion format is `f!(du,u,p,t)` instead of
# `du = f(u,p,t)`.
DiffEqBase.isinplace(::SEI{IIP}) where {IIP} = IIP

################################################################################
# Initialization
################################################################################

function DiffEqBase.__init(prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm"))
simpleeuler_init(prob.f,
DiffEqBase.isinplace(prob),
prob.u0,
prob.tspan[1],
dt,
prob.p)
end

function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler;
dt = error("dt is required for this algorithm"))
u0 = prob.u0
tspan = prob.tspan
ts = Array(tspan[1]:dt:tspan[2])
n = length(ts)
us = Vector{typeof(u0)}(undef, n)

@inbounds us[1] = _copy(u0)

integ = simpleeuler_init(prob.f, DiffEqBase.isinplace(prob), prob.u0,
prob.tspan[1], dt, prob.p)

for i in 1:(n - 1)
step!(integ)
us[i + 1] = _copy(integ.u)
end

sol = DiffEqBase.build_solution(prob, alg, ts, us, calculate_error = false)

DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol;
timeseries_errors = true,
dense_errors = false)

return sol
end

@inline function simpleeuler_init(f::F, IIP::Bool, u0::S, t0::T, dt::T,
p::P) where
{F, P, T, S <: AbstractArray{T}}
integ = SEI{IIP, S, T, P, F}(f,
_copy(u0),
_copy(u0),
_copy(u0),
t0,
t0,
t0,
dt,
sign(dt),
p,
true)

return integ
end

################################################################################
# Stepping
################################################################################

@inline @muladd function DiffEqBase.step!(integ::SEI{true, S, T}) where {T, S}
integ.uprev .= integ.u
tmp = integ.tmp
f! = integ.f
p = integ.p
t = integ.t
dt = integ.dt
uprev = integ.uprev
u = integ.u

f!(u, uprev, p, t)
@. u = uprev + dt * u

integ.tprev = t
integ.t += dt

return nothing
end

@inline @muladd function DiffEqBase.step!(integ::SEI{false, S, T}) where {T, S}
integ.uprev = integ.u
f = integ.f
p = integ.p
t = integ.t
dt = integ.dt
uprev = integ.uprev

k = f(uprev, p, t)
integ.u = uprev + dt * k
integ.tprev = t
integ.t += dt

return nothing
end

################################################################################
# Interpolation
################################################################################

@inline @muladd function (integ::SEI)(t::T) where {T}
t₁, t₀, dt = integ.t, integ.tprev, integ.dt

y₀ = integ.uprev
y₁ = integ.u
Θ = (t - t₀) / dt

# Hermite interpolation.
@inbounds if !isinplace(integ)
u = (1 - Θ) * y₀ + Θ * y₁
return u
else
for i in 1:length(u)
u = @. (1 - Θ) * y₀ + Θ * y₁
end
return u
end
end
39 changes: 39 additions & 0 deletions src/euler/gpueuler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#######################################################################################
# GPU-crutch solve method
# Makes the simplest possible method for GPU-compatibility
# Out of place only
#######################################################################################
struct GPUSimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end
export GPUSimpleEuler

@muladd function DiffEqBase.solve(prob::ODEProblem,
alg::GPUSimpleEuler;
dt = error("dt is required for this algorithm"))
@assert !isinplace(prob)
u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p
t = tspan[1]
tf = prob.tspan[2]
ts = tspan[1]:dt:tspan[2]
us = MVector{length(ts), typeof(u0)}(undef)
us[1] = u0
u = u0

for i in 2:length(ts)
uprev = u
t = ts[i]
k1 = f(u, p, t)
u = uprev + dt * k1
us[i] = u
end

sol = DiffEqBase.build_solution(prob, alg, ts, SArray(us),
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end
115 changes: 115 additions & 0 deletions src/euler/loopeuler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#######################################################################################
# Simplest Loop method
# Makes the simplest possible method for teaching and performance testing
#######################################################################################
struct LoopEuler <: AbstractSimpleDiffEqODEAlgorithm end
export LoopEuler

# Out-of-place
# No caching, good for static arrays, bad for arrays
@muladd function DiffEqBase.__solve(prob::ODEProblem{uType, tType, false},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
@assert !adaptive
@assert !dense
u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p
t = tspan[1]
tf = prob.tspan[2]
ts = tspan[1]:dt:tspan[2]

if save_everystep && save_start
us = Vector{typeof(u0)}(undef, length(ts))
us[1] = u0
elseif save_everystep
us = Vector{typeof(u0)}(undef, length(ts) - 1)
elseif save_start
us = Vector{typeof(u0)}(undef, 2)
us[1] = u0
else
us = Vector{typeof(u0)}(undef, 1) # for interface compatibility
end

u = u0

for i in 2:length(ts)
uprev = u
t = ts[i]
k = f(u, p, t)
u = uprev + dt * k
save_everystep && (us[i] = u)
end

!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end

# In-place
# Good for mutable objects like arrays
# Use DiffEqBase.@.. for simd ivdep
@muladd function DiffEqBase.solve(prob::ODEProblem{uType, tType, true},
alg::LoopEuler;
dt = error("dt is required for this algorithm"),
save_everystep = true,
save_start = true,
adaptive = false,
dense = false,
save_end = true,
kwargs...) where {uType, tType}
@assert !adaptive
@assert !dense
u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p
t = tspan[1]
tf = prob.tspan[2]
ts = tspan[1]:dt:tspan[2]

if save_everystep && save_start
us = Vector{typeof(u0)}(undef, length(ts))
us[1] = u0
elseif save_everystep
us = Vector{typeof(u0)}(undef, length(ts) - 1)
elseif save_start
us = Vector{typeof(u0)}(undef, 2)
us[1] = u0
else
us = Vector{typeof(u0)}(undef, 1) # for interface compatibility
end

u = copy(u0)
k = zero(u0)

for i in 2:length(ts)
t = ts[i]
f(k, u, p, t)
DiffEqBase.@.. u = u + dt * k
save_everystep && (us[i] = copy(u))
end

!save_everystep && save_end && (us[end] = u)

sol = DiffEqBase.build_solution(prob, alg, ts, us,
k = nothing, stats = nothing,
calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end
17 changes: 9 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using SimpleDiffEq, Test
using SimpleDiffEq, SafeTestsets, Test

@time begin
@time @testset "Discrete Tests" begin include("discrete_tests.jl") end
@time @testset "SimpleEM Tests" begin include("simpleem_tests.jl") end
@time @testset "SimpleTsit5 Tests" begin include("simpletsit5_tests.jl") end
@time @testset "SimpleATsit5 Tests" begin include("simpleatsit5_tests.jl") end
@time @testset "GPUSimpleATsit5 Tests" begin include("gpusimpleatsit5_tests.jl") end
@time @testset "SimpleRK4 Tests" begin include("simplerk4_tests.jl") end
@time @testset "GPU Compatible ODE Tests" begin include("gpu_ode_regression.jl") end
@time @safetestset "Discrete Tests" include("discrete_tests.jl")
@time @safetestset "SimpleEM Tests" include("simpleem_tests.jl")
@time @safetestset "SimpleTsit5 Tests" include("simpletsit5_tests.jl")
@time @safetestset "SimpleATsit5 Tests" include("simpleatsit5_tests.jl")
@time @safetestset "GPUSimpleATsit5 Tests" include("gpusimpleatsit5_tests.jl")
@time @safetestset "SimpleRK4 Tests" include("simplerk4_tests.jl")
@time @safetestset "SimpleEuler Tests" include("simpleeuler_tests.jl")
@time @safetestset "GPU Compatible ODE Tests" include("gpu_ode_regression.jl")
end
Loading

0 comments on commit 0c45459

Please sign in to comment.