Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rudrakshnalbalwar committed Oct 14, 2024
1 parent 23fff3e commit a265f30
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
22 changes: 14 additions & 8 deletions tardis/io/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def yaml_load_file(filename, loader=yaml.Loader):
return yaml.load(stream, Loader=loader)


def parse_species_list(self, species_list, packets_mode, nelements=None):
def parse_species_list(sdec_plotter, data, species_list, packets_mode, nelements=None):
"""
Parse user requested species list and create list of species ids to be used.
Expand All @@ -159,14 +159,14 @@ def parse_species_list(self, species_list, packets_mode, nelements=None):
If species list contains invalid entries.
"""
self.sdec_plotter.parse_species_list(species_list)
self._species_list = self.sdec_plotter._species_list
self._species_mapped = self.sdec_plotter._species_mapped
self._keep_colour = self.sdec_plotter._keep_colour
sdec_plotter.parse_species_list(species_list)
_species_list = sdec_plotter._species_list
_species_mapped = sdec_plotter._species_mapped
_keep_colour = sdec_plotter._keep_colour

if nelements:
interaction_counts = (
self.data[packets_mode]
data[packets_mode]
.packets_df_line_interaction["last_line_interaction_species"]
.value_counts()
)
Expand All @@ -179,8 +179,14 @@ def parse_species_list(self, species_list, packets_mode, nelements=None):
atomic_number2element_symbol(element)
for element in top_elements
]
self.parse_species_list(top_species_list, packets_mode)

sub_species_list, sub_species_mapped, sub_keep_colour = parse_species_list(
sdec_plotter, data, top_species_list, packets_mode
)
_species_list = sub_species_list
_species_mapped = sub_species_mapped
_keep_colour = sub_keep_colour

return _species_list, _species_mapped, _keep_colour

def traverse_configs(base, other, func, *args):
"""
Expand Down
8 changes: 7 additions & 1 deletion tardis/visualization/tools/liv_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,13 @@ def _prepare_plot_data(
f"{atomic_number2element_symbol(specie // 100)}"
for specie in species_in_model
]
self._species_list, self._species_mapped = parse_species_list(species_list, packets_mode, nelements)
self._species_list, self._species_mapped, self._keep_colour = parse_species_list(
sdec_plotter=self.sdec_plotter,
data=self.data,
species_list=species_list,
packets_mode=packets_mode,
nelements=nelements,
)
species_in_model = np.unique(
self.data[packets_mode]
.packets_df_line_interaction["last_line_interaction_species"]
Expand Down
19 changes: 15 additions & 4 deletions tardis/visualization/tools/tests/test_liv_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,22 @@ def test_parse_species_list(
attribute: The attribute to test after parsing the species list.
"""
regression_data = RegressionData(request)
parse_species_list(
packets_mode=self.packets_mode[0],
species_list=self.species_list[0],
nelements=self.nelements[0],

packets_mode=self.packets_mode[0]
species_list=self.species_list[0]
nelements=self.nelements[0]

species_list_parsed, species_mapped, keep_colour = parse_species_list(
sdec_plotter=plotter.sdec_plotter,
data=plotter.data,
species_list=species_list,
packets_mode=packets_mode,
nelements=nelements,
)
plotter._species_list = species_list_parsed
plotter._species_mapped = species_mapped
plotter._keep_colour = keep_colour

if attribute == "_species_mapped":
plot_object = getattr(plotter, attribute)
plot_object = [
Expand Down

0 comments on commit a265f30

Please sign in to comment.