diff --git a/docs/io/visualization/convergence_plot.ipynb b/docs/io/visualization/convergence_plot.ipynb
new file mode 100644
index 00000000000..7524fe90e96
--- /dev/null
+++ b/docs/io/visualization/convergence_plot.ipynb
@@ -0,0 +1,855 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d5f95feb",
+ "metadata": {},
+ "source": [
+ "# Convergence Plots"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dc1a0c1f",
+ "metadata": {},
+ "source": [
+ "The Convergence Plots consist of two Plotly FigureWidget Subplots, the `plasma_plot` and the `t_inner_luminosities_plot`. The plots are stored in the `cplots` attribute of the simulation object `sim` and can be accessed using `sim.cplots.plasma_plot` and `sim.cplots.t_inner_luminosities_plot`.\n",
+ "\n",
+ "The Convergence Plots are shown by default when you running TARDIS because `show_cplots` parameter of the `run_tardis()` function is set to `True`. If you don't want to do this, set it to `False`. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6db4edf2",
+ "metadata": {},
+ "source": [
+ "
\n",
+ " \n",
+ "You only need to include `export_cplots=True` in the `run_tardis` function when you want to share the notebook. The function shows the plot using the Plotly `notebook_connected` renderer, which helps display the plot online. You don't need to do it when running the notebook locally.\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "ba7cc7d2",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9e8c4cf955fe459a903c206b12dce519",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(FigureWidget({\n",
+ " 'data': [{'type': 'scatter', 'uid': '6300bc97-3326-433b-9702-37a6cef7ba32', …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from tardis import run_tardis\n",
+ "sim = run_tardis('tardis_example.yml', export_cplots=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ae1623d6",
+ "metadata": {},
+ "source": [
+ "## Displaying Convergence Plots\n",
+ "You can also call the plots outside of `run_tardis` function. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "48768471",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sim.cplots.plasma_plot.show(renderer=\"notebook_connected\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "9db6b395",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sim.cplots.t_inner_luminosities_plot.show(renderer=\"notebook_connected\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1bf44fa",
+ "metadata": {},
+ "source": [
+ "## Changing Line Colors\n",
+ "The default line-colors of the plasma plots can be changed by passing the name of the cmap in the `plasma_cmap` option. \n",
+ "\n",
+ "```py\n",
+ "sim = run_tardis(\"tardis_example.yml\",plasma_cmap= \"viridis\")\n",
+ "```\n",
+ "\n",
+ "Alongwith the cmap name, one can also provide a list of colors in rgb, hex or css-names format in the `t_inner_luminosities_colors` option to change the default colors of the luminosity and inner boundary temperature plots. \n",
+ "```py\n",
+ "# hex colors example list\n",
+ "colors = [\n",
+ " '#8c564b', # chestnut brown\n",
+ " '#e377c2', # raspberry yogurt pink\n",
+ " '#7f7f7f', # middle gray\n",
+ " '#bcbd22', # curry yellow-green\n",
+ " '#17becf' # blue-teal\n",
+ "]\n",
+ "\n",
+ "# rgb colors example list\n",
+ "colors = ['rgb(31, 119, 180)',\n",
+ " 'rgb(255, 127, 14)',\n",
+ " 'rgb(44, 160, 44)', \n",
+ " 'rgb(214, 39, 40)',\n",
+ " 'rgb(148, 103, 189)',]\n",
+ " \n",
+ "# css colors\n",
+ "colors = [\"indigo\",\"lightseagreen\", \"midnightblue\", \"pink\", \"teal\"]\n",
+ "```\n",
+ "For more css-names please see [this](https://www.w3schools.com/colors/colors_names.asp). "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "9109967e",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a5d948f93c5e46fba408974c41901e33",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(FigureWidget({\n",
+ " 'data': [{'type': 'scatter', 'uid': '708836fc-9431-4f4d-b65d-64e3ce7d217f', …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sim = run_tardis(\n",
+ " \"tardis_example.yml\",\n",
+ " plasma_cmap= \"viridis\", \n",
+ " t_inner_luminosities_colors = ['rgb(102, 197, 204)',\n",
+ " 'rgb(246, 207, 113)',\n",
+ " 'rgb(248, 156, 116)',\n",
+ " 'rgb(220, 176, 242)',\n",
+ " 'rgb(135, 197, 95)'],\n",
+ " export_cplots = True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1f63c60a",
+ "metadata": {},
+ "source": [
+ "## Changing the default layout"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3a3bdc2d",
+ "metadata": {},
+ "source": [
+ "You can override the default layout by passing dictionaries as arguments in `t_inner_luminosities_config` and `plasma_plot_config` in the `run_tardis` function. The dictionaries should have the format of `plotly.graph_objects.FigureWidget().to_dict()`. For more information on the structure of the dictionary, please see the [plotly documentation](https://plotly.com/python/figure-structure/). \n",
+ "\n",
+ "For sake of simplicity, all properties in the data dictionary are applied equally across all traces, meaning traces-specific properties can't be changed from the function. They however be changed after the simulation has finished, for example:\n",
+ "```py\n",
+ "sim.cplots.t_inner_luminosities_plot.data[0].line.dash = \"dashdot\"\n",
+ "```\n",
+ "\n",
+ "You can investigate more about the layout/data of any plots by calling `sim.cplots.t_inner_luminosities_plot.layout` or `sim.cplots.t_inner_luminosities_plot.data`. \n",
+ "\n",
+ "Here is an example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "548990c8",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0f661eefdb7b472b86353beeaae2693e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(FigureWidget({\n",
+ " 'data': [{'type': 'scatter', 'uid': 'ad54a6db-bbf4-4c68-9e0d-9cfb0ed1b9ad', …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sim = run_tardis(\n",
+ " \"tardis_example.yml\",\n",
+ " plasma_plot_config={\n",
+ " \"layout\": {\n",
+ " \"template\": \"ggplot2\",\n",
+ " \"xaxis1\": {\n",
+ " \"nticks\": 20\n",
+ " },\n",
+ " \"xaxis2\": {\n",
+ " \"title\": {\"text\": \"new changed title of x axis2\"},\n",
+ " \"nticks\": 20\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ " t_inner_luminosities_config={\n",
+ " \"data\": {\n",
+ " \"line\":{\n",
+ " \"dash\":\"dot\"\n",
+ " },\n",
+ " \"mode\": \"lines+markers\",\n",
+ " },\n",
+ " \"layout\": {\n",
+ " \"template\": \"plotly_dark\",\n",
+ " \"hovermode\":\"x\",\n",
+ " \"xaxis\":{\"showgrid\":False},\n",
+ " \"xaxis2\":{\"showgrid\":False},\n",
+ " \"xaxis3\":{\"showgrid\":False},\n",
+ " \n",
+ " },\n",
+ " },\n",
+ " export_cplots = True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "6ceaaa1e",
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Help on class ConvergencePlots in module tardis.visualization.tools.convergence_plot:\n",
+ "\n",
+ "class ConvergencePlots(builtins.object)\n",
+ " | ConvergencePlots(iterations, **kwargs)\n",
+ " | \n",
+ " | Create and update convergence plots for visualizing convergence of the simulation.\n",
+ " | \n",
+ " | Parameters\n",
+ " | ----------\n",
+ " | iterations : int\n",
+ " | iteration number\n",
+ " | **kwargs : dict, optional\n",
+ " | Additional keyword arguments. These arguments are defined in the Other Parameters section.\n",
+ " | \n",
+ " | Other Parameters\n",
+ " | ----------------\n",
+ " | plasma_plot_config : dict, optional\n",
+ " | Dictionary used to override default plot properties of plasma plots.\n",
+ " | t_inner_luminosities_config : dict, optional\n",
+ " | Dictionary used to override default plot properties of the inner boundary temperature and luminosity plots.\n",
+ " | plasma_cmap : str, default: 'jet', optional\n",
+ " | String defining the cmap used in plasma plots.\n",
+ " | t_inner_luminosities_colors : str or list, optional\n",
+ " | String defining cmap for luminosity and inner boundary temperature plot.\n",
+ " | The list can be a list of colors in rgb, hex or css-names format as well.\n",
+ " | export_cplots : bool, default: False, optional\n",
+ " | If True, plots are displayed again using the ``notebook_connected`` renderer. This helps\n",
+ " | to display the plots in the documentation or in platforms like nbviewer.\n",
+ " | \n",
+ " | Notes\n",
+ " | -----\n",
+ " | When overriding plot's configuration using the ``plasma_plot_config`` and the\n",
+ " | ``t_inner_luminosities_config`` dictionaries, data related properties are\n",
+ " | applied equally accross all traces.\n",
+ " | The dictionary should have a structure like that of ``plotly.graph_objs.FigureWidget.to_dict()``,\n",
+ " | for more information please see https://plotly.com/python/figure-structure/\n",
+ " | \n",
+ " | Methods defined here:\n",
+ " | \n",
+ " | __init__(self, iterations, **kwargs)\n",
+ " | Initialize self. See help(type(self)) for accurate signature.\n",
+ " | \n",
+ " | build(self, display_plot=True)\n",
+ " | Create empty convergence plots and display them.\n",
+ " | \n",
+ " | Parameters\n",
+ " | ----------\n",
+ " | display_plot : bool, default: True, optional\n",
+ " | Displays empty plots.\n",
+ " | \n",
+ " | create_plasma_plot(self)\n",
+ " | Create an empty plasma plot.\n",
+ " | \n",
+ " | create_t_inner_luminosities_plot(self)\n",
+ " | Create an empty t_inner and luminosity plot.\n",
+ " | \n",
+ " | fetch_data(self, name=None, value=None, item_type=None)\n",
+ " | Fetch data from the Simulation class.\n",
+ " | \n",
+ " | Parameters\n",
+ " | ----------\n",
+ " | name : string\n",
+ " | name of the data\n",
+ " | value : string or array\n",
+ " | string or an array of quantities\n",
+ " | item_type : string\n",
+ " | either iterable or value\n",
+ " | \n",
+ " | override_plot_parameters(self, fig, parameters)\n",
+ " | Override default plot properties.\n",
+ " | \n",
+ " | Any property inside the data dictionary is however, applied equally across all traces.\n",
+ " | This means trace-specific data properties can't be changed using this function.\n",
+ " | \n",
+ " | Parameters\n",
+ " | ----------\n",
+ " | fig : go.FigureWidget\n",
+ " | FigureWidget object to be updated\n",
+ " | parameters : dict\n",
+ " | Dictionary used to update the default plot style.\n",
+ " | \n",
+ " | update(self, export_cplots=False, last=False)\n",
+ " | Update the convergence plots every iteration.\n",
+ " | \n",
+ " | Parameters\n",
+ " | ----------\n",
+ " | export_cplots : bool, default: False, optional\n",
+ " | Displays the convergence plots again using plotly's ``notebook_connected`` renderer.\n",
+ " | This helps to display the plots in notebooks when shared on platforms like nbviewer.\n",
+ " | Please see https://plotly.com/python/renderers/ for more information.\n",
+ " | last : bool\n",
+ " | True if it's last iteration.\n",
+ " | \n",
+ " | update_plasma_plots(self)\n",
+ " | Update plasma convergence plots every iteration.\n",
+ " | \n",
+ " | update_t_inner_luminosities_plot(self)\n",
+ " | Update the t_inner and luminosity convergence plots every iteration.\n",
+ " | \n",
+ " | ----------------------------------------------------------------------\n",
+ " | Data descriptors defined here:\n",
+ " | \n",
+ " | __dict__\n",
+ " | dictionary for instance variables (if defined)\n",
+ " | \n",
+ " | __weakref__\n",
+ " | list of weak references to the object (if defined)\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from tardis.visualization import ConvergencePlots\n",
+ "help(ConvergencePlots)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tardis/base.py b/tardis/base.py
index 10a45fe432a..f05f63db7e7 100644
--- a/tardis/base.py
+++ b/tardis/base.py
@@ -12,8 +12,10 @@ def run_tardis(
packet_source=None,
simulation_callbacks=[],
virtual_packet_logging=False,
+ show_cplots=True,
log_level=None,
specific=None,
+ **kwargs,
):
"""
Run TARDIS from a given config object.
@@ -48,6 +50,12 @@ def run_tardis(
If True, only show the log messages from a particular log level, set by `log_level`.
If False, the logger shows log messages belonging to the level set and all levels above it in severity.
The default value None means that the `specific` specified in the configuration file will be used.
+ show_cplots : bool, default: True, optional
+ Option to enable tardis convergence plots.
+ **kwargs : dict, optional
+ Optional keyword arguments including those
+ supported by :obj:`tardis.visualization.tools.convergence_plot.ConvergencePlots`.
+
Returns
-------
@@ -73,6 +81,9 @@ def run_tardis(
)
tardis_config = Configuration.from_config_dict(config)
+ if not isinstance(show_cplots, bool):
+ raise TypeError("Expected bool in show_cplots argument")
+
logging_state(log_level, tardis_config, specific)
if atom_data is not None:
@@ -89,6 +100,8 @@ def run_tardis(
packet_source=packet_source,
atom_data=atom_data,
virtual_packet_logging=virtual_packet_logging,
+ show_cplots=show_cplots,
+ **kwargs,
)
for cb in simulation_callbacks:
simulation.add_callback(*cb)
diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py
index a64112b63f1..ce9daf63a33 100644
--- a/tardis/simulation/base.py
+++ b/tardis/simulation/base.py
@@ -4,6 +4,7 @@
import pandas as pd
from astropy import units as u, constants as const
from collections import OrderedDict
+from tardis import model
from tardis.montecarlo import MontecarloRunner
from tardis.model import Radial1DModel
@@ -12,6 +13,7 @@
from tardis.io.config_reader import ConfigurationError
from tardis.util.base import is_notebook
from tardis.montecarlo import montecarlo_configuration as mc_config_module
+from tardis.visualization import ConvergencePlots
from IPython.display import display
# Adding logging support
@@ -97,6 +99,7 @@ class Simulation(PlasmaStateStorerMixin, HDFWriterMixin):
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity
luminosity_requested : astropy.units.Quantity
+ cplots_kwargs: dict
nthreads : int
The number of threads to run montecarlo with
@@ -129,6 +132,8 @@ def __init__(
luminosity_requested,
convergence_strategy,
nthreads,
+ show_cplots,
+ cplots_kwargs,
):
super(Simulation, self).__init__(iterations, model.no_of_shells)
@@ -146,6 +151,7 @@ def __init__(
self.luminosity_nu_end = luminosity_nu_end
self.luminosity_requested = luminosity_requested
self.nthreads = nthreads
+
if convergence_strategy.type in ("damped"):
self.convergence_strategy = convergence_strategy
self.converged = False
@@ -162,6 +168,18 @@ def __init__(
f"- input is {convergence_strategy.type}"
)
+ if show_cplots:
+ self.cplots = ConvergencePlots(
+ iterations=self.iterations, **cplots_kwargs
+ )
+
+ if "export_cplots" in cplots_kwargs:
+ if not isinstance(cplots_kwargs["export_cplots"], bool):
+ raise TypeError("Expected bool in export_cplots argument")
+ self.export_cplots = cplots_kwargs["export_cplots"]
+ else:
+ self.export_cplots = False
+
self._callbacks = OrderedDict()
self._cb_next_id = 0
@@ -289,6 +307,22 @@ def advance_state(self):
else:
next_t_inner = self.model.t_inner
+ if hasattr(self, "cplots"):
+ self.cplots.fetch_data(
+ name="t_inner",
+ value=self.model.t_inner.value,
+ item_type="value",
+ )
+ self.cplots.fetch_data(
+ name="t_rad", value=self.model.t_rad, item_type="iterable"
+ )
+ self.cplots.fetch_data(
+ name="w", value=self.model.w, item_type="iterable"
+ )
+ self.cplots.fetch_data(
+ name="velocity", value=self.model.velocity, item_type="iterable"
+ )
+
self.log_plasma_state(
self.model.t_rad,
self.model.w,
@@ -343,6 +377,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0, last_run=False):
reabsorbed_luminosity = self.runner.calculate_reabsorbed_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
)
+ if hasattr(self, "cplots"):
+ self.cplots.fetch_data(
+ name="Emitted",
+ value=emitted_luminosity.value,
+ item_type="value",
+ )
+ self.cplots.fetch_data(
+ name="Absorbed",
+ value=reabsorbed_luminosity.value,
+ item_type="value",
+ )
+ self.cplots.fetch_data(
+ name="Requested",
+ value=self.luminosity_requested.value,
+ item_type="value",
+ )
+
self.log_run_results(emitted_luminosity, reabsorbed_luminosity)
self.iterations_executed += 1
@@ -362,6 +413,8 @@ def run(self):
)
self.iterate(self.no_of_packets)
self.converged = self.advance_state()
+ if hasattr(self, "cplots"):
+ self.cplots.update()
self._call_back()
if self.converged:
if self.convergence_strategy.stop_if_converged:
@@ -379,6 +432,13 @@ def run(self):
)
self.reshape_plasma_state_store(self.iterations_executed)
+ if hasattr(self, "cplots"):
+ self.cplots.fetch_data(
+ name="t_inner",
+ value=self.model.t_inner.value,
+ item_type="value",
+ )
+ self.cplots.update(export_cplots=self.export_cplots, last=True)
logger.info(
f"Simulation finished in {self.iterations_executed:d} iterations "
@@ -518,7 +578,12 @@ def remove_callback(self, id):
@classmethod
def from_config(
- cls, config, packet_source=None, virtual_packet_logging=False, **kwargs
+ cls,
+ config,
+ packet_source=None,
+ virtual_packet_logging=False,
+ show_cplots=True,
+ **kwargs,
):
"""
Create a new Simulation instance from a Configuration object.
@@ -564,6 +629,17 @@ def from_config(
virtual_packet_logging=virtual_packet_logging,
)
+ cplots_config_options = [
+ "plasma_plot_config",
+ "t_inner_luminosities_config",
+ "plasma_cmap",
+ "t_inner_luminosities_colors",
+ "export_cplots",
+ ]
+ cplots_kwargs = {}
+ for item in set(cplots_config_options).intersection(kwargs.keys()):
+ cplots_kwargs[item] = kwargs[item]
+
luminosity_nu_start = config.supernova.luminosity_wavelength_end.to(
u.Hz, u.spectral()
)
@@ -587,6 +663,7 @@ def from_config(
model=model,
plasma=plasma,
runner=runner,
+ show_cplots=show_cplots,
no_of_packets=int(config.montecarlo.no_of_packets),
no_of_virtual_packets=int(config.montecarlo.no_of_virtual_packets),
luminosity_nu_start=luminosity_nu_start,
@@ -595,4 +672,5 @@ def from_config(
luminosity_requested=config.supernova.luminosity_requested.cgs,
convergence_strategy=config.montecarlo.convergence_strategy,
nthreads=config.montecarlo.nthreads,
+ cplots_kwargs=cplots_kwargs,
)
diff --git a/tardis/visualization/__init__.py b/tardis/visualization/__init__.py
index ff2a5e96037..bad425bd8d9 100644
--- a/tardis/visualization/__init__.py
+++ b/tardis/visualization/__init__.py
@@ -1,5 +1,7 @@
"""Visualization tools and widgets for TARDIS."""
+from tardis.visualization.tools.convergence_plot import ConvergencePlots
+
from tardis.visualization.widgets.shell_info import (
shell_info_from_simulation,
shell_info_from_hdf,
diff --git a/tardis/visualization/tools/convergence_plot.py b/tardis/visualization/tools/convergence_plot.py
new file mode 100644
index 00000000000..e6f603cee7e
--- /dev/null
+++ b/tardis/visualization/tools/convergence_plot.py
@@ -0,0 +1,439 @@
+"""Convergence Plots to see the convergence of the simulation in real time."""
+from collections import defaultdict
+import matplotlib.cm as cm
+import matplotlib.colors as clr
+import plotly.graph_objects as go
+from IPython.display import display
+import matplotlib as mpl
+import ipywidgets as widgets
+from contextlib import suppress
+from traitlets import TraitError
+from astropy import units as u
+
+
+def transition_colors(length, name="jet"):
+ """
+ Create colorscale for convergence plots, returns a list of colors.
+
+ Parameters
+ ----------
+ length : int
+ The length of the colorscale.
+ name : string, default: 'jet', optional
+ Name of the colorscale.
+
+ Returns
+ -------
+ colors: list
+ """
+ cmap = mpl.cm.get_cmap(name, length)
+ colors = []
+ for i in range(cmap.N):
+ rgb = cmap(i)[:3]
+ colors.append(mpl.colors.rgb2hex(rgb))
+ return colors
+
+
+class ConvergencePlots(object):
+ """
+ Create and update convergence plots for visualizing convergence of the simulation.
+
+ Parameters
+ ----------
+ iterations : int
+ iteration number
+ **kwargs : dict, optional
+ Additional keyword arguments. These arguments are defined in the Other Parameters section.
+
+ Other Parameters
+ ----------------
+ plasma_plot_config : dict, optional
+ Dictionary used to override default plot properties of plasma plots.
+ t_inner_luminosities_config : dict, optional
+ Dictionary used to override default plot properties of the inner boundary temperature and luminosity plots.
+ plasma_cmap : str, default: 'jet', optional
+ String defining the cmap used in plasma plots.
+ t_inner_luminosities_colors : str or list, optional
+ String defining cmap for luminosity and inner boundary temperature plot.
+ The list can be a list of colors in rgb, hex or css-names format as well.
+ export_cplots : bool, default: False, optional
+ If True, plots are displayed again using the `notebook_connected` renderer. This helps
+ to display the plots in the documentation or in platforms like nbviewer.
+
+ Notes
+ -----
+ When overriding plot's configuration using the `plasma_plot_config` and the
+ `t_inner_luminosities_config` dictionaries, data related properties are
+ applied equally accross all traces.
+ The dictionary should have a structure like that of `plotly.graph_objs.FigureWidget.to_dict()`,
+ for more information please see https://plotly.com/python/figure-structure/
+ """
+
+ def __init__(self, iterations, **kwargs):
+ self.iterable_data = {}
+ self.value_data = defaultdict(list)
+ self.iterations = iterations
+ self.current_iteration = 1
+ self.luminosities = ["Emitted", "Absorbed", "Requested"]
+ self.plasma_plot = None
+ self.t_inner_luminosities_plot = None
+
+ if "plasma_plot_config" in kwargs:
+ self.plasma_plot_config = kwargs["plasma_plot_config"]
+
+ if "t_inner_luminosities_config" in kwargs:
+ self.t_inner_luminosities_config = kwargs[
+ "t_inner_luminosities_config"
+ ]
+
+ if "plasma_cmap" in kwargs:
+ self.plasma_colorscale = transition_colors(
+ name=kwargs["plasma_cmap"], length=self.iterations
+ )
+ else:
+ # default color scale is jet
+ self.plasma_colorscale = transition_colors(length=self.iterations)
+
+ if "t_inner_luminosities_colors" in kwargs:
+ # use cmap if string
+ if type(kwargs["t_inner_luminosities_colors"]) == str:
+ self.t_inner_luminosities_colors = transition_colors(
+ length=5,
+ name=kwargs["t_inner_luminosities_colors"],
+ )
+ else:
+ self.t_inner_luminosities_colors = kwargs[
+ "t_inner_luminosities_colors"
+ ]
+ else:
+ # using default plotly colors
+ self.t_inner_luminosities_colors = [None] * 5
+
+ def fetch_data(self, name=None, value=None, item_type=None):
+ """
+ Fetch data from the Simulation class.
+
+ Parameters
+ ----------
+ name : string
+ name of the data
+ value : string or array
+ string or an array of quantities
+ item_type : string
+ either iterable or value
+
+ """
+ # trace data for plasma plots is added in iterable data dictionary
+ if item_type == "iterable":
+ self.iterable_data[name] = value
+
+ # trace data for luminosity plots and inner boundary temperature plot is stored in value_data dictionary
+ if item_type == "value":
+ self.value_data[name].append(value)
+
+ def create_plasma_plot(self):
+ """Create an empty plasma plot."""
+ fig = go.FigureWidget().set_subplots(rows=1, cols=2, shared_xaxes=True)
+
+ # empty traces to build figure
+ fig.add_scatter(row=1, col=1)
+ fig.add_scatter(row=1, col=2)
+
+ # 2 y axes and 2 x axes correspond to the 2 subplots in the plasma plot
+ fig = fig.update_layout(
+ xaxis={
+ "tickformat": "g",
+ "title": r"$\text{Velocity}~[\text{km}~\text{s}^{-1}]$",
+ },
+ xaxis2={
+ "tickformat": "g",
+ "title": r"$\text{Velocity}~[\text{km}~\text{s}^{-1}]$",
+ "matches": "x",
+ },
+ yaxis={
+ "tickformat": "g",
+ "title": r"$T_{\text{rad}}\ [\text{K}]$",
+ "nticks": 15,
+ },
+ yaxis2={
+ "tickformat": "g",
+ "title": r"$W$",
+ "nticks": 15,
+ },
+ height=450,
+ legend_title_text="Iterations",
+ margin=dict(
+ l=10, r=135, b=25, t=25, pad=0
+ ), # reduce whitespace surrounding the plot and increase right indentation to align with the t_inner and luminosity plot
+ )
+
+ # allow overriding default layout
+ if hasattr(self, "plasma_plot_config"):
+ self.override_plot_parameters(fig, self.plasma_plot_config)
+ self.plasma_plot = fig
+
+ def create_t_inner_luminosities_plot(self):
+ """Create an empty t_inner and luminosity plot."""
+ fig = go.FigureWidget().set_subplots(
+ rows=3,
+ cols=1,
+ shared_xaxes=True,
+ vertical_spacing=0.08,
+ row_heights=[0.25, 0.5, 0.25],
+ )
+
+ # add inner boundary temperature vs iterations plot
+ fig.add_scatter(
+ name="Inner
Boundary
Temperature",
+ row=1,
+ col=1,
+ hovertext="text",
+ marker_color=self.t_inner_luminosities_colors[0],
+ mode="lines",
+ )
+
+ # add luminosity vs iterations plot
+ # has three traces for emitted, requested and absorbed luminosities
+ for luminosity, line_color in zip(
+ self.luminosities, self.t_inner_luminosities_colors[1:4]
+ ):
+ fig.add_scatter(
+ name=luminosity + "
Luminosity",
+ mode="lines",
+ row=2,
+ col=1,
+ marker_color=line_color,
+ )
+
+ # add residual luminosity vs iterations plot
+ fig.add_scatter(
+ name="Residual
Luminosity",
+ row=3,
+ col=1,
+ marker_color=self.t_inner_luminosities_colors[4],
+ mode="lines",
+ )
+
+ # 3 y axes and 3 x axes correspond to the 3 subplots in the t_inner and luminosity convergence plot
+ fig = fig.update_layout(
+ xaxis=dict(range=[0, self.iterations + 1], dtick=2),
+ xaxis2=dict(
+ matches="x",
+ range=[0, self.iterations + 1],
+ dtick=2,
+ ),
+ xaxis3=dict(
+ title=r"$\mbox{Iteration Number}$",
+ dtick=2,
+ ),
+ yaxis=dict(
+ title=r"$T_{\text{inner}}\ [\text{K}]$",
+ automargin=True,
+ tickformat="g",
+ exponentformat="e",
+ nticks=4,
+ ),
+ yaxis2=dict(
+ exponentformat="e",
+ title=r"$\text{Luminosity}~[\text{erg s}^{-1}]$",
+ title_font_size=13,
+ automargin=True,
+ nticks=7,
+ ),
+ yaxis3=dict(
+ title=r"$~~\text{Residual}\\\text{Luminosity[%]}$",
+ title_font_size=12,
+ automargin=True,
+ nticks=4,
+ ),
+ height=630,
+ hoverlabel_align="right",
+ margin=dict(
+ b=25, t=25, pad=0
+ ), # reduces whitespace surrounding the plot
+ )
+
+ # allow overriding default layout
+ if hasattr(self, "t_inner_luminosities_config"):
+ self.override_plot_parameters(fig, self.t_inner_luminosities_config)
+
+ self.t_inner_luminosities_plot = fig
+
+ def override_plot_parameters(self, fig, parameters):
+ """
+ Override default plot properties.
+
+ Any property inside the data dictionary is however, applied equally across all traces.
+ This means trace-specific data properties can't be changed using this function.
+
+ Parameters
+ ----------
+ fig : go.FigureWidget
+ FigureWidget object to be updated
+ parameters : dict
+ Dictionary used to update the default plot style.
+ """
+ # because fig.data is a tuple of traces, a property in the data dictionary is applied to all traces
+ # the fig is a nested dictionary, any property n levels deep is not changed until the value is a not dictionary
+ # fig["property_1"]["property_2"]...["property_n"] = "value"
+ for key, value in parameters.items():
+ if key == "data":
+ # all traces will have same data property
+ for trace in list(fig.data):
+ self.override_plot_parameters(trace, value)
+ else:
+ if type(value) == dict:
+ self.override_plot_parameters(fig[key], value)
+ else:
+ fig[key] = value
+
+ def build(self, display_plot=True):
+ """
+ Create empty convergence plots and display them.
+
+ Parameters
+ ----------
+ display_plot : bool, default: True, optional
+ Displays empty plots.
+ """
+ self.create_plasma_plot()
+ self.create_t_inner_luminosities_plot()
+
+ if display_plot:
+ display(
+ widgets.VBox(
+ [self.plasma_plot, self.t_inner_luminosities_plot],
+ )
+ )
+
+ def update_plasma_plots(self):
+ """Update plasma convergence plots every iteration."""
+ # convert velocity to km/s
+ x = self.iterable_data["velocity"].to(u.km / u.s).value.tolist()
+
+ # add luminosity data in hover data in plasma plots
+ customdata = len(x) * [
+ "
"
+ + "Emitted Luminosity: "
+ + f'{self.value_data["Absorbed"][-1]:.4g}'
+ + "
"
+ + "Requested Luminosity: "
+ + f'{self.value_data["Requested"][-1]:.4g}'
+ + "
"
+ + "Absorbed Luminosity: "
+ + f'{self.value_data["Requested"][-1]:.4g}'
+ ]
+
+ # add a radiation temperature vs shell velocity trace to the plasma plot
+ self.plasma_plot.add_scatter(
+ x=x,
+ y=self.iterable_data["t_rad"],
+ line_color=self.plasma_colorscale[self.current_iteration - 1],
+ row=1,
+ col=1,
+ name=self.current_iteration,
+ legendgroup=f"group-{self.current_iteration}",
+ showlegend=False,
+ customdata=customdata,
+ hovertemplate="Y: %{y:.3f} at X = %{x:,.0f}%{customdata}",
+ )
+
+ # add a dilution factor vs shell velocity trace to the plasma plot
+ self.plasma_plot.add_scatter(
+ x=x,
+ y=self.iterable_data["w"],
+ line_color=self.plasma_colorscale[self.current_iteration - 1],
+ row=1,
+ col=2,
+ legendgroup=f"group-{self.current_iteration}",
+ name=self.current_iteration,
+ customdata=customdata,
+ hovertemplate="Y: %{y:.3f} at X = %{x:,.0f}%{customdata}",
+ )
+
+ def update_t_inner_luminosities_plot(self):
+ """Update the t_inner and luminosity convergence plots every iteration."""
+ x = list(range(1, self.iterations + 1))
+
+ with self.t_inner_luminosities_plot.batch_update():
+ # traces are updated according to the order they were added
+ # the first trace is of the inner boundary temperature plot
+ self.t_inner_luminosities_plot.data[0].x = x
+ self.t_inner_luminosities_plot.data[0].y = self.value_data[
+ "t_inner"
+ ]
+ self.t_inner_luminosities_plot.data[
+ 0
+ ].hovertemplate = "%{y:.3f} at X = %{x:,.0f}Inner Boundary Temperature" # trace name in extra tag to avoid new lines in hoverdata
+
+ # the next three for emitted, absorbed and requested luminosities
+ for index, luminosity in zip(range(1, 4), self.luminosities):
+ self.t_inner_luminosities_plot.data[index].x = x
+ self.t_inner_luminosities_plot.data[index].y = self.value_data[
+ luminosity
+ ]
+ self.t_inner_luminosities_plot.data[index].hovertemplate = (
+ "%{y:.4g}" + "
at X = %{x}
"
+ )
+
+ # last is for the residual luminosity
+ y = [
+ ((emitted - requested) * 100) / requested
+ for emitted, requested in zip(
+ self.value_data["Emitted"], self.value_data["Requested"]
+ )
+ ]
+
+ self.t_inner_luminosities_plot.data[4].x = x
+ self.t_inner_luminosities_plot.data[4].y = y
+ self.t_inner_luminosities_plot.data[
+ 4
+ ].hovertemplate = "%{y:.2f}% at X = %{x:,.0f}"
+
+ def update(self, export_cplots=False, last=False):
+ """
+ Update the convergence plots every iteration.
+
+ Parameters
+ ----------
+ export_cplots : bool, default: False, optional
+ Displays the convergence plots again using plotly's `notebook_connected` renderer.
+ This helps to display the plots in notebooks when shared on platforms like nbviewer.
+ Please see https://plotly.com/python/renderers/ for more information.
+ last : bool, default: False, optional
+ True if it's last iteration.
+ """
+ if self.iterable_data != {}:
+ # build only at first iteration
+ if self.current_iteration == 1:
+ self.build()
+
+ self.update_plasma_plots()
+ self.update_t_inner_luminosities_plot()
+
+ # data property for plasma plots needs to be
+ # updated after the last iteration because new traces have been added
+ if hasattr(self, "plasma_plot_config") and last:
+ if "data" in self.plasma_plot_config:
+ self.override_plot_parameters(
+ self.plasma_plot, self.plasma_plot_config
+ )
+
+ self.current_iteration += 1
+
+ # the display function expects a Widget, while
+ # fig.show() returns None, which causes the TraitError.
+ if export_cplots:
+ with suppress(TraitError):
+ display(
+ widgets.VBox(
+ [
+ self.plasma_plot.show(
+ renderer="notebook_connected"
+ ),
+ self.t_inner_luminosities_plot.show(
+ renderer="notebook_connected"
+ ),
+ ]
+ )
+ )
diff --git a/tardis/visualization/tools/tests/__init__.py b/tardis/visualization/tools/tests/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tardis/visualization/tools/tests/test_convergence_plot.py b/tardis/visualization/tools/tests/test_convergence_plot.py
new file mode 100644
index 00000000000..a84aa9e4f4a
--- /dev/null
+++ b/tardis/visualization/tools/tests/test_convergence_plot.py
@@ -0,0 +1,207 @@
+"""Tests for Convergence Plots."""
+import pytest
+from tardis.visualization.tools.convergence_plot import (
+ ConvergencePlots,
+ transition_colors,
+)
+from collections import defaultdict
+import plotly.graph_objects as go
+from astropy import units as u
+
+
+@pytest.fixture(scope="module", params=[0, 1, 2])
+def convergence_plots(request):
+ """Initialize ConvergencePlots class and build empty plots."""
+ cplots = ConvergencePlots(iterations=request.param)
+ cplots.build(display_plot=False)
+ return cplots
+
+
+@pytest.fixture()
+def fetch_luminosity_data(convergence_plots):
+ """Prepare data for t_inner and luminosity plot."""
+ for item in [2] * convergence_plots.iterations:
+ convergence_plots.fetch_data(
+ name="t_inner", value=item, item_type="value"
+ )
+ for item2 in convergence_plots.luminosities:
+ convergence_plots.fetch_data(
+ name=item2, value=item, item_type="value"
+ )
+
+
+def test_transition_colors():
+ """Test whether the object returned by the transition_colors function is a list of appropriate length."""
+ iterations = 3
+ colors = transition_colors(length=iterations)
+ assert type(colors) == list
+ assert len(colors) == iterations
+
+
+def test_convergence_construction(convergence_plots):
+ """Test the construction of the ConvergencePlots class."""
+ assert convergence_plots.iterable_data == {}
+ assert convergence_plots.value_data == defaultdict(list)
+ assert convergence_plots.luminosities == [
+ "Emitted",
+ "Absorbed",
+ "Requested",
+ ]
+
+
+def test_fetch_data(convergence_plots):
+ """Test values of variables updated by fetch_data function."""
+ convergence_plots.fetch_data(
+ name="iterable", value=range(3), item_type="iterable"
+ )
+ convergence_plots.fetch_data(name="value", value=0, item_type="value")
+
+ assert convergence_plots.iterable_data["iterable"] == range(3)
+ assert convergence_plots.value_data["value"] == [0]
+
+
+def test_build(convergence_plots):
+ """Test if convergence plots are instances of plotly.graph_objs.FigureWidget() and have appropriate number of traces."""
+ assert type(convergence_plots.plasma_plot) == go.FigureWidget
+ assert type(convergence_plots.t_inner_luminosities_plot) == go.FigureWidget
+
+ # check number of traces
+ assert len(convergence_plots.t_inner_luminosities_plot.data) == 5
+ assert len(convergence_plots.plasma_plot.data) == 2
+
+
+@pytest.mark.usefixtures("fetch_luminosity_data")
+def test_update_t_inner_luminosities_plot(convergence_plots):
+ """Test the number of traces and length of x and y values."""
+ n_iterations = convergence_plots.iterations
+ convergence_plots.update_t_inner_luminosities_plot()
+
+ # check number of traces
+ assert len(convergence_plots.t_inner_luminosities_plot.data) == 5
+
+ for index in range(0, 5):
+ # check x and y values for all traces
+ assert (
+ len(convergence_plots.t_inner_luminosities_plot.data[index].x)
+ == n_iterations
+ )
+ assert (
+ len(convergence_plots.t_inner_luminosities_plot.data[index].y)
+ == n_iterations
+ )
+
+ # check range of x axes
+ for axis in ["xaxis", "xaxis2", "xaxis3"]:
+ convergence_plots.t_inner_luminosities_plot["layout"][axis]["range"] = [
+ 0,
+ convergence_plots.iterations + 1,
+ ]
+
+
+@pytest.mark.usefixtures("fetch_luminosity_data")
+def test_update_plasma_plots(convergence_plots):
+ """Test the state of plasma plots after updating."""
+ n_iterations = convergence_plots.iterations
+ expected_n_traces = 2 * n_iterations + 2
+ velocity = range(0, n_iterations) * u.m / u.s
+
+ convergence_plots.fetch_data(
+ name="velocity", value=velocity, item_type="iterable"
+ )
+
+ w_val = list(range(n_iterations))
+ t_rad_val = [item * 2 for item in w_val]
+
+ for _ in range(n_iterations):
+ convergence_plots.fetch_data(
+ name="t_rad",
+ value=t_rad_val,
+ item_type="iterable",
+ )
+ convergence_plots.fetch_data(
+ name="w",
+ value=w_val,
+ item_type="iterable",
+ )
+ convergence_plots.update_plasma_plots()
+ convergence_plots.current_iteration += 1
+
+ # check number of traces
+ assert len(convergence_plots.plasma_plot.data) == expected_n_traces
+
+ # traces are added alternatively
+ # trace 0 and 1 and empty
+ assert convergence_plots.plasma_plot.data[0].x == None
+ assert convergence_plots.plasma_plot.data[1].x == None
+
+ # check other traces
+ for index in list(range(expected_n_traces))[::2][1:]:
+ # check values for t_rad subplot
+ assert convergence_plots.plasma_plot.data[index].xaxis == "x"
+ assert convergence_plots.plasma_plot.data[index].yaxis == "y"
+ assert convergence_plots.plasma_plot.data[index].y == tuple(t_rad_val)
+ assert convergence_plots.plasma_plot.data[index].x == tuple(
+ velocity.to(u.km / u.s).value
+ )
+
+ for index in list(range(expected_n_traces))[1::2][1:]:
+ # check values for w subplot
+ assert convergence_plots.plasma_plot.data[index].xaxis == "x2"
+ assert convergence_plots.plasma_plot.data[index].yaxis == "y2"
+ assert convergence_plots.plasma_plot.data[index].y == tuple(w_val)
+ assert convergence_plots.plasma_plot.data[index].x == tuple(
+ velocity.to(u.km / u.s).value
+ )
+
+
+@pytest.mark.usefixtures("fetch_luminosity_data")
+def test_override_plot_parameters(convergence_plots):
+ """Test if default plot properties are overridden properly."""
+ parameters = {
+ "data": {
+ "line": {"dash": "dot"},
+ "mode": "lines+markers",
+ },
+ "layout": {"xaxis2": {"showgrid": False}},
+ }
+ convergence_plots.override_plot_parameters(
+ convergence_plots.plasma_plot, parameters=parameters
+ )
+ convergence_plots.override_plot_parameters(
+ convergence_plots.t_inner_luminosities_plot, parameters=parameters
+ )
+
+ # data properties will be applied across all traces equally
+ # testing plot parameters of t_inner and luminosity plot
+ for trace_index in range(5):
+ assert (
+ convergence_plots.t_inner_luminosities_plot.data[
+ trace_index
+ ].line.dash
+ == "dot"
+ )
+ assert (
+ convergence_plots.t_inner_luminosities_plot.data[trace_index].mode
+ == "lines+markers"
+ )
+ # checking layout
+ assert (
+ convergence_plots.t_inner_luminosities_plot["layout"]["xaxis2"][
+ "showgrid"
+ ]
+ == False
+ )
+
+ # testing plot parameters for plasma plot
+ for trace_index in range(2):
+ assert (
+ convergence_plots.plasma_plot.data[trace_index].line.dash == "dot"
+ )
+ assert (
+ convergence_plots.plasma_plot.data[trace_index].mode
+ == "lines+markers"
+ )
+ # checking layout for plasma plot
+ assert (
+ convergence_plots.plasma_plot["layout"]["xaxis2"]["showgrid"] == False
+ )
diff --git a/tardis/visualization/widgets/shell_info.py b/tardis/visualization/widgets/shell_info.py
index b7528527d05..336c8f580cf 100644
--- a/tardis/visualization/widgets/shell_info.py
+++ b/tardis/visualization/widgets/shell_info.py
@@ -4,7 +4,7 @@
atomic_number2element_symbol,
species_tuple_to_string,
)
-from tardis.simulation import Simulation
+
from tardis.visualization.widgets.util import create_table_widget
import pandas as pd