diff --git a/docs/io/images/progress_bars_demo.gif b/docs/io/images/progress_bars_demo.gif new file mode 100644 index 00000000000..c672b8916fa Binary files /dev/null and b/docs/io/images/progress_bars_demo.gif differ diff --git a/docs/io/output/index.rst b/docs/io/output/index.rst index 791f0a01eba..46036596261 100644 --- a/docs/io/output/index.rst +++ b/docs/io/output/index.rst @@ -11,4 +11,5 @@ In addition to the widgets, TARDIS can output information in several other forms access_iterations to_hdf callback - vpacket_logging \ No newline at end of file + vpacket_logging + progress_bars \ No newline at end of file diff --git a/docs/io/output/progress_bars.ipynb b/docs/io/output/progress_bars.ipynb new file mode 100644 index 00000000000..55c838d532a --- /dev/null +++ b/docs/io/output/progress_bars.ipynb @@ -0,0 +1,63 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cc613a48", + "metadata": {}, + "source": [ + "### Progress Bars for Simulation Run ###\n", + "TARDIS displays progress bars by default to track the simulation. The progress bars are not displayed in the documentation but show up when you run the notebook.\n", + "\n", + "![TARDIS Progress Bars](../images/progress_bars_demo.gif)" + ] + }, + { + "cell_type": "markdown", + "id": "6e6b8b8f", + "metadata": {}, + "source": [ + "You can disable the progress bars as well." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "319759fd", + "metadata": {}, + "outputs": [], + "source": [ + "from tardis import run_tardis\n", + "sim = run_tardis('tardis_example.yml', show_progress_bars = False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e51d94c9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/quickstart/quickstart.ipynb b/docs/quickstart/quickstart.ipynb index 4f914d79e2e..28739a24c13 100644 --- a/docs/quickstart/quickstart.ipynb +++ b/docs/quickstart/quickstart.ipynb @@ -70,6 +70,19 @@ "#### Running the simulation (long output) ####" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "Note\n", + " \n", + "The progress of the simulation can be tracked using progress bars which are displayed when the notebook is run, but are not displayed in the documentation. \n", + " \n", + "
" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/tardis/base.py b/tardis/base.py index 3bcdd2aa025..b404145f296 100644 --- a/tardis/base.py +++ b/tardis/base.py @@ -15,6 +15,7 @@ def run_tardis( show_convergence_plots=True, log_level=None, specific_log_level=None, + show_progress_bars=True, **kwargs, ): """ @@ -52,6 +53,8 @@ def run_tardis( The default value None means that the `specific_log_level` specified in the configuration file will be used. show_convergence_plots : bool, default: True, optional Option to enable tardis convergence plots. + show_progress_bars : bool, default: True, optional + Option to enable the progress bar. **kwargs : dict, optional Optional keyword arguments including those supported by :obj:`tardis.visualization.tools.convergence_plot.ConvergencePlots`. @@ -101,6 +104,7 @@ def run_tardis( atom_data=atom_data, virtual_packet_logging=virtual_packet_logging, show_convergence_plots=show_convergence_plots, + show_progress_bars=show_progress_bars, **kwargs, ) for cb in simulation_callbacks: diff --git a/tardis/montecarlo/base.py b/tardis/montecarlo/base.py index 9ab1fbada27..7dfa116ca8f 100644 --- a/tardis/montecarlo/base.py +++ b/tardis/montecarlo/base.py @@ -268,6 +268,8 @@ def run( nthreads=1, last_run=False, iteration=0, + total_iterations=0, + show_progress_bars=True, ): """ Run the montecarlo calculation @@ -280,6 +282,8 @@ def run( no_of_virtual_packets : int nthreads : int last_run : bool + total_iterations : int + The total number of iterations in the simulation. Returns ------- @@ -301,7 +305,15 @@ def run( ) configuration_initialize(self, no_of_virtual_packets) - montecarlo_radial1d(model, plasma, self) + montecarlo_radial1d( + model, + plasma, + iteration, + no_of_packets, + total_iterations, + show_progress_bars, + self, + ) self._integrator = FormalIntegrator(model, plasma, self) # montecarlo.montecarlo_radial1d( # model, plasma, self, diff --git a/tardis/montecarlo/montecarlo_numba/base.py b/tardis/montecarlo/montecarlo_numba/base.py index ba107d6292a..dbac64d17e1 100644 --- a/tardis/montecarlo/montecarlo_numba/base.py +++ b/tardis/montecarlo/montecarlo_numba/base.py @@ -1,4 +1,4 @@ -from numba import prange, njit, jit +from numba import prange, njit, jit, objmode import logging import numpy as np @@ -26,9 +26,18 @@ ) from tardis.montecarlo.montecarlo_numba import njit_dict from numba.typed import List +from tardis.util.base import update_iterations_pbar, update_packet_pbar -def montecarlo_radial1d(model, plasma, runner): +def montecarlo_radial1d( + model, + plasma, + iteration, + no_of_packets, + total_iterations, + show_progress_bars, + runner, +): packet_collection = PacketCollection( runner.input_r, runner.input_nu, @@ -76,6 +85,10 @@ def montecarlo_radial1d(model, plasma, runner): runner.spectrum_frequency.value, number_of_vpackets, packet_seeds, + iteration=iteration, + show_progress_bars=show_progress_bars, + no_of_packets=no_of_packets, + total_iterations=total_iterations, ) runner._montecarlo_virtual_luminosity.value[:] = v_packets_energy_hist @@ -109,6 +122,7 @@ def montecarlo_radial1d(model, plasma, runner): runner.virt_packet_last_line_interaction_out_id = np.concatenate( np.array(virt_packet_last_line_interaction_out_id) ).ravel() + update_iterations_pbar(1) @njit(**njit_dict) @@ -120,6 +134,10 @@ def montecarlo_main_loop( spectrum_frequency, number_of_vpackets, packet_seeds, + iteration, + show_progress_bars, + no_of_packets, + total_iterations, ): """ This is the main loop of the MonteCarlo routine that generates packets @@ -178,6 +196,16 @@ def montecarlo_main_loop( virt_packet_last_line_interaction_out_id = [] for i in prange(len(output_nus)): + if show_progress_bars: + with objmode: + update_amount = 1 + update_packet_pbar( + update_amount, + current_iteration=iteration, + no_of_packets=no_of_packets, + total_iterations=total_iterations, + ) + if montecarlo_configuration.single_packet_seed != -1: seed = packet_seeds[montecarlo_configuration.single_packet_seed] np.random.seed(seed) @@ -214,8 +242,12 @@ def montecarlo_main_loop( vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx] vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx] - vpackets_initial_mu = vpacket_collection.initial_mus[: vpacket_collection.idx] - vpackets_initial_r = vpacket_collection.initial_rs[: vpacket_collection.idx] + vpackets_initial_mu = vpacket_collection.initial_mus[ + : vpacket_collection.idx + ] + vpackets_initial_r = vpacket_collection.initial_rs[ + : vpacket_collection.idx + ] v_packets_idx = np.floor( (vpackets_nu - spectrum_frequency[0]) / delta_nu @@ -233,17 +265,29 @@ def montecarlo_main_loop( if montecarlo_configuration.VPACKET_LOGGING: for vpacket_collection in vpacket_collections: vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx] - vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx] - vpackets_initial_mu = vpacket_collection.initial_mus[: vpacket_collection.idx] - vpackets_initial_r = vpacket_collection.initial_rs[: vpacket_collection.idx] + vpackets_energy = vpacket_collection.energies[ + : vpacket_collection.idx + ] + vpackets_initial_mu = vpacket_collection.initial_mus[ + : vpacket_collection.idx + ] + vpackets_initial_r = vpacket_collection.initial_rs[ + : vpacket_collection.idx + ] virt_packet_nus.append(np.ascontiguousarray(vpackets_nu)) virt_packet_energies.append(np.ascontiguousarray(vpackets_energy)) - virt_packet_initial_mus.append(np.ascontiguousarray(vpackets_initial_mu)) - virt_packet_initial_rs.append(np.ascontiguousarray(vpackets_initial_r)) - virt_packet_last_interaction_in_nu.append(np.ascontiguousarray( - vpacket_collection.last_interaction_in_nu[ - : vpacket_collection.idx - ]) + virt_packet_initial_mus.append( + np.ascontiguousarray(vpackets_initial_mu) + ) + virt_packet_initial_rs.append( + np.ascontiguousarray(vpackets_initial_r) + ) + virt_packet_last_interaction_in_nu.append( + np.ascontiguousarray( + vpacket_collection.last_interaction_in_nu[ + : vpacket_collection.idx + ] + ) ) virt_packet_last_interaction_type.append( np.ascontiguousarray( diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 77edbd628e6..df0d3427f6a 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -134,6 +134,7 @@ def __init__( nthreads, show_convergence_plots, convergence_plots_kwargs, + show_progress_bars, ): super(Simulation, self).__init__(iterations, model.no_of_shells) @@ -151,6 +152,7 @@ def __init__( self.luminosity_nu_end = luminosity_nu_end self.luminosity_requested = luminosity_requested self.nthreads = nthreads + self.show_progress_bars = show_progress_bars if convergence_strategy.type in ("damped"): self.convergence_strategy = convergence_strategy @@ -372,6 +374,8 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0, last_run=False): nthreads=self.nthreads, last_run=last_run, iteration=self.iterations_executed, + total_iterations=self.iterations, + show_progress_bars=self.show_progress_bars, ) output_energy = self.runner.output_energy if np.sum(output_energy < 0) == len(output_energy): @@ -592,6 +596,7 @@ def from_config( packet_source=None, virtual_packet_logging=False, show_convergence_plots=True, + show_progress_bars=True, **kwargs, ): """ @@ -684,4 +689,5 @@ def from_config( convergence_strategy=config.montecarlo.convergence_strategy, nthreads=config.montecarlo.nthreads, convergence_plots_kwargs=convergence_plots_kwargs, + show_progress_bars=show_progress_bars, ) diff --git a/tardis/util/base.py b/tardis/util/base.py index 78714be9331..89aa8b2fa75 100644 --- a/tardis/util/base.py +++ b/tardis/util/base.py @@ -13,7 +13,8 @@ import tardis from tardis.io.util import get_internal_data_path -from IPython import get_ipython +from IPython import get_ipython, display +import tqdm k_B_cgs = constants.k_B.cgs.value c_cgs = constants.c.cgs.value @@ -596,3 +597,130 @@ def is_notebook(): # All other shell instances are returned False else: return False + + +if is_notebook(): + iterations_pbar = tqdm.notebook.tqdm( + desc="Iterations:", + bar_format="{desc:<}{bar}{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) + iterations_pbar.container.close() + packet_pbar = tqdm.notebook.tqdm( + desc="Packets: ", + postfix="0", + bar_format="{desc:<}{bar}{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) + packet_pbar.container.close() + +else: + iterations_pbar = tqdm.tqdm( + desc="Iterations:", + bar_format="{desc:<}{bar:80}{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) + packet_pbar = tqdm.tqdm( + desc="Packets: ", + postfix="0", + bar_format="{desc:<}{bar:80}{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", + ) + + +def update_packet_pbar(i, current_iteration, no_of_packets, total_iterations): + """ + Update progress bars as each packet is propagated. + + Parameters + ---------- + i : int + Amount by which the progress bar needs to be updated. + current_iteration : int + Current iteration number. + no_of_packets : int + Total number of packets in one iteration. + total_iterations : int + Total number of iterations. + """ + if packet_pbar.postfix == "": + packet_pbar.postfix = "0" + bar_iteration = int(packet_pbar.postfix) - 1 + + # fix bar layout when run_tardis is called for the first time + if iterations_pbar.total == None: + fix_bar_layout(iterations_pbar, total_iterations=total_iterations) + if packet_pbar.total == None: + fix_bar_layout(packet_pbar, no_of_packets=no_of_packets) + + # display and reset progress bar when run_tardis is called again + if iterations_pbar.n == total_iterations: + if type(iterations_pbar).__name__ == "tqdm_notebook": + iterations_pbar.container.close() + fix_bar_layout(iterations_pbar, total_iterations=total_iterations) + + if bar_iteration > current_iteration: + packet_pbar.postfix = current_iteration + if type(packet_pbar).__name__ == "tqdm_notebook": + # stop displaying last container + packet_pbar.container.close() + fix_bar_layout(packet_pbar, no_of_packets=no_of_packets) + + # reset progress bar with each iteration + if bar_iteration < current_iteration: + packet_pbar.reset(total=no_of_packets) + packet_pbar.postfix = str(current_iteration + 1) + + packet_pbar.update(i) + + +def update_iterations_pbar(i): + """ + Update progress bar for each iteration. + + Parameters + ---------- + i : int + Amount by which the progress bar needs to be updated. + """ + iterations_pbar.update(i) + + +def fix_bar_layout(bar, no_of_packets=None, total_iterations=None): + """ + Fix the layout of progress bars. + + Parameters + ---------- + bar : tqdm instance + Progress bar to change the layout of. + no_of_packets : int, optional + Number of packets to be propagated. + total_iterations : int, optional + Total number of iterations. + """ + if type(bar).__name__ == "tqdm_notebook": + bar.container = bar.status_printer( + bar.fp, + bar.total, + bar.desc, + bar.ncols, + ) + if no_of_packets is not None: + bar.reset(total=no_of_packets) + if total_iterations is not None: + bar.reset(total=total_iterations) + + # change the amount of space the prefix string of the bar takes + # here, either packets or iterations + bar.container.children[0].layout.width = "6%" + + # change the length of the bar + bar.container.children[1].layout.width = "60%" + + # display the progress bar + display.display(bar.container) + + else: + if no_of_packets is not None: + bar.reset(total=no_of_packets) + if total_iterations is not None: + bar.reset(total=total_iterations) + else: + pass