diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index 5dc2221398f..4dd5ea3b0a0 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -996,29 +996,26 @@ def bs_plot_data(self) -> dict[str, Any]: "lattice": self._bs.lattice_rec.as_dict(), } - def get_plot_gs(self, ylim: float | None = None) -> Axes: - """Get a matplotlib object for the gruneisen bandstructure plot. + def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes: + """Get a matplotlib object for the Gruneisen bandstructure plot. Args: ylim: Specify the y-axis (gruneisen) limits; by default None let the code choose. + **kwargs: additional keywords passed to ax.plot(). """ ax = pretty_plot(12, 8) - # band_linewidth = 1 + kwargs.setdefault("linewidth", 2) + kwargs.setdefault("marker", "o") + kwargs.setdefault("markersize", 2) data = self.bs_plot_data() - for d in range(len(data["distances"])): - for i in range(self._nb_bands): - ax.plot( - data["distances"][d], - [data["gruneisen"][d][i][j] for j in range(len(data["distances"][d]))], - "b-", - # linewidth=band_linewidth) - marker="o", - markersize=2, - linewidth=2, - ) + for dist_idx in range(len(data["distances"])): + for band_idx in range(self._nb_bands): + ys = [data["gruneisen"][dist_idx][band_idx][idx] for idx in range(len(data["distances"][dist_idx]))] + + ax.plot(data["distances"][dist_idx], ys, "b-", **kwargs) self._make_ticks(ax)