From 4b51b68dc8a0c5859e772eaf2f83ee8a51411ab2 Mon Sep 17 00:00:00 2001 From: Alexander Holas <70367168+AlexHls@users.noreply.github.com> Date: Wed, 6 Dec 2023 16:14:11 +0100 Subject: [PATCH] Fix viz sdec (#2490) * Fix loading/new variable names * Make blackbody optional * Ruff compliance * Rerun tests --- tardis/visualization/tools/sdec_plot.py | 224 ++++++++++++++---------- 1 file changed, 128 insertions(+), 96 deletions(-) diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index 63806535a39..7509e95a5a1 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -4,28 +4,26 @@ This plot is a spectral diagnostics plot similar to those originally proposed by M. Kromer (see, for example, Kromer et al. 2013, figure 4). """ -import numpy as np -import pandas as pd -import astropy.units as u -from astropy.modeling.models import BlackBody +import logging -import matplotlib.pyplot as plt +import astropy.units as u import matplotlib.cm as cm import matplotlib.colors as clr +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd import plotly.graph_objects as go +from astropy.modeling.models import BlackBody from tardis.util.base import ( atomic_number2element_symbol, element_symbol2atomic_number, - species_string_to_tuple, - species_tuple_to_string, - roman_to_int, int_to_roman, + roman_to_int, + species_string_to_tuple, ) from tardis.visualization import plot_util as pu -import logging - logger = logging.getLogger(__name__) @@ -249,9 +247,11 @@ def from_hdf(cls, hdf_fpath, packets_mode): .set_index("line_id") ) r_inner = u.Quantity( - hdf["/simulation/model/r_inner"].to_numpy(), "cm" + hdf["/simulation/simulation_state/r_inner"].to_numpy(), "cm" ) # Convert pd.Series to np.array to construct quantity from it - t_inner = u.Quantity(hdf["/simulation/model/scalars"].t_inner, "K") + t_inner = u.Quantity( + hdf["/simulation/simulation_state/scalars"].t_inner, "K" + ) time_of_simulation = u.Quantity( hdf["/simulation/transport/scalars"].time_of_simulation, "s" ) @@ -419,21 +419,21 @@ def from_simulation(cls, sim): """ if sim.transport.virt_logging: return cls( - dict( - virtual=SDECData.from_simulation(sim, "virtual"), - real=SDECData.from_simulation(sim, "real"), - ) + { + "virtual": SDECData.from_simulation(sim, "virtual"), + "real": SDECData.from_simulation(sim, "real"), + } ) else: return cls( - dict( - virtual=None, - real=SDECData.from_simulation(sim, "real"), - ) + { + "virtual": None, + "real": SDECData.from_simulation(sim, "real"), + } ) @classmethod - def from_hdf(cls, hdf_fpath): + def from_hdf(cls, hdf_fpath, packets_mode=None): """ Create an instance of SDECPlotter from a simulation HDF file. @@ -441,17 +441,39 @@ def from_hdf(cls, hdf_fpath): ---------- hdf_fpath : str Valid path to the HDF file where simulation is saved + packets_mode : {'virtual', 'real'}, optional + Mode of packets to be considered, either real or virtual. If not + specified, both modes are returned Returns ------- SDECPlotter """ - return cls( - dict( - virtual=SDECData.from_hdf(hdf_fpath, "virtual"), - real=SDECData.from_hdf(hdf_fpath, "real"), - ) + assert packets_mode in [None, "virtual", "real"], ( + "Invalid value passed to packets_mode. Only " + "allowed values are 'virtual', 'real' or None" ) + if packets_mode == "virtual": + return cls( + { + "virtual": SDECData.from_hdf(hdf_fpath, "virtual"), + "real": None, + } + ) + elif packets_mode == "real": + return cls( + { + "virtual": None, + "real": SDECData.from_hdf(hdf_fpath, "real"), + } + ) + else: + return cls( + { + "virtual": SDECData.from_hdf(hdf_fpath, "virtual"), + "real": SDECData.from_hdf(hdf_fpath, "real"), + } + ) def _parse_species_list(self, species_list): """ @@ -465,12 +487,11 @@ def _parse_species_list(self, species_list): (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) """ - if species_list is not None: # check if there are any digits in the species list. If there are, then exit. # species_list should only contain species in the Roman numeral # format, e.g. Si II, and each ion must contain a space - if any(char.isdigit() for char in " ".join(species_list)) == True: + if any(char.isdigit() for char in " ".join(species_list)) is True: raise ValueError( "All species must be in Roman numeral form, e.g. Si II" ) @@ -696,7 +717,7 @@ def _calculate_plotting_data( # Repeat this for the emission and absorption dfs # This will require creating a temporary list that includes 'noint' and 'escatter' # packets, because you don't want them dropped or included in 'other' - temp = [species for species in self._species_list] + temp = list(self._species_list) temp.append("noint") temp.append("escatter") mask = np.in1d( @@ -720,7 +741,7 @@ def _calculate_plotting_data( axis=1, ) - temp = [species for species in self._species_list] + temp = list(self._species_list) mask = np.in1d( np.array(list(self.absorption_luminosities_df.keys())), temp ) @@ -1093,6 +1114,7 @@ def generate_plot_mpl( cmapname="jet", nelements=None, species_list=None, + blackbody_photosphere=True, ): """ Generate Spectral element DEComposition (SDEC) Plot using matplotlib. @@ -1136,13 +1158,14 @@ def generate_plot_mpl( Must be given in Roman numeral format. Can include specific ions, a range of ions, individual elements, or any combination of these: e.g. ['Si II', 'Ca II', 'C', 'Fe I-V'] + blackbody_photosphere: bool + Whether to include the blackbody photosphere in the plot. Default value is True Returns ------- matplotlib.axes._subplots.AxesSubplot Axis on which SDEC Plot is created """ - # If species_list and nelements requested, tell user that nelements is ignored if species_list is not None and nelements is not None: logger.info( @@ -1193,7 +1216,7 @@ def generate_plot_mpl( if distance is None: raise ValueError( """ - Distance must be specified if an observed_spectrum is given + Distance must be specified if an observed_spectrum is given so that the model spectrum can be converted into flux space correctly. """ ) @@ -1214,12 +1237,13 @@ def generate_plot_mpl( ) # Plot photosphere - self.ax.plot( - self.plot_wavelength.value, - self.photosphere_luminosity.value, - "--r", - label="Blackbody Photosphere", - ) + if blackbody_photosphere: + self.ax.plot( + self.plot_wavelength.value, + self.photosphere_luminosity.value, + "--r", + label="Blackbody Photosphere", + ) # Set legends and labels self.ax.legend(fontsize=12) @@ -1299,23 +1323,25 @@ def _plot_emission_mpl(self): cmap=self.cmap, linewidth=0, ) - except: + except KeyError: # Add notifications that this species was not in the emission df if self._species_list is None: - logger.info( + info_msg = ( f"{atomic_number2element_symbol(identifier)}" f" is not in the emitted packets; skipping" ) + logger.info(info_msg) else: # Get the ion number and atomic number for each species ion_number = identifier % 100 atomic_number = (identifier - ion_number) / 100 - logger.info( + info_msg = ( f"{atomic_number2element_symbol(atomic_number)}" f"{int_to_roman(ion_number + 1)}" f" is not in the emitted packets; skipping" ) + logger.info(info_msg) def _plot_absorption_mpl(self): """Plot absorption part of the SDEC Plot using matplotlib.""" @@ -1355,27 +1381,28 @@ def _plot_absorption_mpl(self): linewidth=0, ) - except: + except KeyError: # Add notifications that this species was not in the emission df if self._species_list is None: - logger.info( + info_msg = ( f"{atomic_number2element_symbol(identifier)}" f" is not in the absorbed packets; skipping" ) + logger.info(info_msg) else: # Get the ion number and atomic number for each species ion_number = identifier % 100 atomic_number = (identifier - ion_number) / 100 - logger.info( + info_msg = ( f"{atomic_number2element_symbol(atomic_number)}" f"{int_to_roman(ion_number + 1)}" f" is not in the absorbed packets; skipping" ) + logger.info(info_msg) def _show_colorbar_mpl(self): """Show matplotlib colorbar with labels of elements mapped to colors.""" - color_values = [ self.cmap(species_counter / len(self._species_name)) for species_counter in range(len(self._species_name)) @@ -1449,13 +1476,11 @@ def _make_colorbar_colors(self): if previous_atomic_number == 0: # If this is the first species being plotted, then take note of the atomic number # don't update the colour index - color_counter = color_counter previous_atomic_number = atomic_number elif previous_atomic_number in self._keep_colour: # If the atomic number is in the list of elements that should all be plotted in the same colour # then don't update the colour index if this element has been plotted already if previous_atomic_number == atomic_number: - color_counter = color_counter previous_atomic_number = atomic_number else: # Otherwise, increase the colour counter by one, because this is a new element @@ -1489,6 +1514,7 @@ def generate_plot_ply( cmapname="jet", nelements=None, species_list=None, + blackbody_photosphere=True, ): """ Generate interactive Spectral element DEComposition (SDEC) Plot using plotly. @@ -1532,12 +1558,14 @@ def generate_plot_ply( Must be given in Roman numeral format. Can include specific ions, a range of ions, individual elements, or any combination of these: e.g. ['Si II', 'Ca II', 'C', 'Fe I-V'] + blackbody_photosphere: bool + Whether to include the blackbody photosphere in the plot. Default value is True + Returns ------- plotly.graph_objs._figure.Figure Figure object on which SDEC Plot is created """ - # If species_list and nelements requested, tell user that nelements is ignored if species_list is not None and nelements is not None: logger.info( @@ -1579,13 +1607,13 @@ def generate_plot_ply( x=self.plot_wavelength.value, y=self.modeled_spectrum_luminosity.value, mode="lines", - line=dict( - color="blue", - width=1, - ), + line={ + "color": "blue", + "width": 1, + }, name=f"{packets_mode.capitalize()} Spectrum", hovertemplate="(%{x:.2f}, %{y:.3g})", - hoverlabel=dict(namelength=-1), + hoverlabel={"namelength": -1}, ) ) @@ -1594,7 +1622,7 @@ def generate_plot_ply( if distance is None: raise ValueError( """ - Distance must be specified if an observed_spectrum is given + Distance must be specified if an observed_spectrum is given so that the model spectrum can be converted into flux space correctly. """ ) @@ -1611,22 +1639,23 @@ def generate_plot_ply( y=observed_spectrum_flux.value, name="Observed Spectrum", line={"color": "black", "width": 1.2}, - hoverlabel=dict(namelength=-1), + hoverlabel={"namelength": -1}, hovertemplate="(%{x:.2f}, %{y:.3g})", ) # Plot photosphere - self.fig.add_trace( - go.Scatter( - x=self.plot_wavelength.value, - y=self.photosphere_luminosity.value, - mode="lines", - line=dict(width=1.5, color="red", dash="dash"), - name="Blackbody Photosphere", - hoverlabel=dict(namelength=-1), - hovertemplate="(%{x:.2f}, %{y:.3g})", + if blackbody_photosphere: + self.fig.add_trace( + go.Scatter( + x=self.plot_wavelength.value, + y=self.photosphere_luminosity.value, + mode="lines", + line={"width": 1.5, "color": "red", "dash": "dash"}, + name="Blackbody Photosphere", + hoverlabel={"namelength": -1}, + hovertemplate="(%{x:.2f}, %{y:.3g})", + ) ) - ) self._show_colorbar_ply() @@ -1641,11 +1670,11 @@ def generate_plot_ply( "L_{\\lambda}", u.Unit("erg/(s AA)"), only_text=False ) self.fig.update_layout( - xaxis=dict( - title=xlabel, - exponentformat="none", - ), - yaxis=dict(title=ylabel, exponentformat="e"), + xaxis={ + "title": xlabel, + "exponentformat": "none", + }, + yaxis={"title": ylabel, "exponentformat": "e"}, height=graph_height, ) @@ -1694,7 +1723,7 @@ def _plot_emission_ply(self): name="Electron Scatter Only", fillcolor="#8F8F8F", stackgroup="emission", - hoverlabel=dict(namelength=-1), + hoverlabel={"namelength": -1}, hovertemplate="(%{x:.2f}, %{y:.3g})", ) ) @@ -1724,37 +1753,38 @@ def _plot_emission_ply(self): y=self.emission_luminosities_df[identifier], mode="none", name=species_name + " Emission", - hovertemplate=f"{species_name} Emission " + hovertemplate=f"{species_name:s} Emission
" # noqa: ISC003 + "(%{x:.2f}, %{y:.3g})", fillcolor=self.to_rgb255_string( self._color_list[species_counter] ), stackgroup="emission", showlegend=False, - hoverlabel=dict(namelength=-1), + hoverlabel={"namelength": -1}, ) ) - except: + except KeyError: # Add notifications that this species was not in the emission df if self._species_list is None: - logger.info( + info_msg = ( f"{atomic_number2element_symbol(identifier)}" f" is not in the emitted packets; skipping" ) + logger.info(info_msg) else: # Get the ion number and atomic number for each species ion_number = identifier % 100 atomic_number = (identifier - ion_number) / 100 - logger.info( + info_msg = ( f"{atomic_number2element_symbol(atomic_number)}" f"{int_to_roman(ion_number + 1)}" f" is not in the emitted packets; skipping" ) + logger.info(info_msg) def _plot_absorption_ply(self): """Plot absorption part of the SDEC Plot using plotly.""" - # If 'other' column exists then plot as silver if "other" in self.absorption_luminosities_df.keys(): self.fig.add_trace( @@ -1782,34 +1812,36 @@ def _plot_absorption_ply(self): y=self.absorption_luminosities_df[identifier] * -1, mode="none", name=species_name + " Absorption", - hovertemplate=f"{species_name} Absorption " + hovertemplate=f"{species_name:s} Absorption
" # noqa: ISC003 + "(%{x:.2f}, %{y:.3g})", fillcolor=self.to_rgb255_string( self._color_list[species_counter] ), stackgroup="absorption", showlegend=False, - hoverlabel=dict(namelength=-1), + hoverlabel={"namelength": -1}, ) ) - except: + except KeyError: # Add notifications that this species was not in the emission df if self._species_list is None: - logger.info( + info_msg = ( f"{atomic_number2element_symbol(identifier)}" f" is not in the absorbed packets; skipping" ) + logger.info(info_msg) else: # Get the ion number and atomic number for each species ion_number = identifier % 100 atomic_number = (identifier - ion_number) / 100 - logger.info( + info_msg = ( f"{atomic_number2element_symbol(atomic_number)}" f"{int_to_roman(ion_number + 1)}" f" is not in the absorbed packets; skipping" ) + logger.info(info_msg) def _show_colorbar_ply(self): """Show plotly colorbar with labels of elements mapped to colors.""" @@ -1831,21 +1863,21 @@ def _show_colorbar_ply(self): (colorscale_bins[species_counter + 1], color) ) - coloraxis_options = dict( - colorscale=categorical_colorscale, - showscale=True, - cmin=0, - cmax=len(self._species_name), - colorbar=dict( - title="Elements", - tickvals=np.arange(0, len(self._species_name)) + 0.5, - ticktext=self._species_name, + coloraxis_options = { + "colorscale": categorical_colorscale, + "showscale": True, + "cmin": 0, + "cmax": len(self._species_name), + "colorbar": { + "title": "Elements", + "tickvals": np.arange(0, len(self._species_name)) + 0.5, + "ticktext": self._species_name, # to change length and position of colorbar - len=0.75, - yanchor="top", - y=0.75, - ), - ) + "len": 0.75, + "yanchor": "top", + "y": 0.75, + }, + } # Plot an invisible one point scatter trace, to make colorbar show up scatter_point_idx = pu.get_mid_point_idx(self.plot_wavelength)