From 3cb95ee9434f37c6ae1c549eb1e0b765d422968d Mon Sep 17 00:00:00 2001 From: Atharva Arya Date: Tue, 15 Jun 2021 17:19:12 +0530 Subject: [PATCH] refactored code and added functionality to change layout of convergence plots from run_tardis --- tardis/base.py | 2 + tardis/simulation/base.py | 19 +++++- tardis/visualization/__init__.py | 6 +- .../visualization/tools/convergence_plot.py | 64 +++++++++---------- 4 files changed, 50 insertions(+), 41 deletions(-) diff --git a/tardis/base.py b/tardis/base.py index 73766f7fe63..6257e7eb900 100644 --- a/tardis/base.py +++ b/tardis/base.py @@ -7,6 +7,7 @@ def run_tardis( packet_source=None, simulation_callbacks=[], virtual_packet_logging=False, + **kwargs, ): """ This function is one of the core functions to run TARDIS from a given @@ -54,6 +55,7 @@ def run_tardis( packet_source=packet_source, atom_data=atom_data, virtual_packet_logging=virtual_packet_logging, + **kwargs, ) for cb in simulation_callbacks: simulation.add_callback(*cb) diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 36918ba3ebd..cd0aaffbb6e 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd from astropy import units as u, constants as const -from collections import OrderedDict +from collections import OrderedDict, defaultdict from tardis import model from tardis.montecarlo import MontecarloRunner @@ -12,7 +12,7 @@ from tardis.io.util import HDFWriterMixin from tardis.io.config_reader import ConfigurationError from tardis.montecarlo import montecarlo_configuration as mc_config_module -from tardis.visualization import UpdateCplots +from tardis.visualization import ConvergencePlots # Adding logging support logger = logging.getLogger(__name__) @@ -129,6 +129,7 @@ def __init__( luminosity_requested, convergence_strategy, nthreads, + cplots_kwargs, ): super(Simulation, self).__init__(iterations, model.no_of_shells) @@ -162,7 +163,9 @@ def __init__( f"not damped or custom " f"- input is {convergence_strategy.type}" ) - self.cplots = UpdateCplots() + self.cplots = ConvergencePlots( + iterations=self.iterations, **cplots_kwargs + ) self._callbacks = OrderedDict() self._cb_next_id = 0 @@ -566,6 +569,15 @@ def from_config( packet_source=packet_source, virtual_packet_logging=virtual_packet_logging, ) + cplots_kwargs = defaultdict(dict) + + if "plasma_plot_config" in kwargs: + cplots_kwargs["plasma_plot_config"] = kwargs["plasma_plot_config"] + + if "luminosity_plot_config" in kwargs: + cplots_kwargs["luminosity_plot_config"] = kwargs[ + "luminosity_plot_config" + ] luminosity_nu_start = config.supernova.luminosity_wavelength_end.to( u.Hz, u.spectral() @@ -598,4 +610,5 @@ def from_config( luminosity_requested=config.supernova.luminosity_requested.cgs, convergence_strategy=config.montecarlo.convergence_strategy, nthreads=config.montecarlo.nthreads, + cplots_kwargs=cplots_kwargs, ) diff --git a/tardis/visualization/__init__.py b/tardis/visualization/__init__.py index 206fa766919..bad425bd8d9 100644 --- a/tardis/visualization/__init__.py +++ b/tardis/visualization/__init__.py @@ -1,10 +1,6 @@ """Visualization tools and widgets for TARDIS.""" -from tardis.visualization.tools.convergence_plot import ( - ConvergencePlots, - BuildCplots, - UpdateCplots, -) +from tardis.visualization.tools.convergence_plot import ConvergencePlots from tardis.visualization.widgets.shell_info import ( shell_info_from_simulation, diff --git a/tardis/visualization/tools/convergence_plot.py b/tardis/visualization/tools/convergence_plot.py index 57bf744f59a..305eeb99494 100644 --- a/tardis/visualization/tools/convergence_plot.py +++ b/tardis/visualization/tools/convergence_plot.py @@ -17,29 +17,25 @@ def transistion_colors(name="jet", iterations=20): class ConvergencePlots(object): - def __init__(self): + def __init__(self, iterations, **kwargs): self.iterable_data = {} self.value_data = defaultdict(list) - self.rows = 4 - self.cols = 2 - self.specs = [ - [{}, {}], - [{"colspan": 2}, None], - [{"colspan": 2}, None], - [{"colspan": 2}, None], - ] - self.row_heights = [0.45, 0.1, 0.4, 0.1] - self.vertical_spacing = 0.07 - self.iterations = 20 + self.iterations = iterations self.current_iteration = 1 + self.luminosities = ["Emitted", "Absorbed", "Requested"] + + if "plasma_plot_config" in kwargs: + if kwargs["plasma_plot_config"] != {}: + self.plasma_plot_config = kwargs["plasma_plot_config"] + + if "luminosity_plot_config" in kwargs: + if kwargs["luminosity_plot_config"] != {}: + self.luminosity_plot_config = kwargs["luminosity_plot_config"] def fetch_data(self, name=None, value=None, type=None): """ This allows user to fetch data from the Simulation class. This data is stored and used when an iteration is completed. - Returns: - iterable_data: iterable data from the simulation, values like radiation temperature - value_data: values like luminosity """ if type == "iterable": self.iterable_data[name] = value @@ -56,12 +52,6 @@ def get_data(self): """ return self.iterable_data, self.value_data - -class BuildCplots(ConvergencePlots): - def __init__(self): - super().__init__() - self.use_vbox = True - def create_plasma_plot(self): """ creates empty plasma plot @@ -81,20 +71,27 @@ def create_plasma_plot(self): }, yaxis={ "tickformat": "g", - "title": r"$T_{rad}\ [K]$", + "title": r"$W$", "range": [9000, 14000], }, - yaxis2={"tickformat": "g", "title": r"$W$"}, + yaxis2={"tickformat": "g", "title": r"$T_{rad}\ [K]$"}, height=580, ) + # allows overriding default layout + if hasattr(self, "plasma_plot_config"): + for key in self.plasma_plot_config: + fig["layout"][key] = self.plasma_plot_config[key] + self.plasma_plot = fig def create_luminosity_plot(self): + """ + creates empty luminosity plot + """ marker_colors = ["#958aff", "#ff8b85", "#5cff74"] marker_line_colors = ["#27006b", "#800000", "#00801c"] marker_colors = ["#636EFA", "#EF553B", "#00CC96"] - self.luminosities = ["Emitted", "Absorbed", "Requested"] fig = go.FigureWidget().set_subplots( 3, @@ -132,7 +129,7 @@ def create_luminosity_plot(self): ) fig.add_scatter( - name="Next Inner
Boundary Temperature", + name="Inner
Boundary Temperature", row=1, col=1, hovertext="text", @@ -144,10 +141,10 @@ def create_luminosity_plot(self): ) fig = fig.update_layout( - xaxis=dict(range=[0, 21], dtick=2), + xaxis=dict(range=[0, self.iterations + 1], dtick=2), xaxis2=dict( matches="x", - range=[0, 21], + range=[0, self.iterations + 1], dtick=2, ), xaxis3=dict( @@ -179,6 +176,12 @@ def create_luminosity_plot(self): hoverlabel_align="right", legend_title_text="Luminosity", ) + + # allows overriding default layout + if hasattr(self, "luminosity_plot_config"): + for key in self.luminosity_plot_config: + fig["layout"][key] = self.luminosity_plot_config[key] + self.luminosity_plot = fig def build(self): @@ -186,11 +189,6 @@ def build(self): self.create_luminosity_plot() display(widgets.VBox([self.plasma_plot, self.luminosity_plot])) - -class UpdateCplots(BuildCplots): - def __init__(self): - super().__init__() - def update_plasma_plots(self): x = self.iterable_data["velocity"].value.tolist() customdata = len(x) * [ @@ -255,7 +253,7 @@ def update_luminosity_plot(self): self.luminosity_plot.data[-1].y = self.value_data["t_inner"] self.luminosity_plot.data[ -1 - ].hovertemplate = "Next Inner Body Temperature: %{y:.2f} at X = %{x:,.0f}" + ].hovertemplate = "Inner Body Temperature: %{y:.2f} at X = %{x:,.0f}" def update(self): if self.current_iteration == 1: