Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moved _parse_species_list function to tardis.io.util.py #2837

Closed
wants to merge 9 commits into from
2 changes: 2 additions & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,5 @@ Israel Roldan <[email protected]> AirvZxf <[email protected]
Israel Roldan <[email protected]> airv_zxf <[email protected]>

Michael Zingale <[email protected]>

Rudraksh Nalbalwar <[email protected]>
53 changes: 52 additions & 1 deletion tardis/io/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,57 @@ def yaml_load_file(filename, loader=yaml.Loader):
return yaml.load(stream, Loader=loader)


def parse_species_list(sdec_plotter, data, species_list, packets_mode, nelements=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it take a plotter argument? What is the purpose of this function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it relies on some methods and attributes within sdec_plotter to parse and organize the species list.

"""
Parse user requested species list and create list of species ids to be used.

Parameters
----------
species_list : list of species to plot
List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours.
Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions
(e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca])
packets_mode : str, optional
Packet mode, either 'virtual' or 'real'. Default is 'virtual'.
nelements : int, optional
Number of elements to include in plot. The most interacting elements are included. If None, displays all elements.

Raises
------
ValueError
If species list contains invalid entries.

"""
from tardis.util.base import atomic_number2element_symbol
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this import defined here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is defined inside the function to break the circular dependency between tardis.io.util and tardis.util.base

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and in the below sdec_plot
The parse_species_list function in util.py builds on the sdec_plot.py version by not only parsing species but also filtering them based on interaction frequency in the data. This makes it more data-driven and adaptable for cases where top interacting elements are needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should be defined at the start of the file right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but importing this statement at the start of the file creates a circular dependency. Specifically, atomic_number2element_symbol is imported from tardis.util.base, and within tardis.util.base, get_internal_data_path is imported from util.py, which triggers a circular loop. Refactoring the functions to avoid this would increase code complexity and could introduce new errors.

sdec_plotter.parse_species_list(species_list)
_species_list = sdec_plotter._species_list
_species_mapped = sdec_plotter._species_mapped
_keep_colour = sdec_plotter._keep_colour

if nelements:
interaction_counts = (
data[packets_mode]
.packets_df_line_interaction["last_line_interaction_species"]
.value_counts()
)
interaction_counts.index = interaction_counts.index // 100
element_counts = interaction_counts.groupby(
interaction_counts.index
).sum()
top_elements = element_counts.nlargest(nelements).index
top_species_list = [
atomic_number2element_symbol(element)
for element in top_elements
]
sub_species_list, sub_species_mapped, sub_keep_colour = parse_species_list(
sdec_plotter, data, top_species_list, packets_mode
)
_species_list = sub_species_list
_species_mapped = sub_species_mapped
_keep_colour = sub_keep_colour

return _species_list, _species_mapped, _keep_colour

def traverse_configs(base, other, func, *args):
"""
Recursively traverse a base dict or list along with another one
Expand All @@ -160,7 +211,7 @@ def traverse_configs(base, other, func, *args):
traverse_configs(base[k], other[k], func, *args)
elif (
isinstance(base, collections_abc.Iterable)
and not isinstance(base, basestring)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this changed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the change was made to ensure that function is compatible with python 3

and not isinstance(base, str)
and not hasattr(base, "shape")
):
for val1, val2 in zip(base, other):
Expand Down
52 changes: 8 additions & 44 deletions tardis/visualization/tools/liv_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
import tardis.visualization.tools.sdec_plot as sdec
from tardis.visualization import plot_util as pu
from tardis.io.util import parse_species_list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,49 +100,6 @@ def from_hdf(cls, hdf_fpath):
velocity,
)

def _parse_species_list(self, species_list, packets_mode, nelements=None):
"""
Parse user requested species list and create list of species ids to be used.

Parameters
----------
species_list : list of species to plot
List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours.
Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions
(e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca])
packets_mode : str, optional
Packet mode, either 'virtual' or 'real'. Default is 'virtual'.
nelements : int, optional
Number of elements to include in plot. The most interacting elements are included. If None, displays all elements.

