Skip to content

Commit

Permalink
Merge pull request #39 from SciML/muladd
Browse files Browse the repository at this point in the history
sprinkle some muladd
  • Loading branch information
ChrisRackauckas authored Aug 21, 2020
2 parents 8d995d7 + ff02a0a commit 6ba283f
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/rk4/gpurk4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
struct GPUSimpleRK4 <: DiffEqBase.AbstractODEAlgorithm end
export GPUSimpleRK4

function DiffEqBase.solve(prob::ODEProblem,
@muladd function DiffEqBase.solve(prob::ODEProblem,
alg::GPUSimpleRK4;
dt = error("dt is required for this algorithm"))
@assert !isinplace(prob)
Expand Down
4 changes: 2 additions & 2 deletions src/rk4/looprk4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export LoopRK4

# Out-of-place
# No caching, good for static arrays, bad for arrays
function DiffEqBase.__solve(prob::ODEProblem{uType,tType,false},
@muladd function DiffEqBase.__solve(prob::ODEProblem{uType,tType,false},
alg::LoopRK4;
dt = error("dt is required for this algorithm"),
save_everystep = true,
Expand Down Expand Up @@ -67,7 +67,7 @@ end
# In-place
# Good for mutable objects like arrays
# Use DiffEqBase.@.. for simd ivdep
function DiffEqBase.solve(prob::ODEProblem{uType,tType,true},
@muladd function DiffEqBase.solve(prob::ODEProblem{uType,tType,true},
alg::LoopRK4;
dt = error("dt is required for this algorithm"),
save_everystep = true,
Expand Down
6 changes: 3 additions & 3 deletions src/rk4/rk4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ end
# Stepping
################################################################################

@inline function DiffEqBase.step!(integ::SRK4{true, S, T}) where {T, S}
@inline @muladd function DiffEqBase.step!(integ::SRK4{true, S, T}) where {T, S}
integ.uprev .= integ.u
tmp = integ.tmp
f! = integ.f
Expand Down Expand Up @@ -149,7 +149,7 @@ end
return nothing
end

@inline function DiffEqBase.step!(integ::SRK4{false, S, T}) where {T, S}
@inline @muladd function DiffEqBase.step!(integ::SRK4{false, S, T}) where {T, S}
integ.uprev = integ.u
f = integ.f
p = integ.p
Expand Down Expand Up @@ -198,7 +198,7 @@ end
# Interpolation
################################################################################

function (integ::SRK4)(t::T) where T
@inline @muladd function (integ::SRK4)(t::T) where T
t₁, t₀, dt = integ.t, integ.tprev, integ.dt

y₀ = integ.uprev
Expand Down
12 changes: 6 additions & 6 deletions src/tsit5/atsit5.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ end
#######################################################################################
# IIP version for vectors and matrices
#######################################################################################
@inline function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S, T}
@inline @muladd function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S, T}

L = length(integ.u)

Expand Down Expand Up @@ -230,7 +230,7 @@ end
#######################################################################################
# OOP version for vectors and matrices
#######################################################################################
@inline function DiffEqBase.step!(integ::SAT5I{false, S, T}) where {S, T}
@inline @muladd function DiffEqBase.step!(integ::SAT5I{false, S, T}) where {S, T}

c1, c2, c3, c4, c5, c6 = integ.cs;
dt = integ.dtnew; t = integ.t; p = integ.p; tf = integ.tf
Expand Down Expand Up @@ -317,7 +317,7 @@ end
# Vector of Vector (always in-place) stepping
#######################################################################################
# Vector{Vector}
@inline function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S<:Vector{<:Array}, T}
@inline @muladd function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S<:Vector{<:Array}, T}

M = length(integ.u) # number of states
L = length(integ.u[1])
Expand Down Expand Up @@ -434,7 +434,7 @@ end
end

# Vector{SVector}
@inline function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S<:Vector{<:SVector}, T}
@inline @muladd function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S<:Vector{<:SVector}, T}

M = length(integ.u)
L = length(integ.u[1])
Expand Down Expand Up @@ -544,7 +544,7 @@ end
# Interpolation
#######################################################################################
# Interpolation function, both OOP and IIP
@inline function (integ::SAT5I{IIP, S, T})(t::Real) where {IIP, S<:AbstractArray{<:Number}, T}
@inline @muladd function (integ::SAT5I{IIP, S, T})(t::Real) where {IIP, S<:AbstractArray{<:Number}, T}
tnext, tprev, dt = integ.t, integ.tprev, integ.dt

