Skip to content

Commit

Permalink
Bug fix: SpectrumPlotter.add_spectra (#3529)
Browse files Browse the repository at this point in the history
* update plotters

* update plotters

* update plotters

* add test & assert

* modify assertion

* pre-commit auto-fixes

* refactor test_plotters.py and test PDF writing

* rename wion_symbol to work_ion_symbol and update x-axis label

* drop superfluous img_format="eps" kwarg, specify image format via filename extension instead

* change Assertion to ValueError in SpectrumPlotter.add_spectrum

---------

Co-authored-by: Chiu Peter <[email protected]>
Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2023
1 parent e185bc1 commit f762851
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 46 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/bug_report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ body:
attributes:
label: Current behavior
description: What bad behavior do you see?
render: python
render: Python
validations:
required: true

Expand All @@ -49,7 +49,7 @@ body:
attributes:
label: Minimal example
description: Please provide a minimal code snippet to reproduce this bug.
render: python
render: Python
validations:
required: false

Expand Down
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/feature_request.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Feature Reqest
name: Feature Request
description: Use this template to request a new feature
body:
- type: textarea
Expand Down
2 changes: 1 addition & 1 deletion docs/apidoc/conf.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 12 additions & 13 deletions pymatgen/apps/battery/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_plot(self, width=8, height=8, term_zero=True, ax: plt.Axes = None):
ax.plot(x, y, "-", linewidth=2, label=label)

ax.legend()
ax.set_xlabel(self._choose_best_x_label(formula=formula, wion_symbol=working_ion_symbols))
ax.set_xlabel(self._choose_best_x_label(formula=formula, work_ion_symbol=working_ion_symbols))
ax.set_ylabel("Voltage (V)")
plt.tight_layout()
return ax
Expand Down Expand Up @@ -155,7 +155,7 @@ def get_plotly_figure(
width=width,
height=height,
font=font_dict,
xaxis={"title": self._choose_best_x_label(formula=formula, wion_symbol=working_ion_symbols)},
xaxis={"title": self._choose_best_x_label(formula=formula, work_ion_symbol=working_ion_symbols)},
yaxis={"title": "Voltage (V)"},
**kwargs,
),
Expand All @@ -164,24 +164,24 @@ def get_plotly_figure(
fig.update_layout(template="plotly_white", title_x=0.5)
return fig

def _choose_best_x_label(self, formula, wion_symbol):
def _choose_best_x_label(self, formula, work_ion_symbol):
if self.xaxis in {"capacity", "capacity_grav"}:
return "Capacity (mAh/g)"
if self.xaxis == "capacity_vol":
return "Capacity (Ah/l)"

formula = formula.pop() if len(formula) == 1 else None

wion_symbol = wion_symbol.pop() if len(wion_symbol) == 1 else None
work_ion_symbol = work_ion_symbol.pop() if len(work_ion_symbol) == 1 else None

if self.xaxis == "x_form":
if formula and wion_symbol:
return f"x in {wion_symbol}<sub>x</sub>{formula}"
return "x Workion Ion per Host F.U."
if formula and work_ion_symbol:
return f"x in {work_ion_symbol}<sub>x</sub>{formula}"
return "x Work Ion per Host F.U."

if self.xaxis == "frac_x":
if wion_symbol:
return f"Atomic Fraction of {wion_symbol}"
if work_ion_symbol:
return f"Atomic Fraction of {work_ion_symbol}"
return "Atomic Fraction of Working Ion"
raise RuntimeError("No xaxis label can be determined")

Expand All @@ -194,13 +194,12 @@ def show(self, width=8, height=6):
"""
self.get_plot(width, height).show()

def save(self, filename, image_format="eps", width=8, height=6):
def save(self, filename: str, width: float = 8, height: float = 6) -> None:
"""Save the plot to an image file.
Args:
filename: Filename to save to.
image_format: Format to save to. Defaults to eps.
filename (str): Filename to save to. Must include extension to specify image format.
width: Width of the plot. Defaults to 8 in.
height: Height of the plot. Defaults to 6 in.
"""
self.get_plot(width, height).savefig(filename, format=image_format)
self.get_plot(width, height).savefig(filename)
21 changes: 9 additions & 12 deletions pymatgen/electronic_structure/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,11 @@ def get_plot(
plt.tight_layout()
return ax

def save_plot(self, filename, img_format="eps", xlim=None, ylim=None, invert_axes=False, beta_dashed=False) -> None:
def save_plot(self, filename: str, xlim=None, ylim=None, invert_axes=False, beta_dashed=False) -> None:
"""Save matplotlib plot to a file.
Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
filename (str): Filename to write to. Must include extension to specify image format.
xlim: Specifies the x-axis limits. Set to None for automatic
determination.
ylim: Specifies the y-axis limits.
Expand All @@ -253,7 +252,7 @@ def save_plot(self, filename, img_format="eps", xlim=None, ylim=None, invert_axe
beta_dashed (bool): Plots the beta spin channel with a dashed line. Defaults to False.
"""
self.get_plot(xlim, ylim, invert_axes, beta_dashed)
plt.savefig(filename, format=img_format)
plt.savefig(filename)

def show(self, xlim=None, ylim=None, invert_axes=False, beta_dashed=False) -> None:
"""Show the plot using matplotlib.
Expand Down Expand Up @@ -732,19 +731,18 @@ def show(self, zero_to_efermi=True, ylim=None, smooth=False, smooth_tol=None) ->
self.get_plot(zero_to_efermi, ylim, smooth)
plt.show()

def save_plot(self, filename, img_format="eps", ylim=None, zero_to_efermi=True, smooth=False) -> None:
def save_plot(self, filename: str, ylim=None, zero_to_efermi=True, smooth=False) -> None:
"""Save matplotlib plot to a file.
Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
filename (str): Filename to write to. Must include extension to specify image format.
ylim: Specifies the y-axis limits.
zero_to_efermi: Automatically set the Fermi level as the plot's origin (i.e. subtract E - E_f).
Defaults to True.
smooth: Cubic spline interpolation of the bands.
"""
self.get_plot(ylim=ylim, zero_to_efermi=zero_to_efermi, smooth=smooth)
plt.savefig(filename, format=img_format)
plt.savefig(filename)
plt.close()

def get_ticks(self):
Expand Down Expand Up @@ -3791,19 +3789,18 @@ def get_plot(
plt.tight_layout()
return ax

def save_plot(self, filename, img_format="eps", xlim=None, ylim=None) -> None:
def save_plot(self, filename: str, xlim=None, ylim=None) -> None:
"""Save matplotlib plot to a file.
Args:
filename: File name to write to.
img_format: Image format to use. Defaults to EPS.
filename (str): File name to write to. Must include extension to specify image format.
xlim: Specifies the x-axis limits. Defaults to None for
automatic determination.
ylim: Specifies the y-axis limits. Defaults to None for
automatic determination.
"""
self.get_plot(xlim, ylim)
plt.savefig(filename, format=img_format)
plt.savefig(filename)

def show(self, xlim=None, ylim=None) -> None:
"""Show the plot using matplotlib.
Expand Down
16 changes: 9 additions & 7 deletions pymatgen/vis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,24 @@ def add_spectrum(self, label, spectrum, color=None):
a dashed black line. If None, a color will be chosen based on
the default color cycle.
"""
for attribute in "xy":
if not hasattr(spectrum, attribute):
raise ValueError(f"spectrum of type={type(spectrum)} missing required {attribute=}")
self._spectra[label] = spectrum
self.colors.append(color or self.colors_cycle[len(self._spectra) % len(self.colors_cycle)])

def add_spectra(self, spectra_dict, key_sort_func=None):
"""
Add a dictionary of doses, with an optional sorting function for the
Add a dictionary of Spectrum, with an optional sorting function for the
keys.
Args:
dos_dict: dict of {label: Dos}
spectra_dict: dict of {label: Spectrum}
key_sort_func: function used to sort the dos_dict keys.
"""
keys = sorted(spectra_dict, key=key_sort_func) if key_sort_func else list(spectra_dict)
for label in keys:
self.add_spectra(label, spectra_dict[label])
self.add_spectrum(label, spectra_dict[label])

def get_plot(self, xlim=None, ylim=None):
"""
Expand Down Expand Up @@ -125,16 +128,15 @@ def get_plot(self, xlim=None, ylim=None):
plt.tight_layout()
return ax

def save_plot(self, filename, img_format="eps", **kwargs):
def save_plot(self, filename: str, **kwargs):
"""
Save matplotlib plot to a file.
Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
filename (str): Filename to write to. Must include extension to specify image format.
"""
self.get_plot(**kwargs)
plt.savefig(filename, format=img_format)
plt.savefig(filename)

def show(self, **kwargs):
"""Show the plot using matplotlib."""
Expand Down
14 changes: 7 additions & 7 deletions tests/apps/battery/test_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class TestVoltageProfilePlotter(unittest.TestCase):
def setUp(self):
entry_Li = ComputedEntry("Li", -1.90753119)

with open(f"{TEST_FILES_DIR}/LiTiO2_batt.json") as f:
entries_LTO = json.load(f, cls=MontyDecoder)
self.ie_LTO = InsertionElectrode.from_entries(entries_LTO, entry_Li)
with open(f"{TEST_FILES_DIR}/LiTiO2_batt.json") as file:
entries_LTO = json.load(file, cls=MontyDecoder)
self.ie_LTO = InsertionElectrode.from_entries(entries_LTO, entry_Li)

with open(f"{TEST_FILES_DIR}/FeF3_batt.json") as fid:
entries = json.load(fid, cls=MontyDecoder)
self.ce_FF = ConversionElectrode.from_composition_and_entries(Composition("FeF3"), entries)
with open(f"{TEST_FILES_DIR}/FeF3_batt.json") as file:
entries = json.load(file, cls=MontyDecoder)
self.ce_FF = ConversionElectrode.from_composition_and_entries(Composition("FeF3"), entries)

def test_name(self):
plotter = VoltageProfilePlotter(xaxis="frac_x")
Expand All @@ -45,4 +45,4 @@ def test_plotly(self):

plotter.add_electrode(self.ie_LTO, "LTO insertion")
fig = plotter.get_plotly_figure()
assert fig.layout.xaxis.title.text == "x Workion Ion per Host F.U."
assert fig.layout.xaxis.title.text == "x Work Ion per Host F.U."
20 changes: 17 additions & 3 deletions tests/vis/test_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest
from pymatgen.vis.plotters import SpectrumPlotter

test_dir = f"{TEST_FILES_DIR}/spectrum_test"

with open(f"{test_dir}/LiCoO2_k_xanes.json") as fp:
with open(f"{TEST_FILES_DIR}/spectrum_test/LiCoO2_k_xanes.json") as fp:
spect_data_dict = json.load(fp, cls=MontyDecoder)


Expand Down Expand Up @@ -43,3 +41,19 @@ def test_get_stacked_plot(self):
ax = self.plotter.get_plot()
assert isinstance(ax, plt.Axes)
assert len(ax.lines) == 0

def test_get_plot_with_add_spectrum(self):
# create spectra_dict
spectra_dict = {"LiCoO2": self.xanes}
xanes = self.xanes.copy()
xanes.y += np.random.randn(len(xanes.y)) * 0.005
spectra_dict["LiCoO2 + noise"] = spectra_dict["LiCoO2 - replot"] = xanes

self.plotter = SpectrumPlotter(yshift=0.2)
self.plotter.add_spectra(spectra_dict)
ax = self.plotter.get_plot()
assert isinstance(ax, plt.Axes)
assert len(ax.lines) == 3
img_path = f"{self.tmp_path}/spectrum_plotter_test2.pdf"
self.plotter.save_plot(img_path)
assert os.path.isfile(img_path)

0 comments on commit f762851

Please sign in to comment.