diff --git a/Project.toml b/Project.toml index 6d1becd12..8d9b53c08 100644 --- a/Project.toml +++ b/Project.toml @@ -33,6 +33,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote" [compat] ADTypes = "0.1.3, 0.2" ArrayInterface = "6, 7" +ChainRules = "1.57.0" ChainRulesCore = "1.16" CommonSolve = "0.2.4" ConstructionBase = "1" diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index db08eb9a8..ece01eb68 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -29,6 +29,64 @@ function merge_stats(us) reduce(merge, st) end +mutable struct AggregateLogger{T<:Logging.AbstractLogger} <: Logging.AbstractLogger + progress::Dict{Symbol, Float64} + done_counter::Int + total::Float64 + print_time::Float64 + lock::ReentrantLock + logger::T +end +AggregateLogger(logger::Logging.AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger) + +function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...) + if convert(Logging.LogLevel, level) == Logging.LogLevel(-1) && haskey(kwargs, :progress) + pr = kwargs[:progress] + if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing) + try + if pr == "done" + pr = 1.0 + l.done_counter += 1 + end + len = length(l.progress) + if haskey(l.progress, id) + l.total += (pr-l.progress[id])/len + else + l.total = l.total*(len/(len+1)) + pr/(len+1) + len += 1 + end + l.progress[id] = pr + # validation check (slow) + # tot = sum(values(l.progress))/length(l.progress) + # @show tot l.total l.total ≈ tot + curr_time = time() + if l.done_counter >= len + tot="done" + empty!(l.progress) + l.done_counter = 0 + l.print_time = 0.0 + elseif curr_time-l.print_time > 0.1 + tot = l.total + l.print_time = curr_time + else + return + end + id=:total + message="Total" + kwargs=merge(values(kwargs), (progress=tot,)) + finally + unlock(l.lock) + end + else + return + end + end + Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...) +end +Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...) +Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger) +Logging.catch_exceptions(l::AggregateLogger) = Logging.catch_exceptions(l.logger) + function __solve(prob::AbstractEnsembleProblem, alg::Union{AbstractDEAlgorithm, Nothing}; kwargs...) @@ -59,44 +117,57 @@ end function __solve(prob::AbstractEnsembleProblem, alg::A, ensemblealg::BasicEnsembleAlgorithm; - trajectories, batch_size = trajectories, + trajectories, batch_size = trajectories, progress_aggregate=true, pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A} - num_batches = trajectories ÷ batch_size - num_batches < 1 && - error("trajectories ÷ batch_size cannot be less than 1, got $num_batches") - num_batches * batch_size != trajectories && (num_batches += 1) - - if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION - elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories, - pmap_batch_size; kwargs...) - _u = tighten_container_eltype(u) - stats = merge_stats(_u) - return EnsembleSolution(_u, elapsed_time, true, stats) - end + logger = progress_aggregate ? AggregateLogger(Logging.current_logger()) : Logging.current_logger() + + Logging.with_logger(logger) do + num_batches = trajectories ÷ batch_size + num_batches < 1 && + error("trajectories ÷ batch_size cannot be less than 1, got $num_batches") + num_batches * batch_size != trajectories && (num_batches += 1) - converged::Bool = false - elapsed_time = @elapsed begin - i = 1 - II = (batch_size * (i - 1) + 1):(batch_size * i) + if get(kwargs, :progress, false) + name = get(kwargs, :progress_name, "Ensemble") + for i in 1:trajectories + msg = "$name #$i" + Logging.@logmsg(Logging.LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0) + end + end + - batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...) + if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION + elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories, + pmap_batch_size; kwargs...) + _u = tighten_container_eltype(u) + stats = merge_stats(_u) + return EnsembleSolution(_u, elapsed_time, true, stats) + end + + converged::Bool = false + elapsed_time = @elapsed begin + i = 1 + II = (batch_size * (i - 1) + 1):(batch_size * i) - u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init - u, converged = prob.reduction(u, batch_data, II) - for i in 2:num_batches - converged && break - if i == num_batches - II = (batch_size * (i - 1) + 1):trajectories - else - II = (batch_size * (i - 1) + 1):(batch_size * i) - end batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...) + + u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init u, converged = prob.reduction(u, batch_data, II) + for i in 2:num_batches + converged && break + if i == num_batches + II = (batch_size * (i - 1) + 1):trajectories + else + II = (batch_size * (i - 1) + 1):(batch_size * i) + end + batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...) + u, converged = prob.reduction(u, batch_data, II) + end end + _u = tighten_container_eltype(u) + stats = merge_stats(_u) + return EnsembleSolution(_u, elapsed_time, converged, stats) end - _u = tighten_container_eltype(u) - stats = merge_stats(_u) - return EnsembleSolution(_u, elapsed_time, converged, stats) end function batch_func(i, prob, alg; kwargs...) @@ -104,6 +175,14 @@ function batch_func(i, prob, alg; kwargs...) _prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob new_prob = prob.prob_func(_prob, i, iter) rerun = true + + progress = get(kwargs, :progress, false) + if progress + name = get(kwargs, :progress_name, "Ensemble") + progress_name = "$name #$i" + progress_id = Symbol("SciMLBase_$i") + kwargs = (kwargs..., progress_name=progress_name, progress_id=progress_id) + end x = prob.output_func(solve(new_prob, alg; kwargs...), i) if !(x isa Tuple) rerun_warn()