From 11a771a0290f270d67e7c08403b4cc6bf22db464 Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Wed, 2 Oct 2024 13:23:10 +0530 Subject: [PATCH 1/9] moved _parse_species_list function to util.py --- .mailmap | 2 ++ tardis/io/util.py | 45 ++++++++++++++++++++++++++ tardis/visualization/tools/liv_plot.py | 43 ------------------------ 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/.mailmap b/.mailmap index c10dde1a38c..bfc5b0172b9 100644 --- a/.mailmap +++ b/.mailmap @@ -283,3 +283,5 @@ Israel Roldan AirvZxf airv_zxf Michael Zingale + +Rudraksh Nalbalwar diff --git a/tardis/io/util.py b/tardis/io/util.py index e0cdc1079c0..8f0bad03aff 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -138,6 +138,51 @@ def yaml_load_file(filename, loader=yaml.Loader): return yaml.load(stream, Loader=loader) +def _parse_species_list(self, species_list, packets_mode, nelements=None): + """ + Parse user requested species list and create list of species ids to be used. + + Parameters + ---------- + species_list : list of species to plot + List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. + Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions + (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) + packets_mode : str, optional + Packet mode, either 'virtual' or 'real'. Default is 'virtual'. + nelements : int, optional + Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. + + Raises + ------ + ValueError + If species list contains invalid entries. + + """ + from tardis.util.base import atomic_number2element_symbol + self.sdec_plotter._parse_species_list(species_list) + self._species_list = self.sdec_plotter._species_list + self._species_mapped = self.sdec_plotter._species_mapped + self._keep_colour = self.sdec_plotter._keep_colour + + if nelements: + interaction_counts = ( + self.data[packets_mode] + .packets_df_line_interaction["last_line_interaction_species"] + .value_counts() + ) + interaction_counts.index = interaction_counts.index // 100 + element_counts = interaction_counts.groupby( + interaction_counts.index + ).sum() + top_elements = element_counts.nlargest(nelements).index + top_species_list = [ + atomic_number2element_symbol(element) + for element in top_elements + ] + self._parse_species_list(top_species_list, packets_mode) + + def traverse_configs(base, other, func, *args): """ Recursively traverse a base dict or list along with another one diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py index 0b88dd975cc..0b7d972055f 100644 --- a/tardis/visualization/tools/liv_plot.py +++ b/tardis/visualization/tools/liv_plot.py @@ -99,49 +99,6 @@ def from_hdf(cls, hdf_fpath): velocity, ) - def _parse_species_list(self, species_list, packets_mode, nelements=None): - """ - Parse user requested species list and create list of species ids to be used. - - Parameters - ---------- - species_list : list of species to plot - List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. - Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions - (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) - packets_mode : str, optional - Packet mode, either 'virtual' or 'real'. Default is 'virtual'. - nelements : int, optional - Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. - - Raises - ------ - ValueError - If species list contains invalid entries. - - """ - self.sdec_plotter._parse_species_list(species_list) - self._species_list = self.sdec_plotter._species_list - self._species_mapped = self.sdec_plotter._species_mapped - self._keep_colour = self.sdec_plotter._keep_colour - - if nelements: - interaction_counts = ( - self.data[packets_mode] - .packets_df_line_interaction["last_line_interaction_species"] - .value_counts() - ) - interaction_counts.index = interaction_counts.index // 100 - element_counts = interaction_counts.groupby( - interaction_counts.index - ).sum() - top_elements = element_counts.nlargest(nelements).index - top_species_list = [ - atomic_number2element_symbol(element) - for element in top_elements - ] - self._parse_species_list(top_species_list, packets_mode) - def _make_colorbar_labels(self): """ Generate labels for the colorbar based on species. From c7d0f5944e094f191d0299b3cc3ac8f5af20ad52 Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Wed, 2 Oct 2024 14:18:10 +0530 Subject: [PATCH 2/9] Updated parse_species_list method and other changes --- tardis/io/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tardis/io/util.py b/tardis/io/util.py index 8f0bad03aff..1e4fc93fa2f 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -138,7 +138,7 @@ def yaml_load_file(filename, loader=yaml.Loader): return yaml.load(stream, Loader=loader) -def _parse_species_list(self, species_list, packets_mode, nelements=None): +def parse_species_list(self, species_list, packets_mode, nelements=None): """ Parse user requested species list and create list of species ids to be used. @@ -159,7 +159,6 @@ def _parse_species_list(self, species_list, packets_mode, nelements=None): If species list contains invalid entries. """ - from tardis.util.base import atomic_number2element_symbol self.sdec_plotter._parse_species_list(species_list) self._species_list = self.sdec_plotter._species_list self._species_mapped = self.sdec_plotter._species_mapped From 493380f81c14bf1c1e3a7ee0c152895af46136e8 Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Wed, 2 Oct 2024 14:22:22 +0530 Subject: [PATCH 3/9] updated --- tardis/io/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tardis/io/util.py b/tardis/io/util.py index 1e4fc93fa2f..0f9286a8e55 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -159,7 +159,7 @@ def parse_species_list(self, species_list, packets_mode, nelements=None): If species list contains invalid entries. """ - self.sdec_plotter._parse_species_list(species_list) + self.sdec_plotter.parse_species_list(species_list) self._species_list = self.sdec_plotter._species_list self._species_mapped = self.sdec_plotter._species_mapped self._keep_colour = self.sdec_plotter._keep_colour @@ -179,7 +179,7 @@ def parse_species_list(self, species_list, packets_mode, nelements=None): atomic_number2element_symbol(element) for element in top_elements ] - self._parse_species_list(top_species_list, packets_mode) + self.parse_species_list(top_species_list, packets_mode) def traverse_configs(base, other, func, *args): From 73a567877d0fb6c18f5fd92534ed3363f593d788 Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Tue, 8 Oct 2024 23:03:58 +0530 Subject: [PATCH 4/9] updated changes in function calls --- tardis/visualization/tools/liv_plot.py | 2 +- tardis/visualization/tools/sdec_plot.py | 6 +++--- tardis/visualization/tools/tests/test_liv_plot.py | 4 ++-- tardis/visualization/tools/tests/test_sdec_plot.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py index 0b7d972055f..ac082586224 100644 --- a/tardis/visualization/tools/liv_plot.py +++ b/tardis/visualization/tools/liv_plot.py @@ -253,7 +253,7 @@ def _prepare_plot_data( f"{atomic_number2element_symbol(specie // 100)}" for specie in species_in_model ] - self._parse_species_list(species_list, packets_mode, nelements) + self.parse_species_list(species_list, packets_mode, nelements) species_in_model = np.unique( self.data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index 2ca12fbb06f..8a25700417e 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -508,7 +508,7 @@ def from_hdf(cls, hdf_fpath, packets_mode=None): } ) - def _parse_species_list(self, species_list): + def parse_species_list(self, species_list): """ Parse user requested species list and create list of species ids to be used. @@ -1210,7 +1210,7 @@ def generate_plot_mpl( ) # Parse the requested species list - self._parse_species_list(species_list=species_list) + self.parse_species_list(species_list=species_list) # Calculate data attributes required for plotting # and save them in instance itself @@ -1610,7 +1610,7 @@ def generate_plot_ply( ) # Parse the requested species list - self._parse_species_list(species_list=species_list) + self.parse_species_list(species_list=species_list) # Calculate data attributes required for plotting # and save them in instance itself diff --git a/tardis/visualization/tools/tests/test_liv_plot.py b/tardis/visualization/tools/tests/test_liv_plot.py index 5a4f56897bc..621ecb3b78d 100644 --- a/tardis/visualization/tools/tests/test_liv_plot.py +++ b/tardis/visualization/tools/tests/test_liv_plot.py @@ -116,7 +116,7 @@ def test_parse_species_list( attribute, ): """ - Test for the _parse_species_list method in LIVPlotter. + Test for the parse_species_list method in LIVPlotter. Parameters: ----------- @@ -125,7 +125,7 @@ def test_parse_species_list( attribute: The attribute to test after parsing the species list. """ regression_data = RegressionData(request) - plotter._parse_species_list( + plotter.parse_species_list( packets_mode=self.packets_mode[0], species_list=self.species_list[0], nelements=self.nelements[0], diff --git a/tardis/visualization/tools/tests/test_sdec_plot.py b/tardis/visualization/tools/tests/test_sdec_plot.py index 1136ba7e148..6716fe9a2d8 100644 --- a/tardis/visualization/tools/tests/test_sdec_plot.py +++ b/tardis/visualization/tools/tests/test_sdec_plot.py @@ -163,7 +163,7 @@ def observed_spectrum(self): ) def test_parse_species_list(self, request, plotter, attribute): """ - Test _parse_species_list method. + Test parse_species_list method. Parameters ---------- @@ -172,7 +172,7 @@ def test_parse_species_list(self, request, plotter, attribute): species : list """ # THIS NEEDS TO BE RUN FIRST. NOT INDEPENDENT TESTS - plotter._parse_species_list(self.species_list[0]) + plotter.parse_species_list(self.species_list[0]) regression_data = RegressionData(request) data = regression_data.sync_ndarray(getattr(plotter, attribute)) if attribute == "_full_species_list": From d15a8e82341ab9969c994863685574dd49aa5377 Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Sat, 12 Oct 2024 15:38:51 +0530 Subject: [PATCH 5/9] corrected --- tardis/io/util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tardis/io/util.py b/tardis/io/util.py index 0f9286a8e55..321ac8d62a0 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -17,6 +17,8 @@ from tardis import __path__ as TARDIS_PATH from tardis import constants as const +from tardis.util.base import atomic_number2element_symbol + logger = logging.getLogger(__name__) @@ -204,7 +206,7 @@ def traverse_configs(base, other, func, *args): traverse_configs(base[k], other[k], func, *args) elif ( isinstance(base, collections_abc.Iterable) - and not isinstance(base, basestring) + and not isinstance(base, str) and not hasattr(base, "shape") ): for val1, val2 in zip(base, other): From 23fff3e7be121882627e3ebeefe8df699478c0d6 Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Sat, 12 Oct 2024 16:36:59 +0530 Subject: [PATCH 6/9] fn-call --- tardis/io/util.py | 2 -- tardis/visualization/tools/liv_plot.py | 3 ++- tardis/visualization/tools/tests/test_liv_plot.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tardis/io/util.py b/tardis/io/util.py index 321ac8d62a0..6bbbd3b935e 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -17,8 +17,6 @@ from tardis import __path__ as TARDIS_PATH from tardis import constants as const -from tardis.util.base import atomic_number2element_symbol - logger = logging.getLogger(__name__) diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py index ac082586224..cd7b40a466d 100644 --- a/tardis/visualization/tools/liv_plot.py +++ b/tardis/visualization/tools/liv_plot.py @@ -12,6 +12,7 @@ ) import tardis.visualization.tools.sdec_plot as sdec from tardis.visualization import plot_util as pu +from tardis.io.util import parse_species_list logger = logging.getLogger(__name__) @@ -253,7 +254,7 @@ def _prepare_plot_data( f"{atomic_number2element_symbol(specie // 100)}" for specie in species_in_model ] - self.parse_species_list(species_list, packets_mode, nelements) + self._species_list, self._species_mapped = parse_species_list(species_list, packets_mode, nelements) species_in_model = np.unique( self.data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] diff --git a/tardis/visualization/tools/tests/test_liv_plot.py b/tardis/visualization/tools/tests/test_liv_plot.py index 621ecb3b78d..52948de9833 100644 --- a/tardis/visualization/tools/tests/test_liv_plot.py +++ b/tardis/visualization/tools/tests/test_liv_plot.py @@ -12,7 +12,7 @@ from tardis.io.util import HDFWriterMixin from tardis.visualization.tools.liv_plot import LIVPlotter from tardis.tests.fixtures.regression_data import RegressionData - +from tardis.io.util import parse_species_list class PlotDataHDF(HDFWriterMixin): """ @@ -125,7 +125,7 @@ def test_parse_species_list( attribute: The attribute to test after parsing the species list. """ regression_data = RegressionData(request) - plotter.parse_species_list( + parse_species_list( packets_mode=self.packets_mode[0], species_list=self.species_list[0], nelements=self.nelements[0], From a265f30e3f4e44d3259e9aeb05e8bfcc0fc6ee32 Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Mon, 14 Oct 2024 11:58:15 +0530 Subject: [PATCH 7/9] refactor --- tardis/io/util.py | 22 ++++++++++++------- tardis/visualization/tools/liv_plot.py | 8 ++++++- .../tools/tests/test_liv_plot.py | 19 ++++++++++++---- 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/tardis/io/util.py b/tardis/io/util.py index 6bbbd3b935e..319aebb4ca5 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -138,7 +138,7 @@ def yaml_load_file(filename, loader=yaml.Loader): return yaml.load(stream, Loader=loader) -def parse_species_list(self, species_list, packets_mode, nelements=None): +def parse_species_list(sdec_plotter, data, species_list, packets_mode, nelements=None): """ Parse user requested species list and create list of species ids to be used. @@ -159,14 +159,14 @@ def parse_species_list(self, species_list, packets_mode, nelements=None): If species list contains invalid entries. """ - self.sdec_plotter.parse_species_list(species_list) - self._species_list = self.sdec_plotter._species_list - self._species_mapped = self.sdec_plotter._species_mapped - self._keep_colour = self.sdec_plotter._keep_colour + sdec_plotter.parse_species_list(species_list) + _species_list = sdec_plotter._species_list + _species_mapped = sdec_plotter._species_mapped + _keep_colour = sdec_plotter._keep_colour if nelements: interaction_counts = ( - self.data[packets_mode] + data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] .value_counts() ) @@ -179,8 +179,14 @@ def parse_species_list(self, species_list, packets_mode, nelements=None): atomic_number2element_symbol(element) for element in top_elements ] - self.parse_species_list(top_species_list, packets_mode) - + sub_species_list, sub_species_mapped, sub_keep_colour = parse_species_list( + sdec_plotter, data, top_species_list, packets_mode + ) + _species_list = sub_species_list + _species_mapped = sub_species_mapped + _keep_colour = sub_keep_colour + + return _species_list, _species_mapped, _keep_colour def traverse_configs(base, other, func, *args): """ diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py index cd7b40a466d..723775d0893 100644 --- a/tardis/visualization/tools/liv_plot.py +++ b/tardis/visualization/tools/liv_plot.py @@ -254,7 +254,13 @@ def _prepare_plot_data( f"{atomic_number2element_symbol(specie // 100)}" for specie in species_in_model ] - self._species_list, self._species_mapped = parse_species_list(species_list, packets_mode, nelements) + self._species_list, self._species_mapped, self._keep_colour = parse_species_list( + sdec_plotter=self.sdec_plotter, + data=self.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, + ) species_in_model = np.unique( self.data[packets_mode] .packets_df_line_interaction["last_line_interaction_species"] diff --git a/tardis/visualization/tools/tests/test_liv_plot.py b/tardis/visualization/tools/tests/test_liv_plot.py index 52948de9833..b13de734167 100644 --- a/tardis/visualization/tools/tests/test_liv_plot.py +++ b/tardis/visualization/tools/tests/test_liv_plot.py @@ -125,11 +125,22 @@ def test_parse_species_list( attribute: The attribute to test after parsing the species list. """ regression_data = RegressionData(request) - parse_species_list( - packets_mode=self.packets_mode[0], - species_list=self.species_list[0], - nelements=self.nelements[0], + + packets_mode=self.packets_mode[0] + species_list=self.species_list[0] + nelements=self.nelements[0] + + species_list_parsed, species_mapped, keep_colour = parse_species_list( + sdec_plotter=plotter.sdec_plotter, + data=plotter.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, ) + plotter._species_list = species_list_parsed + plotter._species_mapped = species_mapped + plotter._keep_colour = keep_colour + if attribute == "_species_mapped": plot_object = getattr(plotter, attribute) plot_object = [ From f3b31ddf4ea1ebd3ec8e677e08c81a22f1e57948 Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Sat, 2 Nov 2024 18:26:33 +0530 Subject: [PATCH 8/9] - Moved 'atomic_number2element_symbol' import inside the parse_speciees_list function in `tardis/io/util.py` --- tardis/io/util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tardis/io/util.py b/tardis/io/util.py index 319aebb4ca5..cd269fd688b 100644 --- a/tardis/io/util.py +++ b/tardis/io/util.py @@ -159,6 +159,7 @@ def parse_species_list(sdec_plotter, data, species_list, packets_mode, nelements If species list contains invalid entries. """ + from tardis.util.base import atomic_number2element_symbol sdec_plotter.parse_species_list(species_list) _species_list = sdec_plotter._species_list _species_mapped = sdec_plotter._species_mapped From 8b00774ed53b81498e21f40e200c3207d947d10d Mon Sep 17 00:00:00 2001 From: Rudraksh Nalbalwar Date: Tue, 19 Nov 2024 00:44:01 +0530 Subject: [PATCH 9/9] refactored test_sdec_plot.py and sdec_plot.py to adapt for refactored parse_species_list --- tardis/visualization/tools/sdec_plot.py | 108 +++--------------- .../tools/tests/test_sdec_plot.py | 9 +- 2 files changed, 23 insertions(+), 94 deletions(-) diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index 8a25700417e..e772de5fcf1 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -24,7 +24,7 @@ species_string_to_tuple, ) from tardis.visualization import plot_util as pu - +from tardis.io.util import parse_species_list logger = logging.getLogger(__name__) @@ -508,95 +508,7 @@ def from_hdf(cls, hdf_fpath, packets_mode=None): } ) - def parse_species_list(self, species_list): - """ - Parse user requested species list and create list of species ids to be used. - - Parameters - ---------- - species_list : list of species to plot - List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. - Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions - (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) - - """ - if species_list is not None: - # check if there are any digits in the species list. If there are, then exit. - # species_list should only contain species in the Roman numeral - # format, e.g. Si II, and each ion must contain a space - if any(char.isdigit() for char in " ".join(species_list)) is True: - raise ValueError( - "All species must be in Roman numeral form, e.g. Si II" - ) - else: - full_species_list = [] - species_mapped = {} - for species in species_list: - # check if a hyphen is present. If it is, then it indicates a - # range of ions. Add each ion in that range to the list as a new entry - if "-" in species: - # split the string on spaces. First thing in the list is then the element - element = species.split(" ")[0] - # Next thing is the ion range - # convert the requested ions into numerals - first_ion_numeral = roman_to_int( - species.split(" ")[-1].split("-")[0] - ) - second_ion_numeral = roman_to_int( - species.split(" ")[-1].split("-")[-1] - ) - # add each ion between the two requested into the species list - for ion_number in np.arange( - first_ion_numeral, second_ion_numeral + 1 - ): - full_species_list.append( - f"{element} {int_to_roman(ion_number)}" - ) - else: - # Otherwise it's either an element or ion so just add to the list - full_species_list.append(species) - - # full_species_list is now a list containing each individual species requested - # e.g. it parses species_list = [Si I - V] into species_list = [Si I, Si II, Si III, Si IV, Si V] - self._full_species_list = full_species_list - requested_species_ids = [] - keep_colour = [] - - # go through each of the requested species. Check whether it is - # an element or ion (ions have spaces). If it is an element, - # add all possible ions to the ions list. Otherwise just add - # the requested ion - for species in full_species_list: - if " " in species: - species_id = ( - species_string_to_tuple(species)[0] * 100 - + species_string_to_tuple(species)[1] - ) - requested_species_ids.append([species_id]) - species_mapped[species_id] = [species_id] - else: - atomic_number = element_symbol2atomic_number(species) - species_ids = [ - atomic_number * 100 + ion_number - for ion_number in np.arange(atomic_number) - ] - requested_species_ids.append(species_ids) - species_mapped[atomic_number * 100] = species_ids - # add the atomic number to a list so you know that this element should - # have all species in the same colour, i.e. it was requested like - # species_list = [Si] - keep_colour.append(atomic_number) - requested_species_ids = [ - species_id - for temp_list in requested_species_ids - for species_id in temp_list - ] - self._species_mapped = species_mapped - self._species_list = requested_species_ids - self._keep_colour = keep_colour - else: - self._species_list = None def _calculate_plotting_data( self, packets_mode, packet_wvl_range, distance, nelements @@ -1210,8 +1122,13 @@ def generate_plot_mpl( ) # Parse the requested species list - self.parse_species_list(species_list=species_list) - + self._species_list, self._species_mapped, self._keep_colour = parse_species_list( + sdec_plotter=self, + data=self.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, + ) # Calculate data attributes required for plotting # and save them in instance itself self._calculate_plotting_data( @@ -1610,8 +1527,13 @@ def generate_plot_ply( ) # Parse the requested species list - self.parse_species_list(species_list=species_list) - + self._species_list, self._species_mapped, self._keep_colour = parse_species_list( + sdec_plotter=self, + data=self.data, + species_list=species_list, + packets_mode=packets_mode, + nelements=nelements, + ) # Calculate data attributes required for plotting # and save them in instance itself self._calculate_plotting_data( diff --git a/tardis/visualization/tools/tests/test_sdec_plot.py b/tardis/visualization/tools/tests/test_sdec_plot.py index 6716fe9a2d8..8aabf3ebcba 100644 --- a/tardis/visualization/tools/tests/test_sdec_plot.py +++ b/tardis/visualization/tools/tests/test_sdec_plot.py @@ -15,6 +15,7 @@ from tardis.io.util import HDFWriterMixin from tardis.tests.fixtures.regression_data import RegressionData from tardis.visualization.tools.sdec_plot import SDECPlotter +from tardis.io.util import parse_species_list class PlotDataHDF(HDFWriterMixin): @@ -172,7 +173,13 @@ def test_parse_species_list(self, request, plotter, attribute): species : list """ # THIS NEEDS TO BE RUN FIRST. NOT INDEPENDENT TESTS - plotter.parse_species_list(self.species_list[0]) + full_species_list, species_list, keep_colour = parse_species_list(self.species_list[0]) + + # Set the attributes manually on the plotter for testing + plotter._full_species_list = full_species_list + plotter._species_list = species_list + plotter._keep_colour = keep_colour + regression_data = RegressionData(request) data = regression_data.sync_ndarray(getattr(plotter, attribute)) if attribute == "_full_species_list":