Skip to content

Commit

Permalink
change default line colors of PhononDosPlotter and PhononBSPlotter to…
Browse files Browse the repository at this point in the history
… tab:10

tab:blue and tab:orange in particular
  • Loading branch information
janosh committed Dec 13, 2023
1 parent cb7f2bc commit 64cd6ce
Showing 1 changed file with 24 additions and 42 deletions.
66 changes: 24 additions & 42 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import matplotlib.pyplot as plt
import numpy as np
import palettable
import scipy.constants as const
from matplotlib.collections import LineCollection
from monty.json import jsanitize
Expand Down Expand Up @@ -159,8 +158,6 @@ def get_plot(
n_colors = max(3, len(self._doses))
n_colors = min(9, n_colors)

colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors

y = None
all_densities = []
all_frequencies = []
Expand All @@ -186,17 +183,12 @@ def get_plot(
all_frequencies.reverse()
all_pts = []
for idx, (key, frequencies, densities) in enumerate(zip(keys, all_frequencies, all_densities)):
color = self._doses[key].get("color", plt.cm.tab10.colors[idx % n_colors])
all_pts.extend(list(zip(frequencies, densities)))
if self.stack:
ax.fill(frequencies, densities, color=colors[idx % n_colors], label=str(key))
ax.fill(frequencies, densities, color=color, label=str(key))
else:
ax.plot(
frequencies,
densities,
color=colors[idx % n_colors],
label=str(key),
linewidth=3,
)
ax.plot(frequencies, densities, color=color, label=str(key), linewidth=3)

if xlim:
ax.set_xlim(xlim)
Expand Down Expand Up @@ -296,13 +288,9 @@ def _make_ticks(self, ax: Axes) -> Axes:
ax.set_xticks(uniq_d)
ax.set_xticklabels(uniq_l)

for idx in range(len(ticks["label"])):
if ticks["label"][idx] is not None:
# don't print the same label twice
if idx != 0:
ax.axvline(ticks["distance"][idx], color="k")
else:
ax.axvline(ticks["distance"][idx], color="k")
for idx, label in enumerate(ticks["label"]):
if label is not None:
ax.axvline(ticks["distance"][idx], color="k")
return ax

def bs_plot_data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -355,14 +343,11 @@ def get_plot(
ax = pretty_plot(12, 8)

data = self.bs_plot_data()
for d in range(len(data["distances"])):
kwargs.setdefault("color", "tab:blue")
for dists, freqs in zip(data["distances"], data["frequency"]):
for idx in range(self._nb_bands):
ax.plot(
data["distances"][d],
[data["frequency"][d][idx][j] * u.factor for j in range(len(data["distances"][d]))],
"b-",
**kwargs,
)
ys = [freqs[idx][j] * u.factor for j in range(len(dists))]
ax.plot(dists, ys, **kwargs)

self._make_ticks(ax)

Expand Down Expand Up @@ -655,24 +640,21 @@ def plot_compare(
line_width = kwargs.setdefault("linewidth", 1)

ax = self.get_plot(units=units, **kwargs)
for band_idx in range(other_plotter._nb_bands):
for dist_idx in range(len(data_orig["distances"])):
ax.plot(
data_orig["distances"][dist_idx],
[
data["frequency"][dist_idx][band_idx][j] * unit.factor
for j in range(len(data_orig["distances"][dist_idx]))
],
"r-",
**kwargs,
)

# add legend showing which color correspond to which band structure
if labels is None and self._label and other_plotter._label:
labels = (self._label, other_plotter._label)
if labels:
ax.plot([], [], "b-", label=labels[0], linewidth=3 * line_width)
ax.plot([], [], "r-", label=labels[1], linewidth=3 * line_width)
kwargs.setdefault("color", "tab:orange") # don't move this line up! it would mess up self.get_plot color

for band_idx in range(other_plotter._nb_bands):
for dist_idx, dists in enumerate(data_orig["distances"]):
xs = dists
ys = [data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))]
ax.plot(xs, ys, **kwargs)

# add legend showing which color corresponds to which band structure
if labels or (self._label and other_plotter._label):
color_self, color_other = ax.lines[0].get_color(), ax.lines[-1].get_color()
label_self, label_other = labels or (self._label, other_plotter._label)
ax.plot([], [], label=label_self, linewidth=3 * line_width, color=color_self)
ax.plot([], [], label=label_other, linewidth=3 * line_width, color=color_other)
ax.legend(**legend_kwargs)

return ax
Expand Down

0 comments on commit 64cd6ce

Please sign in to comment.