Raises
------
ValueError
If species list contains invalid entries.

"""
self.sdec_plotter._parse_species_list(species_list)
self._species_list = self.sdec_plotter._species_list
self._species_mapped = self.sdec_plotter._species_mapped
self._keep_colour = self.sdec_plotter._keep_colour

if nelements:
interaction_counts = (
self.data[packets_mode]
.packets_df_line_interaction["last_line_interaction_species"]
.value_counts()
)
interaction_counts.index = interaction_counts.index // 100
element_counts = interaction_counts.groupby(
interaction_counts.index
).sum()
top_elements = element_counts.nlargest(nelements).index
top_species_list = [
atomic_number2element_symbol(element)
for element in top_elements
]
self._parse_species_list(top_species_list, packets_mode)

def _make_colorbar_labels(self):
"""
Generate labels for the colorbar based on species.
Expand Down Expand Up @@ -296,7 +254,13 @@ def _prepare_plot_data(
f"{atomic_number2element_symbol(specie // 100)}"
for specie in species_in_model
]
self._parse_species_list(species_list, packets_mode, nelements)
self._species_list, self._species_mapped, self._keep_colour = parse_species_list(
sdec_plotter=self.sdec_plotter,
data=self.data,
species_list=species_list,
packets_mode=packets_mode,
nelements=nelements,
)
species_in_model = np.unique(
self.data[packets_mode]
.packets_df_line_interaction["last_line_interaction_species"]
Expand Down
108 changes: 15 additions & 93 deletions tardis/visualization/tools/sdec_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
species_string_to_tuple,
)
from tardis.visualization import plot_util as pu

from tardis.io.util import parse_species_list
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -508,95 +508,7 @@ def from_hdf(cls, hdf_fpath, packets_mode=None):
}
)

def _parse_species_list(self, species_list):
"""
Parse user requested species list and create list of species ids to be used.

Parameters
----------
species_list : list of species to plot
List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours.
Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions
(e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca])

"""
if species_list is not None:
# check if there are any digits in the species list. If there are, then exit.
# species_list should only contain species in the Roman numeral
# format, e.g. Si II, and each ion must contain a space
if any(char.isdigit() for char in " ".join(species_list)) is True:
raise ValueError(
"All species must be in Roman numeral form, e.g. Si II"
)
else:
full_species_list = []
species_mapped = {}
for species in species_list:
# check if a hyphen is present. If it is, then it indicates a
# range of ions. Add each ion in that range to the list as a new entry
if "-" in species:
# split the string on spaces. First thing in the list is then the element
element = species.split(" ")[0]
# Next thing is the ion range
# convert the requested ions into numerals
first_ion_numeral = roman_to_int(
species.split(" ")[-1].split("-")[0]
)
second_ion_numeral = roman_to_int(
species.split(" ")[-1].split("-")[-1]
)
# add each ion between the two requested into the species list
for ion_number in np.arange(
first_ion_numeral, second_ion_numeral + 1
):
full_species_list.append(
f"{element} {int_to_roman(ion_number)}"
)
else:
# Otherwise it's either an element or ion so just add to the list
full_species_list.append(species)

# full_species_list is now a list containing each individual species requested
# e.g. it parses species_list = [Si I - V] into species_list = [Si I, Si II, Si III, Si IV, Si V]
self._full_species_list = full_species_list
requested_species_ids = []
keep_colour = []

# go through each of the requested species. Check whether it is
# an element or ion (ions have spaces). If it is an element,
# add all possible ions to the ions list. Otherwise just add
# the requested ion
for species in full_species_list:
if " " in species:
species_id = (
species_string_to_tuple(species)[0] * 100
+ species_string_to_tuple(species)[1]
)
requested_species_ids.append([species_id])
species_mapped[species_id] = [species_id]
else:
atomic_number = element_symbol2atomic_number(species)
species_ids = [
atomic_number * 100 + ion_number
for ion_number in np.arange(atomic_number)
]
requested_species_ids.append(species_ids)
species_mapped[atomic_number * 100] = species_ids
# add the atomic number to a list so you know that this element should
# have all species in the same colour, i.e. it was requested like
# species_list = [Si]
keep_colour.append(atomic_number)
requested_species_ids = [
species_id
for temp_list in requested_species_ids
for species_id in temp_list
]

