Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DifferentialEquations.jl as ODE solver #92

Closed
bastikr opened this issue Mar 24, 2017 · 12 comments
Closed

Use DifferentialEquations.jl as ODE solver #92

bastikr opened this issue Mar 24, 2017 · 12 comments

Comments

@bastikr
Copy link
Member

bastikr commented Mar 24, 2017

No description provided.

@ChrisRackauckas
Copy link
Contributor

Hey, what problem are you looking to benchmark/profile? I would be willing to help out, and am looking to profile v0.5 --> v0.6 changes anyways. I know there are a few regressions due to internally changing to broadcast and @. removing muladds, but OrdinaryDiffEq.jl should still be in very good shape (and all of the issues I've found have been traced to Base compiler optimizations suddenly missing... not much we can do there but wait...)

Instead of taking on the full dependency, you can use the minimal dependency versions instead. If you make it so that way the user has to pass the algorithm to solve, you can directly depend on DiffEqBase only, which is a really low dep library (in fact, you already depend on it through ODE). If you want to set the algorithm, for example to DP5() or Tsit5(), you will need to depend on OrdinaryDiffEq.jl.

Let me know what you need and I'd be willing to profile and PR.

@bastikr
Copy link
Member Author

bastikr commented Jun 6, 2017

Hey Chris, thank you for offering your help! I saw that you already fixed a few things in the profiling directory, but I used this code only while I developed the corresponding functionality and most of it is terrible out of date. All the real benchmarking is currently done in https://github.com/qojulia/QuantumOptics.jl-benchmarks, and the results are displayed on our website (which is not completely up to date since we are doing some restructuring). Most functionality is pretty well covered, only the (stochastic) mcwf solver is missing.

The dependence on the ODE package is also a little bit out of date and was mostly used to test that our own ode solver worked correctly.

It's great to hear that there are versions that don't have as many dependencies as the complete DifferentialEquations.jl package. Transitioning the non-stochastic methods should be pretty much straight forward since the actual ode solver is only called at one single point in the code. The monte carlo solver will probably be a little bit more work but also shouldn't be too hard. My plan for the near future is to release version 0.4 in the next two weeks and immediately afterwards I will start working on this. As soon as I have first benchmark results I will share them here. Of course if anyone else is motivated to work on this please just go ahead - help is always welcome!

@ChrisRackauckas
Copy link
Contributor

the actual ode solver is only called at one single point in the code

Can you point me to that spot?

@bastikr
Copy link
Member Author

bastikr commented Jun 6, 2017

@ChrisRackauckas
Copy link
Contributor

Using test_dopri.jl, I did:

using Base.Test
using QuantumOptics

ode_dopri = QuantumOptics.ode_dopri
ode_event = ode_dopri.ode_event
ode = ode_dopri.ode

ζ = 0.5
ω₀ = 10.0

y₀ = Float64[0., sqrt(1-ζ^2)*ω₀]
A = 1
ϕ = 0


function f(t::Float64)
    α = sqrt(1-ζ^2)*ω₀
    x = A*exp(-ζ*ω₀*t)*sin*t + ϕ)
    p = A*exp(-ζ*ω₀*t)*(-ζ*ω₀*sin*t + ϕ) + α*cos*t + ϕ))
    return [x,p]
end

function df(t::Float64, y::Vector{Float64}, dy::Vector{Float64})
    dy[1] = y[2]
    dy[2] = -2*ζ*ω₀*y[2] - ω₀^2*y[1]
    return nothing
end

T = [0.,10.]
@benchmark tout, yout = ode(df, T, y₀; display_intermediatesteps=true)
@test length(tout)>2
maxstep = maximum(tout[2:end]-tout[1:end-1])

using BenchmarkTools
prob = ODEProblem(df,y₀,(T[1],T[2]))
@benchmark solve(prob,Tsit5(),dense=false)

and got for ode_dopri:

