diff --git a/tardis/visualization/tools/convergence_plot.py b/tardis/visualization/tools/convergence_plot.py index 16e588f5e76..baac3ea0722 100644 --- a/tardis/visualization/tools/convergence_plot.py +++ b/tardis/visualization/tools/convergence_plot.py @@ -1,7 +1,9 @@ """Convergence Plots to see the convergence of the simulation in real time.""" + from collections import defaultdict import matplotlib.cm as cm import matplotlib.colors as clr +import numpy as np import plotly.graph_objects as go from IPython.display import display import matplotlib as mpl @@ -330,8 +332,11 @@ def update_plasma_plots(self): # add a radiation temperature vs shell velocity trace to the plasma plot self.plasma_plot.add_scatter( x=velocity_km_s, - y=self.iterable_data["t_rad"], + y=np.append( + self.iterable_data["t_rad"], self.iterable_data["t_rad"][-1:] + ), line_color=self.plasma_colorscale[self.current_iteration - 1], + line_shape="hv", row=1, col=1, name=self.current_iteration, @@ -344,8 +349,9 @@ def update_plasma_plots(self): # add a dilution factor vs shell velocity trace to the plasma plot self.plasma_plot.add_scatter( x=velocity_km_s, - y=self.iterable_data["w"], + y=np.append(self.iterable_data["w"], self.iterable_data["w"][-1:]), line_color=self.plasma_colorscale[self.current_iteration - 1], + line_shape="hv", row=1, col=2, legendgroup=f"group-{self.current_iteration}", diff --git a/tardis/visualization/tools/tests/test_convergence_plot.py b/tardis/visualization/tools/tests/test_convergence_plot.py index 79751b15c72..bb69e23e68b 100644 --- a/tardis/visualization/tools/tests/test_convergence_plot.py +++ b/tardis/visualization/tools/tests/test_convergence_plot.py @@ -1,4 +1,5 @@ """Tests for Convergence Plots.""" + from copy import deepcopy import pytest @@ -143,7 +144,9 @@ def test_update_plasma_plots(convergence_plots): # check values for t_rad subplot assert convergence_plots.plasma_plot.data[index].xaxis == "x" assert convergence_plots.plasma_plot.data[index].yaxis == "y" - assert convergence_plots.plasma_plot.data[index].y == tuple(t_rad_val) + assert ( + convergence_plots.plasma_plot.data[index].y[:-1] == tuple(t_rad_val) + ).all() assert convergence_plots.plasma_plot.data[index].x == tuple( velocity.to(u.km / u.s).value ) @@ -152,7 +155,9 @@ def test_update_plasma_plots(convergence_plots): # check values for w subplot assert convergence_plots.plasma_plot.data[index].xaxis == "x2" assert convergence_plots.plasma_plot.data[index].yaxis == "y2" - assert convergence_plots.plasma_plot.data[index].y == tuple(w_val) + assert ( + convergence_plots.plasma_plot.data[index].y[:-1] == tuple(w_val) + ).all() assert convergence_plots.plasma_plot.data[index].x == tuple( velocity.to(u.km / u.s).value )