Skip to content

Commit

Permalink
refactored code and added functionality to change layout of convergen…
Browse files Browse the repository at this point in the history
…ce plots from run_tardis
  • Loading branch information
atharva-2001 committed Jun 15, 2021
1 parent c712b26 commit 3cb95ee
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 41 deletions.
2 changes: 2 additions & 0 deletions tardis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def run_tardis(
packet_source=None,
simulation_callbacks=[],
virtual_packet_logging=False,
**kwargs,
):
"""
This function is one of the core functions to run TARDIS from a given
Expand Down Expand Up @@ -54,6 +55,7 @@ def run_tardis(
packet_source=packet_source,
atom_data=atom_data,
virtual_packet_logging=virtual_packet_logging,
**kwargs,
)
for cb in simulation_callbacks:
simulation.add_callback(*cb)
Expand Down
19 changes: 16 additions & 3 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
from astropy import units as u, constants as const
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from tardis import model

from tardis.montecarlo import MontecarloRunner
Expand All @@ -12,7 +12,7 @@
from tardis.io.util import HDFWriterMixin
from tardis.io.config_reader import ConfigurationError
from tardis.montecarlo import montecarlo_configuration as mc_config_module
from tardis.visualization import UpdateCplots
from tardis.visualization import ConvergencePlots

# Adding logging support
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -129,6 +129,7 @@ def __init__(
luminosity_requested,
convergence_strategy,
nthreads,
cplots_kwargs,
):

super(Simulation, self).__init__(iterations, model.no_of_shells)
Expand Down Expand Up @@ -162,7 +163,9 @@ def __init__(
f"not damped or custom "
f"- input is {convergence_strategy.type}"
)
self.cplots = UpdateCplots()
self.cplots = ConvergencePlots(
iterations=self.iterations, **cplots_kwargs
)
self._callbacks = OrderedDict()
self._cb_next_id = 0

Expand Down Expand Up @@ -566,6 +569,15 @@ def from_config(
packet_source=packet_source,
virtual_packet_logging=virtual_packet_logging,
)
cplots_kwargs = defaultdict(dict)

if "plasma_plot_config" in kwargs:
cplots_kwargs["plasma_plot_config"] = kwargs["plasma_plot_config"]

if "luminosity_plot_config" in kwargs:
cplots_kwargs["luminosity_plot_config"] = kwargs[
"luminosity_plot_config"
]

luminosity_nu_start = config.supernova.luminosity_wavelength_end.to(
u.Hz, u.spectral()
Expand Down Expand Up @@ -598,4 +610,5 @@ def from_config(
luminosity_requested=config.supernova.luminosity_requested.cgs,
convergence_strategy=config.montecarlo.convergence_strategy,
nthreads=config.montecarlo.nthreads,
cplots_kwargs=cplots_kwargs,
)
6 changes: 1 addition & 5 deletions tardis/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
"""Visualization tools and widgets for TARDIS."""

from tardis.visualization.tools.convergence_plot import (
ConvergencePlots,
BuildCplots,
UpdateCplots,
)
from tardis.visualization.tools.convergence_plot import ConvergencePlots

from tardis.visualization.widgets.shell_info import (
shell_info_from_simulation,
Expand Down
64 changes: 31 additions & 33 deletions tardis/visualization/tools/convergence_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,25 @@ def transistion_colors(name="jet", iterations=20):


class ConvergencePlots(object):
def __init__(self):
def __init__(self, iterations, **kwargs):
self.iterable_data = {}
self.value_data = defaultdict(list)
self.rows = 4
self.cols = 2
self.specs = [
[{}, {}],
[{"colspan": 2}, None],
[{"colspan": 2}, None],
[{"colspan": 2}, None],
]
self.row_heights = [0.45, 0.1, 0.4, 0.1]
self.vertical_spacing = 0.07
self.iterations = 20
self.iterations = iterations
self.current_iteration = 1
self.luminosities = ["Emitted", "Absorbed", "Requested"]

if "plasma_plot_config" in kwargs:
if kwargs["plasma_plot_config"] != {}:
self.plasma_plot_config = kwargs["plasma_plot_config"]

if "luminosity_plot_config" in kwargs:
if kwargs["luminosity_plot_config"] != {}:
self.luminosity_plot_config = kwargs["luminosity_plot_config"]

def fetch_data(self, name=None, value=None, type=None):
"""
This allows user to fetch data from the Simulation class.
This data is stored and used when an iteration is completed.
Returns:
iterable_data: iterable data from the simulation, values like radiation temperature
value_data: values like luminosity
"""
if type == "iterable":
self.iterable_data[name] = value
Expand All @@ -56,12 +52,6 @@ def get_data(self):
"""
return self.iterable_data, self.value_data