BenchmarkTools.Trial: 
  memory estimate:  191.59 KiB
  allocs estimate:  8266
  --------------
  minimum time:     196.164 μs (0.00% GC)
  median time:      210.089 μs (0.00% GC)
  mean time:        236.384 μs (9.92% GC)
  maximum time:     3.226 ms (92.46% GC)
  --------------
  samples:          10000
  evals/sample:     1

and for Tsit5():

BenchmarkTools.Trial: 
  memory estimate:  78.37 KiB
  allocs estimate:  3694
  --------------
  minimum time:     148.435 μs (0.00% GC)
  median time:      182.102 μs (0.00% GC)
  mean time:        203.711 μs (8.00% GC)
  maximum time:     7.028 ms (93.66% GC)
  --------------
  samples:          10000
  evals/sample:     1

then using @benchmark solve(prob,DP5(),dense=false), I get:

BenchmarkTools.Trial: 
  memory estimate:  74.58 KiB
  allocs estimate:  3688
  --------------
  minimum time:     127.853 μs (0.00% GC)
  median time:      141.000 μs (0.00% GC)
  mean time:        154.853 μs (8.48% GC)
  maximum time:     4.371 ms (94.67% GC)
  --------------
  samples:          10000
  evals/sample:     1

so it looks like DP5() does well here.

That's on v0.5.2. However, it looks like DiffEq has a very disgusting regression on v0.6:

@benchmark solve(prob,Tsit5(),dense=false)
BenchmarkTools.Trial: 
  memory estimate:  854.32 KiB
  allocs estimate:  37975
  --------------
  minimum time:     3.283 ms (0.00% GC)
  median time:      6.393 ms (0.00% GC)
  mean time:        6.645 ms (2.45% GC)
  maximum time:     57.946 ms (0.00% GC)
  --------------
  samples:          752
  evals/sample:     1

I'm going to take a look at that right now.

@ChrisRackauckas
Copy link
Contributor

Yes, the regression on Julia v0.6 is due to a problem in Base with broadcasting:

JuliaLang/julia#22255 (comment)

Avoiding broadcasting, Tsit5() gives on v0.6:

BenchmarkTools.Trial:
  memory estimate:  87.43 KiB
  allocs estimate:  3757
  --------------
  minimum time:     281.032 μs (0.00% GC)
  median time:      526.934 μs (0.00% GC)
  mean time:        475.180 μs (2.88% GC)
  maximum time:     5.601 ms (87.13% GC)
  --------------
  samples:          10000
  evals/sample:     1

while ode_dopri on the same computer on v0.6 gives:

BenchmarkTools.Trial:
  memory estimate:  136.09 KiB
  allocs estimate:  7742
  --------------
  minimum time:     411.302 μs (0.00% GC)
  median time:      787.767 μs (0.00% GC)
  mean time:        862.358 μs (2.11% GC)
  maximum time:     28.656 ms (0.00% GC)
  --------------
  samples:          5736
  evals/sample:     1

So yes, I would expect that, disregarding Base bugs which will be worked around, you'll get a small <2x speed boost of by changing.

@bastikr
Copy link
Member Author

bastikr commented Jun 7, 2017

I started the transition to DifferentialEquations.jl in the branch https://github.com/bastikr/QuantumOptics.jl/tree/diffeq. The first benchmarks on more realistic examples already look very promising:
benchmark_diffeq
(Green line DifferentialEquations.jl vs red line ode_dopri). The other examples yield similar results. However, they don't measure completely the same thing yet. At the moment I want to keep the same interface as before, which means I want to save output only at points in time specified in tspan. Also the output should not be stored in the solution object but should only call the _fout function. Is this somehow possible using the callback functionality?

@ChrisRackauckas
Copy link
Contributor

At the moment I want to keep the same interface as before, which means I want to save output only at points in time specified in tspan.

For explicit saving, pass the values using saveat: sol = solve(prob,alg;saveat=ts). tspan is just start and end.

Is this somehow possible using the callback functionality?

Yes, just make the callback call _fout.

@bastikr
Copy link
Member Author

bastikr commented Jun 7, 2017

