Skip to content

Commit

Permalink
Add progress_bar in mcsolve, ssesolve and dsf_mcsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio committed Oct 3, 2024
1 parent 58d5d88 commit 4a44624
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 84 deletions.
12 changes: 7 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
Expand Down Expand Up @@ -37,23 +38,24 @@ CUDA = "5"
DiffEqBase = "6"
DiffEqCallbacks = "2 - 3.1, 3.8, 4"
DiffEqNoiseProcess = "5"
Distributed = "1"
FFTW = "1.5"
Graphs = "1.7"
IncompleteLU = "0.2"
LinearAlgebra = "<0.0.1, 1"
LinearAlgebra = "1"
LinearSolve = "2"
OrdinaryDiffEqCore = "1"
OrdinaryDiffEqTsit5 = "1"
Pkg = "<0.0.1, 1"
Random = "<0.0.1, 1"
Pkg = "1"
Random = "1"
Reexport = "1"
SciMLBase = "2"
SciMLOperators = "0.3"
SparseArrays = "<0.0.1, 1"
SparseArrays = "1"
SpecialFunctions = "2"
StaticArraysCore = "1"
StochasticDiffEq = "6"
Test = "<0.0.1, 1"
Test = "1"
julia = "1.10"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import DiffEqNoiseProcess: RealWienerProcess

# other dependencies (in alphabetical order)
import ArrayInterface: allowed_getindex, allowed_setindex!
import Distributed: RemoteChannel
import FFTW: fft, fftshift
import Graphs: connected_components, DiGraph
import IncompleteLU: ilu
Expand Down
2 changes: 1 addition & 1 deletion src/qobj/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ function tunneling(N::Int, m::Int = 1; sparse::Union{Bool,Val} = Val(false))
(m < 1) && throw(ArgumentError("The number of excitations (m) cannot be less than 1"))

data = ones(ComplexF64, N - m)
if getVal(makeVal(sparse))
if getVal(sparse)
return QuantumObject(spdiagm(m => data, -m => data); type = Operator, dims = N)
else
return QuantumObject(diagm(m => data, -m => data); type = Operator, dims = N)
Expand Down
4 changes: 2 additions & 2 deletions src/qobj/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ It is also possible to specify the list of dimensions `dims` if different subsys
If you want to keep type stability, it is recommended to use `fock(N, j, dims=dims, sparse=Val(sparse))` instead of `fock(N, j, dims=dims, sparse=sparse)`. Consider also to use `dims` as a `Tuple` or `SVector` instead of `Vector`. See [this link](https://docs.julialang.org/en/v1/manual/performance-tips/#man-performance-value-type) and the [related Section](@ref doc:Type-Stability) about type stability for more details.
"""
function fock(N::Int, j::Int = 0; dims::Union{Int,AbstractVector{Int},Tuple} = N, sparse::Union{Bool,Val} = Val(false))
if getVal(makeVal(sparse))
if getVal(sparse)
array = sparsevec([j + 1], [1.0 + 0im], N)
else
array = zeros(ComplexF64, N)
Expand Down Expand Up @@ -130,7 +130,7 @@ function thermal_dm(N::Int, n::Real; sparse::Union{Bool,Val} = Val(false))
β = log(1.0 / n + 1.0)
N_list = Array{Float64}(0:N-1)
data = exp.(-β .* N_list)
if getVal(makeVal(sparse))
if getVal(sparse)
return QuantumObject(spdiagm(0 => data ./ sum(data)), Operator, N)
else
return QuantumObject(diagm(0 => data ./ sum(data)), Operator, N)
Expand Down
76 changes: 50 additions & 26 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ end
function _mcsolve_output_func(sol, i)
resize!(sol.prob.p.jump_times, sol.prob.p.jump_times_which_idx[] - 1)
resize!(sol.prob.p.jump_which, sol.prob.p.jump_times_which_idx[] - 1)
put!(sol.prob.p.progr_channel, true)
return (sol, false)
end

Expand Down Expand Up @@ -204,7 +205,8 @@ function mcsolveProblem(
end

saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
# We disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat, progress_bar = Val(false))
kwargs2 = merge(default_values, kwargs)

cache_mc = similar(ψ0.data)
Expand Down Expand Up @@ -396,15 +398,20 @@ end
mcsolve(H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple}=nothing;
alg::OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
ntraj::Int=1,
ensemble_method=EnsembleThreads(),
jump_callback::TJC=ContinuousLindbladJumpCallback(),
kwargs...)
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
seeds::Union{Nothing,Vector{Int}} = nothing,
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_output_func,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
)
Time evolution of an open quantum system using quantum trajectories.
Expand Down Expand Up @@ -457,6 +464,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `prob_func::Function`: Function to use for generating the ODEProblem.
- `output_func::Function`: Function to use for generating the output of a single trajectory.
- `kwargs...`: Additional keyword arguments to pass to the solver.
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
# Notes
Expand Down Expand Up @@ -486,29 +494,42 @@ function mcsolve(
jump_callback::TJC = ContinuousLindbladJumpCallback(),
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_output_func,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
if !isnothing(seeds) && length(seeds) != ntraj
throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))"))
end

ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = params,
seeds = seeds,
jump_callback = jump_callback,
prob_func = prob_func,
output_func = output_func,
kwargs...,
)
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
@async while take!(progr_channel)
next!(progr)
end

return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
# Stop the async task if an error occurs
try
ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = merge(params, (progr_channel = progr_channel,)),
seeds = seeds,
jump_callback = jump_callback,
prob_func = prob_func,
output_func = output_func,
kwargs...,
)

return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
catch e
put!(progr_channel, false)
rethrow()
end
end

function mcsolve(
Expand All @@ -518,6 +539,9 @@ function mcsolve(
ensemble_method = EnsembleThreads(),
)
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)

put!(sol[:, 1].prob.p.progr_channel, false)

_sol_1 = sol[:, 1]

expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
Expand Down
7 changes: 3 additions & 4 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,13 @@ function mesolveProblem(
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type

t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

L = liouvillian(H, c_ops).data
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
progr = ProgressBar(length(t_l), enable = getVal(progress_bar))

if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, length(t_l))
Expand Down Expand Up @@ -158,7 +157,7 @@ function mesolveProblem(
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
kwargs2 = merge(default_values, kwargs)
kwargs3 = _generate_mesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)
kwargs3 = _generate_mesolve_kwargs(e_ops, makeVal(progress_bar), t_l, kwargs2)

dudt! = is_time_dependent ? mesolve_td_dudt! : mesolve_ti_dudt!

Expand Down Expand Up @@ -241,7 +240,7 @@ function mesolve(
e_ops = e_ops,
H_t = H_t,
params = params,
progress_bar = makeVal(progress_bar),
progress_bar = progress_bar,
kwargs...,
)

Expand Down
7 changes: 3 additions & 4 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,13 @@ function sesolveProblem(
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type

t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

U = -1im * get_data(H)
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
progr = ProgressBar(length(t_l), enable = getVal(progress_bar))

if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, length(t_l))
Expand All @@ -135,7 +134,7 @@ function sesolveProblem(
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
kwargs2 = merge(default_values, kwargs)
kwargs3 = _generate_sesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)
kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), t_l, kwargs2)

dudt! = is_time_dependent ? sesolve_td_dudt! : sesolve_ti_dudt!

Expand Down Expand Up @@ -203,7 +202,7 @@ function sesolve(
e_ops = e_ops,
H_t = H_t,
params = params,
progress_bar = makeVal(progress_bar),
progress_bar = progress_bar,
kwargs...,
)

Expand Down
Loading

0 comments on commit 4a44624

Please sign in to comment.