class BuildCplots(ConvergencePlots):
def __init__(self):
super().__init__()
self.use_vbox = True

def create_plasma_plot(self):
"""
creates empty plasma plot
Expand All @@ -81,20 +71,27 @@ def create_plasma_plot(self):
},
yaxis={
"tickformat": "g",
"title": r"$T_{rad}\ [K]$",
"title": r"$W$",
"range": [9000, 14000],
},
yaxis2={"tickformat": "g", "title": r"$W$"},
yaxis2={"tickformat": "g", "title": r"$T_{rad}\ [K]$"},
height=580,
)

# allows overriding default layout
if hasattr(self, "plasma_plot_config"):
for key in self.plasma_plot_config:
fig["layout"][key] = self.plasma_plot_config[key]

self.plasma_plot = fig

def create_luminosity_plot(self):
"""
creates empty luminosity plot
"""
marker_colors = ["#958aff", "#ff8b85", "#5cff74"]
marker_line_colors = ["#27006b", "#800000", "#00801c"]
marker_colors = ["#636EFA", "#EF553B", "#00CC96"]
self.luminosities = ["Emitted", "Absorbed", "Requested"]

fig = go.FigureWidget().set_subplots(
3,
Expand Down Expand Up @@ -132,7 +129,7 @@ def create_luminosity_plot(self):
)

fig.add_scatter(
name="Next Inner<br>Boundary Temperature",
name="Inner<br>Boundary Temperature",
row=1,
col=1,
hovertext="text",
Expand All @@ -144,10 +141,10 @@ def create_luminosity_plot(self):
)

fig = fig.update_layout(
xaxis=dict(range=[0, 21], dtick=2),
xaxis=dict(range=[0, self.iterations + 1], dtick=2),
xaxis2=dict(
matches="x",
range=[0, 21],
range=[0, self.iterations + 1],
dtick=2,
),
xaxis3=dict(
Expand Down Expand Up @@ -179,18 +176,19 @@ def create_luminosity_plot(self):
hoverlabel_align="right",
legend_title_text="Luminosity",
)

# allows overriding default layout
if hasattr(self, "luminosity_plot_config"):
for key in self.luminosity_plot_config:
fig["layout"][key] = self.luminosity_plot_config[key]

self.luminosity_plot = fig

def build(self):
self.create_plasma_plot()
self.create_luminosity_plot()
display(widgets.VBox([self.plasma_plot, self.luminosity_plot]))


class UpdateCplots(BuildCplots):
def __init__(self):
super().__init__()

def update_plasma_plots(self):
x = self.iterable_data["velocity"].value.tolist()
customdata = len(x) * [
Expand Down Expand Up @@ -255,7 +253,7 @@ def update_luminosity_plot(self):
self.luminosity_plot.data[-1].y = self.value_data["t_inner"]
self.luminosity_plot.data[
-1
].hovertemplate = "Next Inner Body Temperature: %{y:.2f} at X = %{x:,.0f}<extra></extra>"
].hovertemplate = "Inner Body Temperature: %{y:.2f} at X = %{x:,.0f}<extra></extra>"

def update(self):
if self.current_iteration == 1:
Expand Down

0 comments on commit 3cb95ee

Please sign in to comment.