Skip to content

Commit

Permalink
Breaking: all plot methods return plt.Axes (#3749)
Browse files Browse the repository at this point in the history
* pymatgen/phonon/dos.py fix broken gaussian_filter1d import scipy.ndimage(.filters->'')

* ruff fix all FURB110

Replace ternary `if` expression with `or` operator

* snake_case natoms

* doc str return section fixes

* raise NotImplementedError in AbstractDiffractionPatternCalculator.get_pattern method

* breaking: change overlooked plotting methods still returning plt instead of ax: plt.Axes

e.g. WulffShape.get_plotly, BztPlotter.plot_props, several util.plotting funcs

* breaking: surface_analysis.py plot_(one|all)_stability_map rename plt keyword to ax and return ax instead of plt

* fix BSPlotterProjected.get_projected_plots_dots bad doc str return type

* fix TestWulffShape.test_get_plot

* fix test_plot_periodic_heatmap and test_van_arkel_triangle
  • Loading branch information
janosh authored Apr 12, 2024
1 parent 9337a4e commit 48860fb
Show file tree
Hide file tree
Showing 36 changed files with 296 additions and 257 deletions.
4 changes: 1 addition & 3 deletions pymatgen/analysis/chempot_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def __init__(
renormalized_entries = []
for entry in entries:
comp_dict = entry.composition.as_dict()
renormalization_energy = sum(
[comp_dict[el] * _el_refs[Element(el)].energy_per_atom for el in comp_dict]
)
renormalization_energy = sum(comp_dict[el] * _el_refs[Element(el)].energy_per_atom for el in comp_dict)
renormalized_entries.append(_renormalize_entry(entry, renormalization_energy / sum(comp_dict.values())))

entries = renormalized_entries
Expand Down
3 changes: 2 additions & 1 deletion pymatgen/analysis/diffraction/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90)
sphere of radius 2 / wavelength.
Returns:
(DiffractionPattern)
DiffractionPattern
"""
raise NotImplementedError

def get_plot(
self,
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/diffraction/neutron.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90)
sphere of radius 2 / wavelength.
Returns:
(NDPattern)
DiffractionPattern: ND pattern
"""
if self.symprec:
finder = SpacegroupAnalyzer(structure, symprec=self.symprec)
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/diffraction/xrd.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90)
sphere of radius 2 / wavelength.
Returns:
(XRDPattern)
DiffractionPattern: XRD pattern
"""
if self.symprec:
finder = SpacegroupAnalyzer(structure, symprec=self.symprec)
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def get_site_energy(self, site_index):
site_index (int): Index of site
Returns:
(float) - Energy of that site
float: Energy of that site
"""
if not self._initialized:
self._calc_ewald_terms()
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def _is_in_targets(site, targets):
targets ([Element]) List of elements
Returns:
(boolean) Whether this site contains a certain list of elements
boolean: Whether this site contains a certain list of elements
"""
elems = _get_elements(site)
return all(elem in targets for elem in elems)
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/structure_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def sulfide_type(structure):
structure (Structure): Input structure.
Returns:
(str) sulfide/polysulfide or None if structure is a sulfate.
str: sulfide/polysulfide or None if structure is a sulfate.
"""
structure = structure.copy().remove_oxidation_states()
sulphur = Element("S")
Expand Down
54 changes: 31 additions & 23 deletions pymatgen/analysis/surface_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
"""
self.miller_index = miller_index
self.label = label
self.adsorbates = adsorbates if adsorbates else []
self.adsorbates = adsorbates or []
self.clean_entry = clean_entry
self.ads_entries_dict = {str(next(iter(ads.composition.as_dict()))): ads for ads in self.adsorbates}
self.mark = marker
Expand Down Expand Up @@ -180,7 +180,7 @@ def surface_energy(self, ucell_entry, ref_entries=None):
float: The surface energy of the slab.
"""
# Set up
ref_entries = ref_entries if ref_entries else []
ref_entries = ref_entries or []

