diff --git a/docs/io/images/progress_bars_demo.gif b/docs/io/images/progress_bars_demo.gif
new file mode 100644
index 00000000000..c672b8916fa
Binary files /dev/null and b/docs/io/images/progress_bars_demo.gif differ
diff --git a/docs/io/output/index.rst b/docs/io/output/index.rst
index 791f0a01eba..46036596261 100644
--- a/docs/io/output/index.rst
+++ b/docs/io/output/index.rst
@@ -11,4 +11,5 @@ In addition to the widgets, TARDIS can output information in several other forms
access_iterations
to_hdf
callback
- vpacket_logging
\ No newline at end of file
+ vpacket_logging
+ progress_bars
\ No newline at end of file
diff --git a/docs/io/output/progress_bars.ipynb b/docs/io/output/progress_bars.ipynb
new file mode 100644
index 00000000000..55c838d532a
--- /dev/null
+++ b/docs/io/output/progress_bars.ipynb
@@ -0,0 +1,63 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "cc613a48",
+ "metadata": {},
+ "source": [
+ "### Progress Bars for Simulation Run ###\n",
+ "TARDIS displays progress bars by default to track the simulation. The progress bars are not displayed in the documentation but show up when you run the notebook.\n",
+ "\n",
+ "![TARDIS Progress Bars](../images/progress_bars_demo.gif)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6e6b8b8f",
+ "metadata": {},
+ "source": [
+ "You can disable the progress bars as well."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "319759fd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tardis import run_tardis\n",
+ "sim = run_tardis('tardis_example.yml', show_progress_bars = False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e51d94c9",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/quickstart/quickstart.ipynb b/docs/quickstart/quickstart.ipynb
index 4f914d79e2e..28739a24c13 100644
--- a/docs/quickstart/quickstart.ipynb
+++ b/docs/quickstart/quickstart.ipynb
@@ -70,6 +70,19 @@
"#### Running the simulation (long output) ####"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "\n",
+ "Note\n",
+ " \n",
+ "The progress of the simulation can be tracked using progress bars which are displayed when the notebook is run, but are not displayed in the documentation. \n",
+ " \n",
+ "
"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
diff --git a/tardis/base.py b/tardis/base.py
index 3bcdd2aa025..b404145f296 100644
--- a/tardis/base.py
+++ b/tardis/base.py
@@ -15,6 +15,7 @@ def run_tardis(
show_convergence_plots=True,
log_level=None,
specific_log_level=None,
+ show_progress_bars=True,
**kwargs,
):
"""
@@ -52,6 +53,8 @@ def run_tardis(
The default value None means that the `specific_log_level` specified in the configuration file will be used.
show_convergence_plots : bool, default: True, optional
Option to enable tardis convergence plots.
+ show_progress_bars : bool, default: True, optional
+ Option to enable the progress bar.
**kwargs : dict, optional
Optional keyword arguments including those
supported by :obj:`tardis.visualization.tools.convergence_plot.ConvergencePlots`.
@@ -101,6 +104,7 @@ def run_tardis(
atom_data=atom_data,
virtual_packet_logging=virtual_packet_logging,
show_convergence_plots=show_convergence_plots,
+ show_progress_bars=show_progress_bars,
**kwargs,
)
for cb in simulation_callbacks:
diff --git a/tardis/montecarlo/base.py b/tardis/montecarlo/base.py
index 9ab1fbada27..7dfa116ca8f 100644
--- a/tardis/montecarlo/base.py
+++ b/tardis/montecarlo/base.py
@@ -268,6 +268,8 @@ def run(
nthreads=1,
last_run=False,
iteration=0,
+ total_iterations=0,
+ show_progress_bars=True,
):
"""
Run the montecarlo calculation
@@ -280,6 +282,8 @@ def run(
no_of_virtual_packets : int
nthreads : int
last_run : bool
+ total_iterations : int
+ The total number of iterations in the simulation.
Returns
-------
@@ -301,7 +305,15 @@ def run(
)
configuration_initialize(self, no_of_virtual_packets)
- montecarlo_radial1d(model, plasma, self)
+ montecarlo_radial1d(
+ model,
+ plasma,
+ iteration,
+ no_of_packets,
+ total_iterations,
+ show_progress_bars,
+ self,
+ )
self._integrator = FormalIntegrator(model, plasma, self)
# montecarlo.montecarlo_radial1d(
# model, plasma, self,
diff --git a/tardis/montecarlo/montecarlo_numba/base.py b/tardis/montecarlo/montecarlo_numba/base.py
index ba107d6292a..dbac64d17e1 100644
--- a/tardis/montecarlo/montecarlo_numba/base.py
+++ b/tardis/montecarlo/montecarlo_numba/base.py
@@ -1,4 +1,4 @@
-from numba import prange, njit, jit
+from numba import prange, njit, jit, objmode
import logging
import numpy as np
@@ -26,9 +26,18 @@
)
from tardis.montecarlo.montecarlo_numba import njit_dict
from numba.typed import List
+from tardis.util.base import update_iterations_pbar, update_packet_pbar
-def montecarlo_radial1d(model, plasma, runner):
+def montecarlo_radial1d(
+ model,
+ plasma,
+ iteration,
+ no_of_packets,
+ total_iterations,
+ show_progress_bars,
+ runner,
+):
packet_collection = PacketCollection(
runner.input_r,
runner.input_nu,
@@ -76,6 +85,10 @@ def montecarlo_radial1d(model, plasma, runner):
runner.spectrum_frequency.value,
number_of_vpackets,
packet_seeds,
+ iteration=iteration,
+ show_progress_bars=show_progress_bars,
+ no_of_packets=no_of_packets,
+ total_iterations=total_iterations,
)
runner._montecarlo_virtual_luminosity.value[:] = v_packets_energy_hist
@@ -109,6 +122,7 @@ def montecarlo_radial1d(model, plasma, runner):
runner.virt_packet_last_line_interaction_out_id = np.concatenate(
np.array(virt_packet_last_line_interaction_out_id)
).ravel()
+ update_iterations_pbar(1)
@njit(**njit_dict)
@@ -120,6 +134,10 @@ def montecarlo_main_loop(
spectrum_frequency,
number_of_vpackets,
packet_seeds,
+ iteration,
+ show_progress_bars,
+ no_of_packets,
+ total_iterations,
):
"""
This is the main loop of the MonteCarlo routine that generates packets
@@ -178,6 +196,16 @@ def montecarlo_main_loop(
virt_packet_last_line_interaction_out_id = []
for i in prange(len(output_nus)):
+ if show_progress_bars:
+ with objmode:
+ update_amount = 1
+ update_packet_pbar(
+ update_amount,
+ current_iteration=iteration,
+ no_of_packets=no_of_packets,
+ total_iterations=total_iterations,
+ )
+
if montecarlo_configuration.single_packet_seed != -1:
seed = packet_seeds[montecarlo_configuration.single_packet_seed]
np.random.seed(seed)
@@ -214,8 +242,12 @@ def montecarlo_main_loop(
vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx]
vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx]
- vpackets_initial_mu = vpacket_collection.initial_mus[: vpacket_collection.idx]
- vpackets_initial_r = vpacket_collection.initial_rs[: vpacket_collection.idx]
+ vpackets_initial_mu = vpacket_collection.initial_mus[
+ : vpacket_collection.idx
+ ]
+ vpackets_initial_r = vpacket_collection.initial_rs[
+ : vpacket_collection.idx
+ ]
v_packets_idx = np.floor(
(vpackets_nu - spectrum_frequency[0]) / delta_nu
@@ -233,17 +265,29 @@ def montecarlo_main_loop(
if montecarlo_configuration.VPACKET_LOGGING:
for vpacket_collection in vpacket_collections:
vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx]
- vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx]
- vpackets_initial_mu = vpacket_collection.initial_mus[: vpacket_collection.idx]
- vpackets_initial_r = vpacket_collection.initial_rs[: vpacket_collection.idx]
+ vpackets_energy = vpacket_collection.energies[
+ : vpacket_collection.idx
+ ]
+ vpackets_initial_mu = vpacket_collection.initial_mus[
+ : vpacket_collection.idx
+ ]
+ vpackets_initial_r = vpacket_collection.initial_rs[
+ : vpacket_collection.idx
+ ]
virt_packet_nus.append(np.ascontiguousarray(vpackets_nu))
virt_packet_energies.append(np.ascontiguousarray(vpackets_energy))
- virt_packet_initial_mus.append(np.ascontiguousarray(vpackets_initial_mu))
- virt_packet_initial_rs.append(np.ascontiguousarray(vpackets_initial_r))
- virt_packet_last_interaction_in_nu.append(np.ascontiguousarray(
- vpacket_collection.last_interaction_in_nu[
- : vpacket_collection.idx
- ])
+ virt_packet_initial_mus.append(
+ np.ascontiguousarray(vpackets_initial_mu)
+ )
+ virt_packet_initial_rs.append(
+ np.ascontiguousarray(vpackets_initial_r)
+ )
+ virt_packet_last_interaction_in_nu.append(
+ np.ascontiguousarray(
+ vpacket_collection.last_interaction_in_nu[
+ : vpacket_collection.idx
+ ]
+ )
)
virt_packet_last_interaction_type.append(
np.ascontiguousarray(
diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py
index 77edbd628e6..df0d3427f6a 100644
--- a/tardis/simulation/base.py
+++ b/tardis/simulation/base.py
@@ -134,6 +134,7 @@ def __init__(
nthreads,
show_convergence_plots,
convergence_plots_kwargs,
+ show_progress_bars,
):
super(Simulation, self).__init__(iterations, model.no_of_shells)
@@ -151,6 +152,7 @@ def __init__(
self.luminosity_nu_end = luminosity_nu_end
self.luminosity_requested = luminosity_requested
self.nthreads = nthreads
+ self.show_progress_bars = show_progress_bars
if convergence_strategy.type in ("damped"):
self.convergence_strategy = convergence_strategy
@@ -372,6 +374,8 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0, last_run=False):
nthreads=self.nthreads,
last_run=last_run,
iteration=self.iterations_executed,
+ total_iterations=self.iterations,
+ show_progress_bars=self.show_progress_bars,
)
output_energy = self.runner.output_energy
if np.sum(output_energy < 0) == len(output_energy):
@@ -592,6 +596,7 @@ def from_config(
packet_source=None,
virtual_packet_logging=False,
show_convergence_plots=True,
+ show_progress_bars=True,
**kwargs,
):
"""
@@ -684,4 +689,5 @@ def from_config(
convergence_strategy=config.montecarlo.convergence_strategy,
nthreads=config.montecarlo.nthreads,
convergence_plots_kwargs=convergence_plots_kwargs,
+ show_progress_bars=show_progress_bars,
)
diff --git a/tardis/util/base.py b/tardis/util/base.py
index 78714be9331..89aa8b2fa75 100644
--- a/tardis/util/base.py
+++ b/tardis/util/base.py
@@ -13,7 +13,8 @@
import tardis
from tardis.io.util import get_internal_data_path
-from IPython import get_ipython
+from IPython import get_ipython, display
+import tqdm
k_B_cgs = constants.k_B.cgs.value
c_cgs = constants.c.cgs.value
@@ -596,3 +597,130 @@ def is_notebook():
# All other shell instances are returned False
else:
return False
+
+
+if is_notebook():
+ iterations_pbar = tqdm.notebook.tqdm(
+ desc="Iterations:",
+ bar_format="{desc:<}{bar}{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
+ )
+ iterations_pbar.container.close()
+ packet_pbar = tqdm.notebook.tqdm(
+ desc="Packets: ",
+ postfix="0",
+ bar_format="{desc:<}{bar}{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
+ )
+ packet_pbar.container.close()
+
+else:
+ iterations_pbar = tqdm.tqdm(
+ desc="Iterations:",
+ bar_format="{desc:<}{bar:80}{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
+ )
+ packet_pbar = tqdm.tqdm(
+ desc="Packets: ",
+ postfix="0",
+ bar_format="{desc:<}{bar:80}{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
+ )
+
+
+def update_packet_pbar(i, current_iteration, no_of_packets, total_iterations):
+ """
+ Update progress bars as each packet is propagated.
+
+ Parameters
+ ----------
+ i : int
+ Amount by which the progress bar needs to be updated.
+ current_iteration : int
+ Current iteration number.
+ no_of_packets : int
+ Total number of packets in one iteration.
+ total_iterations : int
+ Total number of iterations.
+ """
+ if packet_pbar.postfix == "":
+ packet_pbar.postfix = "0"
+ bar_iteration = int(packet_pbar.postfix) - 1
+
+ # fix bar layout when run_tardis is called for the first time
+ if iterations_pbar.total == None:
+ fix_bar_layout(iterations_pbar, total_iterations=total_iterations)
+ if packet_pbar.total == None:
+ fix_bar_layout(packet_pbar, no_of_packets=no_of_packets)
+
+ # display and reset progress bar when run_tardis is called again
+ if iterations_pbar.n == total_iterations:
+ if type(iterations_pbar).__name__ == "tqdm_notebook":
+ iterations_pbar.container.close()
+ fix_bar_layout(iterations_pbar, total_iterations=total_iterations)
+
+ if bar_iteration > current_iteration:
+ packet_pbar.postfix = current_iteration
+ if type(packet_pbar).__name__ == "tqdm_notebook":
+ # stop displaying last container
+ packet_pbar.container.close()
+ fix_bar_layout(packet_pbar, no_of_packets=no_of_packets)
+
+ # reset progress bar with each iteration
+ if bar_iteration < current_iteration:
+ packet_pbar.reset(total=no_of_packets)
+ packet_pbar.postfix = str(current_iteration + 1)
+
+ packet_pbar.update(i)
+
+
+def update_iterations_pbar(i):
+ """
+ Update progress bar for each iteration.
+
+ Parameters
+ ----------
+ i : int
+ Amount by which the progress bar needs to be updated.
+ """
+ iterations_pbar.update(i)
+
+
+def fix_bar_layout(bar, no_of_packets=None, total_iterations=None):
+ """
+ Fix the layout of progress bars.
+
+ Parameters
+ ----------
+ bar : tqdm instance
+ Progress bar to change the layout of.
+ no_of_packets : int, optional
+ Number of packets to be propagated.
+ total_iterations : int, optional
+ Total number of iterations.
+ """
+ if type(bar).__name__ == "tqdm_notebook":
+ bar.container = bar.status_printer(
+ bar.fp,
+ bar.total,
+ bar.desc,
+ bar.ncols,
+ )
+ if no_of_packets is not None:
+ bar.reset(total=no_of_packets)
+ if total_iterations is not None:
+ bar.reset(total=total_iterations)
+
+ # change the amount of space the prefix string of the bar takes
+ # here, either packets or iterations
+ bar.container.children[0].layout.width = "6%"
+
+ # change the length of the bar
+ bar.container.children[1].layout.width = "60%"
+
+ # display the progress bar
+ display.display(bar.container)
+
+ else:
+ if no_of_packets is not None:
+ bar.reset(total=no_of_packets)
+ if total_iterations is not None:
+ bar.reset(total=total_iterations)
+ else:
+ pass