Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

progress bars for EnsembleProblem #514

Merged
merged 3 commits into from
Nov 9, 2023

Conversation

pepijndevos
Copy link

This is the first part of a series of PRs that adds progress support to EnsemblePoblem.

What this part does is if progress is enabled, initialize progress bars for every trajectory, and give every solve an unique ID.

For very large ensembles the expectation is that the logger shows this in a sensible way, for example, show the overall progress:

sum_logger = let progress = Dict{Symbol, Float64}()
    TransformerLogger(TerminalLogger()) do log
        if log.level == LogLevel(-1) && haskey(log.kwargs, :progress)
            @show log
            pr = log.kwargs[:progress]
            if pr isa Number
                progress[log.id] = pr
            elseif pr == "done"
                progress[log.id] = 1.0
            end
            tot = sum(values(progress))/length(progress)
            if tot>=1.0
                tot="done"
                empty!(progress)
            end
            log = merge(log, (;id=:total, message="Total", kwargs=Dict(:progress=>tot)))
        else
            log
        end
    end
end
global_logger(sum_logger)

@devmotion
Copy link
Member

For very large ensembles the expectation is that the logger shows this in a sensible way, for example, show the overall progress:

Last time I checked none of the common loggers (the default logger in VSCode and TerminalLoggers) did summarize large numbers of progress bars. Everything would be displayed which lead to a terrible user experience IMO.

@pepijndevos
Copy link
Author

Yea that's a fair criticism. I did it like this for a few reasons:

  • The alternative would be some complicated thing with channels
  • For small ensembles the current default is reasonable, and gives more insight than a single progress bar
  • It's fairly easy to aggregate the logs if desired, as the snippet above demonstrates

Maybe it'd be a good idea to contribute this log aggregation somewhere.
We could even have that as the default where in __solve we just do with_logger(sum_logger(current_logger()))

@devmotion
Copy link
Member

I'd like to have better progress bars for ensemble simulations (not only for SciML but e.g. also in Turing where we currently only use a single progress bar for multi-chain sampling which is updated only when a full chain is sampled and hence doesn't either lead to a good user experience) but I think it's crucial to first add better support for such summaries, ideally not only in SciML but more generally for the most common loggers such as the one in VSCode or TerminalLoggers (see e.g. julia-vscode/julia-vscode#3297, and possibly also something like julia-vscode/julia-vscode#3317 and JuliaLogging/ProgressLogging.jl#23).

@oscardssmith
Copy link
Contributor

A possible alternative approach here would be to make the ensembleprob's progress bar be purely based on the number of solutions in the ensemble that have finished.

@devmotion
Copy link
Member

Yeah that's what we do in Turing currently. But the user experience is a bit suboptimal, in particular if you only sample a few chains in parallel - then it can happen that the progress bar is at 0 all the time and jumps to 1 when all chains are sampled at approx the same time.

@codecov
Copy link

codecov bot commented Sep 29, 2023

Codecov Report

Merging #514 (a1370a0) into master (06d5c2c) will decrease coverage by 0.31%.
The diff coverage is 34.24%.

@@            Coverage Diff             @@
##           master     #514      +/-   ##
==========================================
- Coverage   42.19%   41.88%   -0.31%     
==========================================
  Files          53       53              
  Lines        4072     4121      +49     
==========================================
+ Hits         1718     1726       +8     
- Misses       2354     2395      +41     
Files Coverage Δ
src/ensemble/basic_ensemble_solve.jl 51.87% <34.24%> (-15.70%) ⬇️

📣 Codecov offers a browser extension for seamless coverage viewing on GitHub. Try it in Chrome or Firefox today!

@pepijndevos
Copy link
Author

I proposed for Cedar that we'd do the progress of completed simulations but indeed the granularity is quit low if you have a lot of cores or not that many problems. So IMO the correct solution is for the loggers to handle this better, or to aggregate the results internally like I suggested.

@pepijndevos
Copy link
Author

pepijndevos commented Oct 2, 2023

I've added an internal aggregation step that is enabled by default. This makes sure the current UX isn't horrible, and then we can take it from there.

image

@pepijndevos
Copy link
Author

It turns out that if you have many cores and you set progress_steps too low it can cause quite a lot of lock contention, so I made the aggregator use trylock to mitigate that.

