Skip to content

Commit

Permalink
Adding Convergence Plots (#1636)
Browse files Browse the repository at this point in the history
* initial commit

adding convergence plot notebook and python file

* updating class structure

* adding functions to create empty plots

* adding functions to update convergence plots

* fixing imports for convergence plots

* changes to get update convergence plots from tardis/simulation/base.py

* adding updated notebook

* allowing user to customize plot layout from run_tardis

* adding options to not show plots and change colorscale

* adding docstrings

* adding updated notebook

* fix typo

fixing swapped y-axis labels in plasma plot

* exporting convergence plots

* layout changes

fixing colors, removing marker points, fixing labels

* adding check to see if data is collected

necessary when running simulation with just one iteration

* moving convergence plots notebook

* adding tests for convergence class and transition_colors function

* fix typos and adding comments

* reformatted using black

* adding function to override default plot configuration

* showing how plots can be updated in the notebook

* fixing plot heights and tick labels

* code refactor

raising TypeErrors, fix typos

* adding docstrings

* fixing axes and legend, converting units using astropy

* fixing hover data and making axes titles romanized/upright in certain places

* use same colorscale for both plasma and luminosity plots

* add docstrings and code comments

* add option to change colorscale, format using black, add updated notebook

* remove unnecessary customizations, edit docstrings

* add documentation in the notebook

* test length of fig.data tuple after build

* add tests for update and override function

* edit docstrings and code comments

* renaming luminosity_plot to t_inner_luminosities_plot

* minor changes in documentation and docstrings

* [build docs]
  • Loading branch information
atharva-2001 authored Jul 23, 2021
1 parent 2010e81 commit f27fe3d
Show file tree
Hide file tree
Showing 8 changed files with 1,596 additions and 2 deletions.
855 changes: 855 additions & 0 deletions docs/io/visualization/convergence_plot.ipynb

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions tardis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ def run_tardis(
packet_source=None,
simulation_callbacks=[],
virtual_packet_logging=False,
show_cplots=True,
log_level=None,
specific=None,
**kwargs,
):
"""
Run TARDIS from a given config object.
Expand Down Expand Up @@ -48,6 +50,12 @@ def run_tardis(
If True, only show the log messages from a particular log level, set by `log_level`.
If False, the logger shows log messages belonging to the level set and all levels above it in severity.
The default value None means that the `specific` specified in the configuration file will be used.
show_cplots : bool, default: True, optional
Option to enable tardis convergence plots.
**kwargs : dict, optional
Optional keyword arguments including those
supported by :obj:`tardis.visualization.tools.convergence_plot.ConvergencePlots`.
Returns
-------
Expand All @@ -73,6 +81,9 @@ def run_tardis(
)
tardis_config = Configuration.from_config_dict(config)

if not isinstance(show_cplots, bool):
raise TypeError("Expected bool in show_cplots argument")

logging_state(log_level, tardis_config, specific)

if atom_data is not None:
Expand All @@ -89,6 +100,8 @@ def run_tardis(
packet_source=packet_source,
atom_data=atom_data,
virtual_packet_logging=virtual_packet_logging,
show_cplots=show_cplots,
**kwargs,
)
for cb in simulation_callbacks:
simulation.add_callback(*cb)
Expand Down
80 changes: 79 additions & 1 deletion tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
from astropy import units as u, constants as const
from collections import OrderedDict
from tardis import model

from tardis.montecarlo import MontecarloRunner
from tardis.model import Radial1DModel
Expand All @@ -12,6 +13,7 @@
from tardis.io.config_reader import ConfigurationError
from tardis.util.base import is_notebook
from tardis.montecarlo import montecarlo_configuration as mc_config_module
from tardis.visualization import ConvergencePlots
from IPython.display import display

# Adding logging support
Expand Down Expand Up @@ -97,6 +99,7 @@ class Simulation(PlasmaStateStorerMixin, HDFWriterMixin):
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity
luminosity_requested : astropy.units.Quantity
cplots_kwargs: dict
nthreads : int
The number of threads to run montecarlo with
Expand Down Expand Up @@ -129,6 +132,8 @@ def __init__(
luminosity_requested,
convergence_strategy,
nthreads,
show_cplots,
cplots_kwargs,
):

super(Simulation, self).__init__(iterations, model.no_of_shells)
Expand All @@ -146,6 +151,7 @@ def __init__(
self.luminosity_nu_end = luminosity_nu_end
self.luminosity_requested = luminosity_requested
self.nthreads = nthreads

if convergence_strategy.type in ("damped"):
self.convergence_strategy = convergence_strategy
self.converged = False
Expand All @@ -162,6 +168,18 @@ def __init__(
f"- input is {convergence_strategy.type}"
)

if show_cplots:
self.cplots = ConvergencePlots(
iterations=self.iterations, **cplots_kwargs
)

if "export_cplots" in cplots_kwargs:
if not isinstance(cplots_kwargs["export_cplots"], bool):
raise TypeError("Expected bool in export_cplots argument")
self.export_cplots = cplots_kwargs["export_cplots"]
else:
self.export_cplots = False

self._callbacks = OrderedDict()
self._cb_next_id = 0

Expand Down Expand Up @@ -289,6 +307,22 @@ def advance_state(self):
else:
next_t_inner = self.model.t_inner

if hasattr(self, "cplots"):
self.cplots.fetch_data(
name="t_inner",
value=self.model.t_inner.value,
item_type="value",
)
self.cplots.fetch_data(
name="t_rad", value=self.model.t_rad, item_type="iterable"
)
self.cplots.fetch_data(
name="w", value=self.model.w, item_type="iterable"
)
self.cplots.fetch_data(
name="velocity", value=self.model.velocity, item_type="iterable"
)

self.log_plasma_state(
self.model.t_rad,
self.model.w,
Expand Down Expand Up @@ -343,6 +377,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0, last_run=False):
reabsorbed_luminosity = self.runner.calculate_reabsorbed_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
)
if hasattr(self, "cplots"):
self.cplots.fetch_data(
name="Emitted",
value=emitted_luminosity.value,
item_type="value",
)
self.cplots.fetch_data(
name="Absorbed",
value=reabsorbed_luminosity.value,
item_type="value",
)
self.cplots.fetch_data(
name="Requested",
value=self.luminosity_requested.value,
item_type="value",
)

self.log_run_results(emitted_luminosity, reabsorbed_luminosity)
self.iterations_executed += 1

Expand All @@ -362,6 +413,8 @@ def run(self):
)
self.iterate(self.no_of_packets)
self.converged = self.advance_state()
if hasattr(self, "cplots"):
self.cplots.update()
self._call_back()
if self.converged:
if self.convergence_strategy.stop_if_converged:
Expand All @@ -379,6 +432,13 @@ def run(self):
)

self.reshape_plasma_state_store(self.iterations_executed)
if hasattr(self, "cplots"):
self.cplots.fetch_data(
name="t_inner",
value=self.model.t_inner.value,
item_type="value",
)
self.cplots.update(export_cplots=self.export_cplots, last=True)

logger.info(
f"Simulation finished in {self.iterations_executed:d} iterations "
Expand Down Expand Up @@ -518,7 +578,12 @@ def remove_callback(self, id):

@classmethod
def from_config(
cls, config, packet_source=None, virtual_packet_logging=False, **kwargs
cls,
config,
packet_source=None,
virtual_packet_logging=False,
show_cplots=True,
**kwargs,
):
"""
Create a new Simulation instance from a Configuration object.
Expand Down Expand Up @@ -564,6 +629,17 @@ def from_config(
virtual_packet_logging=virtual_packet_logging,
)

cplots_config_options = [
"plasma_plot_config",
"t_inner_luminosities_config",
"plasma_cmap",
"t_inner_luminosities_colors",
"export_cplots",
]
cplots_kwargs = {}
for item in set(cplots_config_options).intersection(kwargs.keys()):
cplots_kwargs[item] = kwargs[item]

luminosity_nu_start = config.supernova.luminosity_wavelength_end.to(
u.Hz, u.spectral()
)
Expand All @@ -587,6 +663,7 @@ def from_config(
model=model,
plasma=plasma,
runner=runner,
show_cplots=show_cplots,
no_of_packets=int(config.montecarlo.no_of_packets),
no_of_virtual_packets=int(config.montecarlo.no_of_virtual_packets),
luminosity_nu_start=luminosity_nu_start,
Expand All @@ -595,4 +672,5 @@ def from_config(
luminosity_requested=config.supernova.luminosity_requested.cgs,
convergence_strategy=config.montecarlo.convergence_strategy,
nthreads=config.montecarlo.nthreads,
cplots_kwargs=cplots_kwargs,
)
2 changes: 2 additions & 0 deletions tardis/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Visualization tools and widgets for TARDIS."""

from tardis.visualization.tools.convergence_plot import ConvergencePlots

from tardis.visualization.widgets.shell_info import (
shell_info_from_simulation,
shell_info_from_hdf,
Expand Down
Loading

0 comments on commit f27fe3d

Please sign in to comment.