θ = (t - tprev)/dt
Expand All @@ -566,7 +566,7 @@ end
end

# Interpolation function, IIP only
@inline function (integ::SAT5I{true, S, T})(u,t::Real) where {S<:AbstractArray, T}
@inline @muladd function (integ::SAT5I{true, S, T})(u,t::Real) where {S<:AbstractArray, T}
tnext, tprev, dt = integ.t, integ.tprev, integ.dt

θ = (t - tprev)/dt
Expand Down
5 changes: 2 additions & 3 deletions src/tsit5/gpuatsit5.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
struct GPUSimpleTsit5 <: DiffEqBase.AbstractODEAlgorithm end
export GPUSimpleTsit5

function DiffEqBase.solve(prob::ODEProblem,
@muladd function DiffEqBase.solve(prob::ODEProblem,
alg::GPUSimpleTsit5;
dt = 0.1f0)
@assert !isinplace(prob)
Expand Down Expand Up @@ -61,7 +61,7 @@ end
struct GPUSimpleATsit5 end
export GPUSimpleATsit5

function DiffEqBase.solve(prob::ODEProblem,
@muladd function DiffEqBase.solve(prob::ODEProblem,
alg::GPUSimpleATsit5;
dt = 0.1f0,saveat = nothing,
save_everystep = true,
Expand Down Expand Up @@ -140,7 +140,6 @@ function DiffEqBase.solve(prob::ODEProblem,
@fastmath q = max(inv(qmax),min(inv(qmin),q/gamma))
qold = max(EEst,qoldinit)
dtold = dt

dt = dt/q #dtnew
dt = min(abs(dt),abs(tf-t-dtold))
told = t
Expand Down
6 changes: 3 additions & 3 deletions src/tsit5/tsit5.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ end
# Stepping
#######################################################################################
# IIP version for vectors and matrices
@inline function DiffEqBase.step!(integ::ST5I{true, S, T}) where {T, S}
@inline @muladd function DiffEqBase.step!(integ::ST5I{true, S, T}) where {T, S}

L = length(integ.u)

Expand Down Expand Up @@ -181,7 +181,7 @@ end
end

# OOP version for vectors and matrices
@inline function DiffEqBase.step!(integ::ST5I{false, S, T}) where {T, S}
@inline @muladd function DiffEqBase.step!(integ::ST5I{false, S, T}) where {T, S}

c1, c2, c3, c4, c5, c6 = integ.cs;
dt = integ.dt; t = integ.t; p = integ.p
Expand Down Expand Up @@ -228,7 +228,7 @@ end
# Interpolation
#######################################################################################
# Interpolation function, OOP
function (integ::ST5I)(t::T) where {T}
@muladd function (integ::ST5I)(t::T) where {T}
tnext, tprev, dt = integ.t, integ.tprev, integ.dt
#@assert tprev ≤ t ≤ tnext
θ = (t - tprev)/dt
Expand Down
6 changes: 3 additions & 3 deletions test/gpusimpleatsit5_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ odeoop = ODEProblem{false}(loop, SVector{3}(u0), (0.0, 100.0), [10, 28, 8/3])
sol = solve(odeoop,SimpleATsit5() ,dt=dt)
sol2 = solve(odeoop,GPUSimpleATsit5(),dt=dt,abstol=1e-6,reltol=1e-3)

@test sol.u == sol2.u
@test sol.t == sol2.t
@test sol.u[5] == sol2.u[5]
@test sol.t[5] == sol2.t[5]

sol = solve(odeoop,Tsit5() ,dt=dt,saveat=0.0:0.1:100.0)
sol2 = solve(odeoop,GPUSimpleATsit5(),dt=dt,saveat=0.0:0.1:100.0,abstol=1e-6,reltol=1e-3)
sol3 = solve(odeoop,SimpleATsit5() ,dt=dt,saveat=0.0:0.1:100.0)

@test sol[20] sol2[20] atol=1e-5
@test sol2.u == sol3.u
@test sol2.u[20] sol3.u[20]
@test sol.t == sol2.t

dt = 1e-1
Expand Down

0 comments on commit 6ba283f

Please sign in to comment.