@pepijndevos
Copy link
Author

FTR we're still investigating how to reduce overhead from progress logging, so some changes are still expected.

@pepijndevos
Copy link
Author

Alright, performance overhead is now negligible using several strategies:

  • use trylock to avoid contention on non-critical updates
  • use time to limit updates to ten per second to limit downstream logging overhead
  • use math to avoid summing 20k items
  • teach Sundial to respect progress_steps

Profile:
image

I've also added the requested weakdep on https://github.com/SciML/DiffEqBase.jl/releases/tag/v6.132.0

@ChrisRackauckas
Copy link
Member

As long as https://github.com/SciML/DiffEqBase.jl/blob/master/test/downstream/inference.jl passes this should generally be good.

@pepijndevos
Copy link
Author

I think that particular test seems to pass. Which of the other failures out of 70 test runs are my fault is hard for me to tell.

@ChrisRackauckas
Copy link
Member

I rebased. We'll see how tests go.

Project.toml Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member

It looks like a backwards compat dispatch is missing: https://github.com/SciML/SciMLBase.jl/actions/runs/6479781336/job/17731842462?pr=512

@pepijndevos
Copy link
Author

Wait I'm confused what's going on here. I think my confusion cancelled out and maybe I did the right thing? Your comment seems to be about the stats PR, so I pushed an update there.

45efdc5

Is there anything that needs doing for this PR?

Project.toml Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member

Can you show it on something simple like the Lorenz equation with SimpleTsit5?

@pepijndevos
Copy link
Author

pepijndevos commented Oct 30, 2023

I'm looking at the docs and all I can find is how to define a rrule? https://fluxml.ai/Zygote.jl/latest/limitations/#Solutions-1
Feels like this should go in ChainRules then?
Or do you mean something else by

just set ignore derivatives needs to be set on the logging calls so it ignores everything in the function.

Ah! https://juliadiff.org/ChainRulesCore.jl/stable/api.html#ChainRulesCore.@ignore_derivatives

@pepijndevos
Copy link
Author

@pepijndevos
Copy link
Author

pepijndevos commented Nov 2, 2023

Keeping track of all the PRs:

Well that's not what I was expecting.. @ChrisRackauckas what did you have in mind here? I think a simple solution would be to only pass down progress_id if progress=true so problems without progress support don't have to care about the kwargs

@pepijndevos
Copy link
Author

FTR I'm getting weird precompile errors about the adjoint so maybe that's still not quite right?

┌ DiffEqBase → DiffEqBaseZygoteExt
│  WARNING: Method definition adjoint(ZygoteRules.AContext, typeof(ZygoteRules.literal_getproperty), SciMLBase.EnsembleSolution{T, N, S} where S where N where T, Base.Val{:u}) in module DiffEqBase overwritten in module SciMLBaseZygoteExt.
│  ┌ Error: Error during loading of extension SciMLBaseZygoteExt of SciMLBase, use `Base.retry_load_extensions()` to retry.
│  │   exception =
│  │    1-element ExceptionStack:
│  │    Method overwriting is not permitted during Module precompile.

@pepijndevos
Copy link
Author

I went back to the alternative logging workaround and did some more import fixes that were hidden by my repl state I guess. This passes the downstream AD tests for me, but I think it requires a new ChainRules release before it'll work on CI.

@oxinabox
Copy link

oxinabox commented Nov 7, 2023

ok new chainrules release is out, so restarting CI and it should pass

@oscardssmith
Copy link
Contributor

this pr needs to bump the chainrules version requirement

@pepijndevos
Copy link
Author

Last time Chris asked for a version requirement I added it and then he removed it again so idk

@ChrisRackauckas
Copy link
Member

I don't think this PR ever had a version requirement on ChainRules/ChainRulesCore?

@pepijndevos
Copy link
Author

No I added one for DiffEqBase because I thought you wanted that to guarantee progress_id is supported but then you removed it again so I'm fine adding a version constraint if it's clear that that's actually what you want

@ChrisRackauckas
Copy link
Member

I don't see why that's related. @oscardssmith is saying you need to version bound on ChainRules/Core since you're implicitly relying on @oxinabox 's latest release for the Zygote fix on logging.

