Skip to content

Commit

Permalink
Breaking: remove single-use PolarizationLattice which inherited fro…
Browse files Browse the repository at this point in the history
…m `Structure` (antipattern) (#3585)

* breaking: remove PolarizationLattice which inherited from Structure (antipattern) and refactor get_same_branch_polarization_data()

* rename d|s vars
  • Loading branch information
janosh authored Jan 26, 2024
1 parent f345f2f commit 9a0eb81
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 88 deletions.
12 changes: 6 additions & 6 deletions pymatgen/analysis/elasticity/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,9 +1002,9 @@ def generate_pseudo(strain_states, order=3):
difference derivative of the stress with respect to the strain state
absent_syms: symbols of the tensor absent from the PI expression
"""
s = sp.Symbol("s")
symb = sp.Symbol("s")
nstates = len(strain_states)
ni = np.array(strain_states) * s
ni = np.array(strain_states) * symb
pseudo_inverses, absent_symbols = [], []
for degree in range(2, order + 1):
cvec, carr = get_symbol_list(degree)
Expand All @@ -1015,14 +1015,14 @@ def generate_pseudo(strain_states, order=3):
for _ in range(degree - 1):
exps = np.dot(exps, strain_v)
exps /= math.factorial(degree - 1)
sarr[n] = [sp.diff(exp, s, degree - 1) for exp in exps]
sarr[n] = [sp.diff(exp, symb, degree - 1) for exp in exps]
svec = sarr.ravel()
present_symbols = set.union(*(exp.atoms(sp.Symbol) for exp in svec))
absent_symbols += [set(cvec) - present_symbols]
m = np.zeros((6 * nstates, len(cvec)))
pseudo_mat = np.zeros((6 * nstates, len(cvec)))
for n, c in enumerate(cvec):
m[:, n] = v_diff(svec, c)
pseudo_inverses.append(np.linalg.pinv(m))
pseudo_mat[:, n] = v_diff(svec, c)
pseudo_inverses.append(np.linalg.pinv(pseudo_mat))
return pseudo_inverses, absent_symbols


Expand Down
64 changes: 34 additions & 30 deletions pymatgen/analysis/ferroelectricity/polarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,20 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from scipy.interpolate import UnivariateSpline

from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure

if TYPE_CHECKING:
from collections.abc import Sequence

from pymatgen.core.sites import PeriodicSite


__author__ = "Tess Smidt"
__copyright__ = "Copyright 2017, The Materials Project"
__version__ = "1.0"
Expand All @@ -73,7 +81,7 @@ def zval_dict_from_potcar(potcar):
return zval_dict


def calc_ionic(site, structure: Structure, zval):
def calc_ionic(site: PeriodicSite, structure: Structure, zval: float) -> np.ndarray:
"""
Calculate the ionic dipole moment using ZVAL from pseudopotential.
Expand Down Expand Up @@ -103,31 +111,27 @@ def get_total_ionic_dipole(structure, zval_dict):
return np.sum(tot_ionic, axis=0)


class PolarizationLattice(Structure):
"""TODO Why is a Lattice inheriting a structure? This is ridiculous."""

def get_nearest_site(self, coords, site, r=None):
"""
Given coords and a site, find closet site to coords.
def get_nearest_site(struct: Structure, coords: Sequence[float], site: PeriodicSite, r: float | None = None):
"""
Given coords and a site, find closet site to coords.
Args:
coords (3x1 array): Cartesian coords of center of sphere
site: site to find closest to coords
r: radius of sphere. Defaults to diagonal of unit cell
Args:
coords (3x1 array): Cartesian coords of center of sphere
site: site to find closest to coords
r (float): radius of sphere. Defaults to diagonal of unit cell
Returns:
Closest site and distance.
"""
index = self.index(site)
if r is None:
r = np.linalg.norm(np.sum(self.lattice.matrix, axis=0))
ns = self.get_sites_in_sphere(coords, r, include_index=True)
# Get sites with identical index to site
ns = [n for n in ns if n[2] == index]
# Sort by distance to coords
ns.sort(key=lambda x: x[1])
# Return PeriodicSite and distance of closest image
return ns[0][0:2]
Returns:
Closest site and distance.
"""
index = struct.index(site)
r = r or np.linalg.norm(np.sum(struct.lattice.matrix, axis=0))
ns = struct.get_sites_in_sphere(coords, r, include_index=True)
# Get sites with identical index to site
ns = [n for n in ns if n[2] == index]
# Sort by distance to coords
ns.sort(key=lambda x: x[1])
# Return PeriodicSite and distance of closest image
return ns[0][0:2]


class Polarization:
Expand Down Expand Up @@ -298,18 +302,18 @@ def get_same_branch_polarization_data(self, convert_to_muC_per_cm2=True, all_in_
for idx in range(n_elecs):
lattice = lattices[idx]
frac_coord = np.divide(np.array([p_tot[idx]]), np.array(lattice.lengths))
d = PolarizationLattice(lattice, ["C"], [np.array(frac_coord).ravel()])
d_structs.append(d)
site = d[0]
struct = Structure(lattice, ["C"], [np.array(frac_coord).ravel()])
d_structs.append(struct)
site = struct[0]
# Adjust nonpolar polarization to be closest to zero.
# This is compatible with both a polarization of zero or a half quantum.
prev_site = [0, 0, 0] if idx == 0 else sites[-1].coords
new_site = d.get_nearest_site(prev_site, site)
new_site = get_nearest_site(struct, prev_site, site)
sites.append(new_site[0])

adjust_pol = []
for site, d in zip(sites, d_structs):
adjust_pol.append(np.multiply(site.frac_coords, np.array(d.lattice.lengths)).ravel())
for site, struct in zip(sites, d_structs):
adjust_pol.append(np.multiply(site.frac_coords, np.array(struct.lattice.lengths)).ravel())
return np.array(adjust_pol)

def get_lattice_quanta(self, convert_to_muC_per_cm2=True, all_in_polar=True):
Expand Down
10 changes: 5 additions & 5 deletions pymatgen/analysis/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,11 +1290,11 @@ def __str__(self):
return out

def __repr__(self):
s = "Structure Graph"
s += f"\nStructure: \n{self.structure!r}"
s += f"\nGraph: {self.name}\n"
s += self._edges_to_str(self.graph)
return s
out = "Structure Graph"
out += f"\nStructure: \n{self.structure!r}"
out += f"\nGraph: {self.name}\n"
out += self._edges_to_str(self.graph)
return out

def __len__(self):
"""length of Structure / number of nodes in graph"""
Expand Down
24 changes: 13 additions & 11 deletions pymatgen/analysis/structure_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,15 @@ def __init__(self, structure: Structure, cutoff=10):
cutoff (float) Cutoff distance.
"""
self.cutoff = cutoff
self.s = structure
recp_len = np.array(self.s.lattice.reciprocal_lattice.abc)
i = np.ceil(cutoff * recp_len / (2 * pi))
offsets = np.mgrid[-i[0] : i[0] + 1, -i[1] : i[1] + 1, -i[2] : i[2] + 1].T
self.structure = structure
recip_vec = np.array(self.structure.lattice.reciprocal_lattice.abc)
cutoff_vec = np.ceil(cutoff * recip_vec / (2 * pi))
offsets = np.mgrid[
-cutoff_vec[0] : cutoff_vec[0] + 1, -cutoff_vec[1] : cutoff_vec[1] + 1, -cutoff_vec[2] : cutoff_vec[2] + 1
].T
self.offsets = np.reshape(offsets, (-1, 3))
# shape = [image, axis]
self.cart_offsets = self.s.lattice.get_cartesian_coords(self.offsets)
self.cart_offsets = self.structure.lattice.get_cartesian_coords(self.offsets)

@property
def connectivity_array(self):
Expand All @@ -271,12 +273,12 @@ def connectivity_array(self):
solid angle of polygon between atom_i and image_j of atom_j
"""
# shape = [site, axis]
cart_coords = np.array(self.s.cart_coords)
cart_coords = np.array(self.structure.cart_coords)
# shape = [site, image, axis]
all_sites = cart_coords[:, None, :] + self.cart_offsets[None, :, :]
vt = Voronoi(all_sites.reshape((-1, 3)))
n_images = all_sites.shape[1]
cs = (len(self.s), len(self.s), len(self.cart_offsets))
cs = (len(self.structure), len(self.structure), len(self.cart_offsets))
connectivity = np.zeros(cs)
vts = np.array(vt.vertices)
for (ki, kj), v in vt.ridge_dict.items():
Expand Down Expand Up @@ -321,7 +323,7 @@ def get_connections(self):
for ii in range(max_conn.shape[0]):
for jj in range(max_conn.shape[1]):
if max_conn[ii][jj] != 0:
dist = self.s.get_distance(ii, jj)
dist = self.structure.get_distance(ii, jj)
con.append([ii, jj, dist])
return con

Expand All @@ -335,9 +337,9 @@ def get_sitej(self, site_index, image_index):
site_index (int): index of the site (3 in the example)
image_index (int): index of the image (12 in the example)
"""
atoms_n_occu = self.s[site_index].species
lattice = self.s.lattice
coords = self.s[site_index].frac_coords + self.offsets[image_index]
atoms_n_occu = self.structure[site_index].species
lattice = self.structure.lattice
coords = self.structure[site_index].frac_coords + self.offsets[image_index]
return PeriodicSite(atoms_n_occu, coords, lattice)


Expand Down
4 changes: 2 additions & 2 deletions pymatgen/analysis/surface_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,8 @@ def chempot_vs_gamma_plot_one(
# w.r.t. bulk. Label with formula if non-stoichiometric
ucell_comp = self.ucell_entry.composition.reduced_composition
if entry.adsorbates:
s = entry.cleaned_up_slab
clean_comp = s.composition.reduced_composition
struct = entry.cleaned_up_slab
clean_comp = struct.composition.reduced_composition
else:
clean_comp = entry.composition.reduced_composition

Expand Down
4 changes: 2 additions & 2 deletions pymatgen/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ def get_reduced_formula_and_factor(self, iupac_ordering: bool = False) -> tuple[
all_int = all(abs(x - round(x)) < Composition.amount_tolerance for x in self.values())
if not all_int:
return self.formula.replace(" ", ""), 1
d = {key: int(round(val)) for key, val in self.get_el_amt_dict().items()}
formula, factor = reduce_formula(d, iupac_ordering=iupac_ordering)
el_amt_dict = {key: int(round(val)) for key, val in self.get_el_amt_dict().items()}
formula, factor = reduce_formula(el_amt_dict, iupac_ordering=iupac_ordering)

if formula in Composition.special_formulas:
formula = Composition.special_formulas[formula]
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/core/ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def get_reduced_formula_and_factor(self, iupac_ordering: bool = False, hydrates:
nH2O = int(nO) if nH >= 2 * nO else int(nH) // 2
comp = self.composition - nH2O * Composition("H2O")

d = {k: int(round(v)) for k, v in comp.get_el_amt_dict().items()}
(formula, factor) = reduce_formula(d, iupac_ordering=iupac_ordering)
el_amt_dict = {k: int(round(v)) for k, v in comp.get_el_amt_dict().items()}
(formula, factor) = reduce_formula(el_amt_dict, iupac_ordering=iupac_ordering)

if self.composition.get("H") == self.composition.get("O") is not None:
formula = formula.replace("HO", "OH")
Expand Down
12 changes: 6 additions & 6 deletions pymatgen/io/abinit/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,18 +858,18 @@ def to_str(self, post=None, with_structure=True, with_pseudos=True, exclude=None
vname = name + post
app(str(InputVariable(vname, value)))

s = "\n".join(lines)
out = "\n".join(lines)
if not with_pseudos:
return s
return out

# Add JSON section with pseudo potentials.
ppinfo = ["\n\n\n#<JSON>"]
d = {"pseudos": [p.as_dict() for p in self.pseudos]}
ppinfo.extend(json.dumps(d, indent=4).splitlines())
psp_dict = {"pseudos": [p.as_dict() for p in self.pseudos]}
ppinfo.extend(json.dumps(psp_dict, indent=4).splitlines())
ppinfo.append("</JSON>")

s += "\n#".join(ppinfo)
return s
out += "\n#".join(ppinfo)
return out

@property
def comment(self):
Expand Down
8 changes: 4 additions & 4 deletions pymatgen/io/adf.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,18 +533,18 @@ def _setup_task(self, geo_subkeys):
self.geo.remove_subkey("Frequencies")

def __str__(self):
s = f"""TITLE {self.title}\n
out = f"""TITLE {self.title}\n
{self.units}
{self.xc}
{self.basis_set}
{self.scf}
{self.geo}"""
s += "\n"
out += "\n"
for block_key in self.other_directives:
if not isinstance(block_key, AdfKey):
raise ValueError(f"{block_key} is not an AdfKey!")
s += str(block_key) + "\n"
return s
out += str(block_key) + "\n"
return out

def as_dict(self):
"""A JSON-serializable dict representation of self."""
Expand Down
10 changes: 5 additions & 5 deletions pymatgen/io/lammps/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,10 @@ def disassemble(
}
ff_df = self.force_field[kw]
for t in ff_df.itertuples(index=True, name=None):
d = {"coeffs": list(t[1:]), "types": []}
coeffs_dict = {"coeffs": list(t[1:]), "types": []}
if class2_coeffs:
d.update({k: list(v[t[0] - 1]) for k, v in class2_coeffs.items()})
topo_coeffs[kw].append(d)
coeffs_dict.update({k: list(v[t[0] - 1]) for k, v in class2_coeffs.items()})
topo_coeffs[kw].append(coeffs_dict)

if self.topology:

Expand All @@ -591,8 +591,8 @@ def label_topo(t) -> tuple:

if any(topo_coeffs):
for v in topo_coeffs.values():
for d in v:
d["types"] = list(set(d["types"]))
for coeffs_dict in v:
coeffs_dict["types"] = list(set(coeffs_dict["types"]))

ff = ForceField(
mass_info=mass_info,
Expand Down
18 changes: 9 additions & 9 deletions pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,8 @@ def _parse(self, stream, parse_dos, parse_eigen, parse_projected_eigen):
elif tag == "varray" and elem.attrib.get("name") == "stress":
md_data[-1]["stress"] = _parse_vasp_array(elem)
elif tag == "energy":
d = {i.attrib["name"]: float(i.text) for i in elem.findall("i")}
if "kinetic" in d:
e_dict = {i.attrib["name"]: float(i.text) for i in elem.findall("i")}
if "kinetic" in e_dict:
md_data[-1]["energy"] = {i.attrib["name"]: float(i.text) for i in elem.findall("i")}
except ET.ParseError as exc:
if self.exception_on_bad_xml:
Expand Down Expand Up @@ -1296,17 +1296,17 @@ def _parse_chemical_shielding_calculation(self, elem):
calculation.append(istep)
for scstep in elem.findall("scstep"):
try:
d = {i.attrib["name"]: _vasprun_float(i.text) for i in scstep.find("energy").findall("i")}
cur_ene = d["e_fr_energy"]
e_steps_dict = {i.attrib["name"]: _vasprun_float(i.text) for i in scstep.find("energy").findall("i")}
cur_ene = e_steps_dict["e_fr_energy"]
min_steps = 1 if len(calculation) >= 1 else self.parameters.get("NELMIN", 5)
if len(calculation[-1]["electronic_steps"]) <= min_steps:
calculation[-1]["electronic_steps"].append(d)
calculation[-1]["electronic_steps"].append(e_steps_dict)
else:
last_ene = calculation[-1]["electronic_steps"][-1]["e_fr_energy"]
if abs(cur_ene - last_ene) < 1.0:
calculation[-1]["electronic_steps"].append(d)
calculation[-1]["electronic_steps"].append(e_steps_dict)
else:
calculation.append({"electronic_steps": [d]})
calculation.append({"electronic_steps": [e_steps_dict]})
except AttributeError: # not all calculations have an energy
pass
calculation[-1].update(calculation[-1]["electronic_steps"][-1])
Expand All @@ -1320,8 +1320,8 @@ def _parse_calculation(self, elem):
esteps = []
for scstep in elem.findall("scstep"):
try:
d = {i.attrib["name"]: _vasprun_float(i.text) for i in scstep.find("energy").findall("i")}
esteps.append(d)
e_step_dict = {i.attrib["name"]: _vasprun_float(i.text) for i in scstep.find("energy").findall("i")}
esteps.append(e_step_dict)
except AttributeError: # not all calculations have an energy
pass
try:
Expand Down
3 changes: 1 addition & 2 deletions tests/analysis/test_eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,9 @@ def test_eos_func_call(self):
assert_allclose(self.num_eos_fit.func(0.0), self.num_eos_fit(0.0))

def test_summary_dict(self):
d = {
assert self.num_eos_fit.results == {
"e0": self.num_eos_fit.e0,
"b0": self.num_eos_fit.b0,
"b1": self.num_eos_fit.b1,
"v0": self.num_eos_fit.v0,
}
assert self.num_eos_fit.results == d
8 changes: 4 additions & 4 deletions tests/analysis/test_structure_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def test_ignore_species(self):
}

def test_get_supercell_size(self):
latt = Lattice.cubic(1)
l2 = Lattice.cubic(0.9)
s1 = Structure(latt, ["Mg", "Cu", "Ag", "Cu", "Ag"], [[0] * 3] * 5)
s2 = Structure(l2, ["Cu", "Cu", "Ag"], [[0] * 3] * 3)
latt1 = Lattice.cubic(1)
latt2 = Lattice.cubic(0.9)
s1 = Structure(latt1, ["Mg", "Cu", "Ag", "Cu", "Ag"], [[0] * 3] * 5)
s2 = Structure(latt2, ["Cu", "Cu", "Ag"], [[0] * 3] * 3)

sm = StructureMatcher(supercell_size="volume")
assert sm._get_supercell_size(s1, s2) == (1, True)
Expand Down

0 comments on commit 9a0eb81

Please sign in to comment.