From 59e2abe36532deef4adb55c58dd787dc91bb6ae4 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 18 May 2019 16:17:16 -0400 Subject: [PATCH] add progress bar support Fixes #40 --- Project.toml | 1 + src/Sundials.jl | 5 +-- src/common_interface/integrator_types.jl | 5 ++- src/common_interface/solve.jl | 50 ++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 5b086349..05524968 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" [compat] julia = "1" diff --git a/src/Sundials.jl b/src/Sundials.jl index f2949346..0a93c3a1 100644 --- a/src/Sundials.jl +++ b/src/Sundials.jl @@ -2,15 +2,14 @@ __precompile__() module Sundials -using Reexport, DataStructures +using Reexport, DataStructures, Logging @reexport using DiffEqBase using DiffEqBase: check_error! using SparseArrays, LinearAlgebra const warnkeywords = (:save_idxs, :d_discontinuities, :isoutofdomain, :unstable_check, - :calck, :progress, :timeseries_steps, - :internalnorm, :gamma, :beta1, :beta2, :qmax, :qmin, :qoldinit) + :calck, :internalnorm, :gamma, :beta1, :beta2, :qmax, :qmin, :qoldinit) function __init__() global warnlist = Set(warnkeywords) diff --git a/src/common_interface/integrator_types.jl b/src/common_interface/integrator_types.jl index 16cc5bd3..a31647c0 100644 --- a/src/common_interface/integrator_types.jl +++ b/src/common_interface/integrator_types.jl @@ -1,4 +1,4 @@ -mutable struct DEOptions{SType,TstopType,CType,reltolType,abstolType} +mutable struct DEOptions{SType,TstopType,CType,reltolType,abstolType,F5} saveat::SType tstops::TstopType save_everystep::Bool @@ -13,6 +13,9 @@ mutable struct DEOptions{SType,TstopType,CType,reltolType,abstolType} verbose::Bool advance_to_tstop::Bool stop_at_next_tstop::Bool + progress::Bool + progress_name::String + progress_message::F5 end abstract type AbstractSundialsIntegrator{algType} <: DiffEqBase.AbstractODEIntegrator{algType,true,Vector{Float64},Float64} end diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index 6aa673ef..4faf2c2f 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -31,6 +31,8 @@ function DiffEqBase.__init( save_start = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : prob.tspan[1] in saveat, save_end = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : prob.tspan[2] in saveat, dense = save_everystep && isempty(saveat), + progress=false,progress_name="ODE", + progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, save_timeseries = nothing, advance_to_tstop = false,stop_at_next_tstop=false, userdata=nothing, @@ -58,6 +60,8 @@ function DiffEqBase.__init( error("Sundials only allows scalar reltol.") end + progress && @logmsg(-1,progress_name,_id=_id = :Sundials,progress=0) + callbacks_internal = CallbackSet(callback,prob.callback) tspan = prob.tspan @@ -265,7 +269,8 @@ function DiffEqBase.__init( calculate_error = false) opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense, timeseries_errors,dense_errors,save_on,save_end, - callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop) + callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop, + progress,progress_name,progress_message) integrator = CVODEIntegrator(u0,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts, tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,0,0.) @@ -288,6 +293,8 @@ function DiffEqBase.__init( save_everystep=isempty(saveat), dense = save_everystep, save_on = true, save_start = true, save_end = true, save_timeseries = nothing, + progress=false,progress_name="ODE", + progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, advance_to_tstop = false,stop_at_next_tstop=false, userdata=nothing, alias_u0=false, @@ -310,6 +317,8 @@ function DiffEqBase.__init( error("Sundials only allows scalar reltol.") end + progress && @logmsg(-1,progress_name,_id=_id = :Sundials,progress=0) + callbacks_internal = CallbackSet(callback,prob.callback) tspan = prob.tspan @@ -634,7 +643,8 @@ function DiffEqBase.__init( calculate_error = false) opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense, timeseries_errors,dense_errors,save_on,save_end, - callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop) + callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop, + progress,progress_name,progress_message) integrator = ARKODEIntegrator(utmp,prob.p,t0,t0,mem,_LS,_A,_MLS,_M,sol,alg,f!,userfun,jac,opts, tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,0,0.) @@ -693,6 +703,8 @@ function DiffEqBase.__init( dense_errors = false, save_everystep=isempty(saveat), dense=save_everystep, save_timeseries=nothing, save_end = true, + progress=false,progress_name="ODE", + progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, advance_to_tstop = false, stop_at_next_tstop = false, userdata=nothing, kwargs...) where {uType, duType, tupType, isinplace, LinearSolver} @@ -714,6 +726,8 @@ function DiffEqBase.__init( error("Sundials only allows scalar reltol.") end + progress && @logmsg(-1,progress_name,_id=_id = :Sundials,progress=0) + callbacks_internal = CallbackSet(callback,prob.callback) tspan = prob.tspan @@ -941,7 +955,8 @@ function DiffEqBase.__init( opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense, timeseries_errors,dense_errors,save_on,save_end, - callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop) + callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop, + progress,progress_name,progress_message) integrator = IDAIntegrator(utmp,dutmp,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts, tout,tdir,sizeu,sizedu,false,tmp,uprev,Cint(flag),false,0,0.) @@ -962,13 +977,34 @@ end function solver_step(integrator::CVODEIntegrator,tstop) integrator.flag = CVode(integrator.mem, tstop, integrator.u, integrator.tout, CV_ONE_STEP) + if integrator.opts.progress + @logmsg(-1, + integrator.opts.progress_name, + _id = :Sundials, + message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t), + progress=integrator.t/integrator.sol.prob.tspan[2]) + end end function solver_step(integrator::ARKODEIntegrator,tstop) integrator.flag = ARKode(integrator.mem, tstop, integrator.u, integrator.tout, ARK_ONE_STEP) + if integrator.opts.progress + @logmsg(-1, + integrator.opts.progress_name, + _id = :Sundials, + message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t), + progress=integrator.t/integrator.sol.prob.tspan[2]) + end end function solver_step(integrator::IDAIntegrator,tstop) integrator.flag = IDASolve(integrator.mem, tstop, integrator.tout, integrator.u, integrator.du, IDA_ONE_STEP) + if integrator.opts.progress + @logmsg(-1, + integrator.opts.progress_name, + _id = :Sundials, + message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t), + progress=integrator.t/integrator.sol.prob.tspan[2]) + end end function set_stop_time(integrator::CVODEIntegrator,tstop) @@ -1016,6 +1052,14 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator) end end + if integrator.opts.progress + @logmsg(-1, + integrator.opts.progress_name, + _id = :Sundials, + message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t), + progress="done") + end + fill_destats!(integrator) empty!(integrator.mem) integrator.A != nothing && empty!(integrator.A)