From cbb44a2352549f6a28e125c841014bf2f4b441dd Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 17 Nov 2023 12:26:53 -0800 Subject: [PATCH] fix PhononDosPlotter.get_plot only ax.set_ylim if relevant_y is non-empty --- pymatgen/core/units.py | 24 ++++++++++++------------ pymatgen/phonon/plotter.py | 3 ++- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/pymatgen/core/units.py b/pymatgen/core/units.py index 94dc97fab8f..6cb3492e71c 100644 --- a/pymatgen/core/units.py +++ b/pymatgen/core/units.py @@ -207,10 +207,10 @@ def __pow__(self, i): def __iter__(self): return iter(self._unit) - def __getitem__(self, i): + def __getitem__(self, i) -> int: return self._unit[i] - def __len__(self): + def __len__(self) -> int: return len(self._unit) def __repr__(self): @@ -254,16 +254,16 @@ def get_conversion_factor(self, new_unit): Args: new_unit: The new unit. """ - uo_base, ofactor = self.as_base_units - un_base, nfactor = Unit(new_unit).as_base_units - units_new = sorted(un_base.items(), key=lambda d: _UNAME2UTYPE[d[0]]) - units_old = sorted(uo_base.items(), key=lambda d: _UNAME2UTYPE[d[0]]) - factor = ofactor / nfactor - for uo, un in zip(units_old, units_new): - if uo[1] != un[1]: - raise UnitError(f"Units {uo} and {un} are not compatible!") - c = ALL_UNITS[_UNAME2UTYPE[uo[0]]] - factor *= (c[uo[0]] / c[un[0]]) ** uo[1] + old_base, old_factor = self.as_base_units + new_base, new_factor = Unit(new_unit).as_base_units + units_new = sorted(new_base.items(), key=lambda d: _UNAME2UTYPE[d[0]]) + units_old = sorted(old_base.items(), key=lambda d: _UNAME2UTYPE[d[0]]) + factor = old_factor / new_factor + for old, new in zip(units_old, units_new): + if old[1] != new[1]: + raise UnitError(f"Units {old} and {new} are not compatible!") + c = ALL_UNITS[_UNAME2UTYPE[old[0]]] + factor *= (c[old[0]] / c[new[0]]) ** old[1] return factor diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index 9daa2d617b8..b032632cfde 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -191,7 +191,8 @@ def get_plot( else: _xlim = ax.set_xlim() relevant_y = [p[1] for p in all_pts if _xlim[0] < p[0] < _xlim[1]] - ax.set_ylim((min(relevant_y), max(relevant_y))) + if len(relevant_y) > 0: + ax.set_ylim((min(relevant_y), max(relevant_y))) ylim = ax.set_ylim() ax.plot([0, 0], ylim, "k--", linewidth=2)