From ff02a0af6495dcc14e0cab1638b02c87f0eb45ab Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 21 Aug 2020 02:42:50 -0400 Subject: [PATCH] sprinkle some muladd --- src/rk4/gpurk4.jl | 2 +- src/rk4/looprk4.jl | 4 ++-- src/rk4/rk4.jl | 6 +++--- src/tsit5/atsit5.jl | 12 ++++++------ src/tsit5/gpuatsit5.jl | 5 ++--- src/tsit5/tsit5.jl | 6 +++--- test/gpusimpleatsit5_tests.jl | 6 +++--- 7 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/rk4/gpurk4.jl b/src/rk4/gpurk4.jl index 83d8fac..6e88ffb 100644 --- a/src/rk4/gpurk4.jl +++ b/src/rk4/gpurk4.jl @@ -6,7 +6,7 @@ struct GPUSimpleRK4 <: DiffEqBase.AbstractODEAlgorithm end export GPUSimpleRK4 -function DiffEqBase.solve(prob::ODEProblem, +@muladd function DiffEqBase.solve(prob::ODEProblem, alg::GPUSimpleRK4; dt = error("dt is required for this algorithm")) @assert !isinplace(prob) diff --git a/src/rk4/looprk4.jl b/src/rk4/looprk4.jl index e8bf012..e165381 100644 --- a/src/rk4/looprk4.jl +++ b/src/rk4/looprk4.jl @@ -7,7 +7,7 @@ export LoopRK4 # Out-of-place # No caching, good for static arrays, bad for arrays -function DiffEqBase.__solve(prob::ODEProblem{uType,tType,false}, +@muladd function DiffEqBase.__solve(prob::ODEProblem{uType,tType,false}, alg::LoopRK4; dt = error("dt is required for this algorithm"), save_everystep = true, @@ -67,7 +67,7 @@ end # In-place # Good for mutable objects like arrays # Use DiffEqBase.@.. for simd ivdep -function DiffEqBase.solve(prob::ODEProblem{uType,tType,true}, +@muladd function DiffEqBase.solve(prob::ODEProblem{uType,tType,true}, alg::LoopRK4; dt = error("dt is required for this algorithm"), save_everystep = true, diff --git a/src/rk4/rk4.jl b/src/rk4/rk4.jl index 2bb35c1..142b2d7 100644 --- a/src/rk4/rk4.jl +++ b/src/rk4/rk4.jl @@ -97,7 +97,7 @@ end # Stepping ################################################################################ -@inline function DiffEqBase.step!(integ::SRK4{true, S, T}) where {T, S} +@inline @muladd function DiffEqBase.step!(integ::SRK4{true, S, T}) where {T, S} integ.uprev .= integ.u tmp = integ.tmp f! = integ.f @@ -149,7 +149,7 @@ end return nothing end -@inline function DiffEqBase.step!(integ::SRK4{false, S, T}) where {T, S} +@inline @muladd function DiffEqBase.step!(integ::SRK4{false, S, T}) where {T, S} integ.uprev = integ.u f = integ.f p = integ.p @@ -198,7 +198,7 @@ end # Interpolation ################################################################################ -function (integ::SRK4)(t::T) where T +@inline @muladd function (integ::SRK4)(t::T) where T t₁, t₀, dt = integ.t, integ.tprev, integ.dt y₀ = integ.uprev diff --git a/src/tsit5/atsit5.jl b/src/tsit5/atsit5.jl index 8573e79..4e3989c 100644 --- a/src/tsit5/atsit5.jl +++ b/src/tsit5/atsit5.jl @@ -131,7 +131,7 @@ end ####################################################################################### # IIP version for vectors and matrices ####################################################################################### -@inline function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S, T} +@inline @muladd function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S, T} L = length(integ.u) @@ -230,7 +230,7 @@ end ####################################################################################### # OOP version for vectors and matrices ####################################################################################### -@inline function DiffEqBase.step!(integ::SAT5I{false, S, T}) where {S, T} +@inline @muladd function DiffEqBase.step!(integ::SAT5I{false, S, T}) where {S, T} c1, c2, c3, c4, c5, c6 = integ.cs; dt = integ.dtnew; t = integ.t; p = integ.p; tf = integ.tf @@ -317,7 +317,7 @@ end # Vector of Vector (always in-place) stepping ####################################################################################### # Vector{Vector} -@inline function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S<:Vector{<:Array}, T} +@inline @muladd function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S<:Vector{<:Array}, T} M = length(integ.u) # number of states L = length(integ.u[1]) @@ -434,7 +434,7 @@ end end # Vector{SVector} -@inline function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S<:Vector{<:SVector}, T} +@inline @muladd function DiffEqBase.step!(integ::SAT5I{true, S, T}) where {S<:Vector{<:SVector}, T} M = length(integ.u) L = length(integ.u[1]) @@ -544,7 +544,7 @@ end # Interpolation ####################################################################################### # Interpolation function, both OOP and IIP -@inline function (integ::SAT5I{IIP, S, T})(t::Real) where {IIP, S<:AbstractArray{<:Number}, T} +@inline @muladd function (integ::SAT5I{IIP, S, T})(t::Real) where {IIP, S<:AbstractArray{<:Number}, T} tnext, tprev, dt = integ.t, integ.tprev, integ.dt θ = (t - tprev)/dt @@ -566,7 +566,7 @@ end end # Interpolation function, IIP only -@inline function (integ::SAT5I{true, S, T})(u,t::Real) where {S<:AbstractArray, T} +@inline @muladd function (integ::SAT5I{true, S, T})(u,t::Real) where {S<:AbstractArray, T} tnext, tprev, dt = integ.t, integ.tprev, integ.dt θ = (t - tprev)/dt diff --git a/src/tsit5/gpuatsit5.jl b/src/tsit5/gpuatsit5.jl index cc66737..27c0a91 100644 --- a/src/tsit5/gpuatsit5.jl +++ b/src/tsit5/gpuatsit5.jl @@ -6,7 +6,7 @@ struct GPUSimpleTsit5 <: DiffEqBase.AbstractODEAlgorithm end export GPUSimpleTsit5 -function DiffEqBase.solve(prob::ODEProblem, +@muladd function DiffEqBase.solve(prob::ODEProblem, alg::GPUSimpleTsit5; dt = 0.1f0) @assert !isinplace(prob) @@ -61,7 +61,7 @@ end struct GPUSimpleATsit5 end export GPUSimpleATsit5 -function DiffEqBase.solve(prob::ODEProblem, +@muladd function DiffEqBase.solve(prob::ODEProblem, alg::GPUSimpleATsit5; dt = 0.1f0,saveat = nothing, save_everystep = true, @@ -140,7 +140,6 @@ function DiffEqBase.solve(prob::ODEProblem, @fastmath q = max(inv(qmax),min(inv(qmin),q/gamma)) qold = max(EEst,qoldinit) dtold = dt - dt = dt/q #dtnew dt = min(abs(dt),abs(tf-t-dtold)) told = t diff --git a/src/tsit5/tsit5.jl b/src/tsit5/tsit5.jl index 9f147d9..55c06f8 100644 --- a/src/tsit5/tsit5.jl +++ b/src/tsit5/tsit5.jl @@ -126,7 +126,7 @@ end # Stepping ####################################################################################### # IIP version for vectors and matrices -@inline function DiffEqBase.step!(integ::ST5I{true, S, T}) where {T, S} +@inline @muladd function DiffEqBase.step!(integ::ST5I{true, S, T}) where {T, S} L = length(integ.u) @@ -181,7 +181,7 @@ end end # OOP version for vectors and matrices -@inline function DiffEqBase.step!(integ::ST5I{false, S, T}) where {T, S} +@inline @muladd function DiffEqBase.step!(integ::ST5I{false, S, T}) where {T, S} c1, c2, c3, c4, c5, c6 = integ.cs; dt = integ.dt; t = integ.t; p = integ.p @@ -228,7 +228,7 @@ end # Interpolation ####################################################################################### # Interpolation function, OOP -function (integ::ST5I)(t::T) where {T} +@muladd function (integ::ST5I)(t::T) where {T} tnext, tprev, dt = integ.t, integ.tprev, integ.dt #@assert tprev ≤ t ≤ tnext θ = (t - tprev)/dt diff --git a/test/gpusimpleatsit5_tests.jl b/test/gpusimpleatsit5_tests.jl index 22085de..0fbe684 100644 --- a/test/gpusimpleatsit5_tests.jl +++ b/test/gpusimpleatsit5_tests.jl @@ -24,15 +24,15 @@ odeoop = ODEProblem{false}(loop, SVector{3}(u0), (0.0, 100.0), [10, 28, 8/3]) sol = solve(odeoop,SimpleATsit5() ,dt=dt) sol2 = solve(odeoop,GPUSimpleATsit5(),dt=dt,abstol=1e-6,reltol=1e-3) -@test sol.u == sol2.u -@test sol.t == sol2.t +@test sol.u[5] == sol2.u[5] +@test sol.t[5] == sol2.t[5] sol = solve(odeoop,Tsit5() ,dt=dt,saveat=0.0:0.1:100.0) sol2 = solve(odeoop,GPUSimpleATsit5(),dt=dt,saveat=0.0:0.1:100.0,abstol=1e-6,reltol=1e-3) sol3 = solve(odeoop,SimpleATsit5() ,dt=dt,saveat=0.0:0.1:100.0) @test sol[20] ≈ sol2[20] atol=1e-5 -@test sol2.u == sol3.u +@test sol2.u[20] ≈ sol3.u[20] @test sol.t == sol2.t dt = 1e-1