diff --git a/docs/io/visualization/convergence_plot.ipynb b/docs/io/visualization/convergence_plot.ipynb new file mode 100644 index 00000000000..7524fe90e96 --- /dev/null +++ b/docs/io/visualization/convergence_plot.ipynb @@ -0,0 +1,855 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d5f95feb", + "metadata": {}, + "source": [ + "# Convergence Plots" + ] + }, + { + "cell_type": "markdown", + "id": "dc1a0c1f", + "metadata": {}, + "source": [ + "The Convergence Plots consist of two Plotly FigureWidget Subplots, the `plasma_plot` and the `t_inner_luminosities_plot`. The plots are stored in the `cplots` attribute of the simulation object `sim` and can be accessed using `sim.cplots.plasma_plot` and `sim.cplots.t_inner_luminosities_plot`.\n", + "\n", + "The Convergence Plots are shown by default when you running TARDIS because `show_cplots` parameter of the `run_tardis()` function is set to `True`. If you don't want to do this, set it to `False`. " + ] + }, + { + "cell_type": "markdown", + "id": "6db4edf2", + "metadata": {}, + "source": [ + "
\n", + " \n", + "You only need to include `export_cplots=True` in the `run_tardis` function when you want to share the notebook. The function shows the plot using the Plotly `notebook_connected` renderer, which helps display the plot online. You don't need to do it when running the notebook locally.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ba7cc7d2", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9e8c4cf955fe459a903c206b12dce519", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(FigureWidget({\n", + " 'data': [{'type': 'scatter', 'uid': '6300bc97-3326-433b-9702-37a6cef7ba32', …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from tardis import run_tardis\n", + "sim = run_tardis('tardis_example.yml', export_cplots=True)" + ] + }, + { + "cell_type": "markdown", + "id": "ae1623d6", + "metadata": {}, + "source": [ + "## Displaying Convergence Plots\n", + "You can also call the plots outside of `run_tardis` function. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "48768471", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sim.cplots.plasma_plot.show(renderer=\"notebook_connected\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9db6b395", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sim.cplots.t_inner_luminosities_plot.show(renderer=\"notebook_connected\")" + ] + }, + { + "cell_type": "markdown", + "id": "b1bf44fa", + "metadata": {}, + "source": [ + "## Changing Line Colors\n", + "The default line-colors of the plasma plots can be changed by passing the name of the cmap in the `plasma_cmap` option. \n", + "\n", + "```py\n", + "sim = run_tardis(\"tardis_example.yml\",plasma_cmap= \"viridis\")\n", + "```\n", + "\n", + "Alongwith the cmap name, one can also provide a list of colors in rgb, hex or css-names format in the `t_inner_luminosities_colors` option to change the default colors of the luminosity and inner boundary temperature plots. \n", + "```py\n", + "# hex colors example list\n", + "colors = [\n", + " '#8c564b', # chestnut brown\n", + " '#e377c2', # raspberry yogurt pink\n", + " '#7f7f7f', # middle gray\n", + " '#bcbd22', # curry yellow-green\n", + " '#17becf' # blue-teal\n", + "]\n", + "\n", + "# rgb colors example list\n", + "colors = ['rgb(31, 119, 180)',\n", + " 'rgb(255, 127, 14)',\n", + " 'rgb(44, 160, 44)', \n", + " 'rgb(214, 39, 40)',\n", + " 'rgb(148, 103, 189)',]\n", + " \n", + "# css colors\n", + "colors = [\"indigo\",\"lightseagreen\", \"midnightblue\", \"pink\", \"teal\"]\n", + "```\n", + "For more css-names please see [this](https://www.w3schools.com/colors/colors_names.asp). " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9109967e", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a5d948f93c5e46fba408974c41901e33", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(FigureWidget({\n", + " 'data': [{'type': 'scatter', 'uid': '708836fc-9431-4f4d-b65d-64e3ce7d217f', …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sim = run_tardis(\n", + " \"tardis_example.yml\",\n", + " plasma_cmap= \"viridis\", \n", + " t_inner_luminosities_colors = ['rgb(102, 197, 204)',\n", + " 'rgb(246, 207, 113)',\n", + " 'rgb(248, 156, 116)',\n", + " 'rgb(220, 176, 242)',\n", + " 'rgb(135, 197, 95)'],\n", + " export_cplots = True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1f63c60a", + "metadata": {}, + "source": [ + "## Changing the default layout" + ] + }, + { + "cell_type": "markdown", + "id": "3a3bdc2d", + "metadata": {}, + "source": [ + "You can override the default layout by passing dictionaries as arguments in `t_inner_luminosities_config` and `plasma_plot_config` in the `run_tardis` function. The dictionaries should have the format of `plotly.graph_objects.FigureWidget().to_dict()`. For more information on the structure of the dictionary, please see the [plotly documentation](https://plotly.com/python/figure-structure/). \n", + "\n", + "For sake of simplicity, all properties in the data dictionary are applied equally across all traces, meaning traces-specific properties can't be changed from the function. They however be changed after the simulation has finished, for example:\n", + "```py\n", + "sim.cplots.t_inner_luminosities_plot.data[0].line.dash = \"dashdot\"\n", + "```\n", + "\n", + "You can investigate more about the layout/data of any plots by calling `sim.cplots.t_inner_luminosities_plot.layout` or `sim.cplots.t_inner_luminosities_plot.data`. \n", + "\n", + "Here is an example:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "548990c8", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0f661eefdb7b472b86353beeaae2693e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(FigureWidget({\n", + " 'data': [{'type': 'scatter', 'uid': 'ad54a6db-bbf4-4c68-9e0d-9cfb0ed1b9ad', …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sim = run_tardis(\n", + " \"tardis_example.yml\",\n", + " plasma_plot_config={\n", + " \"layout\": {\n", + " \"template\": \"ggplot2\",\n", + " \"xaxis1\": {\n", + " \"nticks\": 20\n", + " },\n", + " \"xaxis2\": {\n", + " \"title\": {\"text\": \"new changed title of x axis2\"},\n", + " \"nticks\": 20\n", + " },\n", + " },\n", + " },\n", + " t_inner_luminosities_config={\n", + " \"data\": {\n", + " \"line\":{\n", + " \"dash\":\"dot\"\n", + " },\n", + " \"mode\": \"lines+markers\",\n", + " },\n", + " \"layout\": {\n", + " \"template\": \"plotly_dark\",\n", + " \"hovermode\":\"x\",\n", + " \"xaxis\":{\"showgrid\":False},\n", + " \"xaxis2\":{\"showgrid\":False},\n", + " \"xaxis3\":{\"showgrid\":False},\n", + " \n", + " },\n", + " },\n", + " export_cplots = True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6ceaaa1e", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Help on class ConvergencePlots in module tardis.visualization.tools.convergence_plot:\n", + "\n", + "class ConvergencePlots(builtins.object)\n", + " | ConvergencePlots(iterations, **kwargs)\n", + " | \n", + " | Create and update convergence plots for visualizing convergence of the simulation.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | iterations : int\n", + " | iteration number\n", + " | **kwargs : dict, optional\n", + " | Additional keyword arguments. These arguments are defined in the Other Parameters section.\n", + " | \n", + " | Other Parameters\n", + " | ----------------\n", + " | plasma_plot_config : dict, optional\n", + " | Dictionary used to override default plot properties of plasma plots.\n", + " | t_inner_luminosities_config : dict, optional\n", + " | Dictionary used to override default plot properties of the inner boundary temperature and luminosity plots.\n", + " | plasma_cmap : str, default: 'jet', optional\n", + " | String defining the cmap used in plasma plots.\n", + " | t_inner_luminosities_colors : str or list, optional\n", + " | String defining cmap for luminosity and inner boundary temperature plot.\n", + " | The list can be a list of colors in rgb, hex or css-names format as well.\n", + " | export_cplots : bool, default: False, optional\n", + " | If True, plots are displayed again using the ``notebook_connected`` renderer. This helps\n", + " | to display the plots in the documentation or in platforms like nbviewer.\n", + " | \n", + " | Notes\n", + " | -----\n", + " | When overriding plot's configuration using the ``plasma_plot_config`` and the\n", + " | ``t_inner_luminosities_config`` dictionaries, data related properties are\n", + " | applied equally accross all traces.\n", + " | The dictionary should have a structure like that of ``plotly.graph_objs.FigureWidget.to_dict()``,\n", + " | for more information please see https://plotly.com/python/figure-structure/\n", + " | \n", + " | Methods defined here:\n", + " | \n", + " | __init__(self, iterations, **kwargs)\n", + " | Initialize self. See help(type(self)) for accurate signature.\n", + " | \n", + " | build(self, display_plot=True)\n", + " | Create empty convergence plots and display them.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | display_plot : bool, default: True, optional\n", + " | Displays empty plots.\n", + " | \n", + " | create_plasma_plot(self)\n", + " | Create an empty plasma plot.\n", + " | \n", + " | create_t_inner_luminosities_plot(self)\n", + " | Create an empty t_inner and luminosity plot.\n", + " | \n", + " | fetch_data(self, name=None, value=None, item_type=None)\n", + " | Fetch data from the Simulation class.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | name : string\n", + " | name of the data\n", + " | value : string or array\n", + " | string or an array of quantities\n", + " | item_type : string\n", + " | either iterable or value\n", + " | \n", + " | override_plot_parameters(self, fig, parameters)\n", + " | Override default plot properties.\n", + " | \n", + " | Any property inside the data dictionary is however, applied equally across all traces.\n", + " | This means trace-specific data properties can't be changed using this function.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | fig : go.FigureWidget\n", + " | FigureWidget object to be updated\n", + " | parameters : dict\n", + " | Dictionary used to update the default plot style.\n", + " | \n", + " | update(self, export_cplots=False, last=False)\n", + " | Update the convergence plots every iteration.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | export_cplots : bool, default: False, optional\n", + " | Displays the convergence plots again using plotly's ``notebook_connected`` renderer.\n", + " | This helps to display the plots in notebooks when shared on platforms like nbviewer.\n", + " | Please see https://plotly.com/python/renderers/ for more information.\n", + " | last : bool\n", + " | True if it's last iteration.\n", + " | \n", + " | update_plasma_plots(self)\n", + " | Update plasma convergence plots every iteration.\n", + " | \n", + " | update_t_inner_luminosities_plot(self)\n", + " | Update the t_inner and luminosity convergence plots every iteration.\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Data descriptors defined here:\n", + " | \n", + " | __dict__\n", + " | dictionary for instance variables (if defined)\n", + " | \n", + " | __weakref__\n", + " | list of weak references to the object (if defined)\n", + "\n" + ] + } + ], + "source": [ + "from tardis.visualization import ConvergencePlots\n", + "help(ConvergencePlots)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tardis/base.py b/tardis/base.py index 10a45fe432a..f05f63db7e7 100644 --- a/tardis/base.py +++ b/tardis/base.py @@ -12,8 +12,10 @@ def run_tardis( packet_source=None, simulation_callbacks=[], virtual_packet_logging=False, + show_cplots=True, log_level=None, specific=None, + **kwargs, ): """ Run TARDIS from a given config object. @@ -48,6 +50,12 @@ def run_tardis( If True, only show the log messages from a particular log level, set by `log_level`. If False, the logger shows log messages belonging to the level set and all levels above it in severity. The default value None means that the `specific` specified in the configuration file will be used. + show_cplots : bool, default: True, optional + Option to enable tardis convergence plots. + **kwargs : dict, optional + Optional keyword arguments including those + supported by :obj:`tardis.visualization.tools.convergence_plot.ConvergencePlots`. + Returns ------- @@ -73,6 +81,9 @@ def run_tardis( ) tardis_config = Configuration.from_config_dict(config) + if not isinstance(show_cplots, bool): + raise TypeError("Expected bool in show_cplots argument") + logging_state(log_level, tardis_config, specific) if atom_data is not None: @@ -89,6 +100,8 @@ def run_tardis( packet_source=packet_source, atom_data=atom_data, virtual_packet_logging=virtual_packet_logging, + show_cplots=show_cplots, + **kwargs, ) for cb in simulation_callbacks: simulation.add_callback(*cb) diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index a64112b63f1..ce9daf63a33 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -4,6 +4,7 @@ import pandas as pd from astropy import units as u, constants as const from collections import OrderedDict +from tardis import model from tardis.montecarlo import MontecarloRunner from tardis.model import Radial1DModel @@ -12,6 +13,7 @@ from tardis.io.config_reader import ConfigurationError from tardis.util.base import is_notebook from tardis.montecarlo import montecarlo_configuration as mc_config_module +from tardis.visualization import ConvergencePlots from IPython.display import display # Adding logging support @@ -97,6 +99,7 @@ class Simulation(PlasmaStateStorerMixin, HDFWriterMixin): luminosity_nu_start : astropy.units.Quantity luminosity_nu_end : astropy.units.Quantity luminosity_requested : astropy.units.Quantity + cplots_kwargs: dict nthreads : int The number of threads to run montecarlo with @@ -129,6 +132,8 @@ def __init__( luminosity_requested, convergence_strategy, nthreads, + show_cplots, + cplots_kwargs, ): super(Simulation, self).__init__(iterations, model.no_of_shells) @@ -146,6 +151,7 @@ def __init__( self.luminosity_nu_end = luminosity_nu_end self.luminosity_requested = luminosity_requested self.nthreads = nthreads + if convergence_strategy.type in ("damped"): self.convergence_strategy = convergence_strategy self.converged = False @@ -162,6 +168,18 @@ def __init__( f"- input is {convergence_strategy.type}" ) + if show_cplots: + self.cplots = ConvergencePlots( + iterations=self.iterations, **cplots_kwargs + ) + + if "export_cplots" in cplots_kwargs: + if not isinstance(cplots_kwargs["export_cplots"], bool): + raise TypeError("Expected bool in export_cplots argument") + self.export_cplots = cplots_kwargs["export_cplots"] + else: + self.export_cplots = False + self._callbacks = OrderedDict() self._cb_next_id = 0 @@ -289,6 +307,22 @@ def advance_state(self): else: next_t_inner = self.model.t_inner + if hasattr(self, "cplots"): + self.cplots.fetch_data( + name="t_inner", + value=self.model.t_inner.value, + item_type="value", + ) + self.cplots.fetch_data( + name="t_rad", value=self.model.t_rad, item_type="iterable" + ) + self.cplots.fetch_data( + name="w", value=self.model.w, item_type="iterable" + ) + self.cplots.fetch_data( + name="velocity", value=self.model.velocity, item_type="iterable" + ) + self.log_plasma_state( self.model.t_rad, self.model.w, @@ -343,6 +377,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0, last_run=False): reabsorbed_luminosity = self.runner.calculate_reabsorbed_luminosity( self.luminosity_nu_start, self.luminosity_nu_end ) + if hasattr(self, "cplots"): + self.cplots.fetch_data( + name="Emitted", + value=emitted_luminosity.value, + item_type="value", + ) + self.cplots.fetch_data( + name="Absorbed", + value=reabsorbed_luminosity.value, + item_type="value", + ) + self.cplots.fetch_data( + name="Requested", + value=self.luminosity_requested.value, + item_type="value", + ) + self.log_run_results(emitted_luminosity, reabsorbed_luminosity) self.iterations_executed += 1 @@ -362,6 +413,8 @@ def run(self): ) self.iterate(self.no_of_packets) self.converged = self.advance_state() + if hasattr(self, "cplots"): + self.cplots.update() self._call_back() if self.converged: if self.convergence_strategy.stop_if_converged: @@ -379,6 +432,13 @@ def run(self): ) self.reshape_plasma_state_store(self.iterations_executed) + if hasattr(self, "cplots"): + self.cplots.fetch_data( + name="t_inner", + value=self.model.t_inner.value, + item_type="value", + ) + self.cplots.update(export_cplots=self.export_cplots, last=True) logger.info( f"Simulation finished in {self.iterations_executed:d} iterations " @@ -518,7 +578,12 @@ def remove_callback(self, id): @classmethod def from_config( - cls, config, packet_source=None, virtual_packet_logging=False, **kwargs + cls, + config, + packet_source=None, + virtual_packet_logging=False, + show_cplots=True, + **kwargs, ): """ Create a new Simulation instance from a Configuration object. @@ -564,6 +629,17 @@ def from_config( virtual_packet_logging=virtual_packet_logging, ) + cplots_config_options = [ + "plasma_plot_config", + "t_inner_luminosities_config", + "plasma_cmap", + "t_inner_luminosities_colors", + "export_cplots", + ] + cplots_kwargs = {} + for item in set(cplots_config_options).intersection(kwargs.keys()): + cplots_kwargs[item] = kwargs[item] + luminosity_nu_start = config.supernova.luminosity_wavelength_end.to( u.Hz, u.spectral() ) @@ -587,6 +663,7 @@ def from_config( model=model, plasma=plasma, runner=runner, + show_cplots=show_cplots, no_of_packets=int(config.montecarlo.no_of_packets), no_of_virtual_packets=int(config.montecarlo.no_of_virtual_packets), luminosity_nu_start=luminosity_nu_start, @@ -595,4 +672,5 @@ def from_config( luminosity_requested=config.supernova.luminosity_requested.cgs, convergence_strategy=config.montecarlo.convergence_strategy, nthreads=config.montecarlo.nthreads, + cplots_kwargs=cplots_kwargs, ) diff --git a/tardis/visualization/__init__.py b/tardis/visualization/__init__.py index ff2a5e96037..bad425bd8d9 100644 --- a/tardis/visualization/__init__.py +++ b/tardis/visualization/__init__.py @@ -1,5 +1,7 @@ """Visualization tools and widgets for TARDIS.""" +from tardis.visualization.tools.convergence_plot import ConvergencePlots + from tardis.visualization.widgets.shell_info import ( shell_info_from_simulation, shell_info_from_hdf, diff --git a/tardis/visualization/tools/convergence_plot.py b/tardis/visualization/tools/convergence_plot.py new file mode 100644 index 00000000000..e6f603cee7e --- /dev/null +++ b/tardis/visualization/tools/convergence_plot.py @@ -0,0 +1,439 @@ +"""Convergence Plots to see the convergence of the simulation in real time.""" +from collections import defaultdict +import matplotlib.cm as cm +import matplotlib.colors as clr +import plotly.graph_objects as go +from IPython.display import display +import matplotlib as mpl +import ipywidgets as widgets +from contextlib import suppress +from traitlets import TraitError +from astropy import units as u + + +def transition_colors(length, name="jet"): + """ + Create colorscale for convergence plots, returns a list of colors. + + Parameters + ---------- + length : int + The length of the colorscale. + name : string, default: 'jet', optional + Name of the colorscale. + + Returns + ------- + colors: list + """ + cmap = mpl.cm.get_cmap(name, length) + colors = [] + for i in range(cmap.N): + rgb = cmap(i)[:3] + colors.append(mpl.colors.rgb2hex(rgb)) + return colors + + +class ConvergencePlots(object): + """ + Create and update convergence plots for visualizing convergence of the simulation. + + Parameters + ---------- + iterations : int + iteration number + **kwargs : dict, optional + Additional keyword arguments. These arguments are defined in the Other Parameters section. + + Other Parameters + ---------------- + plasma_plot_config : dict, optional + Dictionary used to override default plot properties of plasma plots. + t_inner_luminosities_config : dict, optional + Dictionary used to override default plot properties of the inner boundary temperature and luminosity plots. + plasma_cmap : str, default: 'jet', optional + String defining the cmap used in plasma plots. + t_inner_luminosities_colors : str or list, optional + String defining cmap for luminosity and inner boundary temperature plot. + The list can be a list of colors in rgb, hex or css-names format as well. + export_cplots : bool, default: False, optional + If True, plots are displayed again using the `notebook_connected` renderer. This helps + to display the plots in the documentation or in platforms like nbviewer. + + Notes + ----- + When overriding plot's configuration using the `plasma_plot_config` and the + `t_inner_luminosities_config` dictionaries, data related properties are + applied equally accross all traces. + The dictionary should have a structure like that of `plotly.graph_objs.FigureWidget.to_dict()`, + for more information please see https://plotly.com/python/figure-structure/ + """ + + def __init__(self, iterations, **kwargs): + self.iterable_data = {} + self.value_data = defaultdict(list) + self.iterations = iterations + self.current_iteration = 1 + self.luminosities = ["Emitted", "Absorbed", "Requested"] + self.plasma_plot = None + self.t_inner_luminosities_plot = None + + if "plasma_plot_config" in kwargs: + self.plasma_plot_config = kwargs["plasma_plot_config"] + + if "t_inner_luminosities_config" in kwargs: + self.t_inner_luminosities_config = kwargs[ + "t_inner_luminosities_config" + ] + + if "plasma_cmap" in kwargs: + self.plasma_colorscale = transition_colors( + name=kwargs["plasma_cmap"], length=self.iterations + ) + else: + # default color scale is jet + self.plasma_colorscale = transition_colors(length=self.iterations) + + if "t_inner_luminosities_colors" in kwargs: + # use cmap if string + if type(kwargs["t_inner_luminosities_colors"]) == str: + self.t_inner_luminosities_colors = transition_colors( + length=5, + name=kwargs["t_inner_luminosities_colors"], + ) + else: + self.t_inner_luminosities_colors = kwargs[ + "t_inner_luminosities_colors" + ] + else: + # using default plotly colors + self.t_inner_luminosities_colors = [None] * 5 + + def fetch_data(self, name=None, value=None, item_type=None): + """ + Fetch data from the Simulation class. + + Parameters + ---------- + name : string + name of the data + value : string or array + string or an array of quantities + item_type : string + either iterable or value + + """ + # trace data for plasma plots is added in iterable data dictionary + if item_type == "iterable": + self.iterable_data[name] = value + + # trace data for luminosity plots and inner boundary temperature plot is stored in value_data dictionary + if item_type == "value": + self.value_data[name].append(value) + + def create_plasma_plot(self): + """Create an empty plasma plot.""" + fig = go.FigureWidget().set_subplots(rows=1, cols=2, shared_xaxes=True) + + # empty traces to build figure + fig.add_scatter(row=1, col=1) + fig.add_scatter(row=1, col=2) + + # 2 y axes and 2 x axes correspond to the 2 subplots in the plasma plot + fig = fig.update_layout( + xaxis={ + "tickformat": "g", + "title": r"$\text{Velocity}~[\text{km}~\text{s}^{-1}]$", + }, + xaxis2={ + "tickformat": "g", + "title": r"$\text{Velocity}~[\text{km}~\text{s}^{-1}]$", + "matches": "x", + }, + yaxis={ + "tickformat": "g", + "title": r"$T_{\text{rad}}\ [\text{K}]$", + "nticks": 15, + }, + yaxis2={ + "tickformat": "g", + "title": r"$W$", + "nticks": 15, + }, + height=450, + legend_title_text="Iterations", + margin=dict( + l=10, r=135, b=25, t=25, pad=0 + ), # reduce whitespace surrounding the plot and increase right indentation to align with the t_inner and luminosity plot + ) + + # allow overriding default layout + if hasattr(self, "plasma_plot_config"): + self.override_plot_parameters(fig, self.plasma_plot_config) + self.plasma_plot = fig + + def create_t_inner_luminosities_plot(self): + """Create an empty t_inner and luminosity plot.""" + fig = go.FigureWidget().set_subplots( + rows=3, + cols=1, + shared_xaxes=True, + vertical_spacing=0.08, + row_heights=[0.25, 0.5, 0.25], + ) + + # add inner boundary temperature vs iterations plot + fig.add_scatter( + name="Inner
Boundary
Temperature", + row=1, + col=1, + hovertext="text", + marker_color=self.t_inner_luminosities_colors[0], + mode="lines", + ) + + # add luminosity vs iterations plot + # has three traces for emitted, requested and absorbed luminosities + for luminosity, line_color in zip( + self.luminosities, self.t_inner_luminosities_colors[1:4] + ): + fig.add_scatter( + name=luminosity + "
Luminosity", + mode="lines", + row=2, + col=1, + marker_color=line_color, + ) + + # add residual luminosity vs iterations plot + fig.add_scatter( + name="Residual
Luminosity", + row=3, + col=1, + marker_color=self.t_inner_luminosities_colors[4], + mode="lines", + ) + + # 3 y axes and 3 x axes correspond to the 3 subplots in the t_inner and luminosity convergence plot + fig = fig.update_layout( + xaxis=dict(range=[0, self.iterations + 1], dtick=2), + xaxis2=dict( + matches="x", + range=[0, self.iterations + 1], + dtick=2, + ), + xaxis3=dict( + title=r"$\mbox{Iteration Number}$", + dtick=2, + ), + yaxis=dict( + title=r"$T_{\text{inner}}\ [\text{K}]$", + automargin=True, + tickformat="g", + exponentformat="e", + nticks=4, + ), + yaxis2=dict( + exponentformat="e", + title=r"$\text{Luminosity}~[\text{erg s}^{-1}]$", + title_font_size=13, + automargin=True, + nticks=7, + ), + yaxis3=dict( + title=r"$~~\text{Residual}\\\text{Luminosity[%]}$", + title_font_size=12, + automargin=True, + nticks=4, + ), + height=630, + hoverlabel_align="right", + margin=dict( + b=25, t=25, pad=0 + ), # reduces whitespace surrounding the plot + ) + + # allow overriding default layout + if hasattr(self, "t_inner_luminosities_config"): + self.override_plot_parameters(fig, self.t_inner_luminosities_config) + + self.t_inner_luminosities_plot = fig + + def override_plot_parameters(self, fig, parameters): + """ + Override default plot properties. + + Any property inside the data dictionary is however, applied equally across all traces. + This means trace-specific data properties can't be changed using this function. + + Parameters + ---------- + fig : go.FigureWidget + FigureWidget object to be updated + parameters : dict + Dictionary used to update the default plot style. + """ + # because fig.data is a tuple of traces, a property in the data dictionary is applied to all traces + # the fig is a nested dictionary, any property n levels deep is not changed until the value is a not dictionary + # fig["property_1"]["property_2"]...["property_n"] = "value" + for key, value in parameters.items(): + if key == "data": + # all traces will have same data property + for trace in list(fig.data): + self.override_plot_parameters(trace, value) + else: + if type(value) == dict: + self.override_plot_parameters(fig[key], value) + else: + fig[key] = value + + def build(self, display_plot=True): + """ + Create empty convergence plots and display them. + + Parameters + ---------- + display_plot : bool, default: True, optional + Displays empty plots. + """ + self.create_plasma_plot() + self.create_t_inner_luminosities_plot() + + if display_plot: + display( + widgets.VBox( + [self.plasma_plot, self.t_inner_luminosities_plot], + ) + ) + + def update_plasma_plots(self): + """Update plasma convergence plots every iteration.""" + # convert velocity to km/s + x = self.iterable_data["velocity"].to(u.km / u.s).value.tolist() + + # add luminosity data in hover data in plasma plots + customdata = len(x) * [ + "
" + + "Emitted Luminosity: " + + f'{self.value_data["Absorbed"][-1]:.4g}' + + "
" + + "Requested Luminosity: " + + f'{self.value_data["Requested"][-1]:.4g}' + + "
" + + "Absorbed Luminosity: " + + f'{self.value_data["Requested"][-1]:.4g}' + ] + + # add a radiation temperature vs shell velocity trace to the plasma plot + self.plasma_plot.add_scatter( + x=x, + y=self.iterable_data["t_rad"], + line_color=self.plasma_colorscale[self.current_iteration - 1], + row=1, + col=1, + name=self.current_iteration, + legendgroup=f"group-{self.current_iteration}", + showlegend=False, + customdata=customdata, + hovertemplate="Y: %{y:.3f} at X = %{x:,.0f}%{customdata}", + ) + + # add a dilution factor vs shell velocity trace to the plasma plot + self.plasma_plot.add_scatter( + x=x, + y=self.iterable_data["w"], + line_color=self.plasma_colorscale[self.current_iteration - 1], + row=1, + col=2, + legendgroup=f"group-{self.current_iteration}", + name=self.current_iteration, + customdata=customdata, + hovertemplate="Y: %{y:.3f} at X = %{x:,.0f}%{customdata}", + ) + + def update_t_inner_luminosities_plot(self): + """Update the t_inner and luminosity convergence plots every iteration.""" + x = list(range(1, self.iterations + 1)) + + with self.t_inner_luminosities_plot.batch_update(): + # traces are updated according to the order they were added + # the first trace is of the inner boundary temperature plot + self.t_inner_luminosities_plot.data[0].x = x + self.t_inner_luminosities_plot.data[0].y = self.value_data[ + "t_inner" + ] + self.t_inner_luminosities_plot.data[ + 0 + ].hovertemplate = "%{y:.3f} at X = %{x:,.0f}Inner Boundary Temperature" # trace name in extra tag to avoid new lines in hoverdata + + # the next three for emitted, absorbed and requested luminosities + for index, luminosity in zip(range(1, 4), self.luminosities): + self.t_inner_luminosities_plot.data[index].x = x + self.t_inner_luminosities_plot.data[index].y = self.value_data[ + luminosity + ] + self.t_inner_luminosities_plot.data[index].hovertemplate = ( + "%{y:.4g}" + "
at X = %{x}
" + ) + + # last is for the residual luminosity + y = [ + ((emitted - requested) * 100) / requested + for emitted, requested in zip( + self.value_data["Emitted"], self.value_data["Requested"] + ) + ] + + self.t_inner_luminosities_plot.data[4].x = x + self.t_inner_luminosities_plot.data[4].y = y + self.t_inner_luminosities_plot.data[ + 4 + ].hovertemplate = "%{y:.2f}% at X = %{x:,.0f}" + + def update(self, export_cplots=False, last=False): + """ + Update the convergence plots every iteration. + + Parameters + ---------- + export_cplots : bool, default: False, optional + Displays the convergence plots again using plotly's `notebook_connected` renderer. + This helps to display the plots in notebooks when shared on platforms like nbviewer. + Please see https://plotly.com/python/renderers/ for more information. + last : bool, default: False, optional + True if it's last iteration. + """ + if self.iterable_data != {}: + # build only at first iteration + if self.current_iteration == 1: + self.build() + + self.update_plasma_plots() + self.update_t_inner_luminosities_plot() + + # data property for plasma plots needs to be + # updated after the last iteration because new traces have been added + if hasattr(self, "plasma_plot_config") and last: + if "data" in self.plasma_plot_config: + self.override_plot_parameters( + self.plasma_plot, self.plasma_plot_config + ) + + self.current_iteration += 1 + + # the display function expects a Widget, while + # fig.show() returns None, which causes the TraitError. + if export_cplots: + with suppress(TraitError): + display( + widgets.VBox( + [ + self.plasma_plot.show( + renderer="notebook_connected" + ), + self.t_inner_luminosities_plot.show( + renderer="notebook_connected" + ), + ] + ) + ) diff --git a/tardis/visualization/tools/tests/__init__.py b/tardis/visualization/tools/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tardis/visualization/tools/tests/test_convergence_plot.py b/tardis/visualization/tools/tests/test_convergence_plot.py new file mode 100644 index 00000000000..a84aa9e4f4a --- /dev/null +++ b/tardis/visualization/tools/tests/test_convergence_plot.py @@ -0,0 +1,207 @@ +"""Tests for Convergence Plots.""" +import pytest +from tardis.visualization.tools.convergence_plot import ( + ConvergencePlots, + transition_colors, +) +from collections import defaultdict +import plotly.graph_objects as go +from astropy import units as u + + +@pytest.fixture(scope="module", params=[0, 1, 2]) +def convergence_plots(request): + """Initialize ConvergencePlots class and build empty plots.""" + cplots = ConvergencePlots(iterations=request.param) + cplots.build(display_plot=False) + return cplots + + +@pytest.fixture() +def fetch_luminosity_data(convergence_plots): + """Prepare data for t_inner and luminosity plot.""" + for item in [2] * convergence_plots.iterations: + convergence_plots.fetch_data( + name="t_inner", value=item, item_type="value" + ) + for item2 in convergence_plots.luminosities: + convergence_plots.fetch_data( + name=item2, value=item, item_type="value" + ) + + +def test_transition_colors(): + """Test whether the object returned by the transition_colors function is a list of appropriate length.""" + iterations = 3 + colors = transition_colors(length=iterations) + assert type(colors) == list + assert len(colors) == iterations + + +def test_convergence_construction(convergence_plots): + """Test the construction of the ConvergencePlots class.""" + assert convergence_plots.iterable_data == {} + assert convergence_plots.value_data == defaultdict(list) + assert convergence_plots.luminosities == [ + "Emitted", + "Absorbed", + "Requested", + ] + + +def test_fetch_data(convergence_plots): + """Test values of variables updated by fetch_data function.""" + convergence_plots.fetch_data( + name="iterable", value=range(3), item_type="iterable" + ) + convergence_plots.fetch_data(name="value", value=0, item_type="value") + + assert convergence_plots.iterable_data["iterable"] == range(3) + assert convergence_plots.value_data["value"] == [0] + + +def test_build(convergence_plots): + """Test if convergence plots are instances of plotly.graph_objs.FigureWidget() and have appropriate number of traces.""" + assert type(convergence_plots.plasma_plot) == go.FigureWidget + assert type(convergence_plots.t_inner_luminosities_plot) == go.FigureWidget + + # check number of traces + assert len(convergence_plots.t_inner_luminosities_plot.data) == 5 + assert len(convergence_plots.plasma_plot.data) == 2 + + +@pytest.mark.usefixtures("fetch_luminosity_data") +def test_update_t_inner_luminosities_plot(convergence_plots): + """Test the number of traces and length of x and y values.""" + n_iterations = convergence_plots.iterations + convergence_plots.update_t_inner_luminosities_plot() + + # check number of traces + assert len(convergence_plots.t_inner_luminosities_plot.data) == 5 + + for index in range(0, 5): + # check x and y values for all traces + assert ( + len(convergence_plots.t_inner_luminosities_plot.data[index].x) + == n_iterations + ) + assert ( + len(convergence_plots.t_inner_luminosities_plot.data[index].y) + == n_iterations + ) + + # check range of x axes + for axis in ["xaxis", "xaxis2", "xaxis3"]: + convergence_plots.t_inner_luminosities_plot["layout"][axis]["range"] = [ + 0, + convergence_plots.iterations + 1, + ] + + +@pytest.mark.usefixtures("fetch_luminosity_data") +def test_update_plasma_plots(convergence_plots): + """Test the state of plasma plots after updating.""" + n_iterations = convergence_plots.iterations + expected_n_traces = 2 * n_iterations + 2 + velocity = range(0, n_iterations) * u.m / u.s + + convergence_plots.fetch_data( + name="velocity", value=velocity, item_type="iterable" + ) + + w_val = list(range(n_iterations)) + t_rad_val = [item * 2 for item in w_val] + + for _ in range(n_iterations): + convergence_plots.fetch_data( + name="t_rad", + value=t_rad_val, + item_type="iterable", + ) + convergence_plots.fetch_data( + name="w", + value=w_val, + item_type="iterable", + ) + convergence_plots.update_plasma_plots() + convergence_plots.current_iteration += 1 + + # check number of traces + assert len(convergence_plots.plasma_plot.data) == expected_n_traces + + # traces are added alternatively + # trace 0 and 1 and empty + assert convergence_plots.plasma_plot.data[0].x == None + assert convergence_plots.plasma_plot.data[1].x == None + + # check other traces + for index in list(range(expected_n_traces))[::2][1:]: + # check values for t_rad subplot + assert convergence_plots.plasma_plot.data[index].xaxis == "x" + assert convergence_plots.plasma_plot.data[index].yaxis == "y" + assert convergence_plots.plasma_plot.data[index].y == tuple(t_rad_val) + assert convergence_plots.plasma_plot.data[index].x == tuple( + velocity.to(u.km / u.s).value + ) + + for index in list(range(expected_n_traces))[1::2][1:]: + # check values for w subplot + assert convergence_plots.plasma_plot.data[index].xaxis == "x2" + assert convergence_plots.plasma_plot.data[index].yaxis == "y2" + assert convergence_plots.plasma_plot.data[index].y == tuple(w_val) + assert convergence_plots.plasma_plot.data[index].x == tuple( + velocity.to(u.km / u.s).value + ) + + +@pytest.mark.usefixtures("fetch_luminosity_data") +def test_override_plot_parameters(convergence_plots): + """Test if default plot properties are overridden properly.""" + parameters = { + "data": { + "line": {"dash": "dot"}, + "mode": "lines+markers", + }, + "layout": {"xaxis2": {"showgrid": False}}, + } + convergence_plots.override_plot_parameters( + convergence_plots.plasma_plot, parameters=parameters + ) + convergence_plots.override_plot_parameters( + convergence_plots.t_inner_luminosities_plot, parameters=parameters + ) + + # data properties will be applied across all traces equally + # testing plot parameters of t_inner and luminosity plot + for trace_index in range(5): + assert ( + convergence_plots.t_inner_luminosities_plot.data[ + trace_index + ].line.dash + == "dot" + ) + assert ( + convergence_plots.t_inner_luminosities_plot.data[trace_index].mode + == "lines+markers" + ) + # checking layout + assert ( + convergence_plots.t_inner_luminosities_plot["layout"]["xaxis2"][ + "showgrid" + ] + == False + ) + + # testing plot parameters for plasma plot + for trace_index in range(2): + assert ( + convergence_plots.plasma_plot.data[trace_index].line.dash == "dot" + ) + assert ( + convergence_plots.plasma_plot.data[trace_index].mode + == "lines+markers" + ) + # checking layout for plasma plot + assert ( + convergence_plots.plasma_plot["layout"]["xaxis2"]["showgrid"] == False + ) diff --git a/tardis/visualization/widgets/shell_info.py b/tardis/visualization/widgets/shell_info.py index b7528527d05..336c8f580cf 100644 --- a/tardis/visualization/widgets/shell_info.py +++ b/tardis/visualization/widgets/shell_info.py @@ -4,7 +4,7 @@ atomic_number2element_symbol, species_tuple_to_string, ) -from tardis.simulation import Simulation + from tardis.visualization.widgets.util import create_table_widget import pandas as pd