@pepijndevos
Copy link
Author

Done. Just wanted to make sure.

I'm still getting that weird DiffEqBaseZygoteExt precompile error btw

@pepijndevos
Copy link
Author

Here are your performance numbers btw

using OrdinaryDiffEq, BenchmarkTools
function lorenz(u, p, t)
    σ = p[1]
    ρ = p[2]
    β = p[3]
    du1 = σ * (u[2] - u[1])
    du2 = u[1] *- u[3]) - u[2]
    du3 = u[1] * u[2] - β * u[3]
    return [du1, du2, du3]
end

u0 = [1.0f0; 0.0f0; 0.0f0]
tspan = (0.0f0, 10.0f0)
p = [10.0f0, 28.0f0, 8 / 3.0f0]
prob = ODEProblem{false}(lorenz, u0, tspan, p)

prob_func = (prob, i, repeat) -> remake(prob, p = rand(Float32, 3) .* p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
@benchmark sol = solve(monteprob, Tsit5(), EnsembleThreads(), trajectories = 10_000, saveat = 1.0f0, progress=false)
@benchmark sol = solve(monteprob, Tsit5(), EnsembleThreads(), trajectories = 10_000, saveat = 1.0f0, progress=true)
BenchmarkTools.Trial: 12 samples with 1 evaluation.
 Range (min … max):  186.853 ms … 720.231 ms  ┊ GC (min … max):  0.00% … 73.95%
 Time  (median):     500.877 ms               ┊ GC (median):    62.15%
 Time  (mean ± σ):   446.853 ms ± 206.921 ms  ┊ GC (mean ± σ):  57.52% ± 33.01%

  ██                       ▁         █ ▁    ▁             ▁ ▁ ▁  
  ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁█▁█▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁█▁█▁█ ▁
  187 ms           Histogram: frequency by time          720 ms <

 Memory estimate: 3.40 GiB, allocs estimate: 45697823.

BenchmarkTools.Trial: 10 samples with 1 evaluation.
 Range (min … max):  265.698 ms … 743.625 ms  ┊ GC (min … max):  0.00% … 64.56%
 Time  (median):     558.434 ms               ┊ GC (median):    52.23%
 Time  (mean ± σ):   516.208 ms ± 154.807 ms  ┊ GC (mean ± σ):  49.51% ± 23.63%

  █                    ▁         ▁    ▁▁▁   ▁      ▁          ▁  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁█▁▁▁▁███▁▁▁█▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁█ ▁
  266 ms           Histogram: frequency by time          744 ms <

 Memory estimate: 3.55 GiB, allocs estimate: 47195479.

Logging output of 10k simulation runs after all these optimizations is just a hand full of lines

┌ LogLevel(-1): Total
│   progress = 0.0
└ @ SciMLBase ~/code/SciMLBase.jl/src/ensemble/basic_ensemble_solve.jl:134
┌ LogLevel(-1): Total
│   message = "dt=0.07805729\nt=10.0\nmax u=25.905659"
│   progress = 0.211799999999993
└ @ OrdinaryDiffEq ~/code/OrdinaryDiffEq.jl/src/integrators/integrator_utils.jl:153
┌ LogLevel(-1): Total
│   progress = 0.756299999999933
└ @ OrdinaryDiffEq ~/code/OrdinaryDiffEq.jl/src/solve.jl:103
┌ LogLevel(-1): Total
│   message = "dt=0.046221733\nt=10.0\nmax u=23.776852"
│   progress = "done"
└ @ OrdinaryDiffEq ~/code/OrdinaryDiffEq.jl/src/integrators/integrator_utils.jl:153

@pepijndevos
Copy link
Author

For reference, on master

BenchmarkTools.Trial: 12 samples with 1 evaluation.
 Range (min … max):  189.615 ms … 701.238 ms  ┊ GC (min … max):  0.00% … 72.88%
 Time  (median):     493.090 ms               ┊ GC (median):    61.75%
 Time  (mean ± σ):   457.536 ms ± 160.567 ms  ┊ GC (mean ± σ):  58.25% ± 25.18%

  ██           █        █           ███  ██  █          █     █  
  ██▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁███▁▁██▁▁█▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁█ ▁
  190 ms           Histogram: frequency by time          701 ms <

 Memory estimate: 3.40 GiB, allocs estimate: 45635653.

BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took 6.611 s (3.17% GC) to evaluate,
 with a memory estimate of 3.77 GiB, over 51008757 allocations.

thousands of log messages omitted for brevity ;)

