Skip to content

Commit

Permalink
add progress bar support
Browse files Browse the repository at this point in the history
Fixes #40
  • Loading branch information
ChrisRackauckas committed May 18, 2019
1 parent f465ff5 commit d60e81d
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 7 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"

[compat]
julia = "1"
Expand Down
5 changes: 2 additions & 3 deletions src/Sundials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ __precompile__()

module Sundials

using Reexport, DataStructures
using Reexport, DataStructures, Logging
@reexport using DiffEqBase
using DiffEqBase: check_error!
using SparseArrays, LinearAlgebra

const warnkeywords =
(:save_idxs, :d_discontinuities, :isoutofdomain, :unstable_check,
:calck, :progress, :timeseries_steps,
:internalnorm, :gamma, :beta1, :beta2, :qmax, :qmin, :qoldinit)
:calck, :internalnorm, :gamma, :beta1, :beta2, :qmax, :qmin, :qoldinit)

function __init__()
global warnlist = Set(warnkeywords)
Expand Down
5 changes: 4 additions & 1 deletion src/common_interface/integrator_types.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mutable struct DEOptions{SType,TstopType,CType,reltolType,abstolType}
mutable struct DEOptions{SType,TstopType,CType,reltolType,abstolType,F5}
saveat::SType
tstops::TstopType
save_everystep::Bool
Expand All @@ -13,6 +13,9 @@ mutable struct DEOptions{SType,TstopType,CType,reltolType,abstolType}
verbose::Bool
advance_to_tstop::Bool
stop_at_next_tstop::Bool
progress::Bool
progress_name::String
progress_message::F5
end

abstract type AbstractSundialsIntegrator{algType} <: DiffEqBase.AbstractODEIntegrator{algType,true,Vector{Float64},Float64} end
Expand Down
50 changes: 47 additions & 3 deletions src/common_interface/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ function DiffEqBase.__init(
save_start = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : prob.tspan[1] in saveat,
save_end = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : prob.tspan[2] in saveat,
dense = save_everystep && isempty(saveat),
progress=false,progress_name="ODE",
progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE,
save_timeseries = nothing,
advance_to_tstop = false,stop_at_next_tstop=false,
userdata=nothing,
Expand Down Expand Up @@ -58,6 +60,8 @@ function DiffEqBase.__init(
error("Sundials only allows scalar reltol.")
end

progress && @logmsg(-1,progress_name,_id=_id = :Sundials,progress=0)

callbacks_internal = CallbackSet(callback,prob.callback)

tspan = prob.tspan
Expand Down Expand Up @@ -265,7 +269,8 @@ function DiffEqBase.__init(
calculate_error = false)
opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense,
timeseries_errors,dense_errors,save_on,save_end,
callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop)
callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop,
progress,progress_name,progress_message)
integrator = CVODEIntegrator(u0,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts,
tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,0,0.)

Expand All @@ -288,6 +293,8 @@ function DiffEqBase.__init(
save_everystep=isempty(saveat), dense = save_everystep,
save_on = true, save_start = true, save_end = true,
save_timeseries = nothing,
progress=false,progress_name="ODE",
progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE,
advance_to_tstop = false,stop_at_next_tstop=false,
userdata=nothing,
alias_u0=false,
Expand All @@ -310,6 +317,8 @@ function DiffEqBase.__init(
error("Sundials only allows scalar reltol.")
end

progress && @logmsg(-1,progress_name,_id=_id = :Sundials,progress=0)

callbacks_internal = CallbackSet(callback,prob.callback)

tspan = prob.tspan
Expand Down Expand Up @@ -634,7 +643,8 @@ function DiffEqBase.__init(
calculate_error = false)
opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense,
timeseries_errors,dense_errors,save_on,save_end,
callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop)
callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop,
progress,progress_name,progress_message)
integrator = ARKODEIntegrator(utmp,prob.p,t0,t0,mem,_LS,_A,_MLS,_M,sol,alg,f!,userfun,jac,opts,
tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,0,0.)

Expand Down Expand Up @@ -693,6 +703,8 @@ function DiffEqBase.__init(
dense_errors = false,
save_everystep=isempty(saveat), dense=save_everystep,
save_timeseries=nothing, save_end = true,
progress=false,progress_name="ODE",
progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE,
advance_to_tstop = false, stop_at_next_tstop = false,
userdata=nothing,
kwargs...) where {uType, duType, tupType, isinplace, LinearSolver}
Expand All @@ -714,6 +726,8 @@ function DiffEqBase.__init(
error("Sundials only allows scalar reltol.")
end

progress && @logmsg(-1,progress_name,_id=_id = :Sundials,progress=0)

callbacks_internal = CallbackSet(callback,prob.callback)

tspan = prob.tspan
Expand Down Expand Up @@ -941,7 +955,8 @@ function DiffEqBase.__init(

opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense,
timeseries_errors,dense_errors,save_on,save_end,
callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop)
callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop,
progress,progress_name,progress_message)

integrator = IDAIntegrator(utmp,dutmp,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts,
tout,tdir,sizeu,sizedu,false,tmp,uprev,Cint(flag),false,0,0.)
Expand All @@ -962,13 +977,34 @@ end

function solver_step(integrator::CVODEIntegrator,tstop)
integrator.flag = CVode(integrator.mem, tstop, integrator.u, integrator.tout, CV_ONE_STEP)
if integrator.opts.progress
@logmsg(-1,
integrator.opts.progress_name,
_id = :Sundials,
message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t),
progress=integrator.t/integrator.sol.prob.tspan[2])
end
end
function solver_step(integrator::ARKODEIntegrator,tstop)
integrator.flag = ARKode(integrator.mem, tstop, integrator.u, integrator.tout, ARK_ONE_STEP)
if integrator.opts.progress
@logmsg(-1,
integrator.opts.progress_name,
_id = :Sundials,
message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t),
progress=integrator.t/integrator.sol.prob.tspan[2])
end
end
function solver_step(integrator::IDAIntegrator,tstop)
integrator.flag = IDASolve(integrator.mem, tstop, integrator.tout,
integrator.u, integrator.du, IDA_ONE_STEP)
if integrator.opts.progress
@logmsg(-1,
integrator.opts.progress_name,
_id = :Sundials,
message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t),
progress=integrator.t/integrator.sol.prob.tspan[2])
end
end

function set_stop_time(integrator::CVODEIntegrator,tstop)
Expand Down Expand Up @@ -1016,6 +1052,14 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator)
end
end

if integrator.opts.progress
@logmsg(-1,
integrator.opts.progress_name,
_id = :Sundials,
message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t),
progress="done")
end

fill_destats!(integrator)
empty!(integrator.mem)
integrator.A != nothing && empty!(integrator.A)
Expand Down

0 comments on commit d60e81d

Please sign in to comment.