Skip to content

Commit

Permalink
Convergence plot correction: Use step plots (#2528)
Browse files Browse the repository at this point in the history
* Change convergence plots to step plots

* - Add dummy x-point while plotting convergence plots, to draw the last step correctly
- Rebase master
- Run black formatter on convergence_plot.py

* Fix convergence plot tests

* Run black formatter
  • Loading branch information
sarthak-dv authored Apr 24, 2024
1 parent b8788a1 commit 215ad61
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
10 changes: 8 additions & 2 deletions tardis/visualization/tools/convergence_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Convergence Plots to see the convergence of the simulation in real time."""

from collections import defaultdict
import matplotlib.cm as cm
import matplotlib.colors as clr
import numpy as np
import plotly.graph_objects as go
from IPython.display import display
import matplotlib as mpl
Expand Down Expand Up @@ -330,8 +332,11 @@ def update_plasma_plots(self):
# add a radiation temperature vs shell velocity trace to the plasma plot
self.plasma_plot.add_scatter(
x=velocity_km_s,
y=self.iterable_data["t_rad"],
y=np.append(
self.iterable_data["t_rad"], self.iterable_data["t_rad"][-1:]
),
line_color=self.plasma_colorscale[self.current_iteration - 1],
line_shape="hv",
row=1,
col=1,
name=self.current_iteration,
Expand All @@ -344,8 +349,9 @@ def update_plasma_plots(self):
# add a dilution factor vs shell velocity trace to the plasma plot
self.plasma_plot.add_scatter(
x=velocity_km_s,
y=self.iterable_data["w"],
y=np.append(self.iterable_data["w"], self.iterable_data["w"][-1:]),
line_color=self.plasma_colorscale[self.current_iteration - 1],
line_shape="hv",
row=1,
col=2,
legendgroup=f"group-{self.current_iteration}",
Expand Down
9 changes: 7 additions & 2 deletions tardis/visualization/tools/tests/test_convergence_plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for Convergence Plots."""

from copy import deepcopy

import pytest
Expand Down Expand Up @@ -143,7 +144,9 @@ def test_update_plasma_plots(convergence_plots):
# check values for t_rad subplot
assert convergence_plots.plasma_plot.data[index].xaxis == "x"
assert convergence_plots.plasma_plot.data[index].yaxis == "y"
assert convergence_plots.plasma_plot.data[index].y == tuple(t_rad_val)
assert (
convergence_plots.plasma_plot.data[index].y[:-1] == tuple(t_rad_val)
).all()
assert convergence_plots.plasma_plot.data[index].x == tuple(
velocity.to(u.km / u.s).value
)
Expand All @@ -152,7 +155,9 @@ def test_update_plasma_plots(convergence_plots):
# check values for w subplot
assert convergence_plots.plasma_plot.data[index].xaxis == "x2"
assert convergence_plots.plasma_plot.data[index].yaxis == "y2"
assert convergence_plots.plasma_plot.data[index].y == tuple(w_val)
assert (
convergence_plots.plasma_plot.data[index].y[:-1] == tuple(w_val)
).all()
assert convergence_plots.plasma_plot.data[index].x == tuple(
velocity.to(u.km / u.s).value
)
Expand Down

0 comments on commit 215ad61

Please sign in to comment.