Skip to content

Commit

Permalink
Merge pull request #40 from SciML/inplaceem
Browse files Browse the repository at this point in the history
Inplace SimpleEM
  • Loading branch information
ChrisRackauckas authored Aug 22, 2020
2 parents b3d579d + fed8fc9 commit 0040d65
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
43 changes: 41 additions & 2 deletions src/euler_maruyama.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
struct SimpleEM <: DiffEqBase.AbstractSDEAlgorithm end
export SimpleEM

@muladd function DiffEqBase.solve(prob::SDEProblem,alg::SimpleEM,args...;
dt = error("dt required for SimpleEM"))
@muladd function DiffEqBase.solve(prob::SDEProblem{uType,tType,false},alg::SimpleEM,args...;
dt = error("dt required for SimpleEM")) where {uType,tType}

f = prob.f
g = prob.g
Expand Down Expand Up @@ -30,3 +30,42 @@ export SimpleEM
sol = DiffEqBase.build_solution(prob,alg,t,u,
calculate_error = false)
end

@muladd function DiffEqBase.solve(prob::SDEProblem{uType,tType,true},alg::SimpleEM,args...;
dt = error("dt required for SimpleEM")) where {uType,tType}

f = prob.f
g = prob.g
u0 = prob.u0
tspan = prob.tspan
p = prob.p
ftmp = zero(u0)
gtmp = DiffEqBase.is_diagonal_noise(prob) ? zero(u0) : zero(prob.noise_rate_prototype)
gtmp2 = DiffEqBase.is_diagonal_noise(prob) ? nothing : zero(u0)
dW = DiffEqBase.is_diagonal_noise(prob) ? zero(u0) : false .* prob.noise_rate_prototype[1,:]

@inbounds begin
n = Int((tspan[2] - tspan[1])/dt) + 1
u = [copy(u0) for i in 1:n]
t = [tspan[1] + i*dt for i in 0:n-1]
sqdt = sqrt(dt)
end

@inbounds for i in 2:n
uprev = u[i-1]
tprev = t[i-1]
f(ftmp,uprev,p,tprev)
g(gtmp,uprev,p,tprev)
@. dW = randn(eltype(dW))

if DiffEqBase.is_diagonal_noise(prob)
DiffEqBase.@.. u[i] = uprev + ftmp*dt + sqdt*gtmp*dW
else
mul!(gtmp2,gtmp,dW)
DiffEqBase.@.. u[i] = uprev + ftmp*dt + sqdt*gtmp2
end
end

sol = DiffEqBase.build_solution(prob,alg,t,u,
calculate_error = false)
end
7 changes: 7 additions & 0 deletions test/simpleem_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ prob = SDEProblem(f,g,u0,tspan)
sol = solve(prob,SimpleEM(),dt=0.25)

@test typeof(sol.u) <: Vector{SVector{2,Float64}}

f(du,u,p,t) = du .= 2.0 * u
g(du,u,p,t) = du .= 1
u0 = 0.5ones(4)
tspan = (0.0,1.0)
prob = SDEProblem(f,g,u0,tspan)
sol = solve(prob,SimpleEM(),dt=0.25)

0 comments on commit 0040d65

Please sign in to comment.