diff --git a/src/SimpleDiffEq.jl b/src/SimpleDiffEq.jl index 6299574..be4af2a 100644 --- a/src/SimpleDiffEq.jl +++ b/src/SimpleDiffEq.jl @@ -6,6 +6,7 @@ using Reexport, MuladdMacro @reexport using DiffEqBase using StaticArrays using RecursiveArrayTools +const ^ = DiffEqBase.fastpow @inline _copy(a::SArray) = a @inline _copy(a) = copy(a) diff --git a/src/tsit5/atsit5.jl b/src/tsit5/atsit5.jl index f3b61c5..33e5c57 100644 --- a/src/tsit5/atsit5.jl +++ b/src/tsit5/atsit5.jl @@ -52,10 +52,10 @@ function DiffEqBase.__init(prob::ODEProblem,alg::SimpleATsit5; internalnorm) end -function DiffEqBase.__solve(prob::ODEProblem,alg::SimpleATsit5; +function DiffEqBase.__solve(prob::ODEProblem,alg::SimpleATsit5,args...; dt = 0.1, saveat = nothing, save_everystep = true, abstol = 1e-6, reltol = 1e-3, - internalnorm = DiffEqBase.ODE_DEFAULT_NORM) + internalnorm = DiffEqBase.ODE_DEFAULT_NORM,kwargs...) u0 = prob.u0 tspan = prob.tspan ts = Vector{eltype(dt)}(undef,1) @@ -78,6 +78,7 @@ function DiffEqBase.__solve(prob::ODEProblem,alg::SimpleATsit5; integ = simpleatsit5_init(prob.f,DiffEqBase.isinplace(prob),prob.u0, tspan[1], tspan[2], dt, prob.p, abstol, reltol, internalnorm) # FSAL + cur_t = 1 while integ.t < tspan[2] step!(integ) if saveat === nothing && save_everystep @@ -96,8 +97,8 @@ function DiffEqBase.__solve(prob::ODEProblem,alg::SimpleATsit5; end if saveat === nothing && !save_everystep - push!(us,recursivecopy(u)) - push!(ts,t) + push!(us,recursivecopy(integ.u)) + push!(ts,integ.t) end sol = DiffEqBase.build_solution(prob,alg,ts,us,