Skip to content

Commit

Permalink
fix PhononDosPlotter.get_plot only ax.set_ylim if relevant_y is non-e…
Browse files Browse the repository at this point in the history
…mpty
  • Loading branch information
janosh committed Nov 17, 2023
1 parent fc7d8db commit cbb44a2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
24 changes: 12 additions & 12 deletions pymatgen/core/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ def __pow__(self, i):
def __iter__(self):
return iter(self._unit)

def __getitem__(self, i):
def __getitem__(self, i) -> int:
return self._unit[i]

def __len__(self):
def __len__(self) -> int:
return len(self._unit)

def __repr__(self):
Expand Down Expand Up @@ -254,16 +254,16 @@ def get_conversion_factor(self, new_unit):
Args:
new_unit: The new unit.
"""
uo_base, ofactor = self.as_base_units
un_base, nfactor = Unit(new_unit).as_base_units
units_new = sorted(un_base.items(), key=lambda d: _UNAME2UTYPE[d[0]])
units_old = sorted(uo_base.items(), key=lambda d: _UNAME2UTYPE[d[0]])
factor = ofactor / nfactor
for uo, un in zip(units_old, units_new):
if uo[1] != un[1]:
raise UnitError(f"Units {uo} and {un} are not compatible!")
c = ALL_UNITS[_UNAME2UTYPE[uo[0]]]
factor *= (c[uo[0]] / c[un[0]]) ** uo[1]
old_base, old_factor = self.as_base_units
new_base, new_factor = Unit(new_unit).as_base_units
units_new = sorted(new_base.items(), key=lambda d: _UNAME2UTYPE[d[0]])
units_old = sorted(old_base.items(), key=lambda d: _UNAME2UTYPE[d[0]])
factor = old_factor / new_factor
for old, new in zip(units_old, units_new):
if old[1] != new[1]:
raise UnitError(f"Units {old} and {new} are not compatible!")
c = ALL_UNITS[_UNAME2UTYPE[old[0]]]
factor *= (c[old[0]] / c[new[0]]) ** old[1]
return factor


Expand Down
3 changes: 2 additions & 1 deletion pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def get_plot(
else:
_xlim = ax.set_xlim()
relevant_y = [p[1] for p in all_pts if _xlim[0] < p[0] < _xlim[1]]
ax.set_ylim((min(relevant_y), max(relevant_y)))
if len(relevant_y) > 0:
ax.set_ylim((min(relevant_y), max(relevant_y)))

ylim = ax.set_ylim()
ax.plot([0, 0], ylim, "k--", linewidth=2)
Expand Down

0 comments on commit cbb44a2

Please sign in to comment.