# Check if appropriate ref_entries are present if the slab is non-stoichiometric
# TODO: There should be a way to identify which specific species are
Expand Down Expand Up @@ -861,6 +861,7 @@ def chempot_vs_gamma_plot_one(
"""
delu_dict = delu_dict or {}
chempot_range = sorted(chempot_range)
ax = ax or plt.gca()

# use dashed lines for slabs that are not stoichiometric
# w.r.t. bulk. Label with formula if non-stoichiometric
Expand All @@ -884,9 +885,10 @@ def chempot_vs_gamma_plot_one(

se_range = np.array(gamma_range) * EV_PER_ANG2_TO_JOULES_PER_M2 if JPERM2 else gamma_range

mark = entry.mark if entry.mark else mark
c = entry.color if entry.color else self.color_dict[entry]
return plt.plot(chempot_range, se_range, mark, color=c, label=label)
mark = entry.mark or mark
color = entry.color or self.color_dict[entry]
ax.plot(chempot_range, se_range, mark, color=color, label=label)
return ax

def chempot_vs_gamma(
self,
Expand Down Expand Up @@ -941,7 +943,7 @@ def chempot_vs_gamma(
delu_dict = {}
chempot_range = sorted(chempot_range)

plt = plt if plt else pretty_plot(width=8, height=7)
plt = plt or pretty_plot(width=8, height=7)
axes = plt.gca()

for hkl in self.all_slab_entries:
Expand Down Expand Up @@ -1175,7 +1177,7 @@ def surface_chempot_range_map(
"""
# Set up
delu_dict = delu_dict or {}
ax = ax if ax else pretty_plot(12, 8)
ax = ax or pretty_plot(12, 8)
el1, el2 = str(elements[0]), str(elements[1])
delu1 = Symbol(f"delu_{elements[0]}")
delu2 = Symbol(f"delu_{elements[1]}")
Expand Down Expand Up @@ -1255,7 +1257,7 @@ def surface_chempot_range_map(
# Label the phases
x = np.mean([max(xvals), min(xvals)])
y = np.mean([max(yvals), min(yvals)])
label = entry.label if entry.label else entry.reduced_formula
label = entry.label or entry.reduced_formula
ax.annotate(label, xy=[x, y], xytext=[x, y], fontsize=fontsize)

# Label plot
Expand Down Expand Up @@ -1314,7 +1316,7 @@ def entry_dict_from_list(all_slab_entries):
hkl = tuple(entry.miller_index)
if hkl not in entry_dict:
entry_dict[hkl] = {}
clean = entry.clean_entry if entry.clean_entry else entry
clean = entry.clean_entry or entry
if clean not in entry_dict[hkl]:
entry_dict[hkl][clean] = []
if entry.adsorbates:
Expand Down Expand Up @@ -1425,7 +1427,7 @@ def get_locpot_along_slab_plot(self, label_energies=True, plt=None, label_fontsi
Returns plt of the locpot vs c axis
"""
plt = plt if plt else pretty_plot(width=6, height=4)
plt = plt or pretty_plot(width=6, height=4)

# plot the raw locpot signal along c
plt.plot(self.along_c, self.locpot_along_c, "b--")
Expand Down Expand Up @@ -1769,7 +1771,7 @@ def plot_one_stability_map(
label="",
increments=50,
delu_default=0,
plt=None,
ax=None,
from_sphere_area=False,
e_units="keV",
r_units="nanometers",
Expand Down Expand Up @@ -1797,8 +1799,11 @@ def plot_one_stability_map(
r_units (str): Can be nanometers or Angstrom
e_units (str): Can be keV or eV
normalize (str): Whether or not to normalize energy by volume
Returns:
plt.Axes: matplotlib Axes object
"""
plt = plt or pretty_plot(width=8, height=7)
ax = ax or pretty_plot(width=8, height=7)

wulff_shape = analyzer.wulff_from_chempot(delu_dict=delu_dict, delu_default=delu_default, symprec=self.symprec)

Expand All @@ -1818,21 +1823,21 @@ def plot_one_stability_map(
r_list.append(radius)

ru = "nm" if r_units == "nanometers" else r"\AA"
plt.xlabel(rf"Particle radius (${ru}$)")
ax.xlabel(rf"Particle radius (${ru}$)")
eu = f"${e_units}/{ru}^3$"
plt.ylabel(rf"$G_{{form}}$ ({eu})")
ax.ylabel(rf"$G_{{form}}$ ({eu})")

plt.plot(r_list, gform_list, label=label)
ax.plot(r_list, gform_list, label=label)

return plt
return ax

def plot_all_stability_map(
self,
max_r,
increments=50,
delu_dict=None,
delu_default=0,
plt=None,
ax=None,
labels=None,
from_sphere_area=False,
e_units="keV",
Expand All @@ -1857,17 +1862,20 @@ def plot_all_stability_map(
from_sphere_area (bool): There are two ways to calculate the bulk
formation energy. Either by treating the volume and thus surface
area of the particle as a perfect sphere, or as a Wulff shape.
Returns:
plt.Axes: matplotlib Axes object
"""
plt = plt or pretty_plot(width=8, height=7)
ax = ax or pretty_plot(width=8, height=7)

for i, analyzer in enumerate(self.se_analyzers):
label = labels[i] if labels else ""
plt = self.plot_one_stability_map(
for idx, analyzer in enumerate(self.se_analyzers):
label = labels[idx] if labels else ""
ax = self.plot_one_stability_map(
analyzer,
max_r,
delu_dict,
label=label,
plt=plt,
ax=ax,
increments=increments,
delu_default=delu_default,
from_sphere_area=from_sphere_area,
Expand All @@ -1877,7 +1885,7 @@ def plot_all_stability_map(
scale_per_atom=scale_per_atom,
)

return plt
return ax


def sub_chempots(gamma_dict, chempots):
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/analysis/wulff.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def get_plot(
Joules per square meter (True)
Returns:
(matplotlib.pyplot)
mpl_toolkits.mplot3d.Axes3D: 3D plot of the Wulff shape.
"""
from mpl_toolkits.mplot3d import art3d

Expand Down Expand Up @@ -472,7 +472,7 @@ def get_plot(
ax_3d.grid("off")
if axis_off:
ax_3d.axis("off")
return plt
return ax_3d

def get_plotly(
self,
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ def reduce_formula(sym_amt, iupac_ordering: bool = False) -> tuple[str, float]:
the elements.
Returns:
(reduced_formula, factor).
tuple[str, float]: reduced formula and factor.
"""
syms = sorted(sym_amt, key=lambda x: [get_el_sp(x).X, x])

Expand Down
12 changes: 8 additions & 4 deletions pymatgen/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,7 +2185,8 @@ def reduce_mat(mat, mag, r_matrix):
mat (3 by 3 array): input matrix
mag (int): reduce times for the determinant
r_matrix (3 by 3 array): rotation matrix
Return:
Returns:
the reduced integer array
"""
max_j = abs(int(round(np.linalg.det(mat) / mag)))
Expand Down Expand Up @@ -2218,7 +2219,8 @@ def vec_to_surface(vec):
Args:
vec (1 by 3 array float vector): input float vector
Return:
Returns:
the surface miller index of the input vector.
"""
miller = [None] * 3
Expand Down Expand Up @@ -2257,7 +2259,8 @@ def fix_pbc(structure, matrix=None):
matrix (lattice matrix, 3 by 3 array/matrix): new structure's lattice matrix,
If None, use input structure's matrix.
Return:
Returns:
new structure with fixed frac_coords and lattice matrix
"""
spec = []
Expand Down Expand Up @@ -2285,7 +2288,8 @@ def symm_group_cubic(mat):
Args:
mat (n by 3 array/matrix): lattice matrix
Return:
Returns:
cubic symmetric equivalents of the list of vectors.
"""
sym_group = np.zeros([24, 3, 3])
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/core/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,8 +947,8 @@ def find_mapping(
Defaults to False.
Returns:
(aligned_lattice, rotation_matrix, scale_matrix) if a mapping is
found. aligned_lattice is a rotated version of other_lattice that
tuple[Lattice, np.ndarray, np.ndarray]: (aligned_lattice, rotation_matrix, scale_matrix)
if a mapping is found. aligned_lattice is a rotated version of other_lattice that
has the same lattice parameters, but which is aligned in the
coordinate system of this lattice so that translational points
match up in 3D. rotation_matrix is the rotation that has to be
Expand Down
4 changes: 3 additions & 1 deletion pymatgen/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def are_symmetrically_related_vectors(
tol (float): Absolute tolerance for checking distance.
Returns:
(are_related, is_reversed)
tuple[bool, bool]: First bool indicates if the vectors are related,
the second if the vectors are related but the starting and end point
are exchanged.
"""
from_c = self.operate(from_a)
to_c = self.operate(to_a)
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def from_sites(
ValueError: If sites is empty or sites do not have the same lattice.
Returns:
(Structure) Note that missing properties are set as None.
IStructure: Note that missing properties are set as None.
"""
if not sites:
raise ValueError(f"You need at least 1 site to construct a {cls.__name__}")
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/core/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def as_base_units(self):
"""Converts all units to base SI units, including derived units.
Returns:
(base_units_dict, scaling factor). base_units_dict will not
contain any constants, which are gathered in the scaling factor.
tuple[dict, float]: (base_units_dict, scaling factor). base_units_dict will not
contain any constants, which are gathered in the scaling factor.
"""
b = collections.defaultdict(int)
factor = 1
Expand Down
12 changes: 6 additions & 6 deletions pymatgen/electronic_structure/boltztrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,7 @@ def parse_struct(path_dir):
path_dir: (str) dir containing the boltztrap.struct file
Returns:
(float) volume
float: volume of the structure in Angstrom^3
"""
with open(f"{path_dir}/boltztrap.struct") as file:
tokens = file.readlines()
Expand Down Expand Up @@ -2156,15 +2156,15 @@ def read_cube_file(filename):
Energy data.
"""
with open(filename) as file:
natoms = 0
n_atoms = 0
for idx, line in enumerate(file):
line = line.rstrip("\n")
if idx == 0 and "CUBE" not in line:
raise ValueError("CUBE file format not recognized")

if idx == 2:
tokens = line.split()
natoms = int(tokens[0])
n_atoms = int(tokens[0])
if idx == 3:
tokens = line.split()
n1 = int(tokens[0])
Expand All @@ -2178,12 +2178,12 @@ def read_cube_file(filename):
break

if "fort.30" in filename:
energy_data = np.genfromtxt(filename, skip_header=natoms + 6, skip_footer=1)
energy_data = np.genfromtxt(filename, skip_header=n_atoms + 6, skip_footer=1)
n_lines_data = len(energy_data)
last_line = np.genfromtxt(filename, skip_header=n_lines_data + natoms + 6)
last_line = np.genfromtxt(filename, skip_header=n_lines_data + n_atoms + 6)
energy_data = np.append(energy_data.flatten(), last_line).reshape(n1, n2, n3)
elif "boltztrap_BZ.cube" in filename:
energy_data = np.loadtxt(filename, skiprows=natoms + 6).reshape(n1, n2, n3)
energy_data = np.loadtxt(filename, skiprows=n_atoms + 6).reshape(n1, n2, n3)

energy_data /= Energy(1, "eV").to("Ry")

Expand Down
Loading

0 comments on commit 48860fb

Please sign in to comment.