Is it somehow possible to call the function only at times that are given to saveat? The DiscreteCallback is called at every single ode step and I guess using ContinuousCallback would be less efficient.

@ChrisRackauckas
Copy link
Contributor

Is it somehow possible to call the function only at times that are given to saveat? The DiscreteCallback is called at every single ode step and I guess using ContinuousCallback would be less efficient.

Yes, a DiscreteCallback is the tool for the job. You can take over the entire saving procedure. If you do condition(t,u,integrator) = true to make it apply every time, save_values = (false,false) to turn off the default saving, you take full control of the saving and call it directly by doing:

function affect!(integrator)
  savevalues!(integrator)
end

With that, you're essentially doing exactly what OrdinaryDiffEq.jl is doing internally. In OrdinaryDiffEq.jl, this is implemented here:

https://github.com/JuliaDiffEq/OrdinaryDiffEq.jl/blob/master/src/integrators/integrator_utils.jl#L71

There are many many complications there that probably don't apply to your use-case, but the first part is how saveat is done:

https://github.com/JuliaDiffEq/OrdinaryDiffEq.jl/blob/master/src/integrators/integrator_utils.jl#L72

  while !isempty(integrator.opts.saveat) && integrator.tdir*top(integrator.opts.saveat) <= integrator.tdir*integrator.t # Perform saveat
    integrator.saveiter += 1
    curt = pop!(integrator.opts.saveat)
    if curt!=integrator.t # If <t, interpolate
      ode_addsteps!(integrator)
      Θ = (curt - integrator.tprev)/integrator.dt
      val = ode_interpolant(Θ,integrator,integrator.opts.save_idxs,Val{0}) # out of place, but no force copy later
      copyat_or_push!(integrator.sol.t,integrator.saveiter,curt)
      save_val = val
      copyat_or_push!(integrator.sol.u,integrator.saveiter,save_val,Val{false})
      if typeof(integrator.alg) <: OrdinaryDiffEqCompositeAlgorithm
        copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
      end
    else # ==t, just save
      copyat_or_push!(integrator.sol.t,integrator.saveiter,integrator.t)
      if integrator.opts.save_idxs ==nothing
        copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u)
      else
        copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u[integrator.opts.save_idxs],Val{false})
      end
      if typeof(integrator.alg) <: Discrete || integrator.opts.dense
        integrator.saveiter_dense +=1
        copyat_or_push!(integrator.notsaveat_idxs,integrator.saveiter_dense,integrator.saveiter)
        if integrator.opts.dense
          if integrator.opts.save_idxs ==nothing
            copyat_or_push!(integrator.sol.k,integrator.saveiter_dense,integrator.k)
          else
            copyat_or_push!(integrator.sol.k,integrator.saveiter_dense,[k[integrator.opts.save_idxs] for k in integrator.k],Val{false})
          end
        end
      end
      if typeof(integrator.alg) <: OrdinaryDiffEqCompositeAlgorithm
        copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
      end
    end
  end

You can see that like what's done internally in Sundials, saveat points are skipped over and back-interpolated to (if you instead want to force hitting the timepoints, say because they are discontinuities, use tstops=ts instead). Instead of directly pushing the interpolated value:

      copyat_or_push!(integrator.sol.u,integrator.saveiter,save_val,Val{false})
      if typeof(integrator.alg) <: OrdinaryDiffEqCompositeAlgorithm
        copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
      end

a function call can be put in there to modify save_val before it's saved.

What's the interface for _fout? I can just do a general implementation for the callback in DiffEqCallbacks.jl which takes in an _fout and have it match what you're looking for. I actually have an open issue for that:

SciML/DiffEqCallbacks.jl#4

@ChrisRackauckas
Copy link
Contributor

ChrisRackauckas commented Jan 9, 2018

We have a callback for this now in the callback library.

http://docs.juliadiffeq.org/latest/features/callback_library.html#SavingCallback-1

If you add that to the solve call it will save the output of your save_func which above you call _fout. So let's revive this!

@david-pl
Copy link
Member

Done with #191

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants