diff --git a/tardis/io/util.py b/tardis/io/util.py index 6bbbd3b935e..319aebb4ca5 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -138,7 +138,7 @@ def yaml_load_file(filename, loader=yaml.Loader): return yaml.load(stream, Loader=loader) -def parse_species_list(self, species_list, packets_mode, nelements=None): +def parse_species_list(sdec_plotter, data, species_list, packets_mode, nelements=None): """ Parse user requested species list and create list of species ids to be used. @@ -159,14 +159,14 @@ def parse_species_list(self, species_list, packets_mode, nelements=None): If species list contains invalid entries. """ - self.sdec_plotter.parse_species_list(species_list) - self._species_list = self.sdec_plotter._species_list - self._species_mapped = self.sdec_plotter._species_mapped - self._keep_colour = self.sdec_plotter._keep_colour + sdec_plotter.parse_species_list(species_list) + _species_list = sdec_plotter._species_list + _species_mapped = sdec_plotter._species_mapped + _keep_colour = sdec_plotter._keep_colour if nelements: interaction_counts = ( - self.data[packets_mode] + data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] .value_counts() ) @@ -179,8 +179,14 @@ def parse_species_list(self, species_list, packets_mode, nelements=None): atomic_number2element_symbol(element) for element in top_elements ] - self.parse_species_list(top_species_list, packets_mode) - + sub_species_list, sub_species_mapped, sub_keep_colour = parse_species_list( + sdec_plotter, data, top_species_list, packets_mode + ) + _species_list = sub_species_list + _species_mapped = sub_species_mapped + _keep_colour = sub_keep_colour + + return _species_list, _species_mapped, _keep_colour def traverse_configs(base, other, func, *args): """ diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py index cd7b40a466d..723775d0893 100644 --- a/tardis/visualization/tools/liv_plot.py +++ b/tardis/visualization/tools/liv_plot.py @@ -254,7 +254,13 @@ def _prepare_plot_data( f"{atomic_number2element_symbol(specie // 100)}" for specie in species_in_model ] - self._species_list, self._species_mapped = parse_species_list(species_list, packets_mode, nelements) + self._species_list, self._species_mapped, self._keep_colour = parse_species_list( + sdec_plotter=self.sdec_plotter, + data=self.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, + ) species_in_model = np.unique( self.data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] diff --git a/tardis/visualization/tools/tests/test_liv_plot.py b/tardis/visualization/tools/tests/test_liv_plot.py index 52948de9833..b13de734167 100644 --- a/tardis/visualization/tools/tests/test_liv_plot.py +++ b/tardis/visualization/tools/tests/test_liv_plot.py @@ -125,11 +125,22 @@ def test_parse_species_list( attribute: The attribute to test after parsing the species list. """ regression_data = RegressionData(request) - parse_species_list( - packets_mode=self.packets_mode[0], - species_list=self.species_list[0], - nelements=self.nelements[0], + + packets_mode=self.packets_mode[0] + species_list=self.species_list[0] + nelements=self.nelements[0] + + species_list_parsed, species_mapped, keep_colour = parse_species_list( + sdec_plotter=plotter.sdec_plotter, + data=plotter.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, ) + plotter._species_list = species_list_parsed + plotter._species_mapped = species_mapped + plotter._keep_colour = keep_colour + if attribute == "_species_mapped": plot_object = getattr(plotter, attribute) plot_object = [