-
-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #70 from SciML/euler
add a SimpleEuler method
- Loading branch information
Showing
7 changed files
with
465 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # | ||
# | ||
# Euler solver | ||
# | ||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # | ||
|
||
struct SimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end | ||
export SimpleEuler | ||
|
||
mutable struct SimpleEulerIntegrator{IIP, S, T, P, F} <: | ||
DiffEqBase.AbstractODEIntegrator{SimpleEuler, IIP, S, T} | ||
f::F # ..................................... Equations of motion | ||
uprev::S # .......................................... Previous state | ||
u::S # ........................................... Current state | ||
tmp::S # Auxiliary variable similar to state to avoid allocations | ||
tprev::T # ...................................... Previous time step | ||
t::T # ....................................... Current time step | ||
t0::T # ........... Initial time step, only for re-initialization | ||
dt::T # ............................................... Step size | ||
tdir::T # ...................................... Not used for Euler | ||
p::P # .................................... Parameters container | ||
u_modified::Bool # ..... If `true`, then the input of last step was modified | ||
end | ||
|
||
const SEI = SimpleEulerIntegrator | ||
|
||
# If `true`, then the equation of motion format is `f!(du,u,p,t)` instead of | ||
# `du = f(u,p,t)`. | ||
DiffEqBase.isinplace(::SEI{IIP}) where {IIP} = IIP | ||
|
||
################################################################################ | ||
# Initialization | ||
################################################################################ | ||
|
||
function DiffEqBase.__init(prob::ODEProblem, alg::SimpleEuler; | ||
dt = error("dt is required for this algorithm")) | ||
simpleeuler_init(prob.f, | ||
DiffEqBase.isinplace(prob), | ||
prob.u0, | ||
prob.tspan[1], | ||
dt, | ||
prob.p) | ||
end | ||
|
||
function DiffEqBase.__solve(prob::ODEProblem, alg::SimpleEuler; | ||
dt = error("dt is required for this algorithm")) | ||
u0 = prob.u0 | ||
tspan = prob.tspan | ||
ts = Array(tspan[1]:dt:tspan[2]) | ||
n = length(ts) | ||
us = Vector{typeof(u0)}(undef, n) | ||
|
||
@inbounds us[1] = _copy(u0) | ||
|
||
integ = simpleeuler_init(prob.f, DiffEqBase.isinplace(prob), prob.u0, | ||
prob.tspan[1], dt, prob.p) | ||
|
||
for i in 1:(n - 1) | ||
step!(integ) | ||
us[i + 1] = _copy(integ.u) | ||
end | ||
|
||
sol = DiffEqBase.build_solution(prob, alg, ts, us, calculate_error = false) | ||
|
||
DiffEqBase.has_analytic(prob.f) && | ||
DiffEqBase.calculate_solution_errors!(sol; | ||
timeseries_errors = true, | ||
dense_errors = false) | ||
|
||
return sol | ||
end | ||
|
||
@inline function simpleeuler_init(f::F, IIP::Bool, u0::S, t0::T, dt::T, | ||
p::P) where | ||
{F, P, T, S <: AbstractArray{T}} | ||
integ = SEI{IIP, S, T, P, F}(f, | ||
_copy(u0), | ||
_copy(u0), | ||
_copy(u0), | ||
t0, | ||
t0, | ||
t0, | ||
dt, | ||
sign(dt), | ||
p, | ||
true) | ||
|
||
return integ | ||
end | ||
|
||
################################################################################ | ||
# Stepping | ||
################################################################################ | ||
|
||
@inline @muladd function DiffEqBase.step!(integ::SEI{true, S, T}) where {T, S} | ||
integ.uprev .= integ.u | ||
tmp = integ.tmp | ||
f! = integ.f | ||
p = integ.p | ||
t = integ.t | ||
dt = integ.dt | ||
uprev = integ.uprev | ||
u = integ.u | ||
|
||
f!(u, uprev, p, t) | ||
@. u = uprev + dt * u | ||
|
||
integ.tprev = t | ||
integ.t += dt | ||
|
||
return nothing | ||
end | ||
|
||
@inline @muladd function DiffEqBase.step!(integ::SEI{false, S, T}) where {T, S} | ||
integ.uprev = integ.u | ||
f = integ.f | ||
p = integ.p | ||
t = integ.t | ||
dt = integ.dt | ||
uprev = integ.uprev | ||
|
||
k = f(uprev, p, t) | ||
integ.u = uprev + dt * k | ||
integ.tprev = t | ||
integ.t += dt | ||
|
||
return nothing | ||
end | ||
|
||
################################################################################ | ||
# Interpolation | ||
################################################################################ | ||
|
||
@inline @muladd function (integ::SEI)(t::T) where {T} | ||
t₁, t₀, dt = integ.t, integ.tprev, integ.dt | ||
|
||
y₀ = integ.uprev | ||
y₁ = integ.u | ||
Θ = (t - t₀) / dt | ||
|
||
# Hermite interpolation. | ||
@inbounds if !isinplace(integ) | ||
u = (1 - Θ) * y₀ + Θ * y₁ | ||
return u | ||
else | ||
for i in 1:length(u) | ||
u = @. (1 - Θ) * y₀ + Θ * y₁ | ||
end | ||
return u | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
####################################################################################### | ||
# GPU-crutch solve method | ||
# Makes the simplest possible method for GPU-compatibility | ||
# Out of place only | ||
####################################################################################### | ||
struct GPUSimpleEuler <: AbstractSimpleDiffEqODEAlgorithm end | ||
export GPUSimpleEuler | ||
|
||
@muladd function DiffEqBase.solve(prob::ODEProblem, | ||
alg::GPUSimpleEuler; | ||
dt = error("dt is required for this algorithm")) | ||
@assert !isinplace(prob) | ||
u0 = prob.u0 | ||
tspan = prob.tspan | ||
f = prob.f | ||
p = prob.p | ||
t = tspan[1] | ||
tf = prob.tspan[2] | ||
ts = tspan[1]:dt:tspan[2] | ||
us = MVector{length(ts), typeof(u0)}(undef) | ||
us[1] = u0 | ||
u = u0 | ||
|
||
for i in 2:length(ts) | ||
uprev = u | ||
t = ts[i] | ||
k1 = f(u, p, t) | ||
u = uprev + dt * k1 | ||
us[i] = u | ||
end | ||
|
||
sol = DiffEqBase.build_solution(prob, alg, ts, SArray(us), | ||
k = nothing, stats = nothing, | ||
calculate_error = false) | ||
DiffEqBase.has_analytic(prob.f) && | ||
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, | ||
dense_errors = false) | ||
sol | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
####################################################################################### | ||
# Simplest Loop method | ||
# Makes the simplest possible method for teaching and performance testing | ||
####################################################################################### | ||
struct LoopEuler <: AbstractSimpleDiffEqODEAlgorithm end | ||
export LoopEuler | ||
|
||
# Out-of-place | ||
# No caching, good for static arrays, bad for arrays | ||
@muladd function DiffEqBase.__solve(prob::ODEProblem{uType, tType, false}, | ||
alg::LoopEuler; | ||
dt = error("dt is required for this algorithm"), | ||
save_everystep = true, | ||
save_start = true, | ||
adaptive = false, | ||
dense = false, | ||
save_end = true, | ||
kwargs...) where {uType, tType} | ||
@assert !adaptive | ||
@assert !dense | ||
u0 = prob.u0 | ||
tspan = prob.tspan | ||
f = prob.f | ||
p = prob.p | ||
t = tspan[1] | ||
tf = prob.tspan[2] | ||
ts = tspan[1]:dt:tspan[2] | ||
|
||
if save_everystep && save_start | ||
us = Vector{typeof(u0)}(undef, length(ts)) | ||
us[1] = u0 | ||
elseif save_everystep | ||
us = Vector{typeof(u0)}(undef, length(ts) - 1) | ||
elseif save_start | ||
us = Vector{typeof(u0)}(undef, 2) | ||
us[1] = u0 | ||
else | ||
us = Vector{typeof(u0)}(undef, 1) # for interface compatibility | ||
end | ||
|
||
u = u0 | ||
|
||
for i in 2:length(ts) | ||
uprev = u | ||
t = ts[i] | ||
k = f(u, p, t) | ||
u = uprev + dt * k | ||
save_everystep && (us[i] = u) | ||
end | ||
|
||
!save_everystep && save_end && (us[end] = u) | ||
|
||
sol = DiffEqBase.build_solution(prob, alg, ts, us, | ||
k = nothing, stats = nothing, | ||
calculate_error = false) | ||
DiffEqBase.has_analytic(prob.f) && | ||
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, | ||
dense_errors = false) | ||
sol | ||
end | ||
|
||
# In-place | ||
# Good for mutable objects like arrays | ||
# Use DiffEqBase.@.. for simd ivdep | ||
@muladd function DiffEqBase.solve(prob::ODEProblem{uType, tType, true}, | ||
alg::LoopEuler; | ||
dt = error("dt is required for this algorithm"), | ||
save_everystep = true, | ||
save_start = true, | ||
adaptive = false, | ||
dense = false, | ||
save_end = true, | ||
kwargs...) where {uType, tType} | ||
@assert !adaptive | ||
@assert !dense | ||
u0 = prob.u0 | ||
tspan = prob.tspan | ||
f = prob.f | ||
p = prob.p | ||
t = tspan[1] | ||
tf = prob.tspan[2] | ||
ts = tspan[1]:dt:tspan[2] | ||
|
||
if save_everystep && save_start | ||
us = Vector{typeof(u0)}(undef, length(ts)) | ||
us[1] = u0 | ||
elseif save_everystep | ||
us = Vector{typeof(u0)}(undef, length(ts) - 1) | ||
elseif save_start | ||
us = Vector{typeof(u0)}(undef, 2) | ||
us[1] = u0 | ||
else | ||
us = Vector{typeof(u0)}(undef, 1) # for interface compatibility | ||
end | ||
|
||
u = copy(u0) | ||
k = zero(u0) | ||
|
||
for i in 2:length(ts) | ||
t = ts[i] | ||
f(k, u, p, t) | ||
DiffEqBase.@.. u = u + dt * k | ||
save_everystep && (us[i] = copy(u)) | ||
end | ||
|
||
!save_everystep && save_end && (us[end] = u) | ||
|
||
sol = DiffEqBase.build_solution(prob, alg, ts, us, | ||
k = nothing, stats = nothing, | ||
calculate_error = false) | ||
DiffEqBase.has_analytic(prob.f) && | ||
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, | ||
dense_errors = false) | ||
sol | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
using SimpleDiffEq, Test | ||
using SimpleDiffEq, SafeTestsets, Test | ||
|
||
@time begin | ||
@time @testset "Discrete Tests" begin include("discrete_tests.jl") end | ||
@time @testset "SimpleEM Tests" begin include("simpleem_tests.jl") end | ||
@time @testset "SimpleTsit5 Tests" begin include("simpletsit5_tests.jl") end | ||
@time @testset "SimpleATsit5 Tests" begin include("simpleatsit5_tests.jl") end | ||
@time @testset "GPUSimpleATsit5 Tests" begin include("gpusimpleatsit5_tests.jl") end | ||
@time @testset "SimpleRK4 Tests" begin include("simplerk4_tests.jl") end | ||
@time @testset "GPU Compatible ODE Tests" begin include("gpu_ode_regression.jl") end | ||
@time @safetestset "Discrete Tests" include("discrete_tests.jl") | ||
@time @safetestset "SimpleEM Tests" include("simpleem_tests.jl") | ||
@time @safetestset "SimpleTsit5 Tests" include("simpletsit5_tests.jl") | ||
@time @safetestset "SimpleATsit5 Tests" include("simpleatsit5_tests.jl") | ||
@time @safetestset "GPUSimpleATsit5 Tests" include("gpusimpleatsit5_tests.jl") | ||
@time @safetestset "SimpleRK4 Tests" include("simplerk4_tests.jl") | ||
@time @safetestset "SimpleEuler Tests" include("simpleeuler_tests.jl") | ||
@time @safetestset "GPU Compatible ODE Tests" include("gpu_ode_regression.jl") | ||
end |
Oops, something went wrong.