Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

progress bars for EnsembleProblem #514

Merged
merged 3 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand All @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote"
[compat]
ADTypes = "0.1.3, 0.2"
ArrayInterface = "6, 7"
ChainRules = "1.57.0"
ChainRulesCore = "1.16"
CommonSolve = "0.2.4"
ConstructionBase = "1"
Expand Down
139 changes: 109 additions & 30 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,64 @@
reduce(merge, st)
end

mutable struct AggregateLogger{T<:Logging.AbstractLogger} <: Logging.AbstractLogger
progress::Dict{Symbol, Float64}
done_counter::Int
total::Float64
print_time::Float64
lock::ReentrantLock
logger::T
end
AggregateLogger(logger::Logging.AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger)

function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...)
if convert(Logging.LogLevel, level) == Logging.LogLevel(-1) && haskey(kwargs, :progress)
pr = kwargs[:progress]
if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing)
try
if pr == "done"
pr = 1.0
l.done_counter += 1

Check warning on line 49 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L42-L49

Added lines #L42 - L49 were not covered by tests
end
len = length(l.progress)
if haskey(l.progress, id)
l.total += (pr-l.progress[id])/len

Check warning on line 53 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L51-L53

Added lines #L51 - L53 were not covered by tests
else
l.total = l.total*(len/(len+1)) + pr/(len+1)
len += 1

Check warning on line 56 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L55-L56

Added lines #L55 - L56 were not covered by tests
end
l.progress[id] = pr

Check warning on line 58 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L58

Added line #L58 was not covered by tests
# validation check (slow)
# tot = sum(values(l.progress))/length(l.progress)
# @show tot l.total l.total ≈ tot
curr_time = time()
if l.done_counter >= len
tot="done"
empty!(l.progress)
l.done_counter = 0
l.print_time = 0.0
elseif curr_time-l.print_time > 0.1
tot = l.total
l.print_time = curr_time

Check warning on line 70 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L62-L70

Added lines #L62 - L70 were not covered by tests
else
return

Check warning on line 72 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L72

Added line #L72 was not covered by tests
end
id=:total
message="Total"
kwargs=merge(values(kwargs), (progress=tot,))

Check warning on line 76 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L74-L76

Added lines #L74 - L76 were not covered by tests
finally
unlock(l.lock)

Check warning on line 78 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L78

Added line #L78 was not covered by tests
end
else
return

Check warning on line 81 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L81

Added line #L81 was not covered by tests
end
end
Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...)

Check warning on line 84 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L84

Added line #L84 was not covered by tests
end
Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...)

Check warning on line 86 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L86

Added line #L86 was not covered by tests
Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger)
Logging.catch_exceptions(l::AggregateLogger) = Logging.catch_exceptions(l.logger)

Check warning on line 88 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L88

Added line #L88 was not covered by tests

function __solve(prob::AbstractEnsembleProblem,
alg::Union{AbstractDEAlgorithm, Nothing};
kwargs...)
Expand Down Expand Up @@ -59,51 +117,72 @@
function __solve(prob::AbstractEnsembleProblem,
alg::A,
ensemblealg::BasicEnsembleAlgorithm;
trajectories, batch_size = trajectories,
trajectories, batch_size = trajectories, progress_aggregate=true,
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A}
num_batches = trajectories ÷ batch_size
num_batches < 1 &&
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
num_batches * batch_size != trajectories && (num_batches += 1)

if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)
end
logger = progress_aggregate ? AggregateLogger(Logging.current_logger()) : Logging.current_logger()

Logging.with_logger(logger) do
num_batches = trajectories ÷ batch_size
num_batches < 1 &&
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
num_batches * batch_size != trajectories && (num_batches += 1)

converged::Bool = false
elapsed_time = @elapsed begin
i = 1
II = (batch_size * (i - 1) + 1):(batch_size * i)
if get(kwargs, :progress, false)
name = get(kwargs, :progress_name, "Ensemble")
for i in 1:trajectories
msg = "$name #$i"
Logging.@logmsg(Logging.LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0)
end

Check warning on line 135 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L131-L135

Added lines #L131 - L135 were not covered by tests
end


batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)
end

converged::Bool = false
elapsed_time = @elapsed begin
i = 1
II = (batch_size * (i - 1) + 1):(batch_size * i)

u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
u, converged = prob.reduction(u, batch_data, II)
for i in 2:num_batches
converged && break
if i == num_batches
II = (batch_size * (i - 1) + 1):trajectories
else
II = (batch_size * (i - 1) + 1):(batch_size * i)
end
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)

u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
u, converged = prob.reduction(u, batch_data, II)
for i in 2:num_batches
converged && break
if i == num_batches
II = (batch_size * (i - 1) + 1):trajectories

Check warning on line 159 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L157-L159

Added lines #L157 - L159 were not covered by tests
else
II = (batch_size * (i - 1) + 1):(batch_size * i)

Check warning on line 161 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L161

Added line #L161 was not covered by tests
end
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
u, converged = prob.reduction(u, batch_data, II)
end

Check warning on line 165 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L163-L165

Added lines #L163 - L165 were not covered by tests
end
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)
end
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)
end

function batch_func(i, prob, alg; kwargs...)
iter = 1
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
new_prob = prob.prob_func(_prob, i, iter)
rerun = true

progress = get(kwargs, :progress, false)
if progress
name = get(kwargs, :progress_name, "Ensemble")
progress_name = "$name #$i"
progress_id = Symbol("SciMLBase_$i")
kwargs = (kwargs..., progress_name=progress_name, progress_id=progress_id)

Check warning on line 184 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L181-L184

Added lines #L181 - L184 were not covered by tests
end
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
if !(x isa Tuple)
rerun_warn()
Expand Down
Loading