Pepijn de Vos added 2 commits November 8, 2023 10:01
ODE->Ensemble

optionally aggregate progress bars

handle integer loglevel (sundials)

avoid lock contention

only log significant progress

improve progress performance

add version constraint

Update Project.toml

Update Project.toml

ignore derivatives of logging

more AD fixing attempts

only pass progress_id if needed

use ignore_deriviative and fix rrule for with_logger

delete rules moved to ChainRules

remove using

import as opposed to using

more import fixes

more missing Logging qualifiers
@ChrisRackauckas
Copy link
Member

This most likely got lost in a rebase
@pepijndevos
Copy link
Author

pepijndevos commented Nov 8, 2023

Ahhh I thought that must have been some CI glitch from the stats PR, but turns out that in one of the many rebase conflicts of this PR over its long lifetime the stats got lost.

@pepijndevos
Copy link
Author

Seems fine now? Could it be. Could it really be done?

@ChrisRackauckas
Copy link
Member

Before

#=
BenchmarkTools.Trial: 198 samples with 1 evaluation.
Range (min … max): 22.291 ms … 43.583 ms ┊ GC (min … max): 0.00% … 43.89%
Time (median): 23.324 ms ┊ GC (median): 0.00%
Time (mean ± σ): 25.246 ms ± 5.548 ms ┊ GC (mean ± σ): 6.91% ± 12.90%

▆█▇▇▅▃▂
███████▆▁▆▅▅▅▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▅▁▁▁▁▁▁▁▁▅▁▁▅▆▆▆▆▁▇█ ▅
22.3 ms Histogram: log(frequency) by time 42.7 ms <

Memory estimate: 24.42 MiB, allocs estimate: 320057.

BenchmarkTools.Trial: 89 samples with 1 evaluation.
Range (min … max): 41.775 ms … 83.893 ms ┊ GC (min … max): 0.00% … 34.04%
Time (median): 44.861 ms ┊ GC (median): 0.00%
Time (mean ± σ): 56.267 ms ± 15.007 ms ┊ GC (mean ± σ): 22.08% ± 19.77%

█▄
████▇▄▃▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▄▅▃▄▅▆▇▆▄▄▅▁▁▁▁▃▃ ▁
41.8 ms Histogram: frequency by time 80.1 ms <

Memory estimate: 152.59 MiB, allocs estimate: 1190057.
=#

After

#=
BenchmarkTools.Trial: 221 samples with 1 evaluation.
Range (min … max): 19.745 ms … 51.209 ms ┊ GC (min … max): 0.00% … 56.97%
Time (median): 20.502 ms ┊ GC (median): 0.00%
Time (mean ± σ): 22.704 ms ± 7.274 ms ┊ GC (mean ± σ): 8.31% ± 14.29%

▇█▅▂▁
█████▄▅▁▁▁▄▁▄▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▄▇▄▅▁▆ ▅
19.7 ms Histogram: log(frequency) by time 50.6 ms <

Memory estimate: 24.42 MiB, allocs estimate: 320076.

BenchmarkTools.Trial: 31 samples with 1 evaluation.
Range (min … max): 141.556 ms … 179.966 ms ┊ GC (min … max): 4.08% … 11.54%
Time (median): 166.417 ms ┊ GC (median): 11.92%
Time (mean ± σ): 166.345 ms ± 8.579 ms ┊ GC (mean ± σ): 12.01% ± 2.93%

                       ▃         █  ▃  ▃                  ▃  

▇▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇█▁▁▁▁▇▇▁▇▇█▁▇█▇▇█▁▇▇▇▇▇▁▁▁▇▇▁▇▁▇▇▁▁█ ▁
142 ms Histogram: frequency by time 180 ms <

Memory estimate: 220.28 MiB, allocs estimate: 1785509.
=#

@ChrisRackauckas ChrisRackauckas merged commit 5797257 into SciML:master Nov 9, 2023
34 of 39 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants