Skip to content

Commit

Permalink
Fix legend label order in PhononBSPlotter.plot_compare() (#3510)
Browse files Browse the repository at this point in the history
* fix label order in PhononBSPlotter plot_compare() method

* snake_case variable names in pymatgen/io/phonopy.py

* phonon/test_plotter.py check label order in test_plot_compare
  • Loading branch information
janosh authored Dec 12, 2023
1 parent 9268942 commit f0f530d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 47 deletions.
82 changes: 40 additions & 42 deletions pymatgen/io/phonopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_phonopy_structure(pmg_structure: Structure) -> PhonopyAtoms:
)


def get_structure_from_dict(d):
def get_structure_from_dict(dct):
"""
Extracts a structure from the dictionary extracted from the output
files of phonopy like phonopy.yaml or band.yaml.
Expand All @@ -71,20 +71,20 @@ def get_structure_from_dict(d):
species = []
frac_coords = []
masses = []
if "points" in d:
for p in d["points"]:
if "points" in dct:
for p in dct["points"]:
species.append(p["symbol"])
frac_coords.append(p["coordinates"])
masses.append(p["mass"])
elif "atoms" in d:
for p in d["atoms"]:
elif "atoms" in dct:
for p in dct["atoms"]:
species.append(p["symbol"])
frac_coords.append(p["position"])
masses.append(p["mass"])
else:
raise ValueError("The dict does not contain structural information")

return Structure(d["lattice"], species, frac_coords, site_properties={"phonopy_masses": masses})
return Structure(dct["lattice"], species, frac_coords, site_properties={"phonopy_masses": masses})


def eigvec_to_eigdispl(v, q, frac_coords, mass):
Expand Down Expand Up @@ -129,13 +129,13 @@ def get_ph_bs_symm_line_from_dict(bands_dict, has_nac=False, labels_dict=None):
"""
structure = get_structure_from_dict(bands_dict)

qpts = []
q_pts = []
frequencies = []
eigendisplacements = []
eigen_displacements = []
phonopy_labels_dict = {}
for p in bands_dict["phonon"]:
q = p["q-position"]
qpts.append(q)
q_pts.append(q)
bands = []
eig_q = []
for b in p["band"]:
Expand All @@ -159,26 +159,26 @@ def get_ph_bs_symm_line_from_dict(bands_dict, has_nac=False, labels_dict=None):
if "label" in p:
phonopy_labels_dict[p["label"]] = p["q-position"]
if eig_q:
eigendisplacements.append(eig_q)
eigen_displacements.append(eig_q)

qpts = np.array(qpts)
q_pts = np.array(q_pts)
# transpose to match the convention in PhononBandStructure
frequencies = np.transpose(frequencies)
if eigendisplacements:
eigendisplacements = np.transpose(eigendisplacements, (1, 0, 2, 3))
if eigen_displacements:
eigen_displacements = np.transpose(eigen_displacements, (1, 0, 2, 3))

rec_latt = Lattice(bands_dict["reciprocal_lattice"])
rec_lattice = Lattice(bands_dict["reciprocal_lattice"])

labels_dict = labels_dict or phonopy_labels_dict

return PhononBandStructureSymmLine(
qpts,
q_pts,
frequencies,
rec_latt,
rec_lattice,
has_nac=has_nac,
labels_dict=labels_dict,
structure=structure,
eigendisplacements=eigendisplacements,
eigendisplacements=eigen_displacements,
)


Expand All @@ -195,7 +195,7 @@ def get_ph_bs_symm_line(bands_path, has_nac=False, labels_dict=None):
bands_path: path to the band.yaml file
has_nac: True if the data have been obtained with the option
--nac option. Default False.
labels_dict: dict that links a qpoint in frac coords to a label.
labels_dict: dict that links a q-point in frac coords to a label.
"""
return get_ph_bs_symm_line_from_dict(loadfn(bands_path), has_nac, labels_dict)

Expand Down Expand Up @@ -229,11 +229,11 @@ def get_complete_ph_dos(partial_dos_path, phonopy_yaml_path):

total_dos = PhononDos(a[0], a[1:].sum(axis=0))

pdoss = {}
for site, pdos in zip(structure, a[1:]):
pdoss[site] = pdos.tolist()
partial_doses = {}
for site, p_dos in zip(structure, a[1:]):
partial_doses[site] = p_dos.tolist()

return CompletePhononDos(structure, total_dos, pdoss)
return CompletePhononDos(structure, total_dos, partial_doses)


@requires(Phonopy, "phonopy not installed!")
Expand All @@ -254,7 +254,7 @@ def get_displaced_structures(pmg_structure, atom_disp=0.01, supercell_matrix=Non
A list of symmetrically inequivalent structures with displacements, in
which the first element is the perfect supercell structure.
"""
is_plusminus = kwargs.get("is_plusminus", "auto")
is_plus_minus = kwargs.get("is_plusminus", "auto")
is_diagonal = kwargs.get("is_diagonal", True)
is_trigonal = kwargs.get("is_trigonal", False)

Expand All @@ -266,7 +266,7 @@ def get_displaced_structures(pmg_structure, atom_disp=0.01, supercell_matrix=Non
phonon = Phonopy(unitcell=ph_structure, supercell_matrix=supercell_matrix)
phonon.generate_displacements(
distance=atom_disp,
is_plusminus=is_plusminus,
is_plusminus=is_plus_minus,
is_diagonal=is_diagonal,
is_trigonal=is_trigonal,
)
Expand All @@ -286,9 +286,9 @@ def get_displaced_structures(pmg_structure, atom_disp=0.01, supercell_matrix=Non
# Structure list to be returned
structure_list = [get_pmg_structure(init_supercell)]

for c in disp_supercells:
if c is not None:
structure_list.append(get_pmg_structure(c))
for cell in disp_supercells:
if cell is not None:
structure_list.append(get_pmg_structure(cell))

return structure_list

Expand Down Expand Up @@ -337,10 +337,10 @@ def get_phonon_dos_from_fc(
phonon.run_projected_dos(freq_min=freq_min, freq_max=freq_max, freq_pitch=freq_pitch)

dos_raw = phonon.projected_dos.get_partial_dos()
pdoss = dict(zip(structure, dos_raw[1]))
p_doses = dict(zip(structure, dos_raw[1]))

total_dos = PhononDos(dos_raw[0], dos_raw[1].sum(axis=0))
return CompletePhononDos(structure, total_dos, pdoss)
return CompletePhononDos(structure, total_dos, p_doses)


@requires(Phonopy, "phonopy is required to calculate phonon band structures")
Expand Down Expand Up @@ -405,9 +405,9 @@ def get_phonon_band_structure_symm_line_from_fc(
phonon = Phonopy(structure_phonopy, supercell_matrix=supercell_matrix, symprec=symprec, **kwargs)
phonon.set_force_constants(force_constants)

kpath = HighSymmKpath(structure, symprec=symprec)
k_path = HighSymmKpath(structure, symprec=symprec)

kpoints, labels = kpath.get_kpoints(line_density=line_density, coords_are_cartesian=False)
kpoints, labels = k_path.get_kpoints(line_density=line_density, coords_are_cartesian=False)

phonon.run_qpoints(kpoints)
frequencies = phonon.qpoints.get_frequencies().T
Expand Down Expand Up @@ -442,12 +442,12 @@ def get_gruneisenparameter(gruneisen_path, structure=None, structure_path=None)
except ValueError as exc:
raise ValueError("Please provide a structure or structure path") from exc

qpts, multiplicities, frequencies, gruneisen = ([] for _ in range(4))
q_pts, multiplicities, frequencies, gruneisen = ([] for _ in range(4))
phonopy_labels_dict = {}

for p in gruneisen_dict["phonon"]:
q = p["q-position"]
qpts.append(q)
q_pts.append(q)
m = p.get("multiplicity", 1)
multiplicities.append(m)
bands, gruneisenband = [], []
Expand All @@ -460,15 +460,15 @@ def get_gruneisenparameter(gruneisen_path, structure=None, structure_path=None)
if "label" in p:
phonopy_labels_dict[p["label"]] = p["q-position"]

qpts_np = np.array(qpts)
q_pts_np = np.array(q_pts)
multiplicities_np = np.array(multiplicities)
# transpose to match the convention in PhononBandStructure
frequencies_np = np.transpose(frequencies)
gruneisen_np = np.transpose(gruneisen)

return GruneisenParameter(
gruneisen=gruneisen_np,
qpoints=qpts_np,
qpoints=q_pts_np,
multiplicities=multiplicities_np,
frequencies=frequencies_np,
structure=structure,
Expand Down Expand Up @@ -521,7 +521,7 @@ def get_gs_ph_bs_symm_line_from_dict(
end = pa["phonon"][-1]

if start["q-position"] == [0, 0, 0]: # Gamma at start of band
qpts_temp, frequencies_temp = [], []
q_pts_temp, frequencies_temp = [], []
gruneisen_temp: list[list[float]] = []
distance: list[float] = []
for i in range(pa["nqpoint"]):
Expand All @@ -533,7 +533,7 @@ def get_gs_ph_bs_symm_line_from_dict(
gruen = _extrapolate_grun(b, distance, gruneisen_temp, gruneisen_band, i, pa)
gruneisen_band.append(gruen)
q = phonon[pa["nqpoint"] - i - 1]["q-position"]
qpts_temp.append(q)
q_pts_temp.append(q)
d = phonon[pa["nqpoint"] - i - 1]["distance"]
distance.append(d)
frequencies_temp.append(bands)
Expand All @@ -543,7 +543,7 @@ def get_gs_ph_bs_symm_line_from_dict(
"q-position"
]

q_points.extend(list(reversed(qpts_temp)))
q_points.extend(list(reversed(q_pts_temp)))
frequencies.extend(list(reversed(frequencies_temp)))
gruneisen_params.extend(list(reversed(gruneisen_temp)))

Expand Down Expand Up @@ -594,14 +594,14 @@ def get_gs_ph_bs_symm_line_from_dict(
if "label" in p:
phonopy_labels_dict[p["label"]] = p["q-position"]

rec_latt = structure.lattice.reciprocal_lattice
rec_lattice = structure.lattice.reciprocal_lattice
labels_dict = labels_dict or phonopy_labels_dict
return GruneisenPhononBandStructureSymmLine(
qpoints=np.array(q_points),
# transpose to match the convention in PhononBandStructure
frequencies=np.transpose(frequencies),
gruneisenparameters=np.transpose(gruneisen_params),
lattice=rec_latt,
lattice=rec_lattice,
labels_dict=labels_dict,
structure=structure,
eigendisplacements=None,
Expand Down Expand Up @@ -666,8 +666,6 @@ def get_thermal_displacement_matrices(
"""
thermal_displacements_dict = loadfn(thermal_displacements_yaml)

if not structure_path:
raise ValueError("Please provide a structure_path")
structure = Structure.from_file(structure_path)

thermal_displacement_objects_list = []
Expand Down
8 changes: 5 additions & 3 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,9 @@ def plot_compare(
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
Defaults to 'thz'.
labels: labels for the two band structures. Defaults to None, which will use the
legend of the two PhononBSPlotter objects.
label of the two PhononBSPlotter objects if present.
Label order is (self_label, other_label), i.e. the label of the PhononBSPlotter
on which plot_compare() is called must come first.
legend_kwargs: kwargs passed to ax.legend().
**kwargs: passed to ax.plot().
Expand Down Expand Up @@ -665,8 +667,8 @@ def plot_compare(
if labels is None and self._label and other_plotter._label:
labels = (self._label, other_plotter._label)
if labels:
ax.plot([], [], "r-", label=labels[0], linewidth=3 * line_width)
ax.plot([], [], "b-", label=labels[1], linewidth=3 * line_width)
ax.plot([], [], "b-", label=labels[0], linewidth=3 * line_width)
ax.plot([], [], "r-", label=labels[1], linewidth=3 * line_width)
ax.legend(**legend_kwargs)

return ax
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ def test_relax_m3gnet(self):
pytest.importorskip("matgl")
struct = self.get_structure("Si")
relaxed = struct.relax()
assert relaxed.lattice.a == approx(3.867626620642243, abs=0.039) # 1% error
assert relaxed.lattice.a == approx(3.867626620642243, rel=0.01) # allow 1% error
assert hasattr(relaxed, "calc")
for key, val in {"type": "optimization", "optimizer": "FIRE"}.items():
actual = relaxed.dynamics[key]
Expand Down
8 changes: 7 additions & 1 deletion tests/phonon/test_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ def test_proj_plot(self):
self.plotter_sto.get_proj_plot(site_comb=[[0], [1], [2], [3, 4]])

def test_plot_compare(self):
self.plotter.plot_compare(self.plotter, units="mev")
labels = ("NaCl", "NaCl 2")
ax = self.plotter.plot_compare(self.plotter, units="mev", labels=labels)
assert isinstance(ax, axes.Axes)
assert ax.get_ylabel() == "$\\mathrm{Frequencies\\ (meV)}$"
assert ax.get_xlabel() == "$\\mathrm{Wave\\ Vector}$"
assert ax.get_title() == ""
assert [itm.get_text() for itm in ax.get_legend().get_texts()] == list(labels)


class TestThermoPlotter(unittest.TestCase):
Expand Down

0 comments on commit f0f530d

Please sign in to comment.