From 4b51b68dc8a0c5859e772eaf2f83ee8a51411ab2 Mon Sep 17 00:00:00 2001
From: Alexander Holas <70367168+AlexHls@users.noreply.github.com>
Date: Wed, 6 Dec 2023 16:14:11 +0100
Subject: [PATCH] Fix viz sdec (#2490)
* Fix loading/new variable names
* Make blackbody optional
* Ruff compliance
* Rerun tests
---
tardis/visualization/tools/sdec_plot.py | 224 ++++++++++++++----------
1 file changed, 128 insertions(+), 96 deletions(-)
diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py
index 63806535a39..7509e95a5a1 100644
--- a/tardis/visualization/tools/sdec_plot.py
+++ b/tardis/visualization/tools/sdec_plot.py
@@ -4,28 +4,26 @@
This plot is a spectral diagnostics plot similar to those originally
proposed by M. Kromer (see, for example, Kromer et al. 2013, figure 4).
"""
-import numpy as np
-import pandas as pd
-import astropy.units as u
-from astropy.modeling.models import BlackBody
+import logging
-import matplotlib.pyplot as plt
+import astropy.units as u
import matplotlib.cm as cm
import matplotlib.colors as clr
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
import plotly.graph_objects as go
+from astropy.modeling.models import BlackBody
from tardis.util.base import (
atomic_number2element_symbol,
element_symbol2atomic_number,
- species_string_to_tuple,
- species_tuple_to_string,
- roman_to_int,
int_to_roman,
+ roman_to_int,
+ species_string_to_tuple,
)
from tardis.visualization import plot_util as pu
-import logging
-
logger = logging.getLogger(__name__)
@@ -249,9 +247,11 @@ def from_hdf(cls, hdf_fpath, packets_mode):
.set_index("line_id")
)
r_inner = u.Quantity(
- hdf["/simulation/model/r_inner"].to_numpy(), "cm"
+ hdf["/simulation/simulation_state/r_inner"].to_numpy(), "cm"
) # Convert pd.Series to np.array to construct quantity from it
- t_inner = u.Quantity(hdf["/simulation/model/scalars"].t_inner, "K")
+ t_inner = u.Quantity(
+ hdf["/simulation/simulation_state/scalars"].t_inner, "K"
+ )
time_of_simulation = u.Quantity(
hdf["/simulation/transport/scalars"].time_of_simulation, "s"
)
@@ -419,21 +419,21 @@ def from_simulation(cls, sim):
"""
if sim.transport.virt_logging:
return cls(
- dict(
- virtual=SDECData.from_simulation(sim, "virtual"),
- real=SDECData.from_simulation(sim, "real"),
- )
+ {
+ "virtual": SDECData.from_simulation(sim, "virtual"),
+ "real": SDECData.from_simulation(sim, "real"),
+ }
)
else:
return cls(
- dict(
- virtual=None,
- real=SDECData.from_simulation(sim, "real"),
- )
+ {
+ "virtual": None,
+ "real": SDECData.from_simulation(sim, "real"),
+ }
)
@classmethod
- def from_hdf(cls, hdf_fpath):
+ def from_hdf(cls, hdf_fpath, packets_mode=None):
"""
Create an instance of SDECPlotter from a simulation HDF file.
@@ -441,17 +441,39 @@ def from_hdf(cls, hdf_fpath):
----------
hdf_fpath : str
Valid path to the HDF file where simulation is saved
+ packets_mode : {'virtual', 'real'}, optional
+ Mode of packets to be considered, either real or virtual. If not
+ specified, both modes are returned
Returns
-------
SDECPlotter
"""
- return cls(
- dict(
- virtual=SDECData.from_hdf(hdf_fpath, "virtual"),
- real=SDECData.from_hdf(hdf_fpath, "real"),
- )
+ assert packets_mode in [None, "virtual", "real"], (
+ "Invalid value passed to packets_mode. Only "
+ "allowed values are 'virtual', 'real' or None"
)
+ if packets_mode == "virtual":
+ return cls(
+ {
+ "virtual": SDECData.from_hdf(hdf_fpath, "virtual"),
+ "real": None,
+ }
+ )
+ elif packets_mode == "real":
+ return cls(
+ {
+ "virtual": None,
+ "real": SDECData.from_hdf(hdf_fpath, "real"),
+ }
+ )
+ else:
+ return cls(
+ {
+ "virtual": SDECData.from_hdf(hdf_fpath, "virtual"),
+ "real": SDECData.from_hdf(hdf_fpath, "real"),
+ }
+ )
def _parse_species_list(self, species_list):
"""
@@ -465,12 +487,11 @@ def _parse_species_list(self, species_list):
(e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca])
"""
-
if species_list is not None:
# check if there are any digits in the species list. If there are, then exit.
# species_list should only contain species in the Roman numeral
# format, e.g. Si II, and each ion must contain a space
- if any(char.isdigit() for char in " ".join(species_list)) == True:
+ if any(char.isdigit() for char in " ".join(species_list)) is True:
raise ValueError(
"All species must be in Roman numeral form, e.g. Si II"
)
@@ -696,7 +717,7 @@ def _calculate_plotting_data(
# Repeat this for the emission and absorption dfs
# This will require creating a temporary list that includes 'noint' and 'escatter'
# packets, because you don't want them dropped or included in 'other'
- temp = [species for species in self._species_list]
+ temp = list(self._species_list)
temp.append("noint")
temp.append("escatter")
mask = np.in1d(
@@ -720,7 +741,7 @@ def _calculate_plotting_data(
axis=1,
)
- temp = [species for species in self._species_list]
+ temp = list(self._species_list)
mask = np.in1d(
np.array(list(self.absorption_luminosities_df.keys())), temp
)
@@ -1093,6 +1114,7 @@ def generate_plot_mpl(
cmapname="jet",
nelements=None,
species_list=None,
+ blackbody_photosphere=True,
):
"""
Generate Spectral element DEComposition (SDEC) Plot using matplotlib.
@@ -1136,13 +1158,14 @@ def generate_plot_mpl(
Must be given in Roman numeral format. Can include specific ions, a range of ions,
individual elements, or any combination of these:
e.g. ['Si II', 'Ca II', 'C', 'Fe I-V']
+ blackbody_photosphere: bool
+ Whether to include the blackbody photosphere in the plot. Default value is True
Returns
-------
matplotlib.axes._subplots.AxesSubplot
Axis on which SDEC Plot is created
"""
-
# If species_list and nelements requested, tell user that nelements is ignored
if species_list is not None and nelements is not None:
logger.info(
@@ -1193,7 +1216,7 @@ def generate_plot_mpl(
if distance is None:
raise ValueError(
"""
- Distance must be specified if an observed_spectrum is given
+ Distance must be specified if an observed_spectrum is given
so that the model spectrum can be converted into flux space correctly.
"""
)
@@ -1214,12 +1237,13 @@ def generate_plot_mpl(
)
# Plot photosphere
- self.ax.plot(
- self.plot_wavelength.value,
- self.photosphere_luminosity.value,
- "--r",
- label="Blackbody Photosphere",
- )
+ if blackbody_photosphere:
+ self.ax.plot(
+ self.plot_wavelength.value,
+ self.photosphere_luminosity.value,
+ "--r",
+ label="Blackbody Photosphere",
+ )
# Set legends and labels
self.ax.legend(fontsize=12)
@@ -1299,23 +1323,25 @@ def _plot_emission_mpl(self):
cmap=self.cmap,
linewidth=0,
)
- except:
+ except KeyError:
# Add notifications that this species was not in the emission df
if self._species_list is None:
- logger.info(
+ info_msg = (
f"{atomic_number2element_symbol(identifier)}"
f" is not in the emitted packets; skipping"
)
+ logger.info(info_msg)
else:
# Get the ion number and atomic number for each species
ion_number = identifier % 100
atomic_number = (identifier - ion_number) / 100
- logger.info(
+ info_msg = (
f"{atomic_number2element_symbol(atomic_number)}"
f"{int_to_roman(ion_number + 1)}"
f" is not in the emitted packets; skipping"
)
+ logger.info(info_msg)
def _plot_absorption_mpl(self):
"""Plot absorption part of the SDEC Plot using matplotlib."""
@@ -1355,27 +1381,28 @@ def _plot_absorption_mpl(self):
linewidth=0,
)
- except:
+ except KeyError:
# Add notifications that this species was not in the emission df
if self._species_list is None:
- logger.info(
+ info_msg = (
f"{atomic_number2element_symbol(identifier)}"
f" is not in the absorbed packets; skipping"
)
+ logger.info(info_msg)
else:
# Get the ion number and atomic number for each species
ion_number = identifier % 100
atomic_number = (identifier - ion_number) / 100
- logger.info(
+ info_msg = (
f"{atomic_number2element_symbol(atomic_number)}"
f"{int_to_roman(ion_number + 1)}"
f" is not in the absorbed packets; skipping"
)
+ logger.info(info_msg)
def _show_colorbar_mpl(self):
"""Show matplotlib colorbar with labels of elements mapped to colors."""
-
color_values = [
self.cmap(species_counter / len(self._species_name))
for species_counter in range(len(self._species_name))
@@ -1449,13 +1476,11 @@ def _make_colorbar_colors(self):
if previous_atomic_number == 0:
# If this is the first species being plotted, then take note of the atomic number
# don't update the colour index
- color_counter = color_counter
previous_atomic_number = atomic_number
elif previous_atomic_number in self._keep_colour:
# If the atomic number is in the list of elements that should all be plotted in the same colour
# then don't update the colour index if this element has been plotted already
if previous_atomic_number == atomic_number:
- color_counter = color_counter
previous_atomic_number = atomic_number
else:
# Otherwise, increase the colour counter by one, because this is a new element
@@ -1489,6 +1514,7 @@ def generate_plot_ply(
cmapname="jet",
nelements=None,
species_list=None,
+ blackbody_photosphere=True,
):
"""
Generate interactive Spectral element DEComposition (SDEC) Plot using plotly.
@@ -1532,12 +1558,14 @@ def generate_plot_ply(
Must be given in Roman numeral format. Can include specific ions, a range of ions,
individual elements, or any combination of these:
e.g. ['Si II', 'Ca II', 'C', 'Fe I-V']
+ blackbody_photosphere: bool
+ Whether to include the blackbody photosphere in the plot. Default value is True
+
Returns
-------
plotly.graph_objs._figure.Figure
Figure object on which SDEC Plot is created
"""
-
# If species_list and nelements requested, tell user that nelements is ignored
if species_list is not None and nelements is not None:
logger.info(
@@ -1579,13 +1607,13 @@ def generate_plot_ply(
x=self.plot_wavelength.value,
y=self.modeled_spectrum_luminosity.value,
mode="lines",
- line=dict(
- color="blue",
- width=1,
- ),
+ line={
+ "color": "blue",
+ "width": 1,
+ },
name=f"{packets_mode.capitalize()} Spectrum",
hovertemplate="(%{x:.2f}, %{y:.3g})",
- hoverlabel=dict(namelength=-1),
+ hoverlabel={"namelength": -1},
)
)
@@ -1594,7 +1622,7 @@ def generate_plot_ply(
if distance is None:
raise ValueError(
"""
- Distance must be specified if an observed_spectrum is given
+ Distance must be specified if an observed_spectrum is given
so that the model spectrum can be converted into flux space correctly.
"""
)
@@ -1611,22 +1639,23 @@ def generate_plot_ply(
y=observed_spectrum_flux.value,
name="Observed Spectrum",
line={"color": "black", "width": 1.2},
- hoverlabel=dict(namelength=-1),
+ hoverlabel={"namelength": -1},
hovertemplate="(%{x:.2f}, %{y:.3g})",
)
# Plot photosphere
- self.fig.add_trace(
- go.Scatter(
- x=self.plot_wavelength.value,
- y=self.photosphere_luminosity.value,
- mode="lines",
- line=dict(width=1.5, color="red", dash="dash"),
- name="Blackbody Photosphere",
- hoverlabel=dict(namelength=-1),
- hovertemplate="(%{x:.2f}, %{y:.3g})",
+ if blackbody_photosphere:
+ self.fig.add_trace(
+ go.Scatter(
+ x=self.plot_wavelength.value,
+ y=self.photosphere_luminosity.value,
+ mode="lines",
+ line={"width": 1.5, "color": "red", "dash": "dash"},
+ name="Blackbody Photosphere",
+ hoverlabel={"namelength": -1},
+ hovertemplate="(%{x:.2f}, %{y:.3g})",
+ )
)
- )
self._show_colorbar_ply()
@@ -1641,11 +1670,11 @@ def generate_plot_ply(
"L_{\\lambda}", u.Unit("erg/(s AA)"), only_text=False
)
self.fig.update_layout(
- xaxis=dict(
- title=xlabel,
- exponentformat="none",
- ),
- yaxis=dict(title=ylabel, exponentformat="e"),
+ xaxis={
+ "title": xlabel,
+ "exponentformat": "none",
+ },
+ yaxis={"title": ylabel, "exponentformat": "e"},
height=graph_height,
)
@@ -1694,7 +1723,7 @@ def _plot_emission_ply(self):
name="Electron Scatter Only",
fillcolor="#8F8F8F",
stackgroup="emission",
- hoverlabel=dict(namelength=-1),
+ hoverlabel={"namelength": -1},
hovertemplate="(%{x:.2f}, %{y:.3g})",
)
)
@@ -1724,37 +1753,38 @@ def _plot_emission_ply(self):
y=self.emission_luminosities_df[identifier],
mode="none",
name=species_name + " Emission",
- hovertemplate=f"{species_name} Emission "
+ hovertemplate=f"{species_name:s} Emission
" # noqa: ISC003
+ "(%{x:.2f}, %{y:.3g})",
fillcolor=self.to_rgb255_string(
self._color_list[species_counter]
),
stackgroup="emission",
showlegend=False,
- hoverlabel=dict(namelength=-1),
+ hoverlabel={"namelength": -1},
)
)
- except:
+ except KeyError:
# Add notifications that this species was not in the emission df
if self._species_list is None:
- logger.info(
+ info_msg = (
f"{atomic_number2element_symbol(identifier)}"
f" is not in the emitted packets; skipping"
)
+ logger.info(info_msg)
else:
# Get the ion number and atomic number for each species
ion_number = identifier % 100
atomic_number = (identifier - ion_number) / 100
- logger.info(
+ info_msg = (
f"{atomic_number2element_symbol(atomic_number)}"
f"{int_to_roman(ion_number + 1)}"
f" is not in the emitted packets; skipping"
)
+ logger.info(info_msg)
def _plot_absorption_ply(self):
"""Plot absorption part of the SDEC Plot using plotly."""
-
# If 'other' column exists then plot as silver
if "other" in self.absorption_luminosities_df.keys():
self.fig.add_trace(
@@ -1782,34 +1812,36 @@ def _plot_absorption_ply(self):
y=self.absorption_luminosities_df[identifier] * -1,
mode="none",
name=species_name + " Absorption",
- hovertemplate=f"{species_name} Absorption "
+ hovertemplate=f"{species_name:s} Absorption
" # noqa: ISC003
+ "(%{x:.2f}, %{y:.3g})",
fillcolor=self.to_rgb255_string(
self._color_list[species_counter]
),
stackgroup="absorption",
showlegend=False,
- hoverlabel=dict(namelength=-1),
+ hoverlabel={"namelength": -1},
)
)
- except:
+ except KeyError:
# Add notifications that this species was not in the emission df
if self._species_list is None:
- logger.info(
+ info_msg = (
f"{atomic_number2element_symbol(identifier)}"
f" is not in the absorbed packets; skipping"
)
+ logger.info(info_msg)
else:
# Get the ion number and atomic number for each species
ion_number = identifier % 100
atomic_number = (identifier - ion_number) / 100
- logger.info(
+ info_msg = (
f"{atomic_number2element_symbol(atomic_number)}"
f"{int_to_roman(ion_number + 1)}"
f" is not in the absorbed packets; skipping"
)
+ logger.info(info_msg)
def _show_colorbar_ply(self):
"""Show plotly colorbar with labels of elements mapped to colors."""
@@ -1831,21 +1863,21 @@ def _show_colorbar_ply(self):
(colorscale_bins[species_counter + 1], color)
)
- coloraxis_options = dict(
- colorscale=categorical_colorscale,
- showscale=True,
- cmin=0,
- cmax=len(self._species_name),
- colorbar=dict(
- title="Elements",
- tickvals=np.arange(0, len(self._species_name)) + 0.5,
- ticktext=self._species_name,
+ coloraxis_options = {
+ "colorscale": categorical_colorscale,
+ "showscale": True,
+ "cmin": 0,
+ "cmax": len(self._species_name),
+ "colorbar": {
+ "title": "Elements",
+ "tickvals": np.arange(0, len(self._species_name)) + 0.5,
+ "ticktext": self._species_name,
# to change length and position of colorbar
- len=0.75,
- yanchor="top",
- y=0.75,
- ),
- )
+ "len": 0.75,
+ "yanchor": "top",
+ "y": 0.75,
+ },
+ }
# Plot an invisible one point scatter trace, to make colorbar show up
scatter_point_idx = pu.get_mid_point_idx(self.plot_wavelength)