Skip to content

Commit

Permalink
add **kwargs to get_plot_gs() to customize ax.plot()
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Dec 10, 2023
1 parent 90a26c5 commit 4d73985
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,29 +996,26 @@ def bs_plot_data(self) -> dict[str, Any]:
"lattice": self._bs.lattice_rec.as_dict(),
}

def get_plot_gs(self, ylim: float | None = None) -> Axes:
"""Get a matplotlib object for the gruneisen bandstructure plot.
def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes:
"""Get a matplotlib object for the Gruneisen bandstructure plot.
Args:
ylim: Specify the y-axis (gruneisen) limits; by default None let
the code choose.
**kwargs: additional keywords passed to ax.plot().
"""
ax = pretty_plot(12, 8)

# band_linewidth = 1
kwargs.setdefault("linewidth", 2)
kwargs.setdefault("marker", "o")
kwargs.setdefault("markersize", 2)

data = self.bs_plot_data()
for d in range(len(data["distances"])):
for i in range(self._nb_bands):
ax.plot(
data["distances"][d],
[data["gruneisen"][d][i][j] for j in range(len(data["distances"][d]))],
"b-",
# linewidth=band_linewidth)
marker="o",
markersize=2,
linewidth=2,
)
for dist_idx in range(len(data["distances"])):
for band_idx in range(self._nb_bands):
ys = [data["gruneisen"][dist_idx][band_idx][idx] for idx in range(len(data["distances"][dist_idx]))]

ax.plot(data["distances"][dist_idx], ys, "b-", **kwargs)

self._make_ticks(ax)

Expand Down

0 comments on commit 4d73985

Please sign in to comment.