From 4704edc6f85e44a4432dface37d2ae4cef0c7871 Mon Sep 17 00:00:00 2001 From: Atharva Arya Date: Thu, 29 Jul 2021 11:49:17 +0530 Subject: [PATCH] Add option to enable/disble the progress bar --- tardis/base.py | 4 ++++ tardis/montecarlo/base.py | 2 ++ tardis/montecarlo/montecarlo_numba/base.py | 28 ++++++++++++++++------ tardis/simulation/base.py | 5 ++++ 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/tardis/base.py b/tardis/base.py index f05f63db7e7..f8df1d9400c 100644 --- a/tardis/base.py +++ b/tardis/base.py @@ -15,6 +15,7 @@ def run_tardis( show_cplots=True, log_level=None, specific=None, + show_progress_bar=True, **kwargs, ): """ @@ -52,6 +53,8 @@ def run_tardis( The default value None means that the `specific` specified in the configuration file will be used. show_cplots : bool, default: True, optional Option to enable tardis convergence plots. + show_progress_bar : bool, default: True, optional + Option to enable the progress bar. **kwargs : dict, optional Optional keyword arguments including those supported by :obj:`tardis.visualization.tools.convergence_plot.ConvergencePlots`. @@ -101,6 +104,7 @@ def run_tardis( atom_data=atom_data, virtual_packet_logging=virtual_packet_logging, show_cplots=show_cplots, + show_progress_bar=show_progress_bar, **kwargs, ) for cb in simulation_callbacks: diff --git a/tardis/montecarlo/base.py b/tardis/montecarlo/base.py index 401817fda76..1ab380fb0ab 100644 --- a/tardis/montecarlo/base.py +++ b/tardis/montecarlo/base.py @@ -270,6 +270,7 @@ def run( last_run=False, iteration=0, total_iterations=0, + show_progress_bar=True, ): """ Run the montecarlo calculation @@ -310,6 +311,7 @@ def run( iteration, total_packets, total_iterations, + show_progress_bar, self, ) self._integrator = FormalIntegrator(model, plasma, self) diff --git a/tardis/montecarlo/montecarlo_numba/base.py b/tardis/montecarlo/montecarlo_numba/base.py index 915f53d884e..e96367593b0 100644 --- a/tardis/montecarlo/montecarlo_numba/base.py +++ b/tardis/montecarlo/montecarlo_numba/base.py @@ -39,6 +39,7 @@ dynamic_ncols=True, bar_format="{bar}{percentage:3.0f}% of packets propagated, iteration 0/?", ) +packet_pbar.container.close() def update_packet_pbar(i, current_iteration, total_iterations, total_packets): @@ -62,7 +63,16 @@ def update_packet_pbar(i, current_iteration, total_iterations, total_packets): # set bar total when first called if packet_pbar.total == None: + packet_pbar.ncols = "100%" + packet_pbar.container = packet_pbar.status_printer( + packet_pbar.fp, + packet_pbar.total, + packet_pbar.desc, + packet_pbar.ncols, + ) + display(packet_pbar.container) packet_pbar.reset(total=total_packets) + packet_pbar.display() # display and reset progress bar when run_tardis is called again if bar_iteration > current_iteration: @@ -112,6 +122,7 @@ def montecarlo_radial1d( iteration, total_packets, total_iterations, + show_progress_bar, runner, ): packet_collection = PacketCollection( @@ -164,6 +175,7 @@ def montecarlo_radial1d( packet_seeds, iteration=iteration, total_iterations=total_iterations, + show_progress_bar=show_progress_bar, ) runner._montecarlo_virtual_luminosity.value[:] = v_packets_energy_hist @@ -211,6 +223,7 @@ def montecarlo_main_loop( packet_seeds, iteration, total_iterations, + show_progress_bar, ): """ This is the main loop of the MonteCarlo routine that generates packets @@ -269,13 +282,14 @@ def montecarlo_main_loop( virt_packet_last_line_interaction_out_id = [] for i in prange(len(output_nus)): - with objmode: - update_packet_pbar( - 1, - current_iteration=iteration, - total_iterations=total_iterations, - total_packets=total_packets, - ) + if show_progress_bar: + with objmode: + update_packet_pbar( + 1, + current_iteration=iteration, + total_iterations=total_iterations, + total_packets=total_packets, + ) if montecarlo_configuration.single_packet_seed != -1: seed = packet_seeds[montecarlo_configuration.single_packet_seed] diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index f783bef7bb4..4aab319eb9c 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -134,6 +134,7 @@ def __init__( convergence_strategy, nthreads, show_cplots, + show_progress_bar, cplots_kwargs, ): @@ -153,6 +154,7 @@ def __init__( self.luminosity_nu_end = luminosity_nu_end self.luminosity_requested = luminosity_requested self.nthreads = nthreads + self.show_progress_bar = show_progress_bar if convergence_strategy.type in ("damped"): self.convergence_strategy = convergence_strategy @@ -370,6 +372,7 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0, last_run=False): last_run=last_run, iteration=self.iterations_executed, total_iterations=self.iterations, + show_progress_bar=self.show_progress_bar, ) output_energy = self.runner.output_energy if np.sum(output_energy < 0) == len(output_energy): @@ -587,6 +590,7 @@ def from_config( packet_source=None, virtual_packet_logging=False, show_cplots=True, + show_progress_bar=True, **kwargs, ): """ @@ -683,4 +687,5 @@ def from_config( convergence_strategy=config.montecarlo.convergence_strategy, nthreads=config.montecarlo.nthreads, cplots_kwargs=cplots_kwargs, + show_progress_bar=show_progress_bar, )