self._species_mapped = species_mapped
self._species_list = requested_species_ids
self._keep_colour = keep_colour
else:
self._species_list = None

def _calculate_plotting_data(
self, packets_mode, packet_wvl_range, distance, nelements
Expand Down Expand Up @@ -1210,8 +1122,13 @@ def generate_plot_mpl(
)

# Parse the requested species list
self._parse_species_list(species_list=species_list)

self._species_list, self._species_mapped, self._keep_colour = parse_species_list(
sdec_plotter=self,
data=self.data,
species_list=species_list,
packets_mode=packets_mode,
nelements=nelements,
)
# Calculate data attributes required for plotting
# and save them in instance itself
self._calculate_plotting_data(
Expand Down Expand Up @@ -1610,8 +1527,13 @@ def generate_plot_ply(
)

# Parse the requested species list
self._parse_species_list(species_list=species_list)

self._species_list, self._species_mapped, self._keep_colour = parse_species_list(
sdec_plotter=self,
data=self.data,
species_list=species_list,
packets_mode=packets_mode,
nelements=nelements,
)
# Calculate data attributes required for plotting
# and save them in instance itself
self._calculate_plotting_data(
Expand Down
23 changes: 17 additions & 6 deletions tardis/visualization/tools/tests/test_liv_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tardis.io.util import HDFWriterMixin
from tardis.visualization.tools.liv_plot import LIVPlotter
from tardis.tests.fixtures.regression_data import RegressionData

from tardis.io.util import parse_species_list

class PlotDataHDF(HDFWriterMixin):
"""
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_parse_species_list(
attribute,
):
"""
Test for the _parse_species_list method in LIVPlotter.
Test for the parse_species_list method in LIVPlotter.

Parameters:
-----------
Expand All @@ -125,11 +125,22 @@ def test_parse_species_list(
attribute: The attribute to test after parsing the species list.
"""
regression_data = RegressionData(request)
plotter._parse_species_list(
packets_mode=self.packets_mode[0],
species_list=self.species_list[0],
nelements=self.nelements[0],

packets_mode=self.packets_mode[0]
species_list=self.species_list[0]
nelements=self.nelements[0]

species_list_parsed, species_mapped, keep_colour = parse_species_list(
sdec_plotter=plotter.sdec_plotter,
data=plotter.data,
species_list=species_list,
packets_mode=packets_mode,
nelements=nelements,
)
plotter._species_list = species_list_parsed
plotter._species_mapped = species_mapped
plotter._keep_colour = keep_colour

if attribute == "_species_mapped":
plot_object = getattr(plotter, attribute)
plot_object = [
Expand Down
11 changes: 9 additions & 2 deletions tardis/visualization/tools/tests/test_sdec_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tardis.io.util import HDFWriterMixin
from tardis.tests.fixtures.regression_data import RegressionData
from tardis.visualization.tools.sdec_plot import SDECPlotter
from tardis.io.util import parse_species_list


class PlotDataHDF(HDFWriterMixin):
Expand Down Expand Up @@ -163,7 +164,7 @@ def observed_spectrum(self):
)
def test_parse_species_list(self, request, plotter, attribute):
"""
Test _parse_species_list method.
Test parse_species_list method.

Parameters
----------
Expand All @@ -172,7 +173,13 @@ def test_parse_species_list(self, request, plotter, attribute):
species : list
"""
# THIS NEEDS TO BE RUN FIRST. NOT INDEPENDENT TESTS
plotter._parse_species_list(self.species_list[0])
full_species_list, species_list, keep_colour = parse_species_list(self.species_list[0])

# Set the attributes manually on the plotter for testing
plotter._full_species_list = full_species_list
plotter._species_list = species_list
plotter._keep_colour = keep_colour

regression_data = RegressionData(request)
data = regression_data.sync_ndarray(getattr(plotter, attribute))
if attribute == "_full_species_list":
Expand Down
Loading