From 4b6587cf7f9cacd20544079efb317a81072ec20e Mon Sep 17 00:00:00 2001 From: Andrew-S-Rosen Date: Sat, 13 Jan 2024 15:10:23 -0800 Subject: [PATCH 01/25] fix --- pymatgen/io/cif.py | 312 +++++++++++++++++++++++++++++++++---------- tests/io/test_cif.py | 15 +++ 2 files changed, 258 insertions(+), 69 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 56d40e4299d..3583b8fa0d6 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -121,7 +121,11 @@ def _format_field(self, v): # add quotes if necessary if v == "": return '""' - if (" " in v or v[0] == "_") and not (v[0] == "'" and v[-1] == "'") and not (v[0] == '"' and v[-1] == '"'): + if ( + (" " in v or v[0] == "_") + and not (v[0] == "'" and v[-1] == "'") + and not (v[0] == '"' and v[-1] == '"') + ): q = '"' if "'" in v else "'" v = q + v + q return v @@ -252,7 +256,9 @@ def from_str(cls, string) -> CifFile: """ dct = {} - for block_str in re.split(r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL)[1:]: + for block_str in re.split( + r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL + )[1:]: # Skip over Cif block that contains powder diffraction data. # Some elements in this block were missing from CIF files in # Springer materials/Pauling file DBs. @@ -323,7 +329,11 @@ def is_magcif() -> bool: """Checks to see if file appears to be a magCIF file (heuristic).""" # Doesn't seem to be a canonical way to test if file is magCIF or # not, so instead check for magnetic symmetry datanames - prefixes = ["_space_group_magn", "_atom_site_moment", "_space_group_symop_magn"] + prefixes = [ + "_space_group_magn", + "_atom_site_moment", + "_space_group_symop_magn", + ] for d in self._cif.data.values(): for k in d.data: for prefix in prefixes: @@ -397,7 +407,11 @@ def _sanitize_data(self, data): """ # check for implicit hydrogens, warn if any present if "_atom_site_attached_hydrogens" in data.data: - attached_hydrogens = [str2float(x) for x in data.data["_atom_site_attached_hydrogens"] if str2float(x) != 0] + attached_hydrogens = [ + str2float(x) + for x in data.data["_atom_site_attached_hydrogens"] + if str2float(x) != 0 + ] if len(attached_hydrogens) > 0: self.warnings.append( "Structure has implicit hydrogens defined, parsed structure unlikely to be " @@ -432,7 +446,9 @@ def _sanitize_data(self, data): # Below, we split the strings on ' + ' to # check if the length (or number of elements) in the label and # symbol are equal. - if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len(el_row.split(" + ")): + if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len( + el_row.split(" + ") + ): # Dictionary to hold extracted elements and occupancies els_occu = {} @@ -442,14 +458,23 @@ def _sanitize_data(self, data): symbol_str_lst = symbol_str.split(" + ") for elocc_idx, sym in enumerate(symbol_str_lst): # Remove any bracketed items in the string - symbol_str_lst[elocc_idx] = re.sub(r"\([0-9]*\)", "", sym.strip()) + symbol_str_lst[elocc_idx] = re.sub( + r"\([0-9]*\)", "", sym.strip() + ) # Extract element name and its occupancy from the # string, and store it as a # key-value pair in "els_occ". els_occu[ - str(re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1]).replace("", "") - ] = float("0" + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[1]) + str( + re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1] + ).replace("", "") + ] = float( + "0" + + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[ + 1 + ] + ) x = str2float(data["_atom_site_fract_x"][idx]) y = str2float(data["_atom_site_fract_y"][idx]) @@ -457,7 +482,9 @@ def _sanitize_data(self, data): for et, occu in els_occu.items(): # new atom site labels have 'fix' appended - new_atom_site_label.append(f"{et}_fix{len(new_atom_site_label)}") + new_atom_site_label.append( + f"{et}_fix{len(new_atom_site_label)}" + ) new_atom_site_type_symbol.append(et) new_atom_site_occupancy.append(str(occu)) new_fract_x.append(str(x)) @@ -593,7 +620,9 @@ def _unique_coords( ) else: magmom = Magmom(tmp_magmom) - if not in_coord_list_pbc(coords_out, coord, atol=self._site_tolerance): + if not in_coord_list_pbc( + coords_out, coord, atol=self._site_tolerance + ): coords_out.append(coord) magmoms_out.append(magmom) labels_out.append(labels.get(tmp_coord)) @@ -624,20 +653,32 @@ def get_lattice( """ try: return self.get_lattice_no_exception( - data=data, angle_strings=angle_strings, lattice_type=lattice_type, length_strings=length_strings + data=data, + angle_strings=angle_strings, + lattice_type=lattice_type, + length_strings=length_strings, ) except KeyError: # Missing Key search for cell setting - for lattice_label in ["_symmetry_cell_setting", "_space_group_crystal_system"]: + for lattice_label in [ + "_symmetry_cell_setting", + "_space_group_crystal_system", + ]: if data.data.get(lattice_label): lattice_type = data.data.get(lattice_label).lower() try: required_args = getargspec(getattr(Lattice, lattice_type)).args - lengths = (length for length in length_strings if length in required_args) + lengths = ( + length + for length in length_strings + if length in required_args + ) angles = (a for a in angle_strings if a in required_args) - return self.get_lattice(data, lengths, angles, lattice_type=lattice_type) + return self.get_lattice( + data, lengths, angles, lattice_type=lattice_type + ) except AttributeError as exc: self.warnings.append(str(exc)) warnings.warn(exc) @@ -648,7 +689,10 @@ def get_lattice( @staticmethod def get_lattice_no_exception( - data, length_strings=("a", "b", "c"), angle_strings=("alpha", "beta", "gamma"), lattice_type=None + data, + length_strings=("a", "b", "c"), + angle_strings=("alpha", "beta", "gamma"), + lattice_type=None, ): """ Take a dictionary of CIF data and returns a pymatgen Lattice object. @@ -728,7 +772,11 @@ def get_symops(self, data): try: cod_data = loadfn( - os.path.join(os.path.dirname(os.path.dirname(__file__)), "symmetry", "symm_ops.json") + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "symmetry", + "symm_ops.json", + ) ) for d in cod_data: if sg == re.sub(r"\s+", "", d["hermann_mauguin"]): @@ -774,8 +822,12 @@ def get_magsymops(self, data): (which changes magnetic moments on sites) needs to be returned. """ mag_symm_ops = [] - bns_name = data.data.get("_space_group_magn.name_BNS") # get BNS label for MagneticSpaceGroup() - bns_num = data.data.get("_space_group_magn.number_BNS") # get BNS number for MagneticSpaceGroup() + bns_name = data.data.get( + "_space_group_magn.name_BNS" + ) # get BNS label for MagneticSpaceGroup() + bns_num = data.data.get( + "_space_group_magn.number_BNS" + ) # get BNS number for MagneticSpaceGroup() # check to see if magCIF file explicitly contains magnetic symmetry operations if xyzt := data.data.get("_space_group_symop_magn_operation.xyz"): @@ -793,9 +845,13 @@ def get_magsymops(self, data): for op in mag_symm_ops: for centering_op in centering_symops: new_translation = [ - i - np.floor(i) for i in op.translation_vector + centering_op.translation_vector + i - np.floor(i) + for i in op.translation_vector + + centering_op.translation_vector ] - new_time_reversal = op.time_reversal * centering_op.time_reversal + new_time_reversal = ( + op.time_reversal * centering_op.time_reversal + ) all_ops.append( MagSymmOp.from_rotation_and_translation_and_time_reversal( rotation_matrix=op.rotation_matrix, @@ -833,13 +889,17 @@ def parse_oxi_states(data): """Parse oxidation states from data dictionary.""" try: oxi_states = { - data["_atom_type_symbol"][i]: str2float(data["_atom_type_oxidation_number"][i]) + data["_atom_type_symbol"][i]: str2float( + data["_atom_type_oxidation_number"][i] + ) for i in range(len(data["_atom_type_symbol"])) } # attempt to strip oxidation state from _atom_type_symbol # in case the label does not contain an oxidation state for i, symbol in enumerate(data["_atom_type_symbol"]): - oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float(data["_atom_type_oxidation_number"][i]) + oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float( + data["_atom_type_oxidation_number"][i] + ) except (ValueError, KeyError): oxi_states = None @@ -911,7 +971,11 @@ def _parse_symbol(self, sym): return parsed_sym def _get_structure( - self, data: dict[str, Any], primitive: bool, symmetrized: bool, check_occu: bool = False + self, + data: dict[str, Any], + primitive: bool, + symmetrized: bool, + check_occu: bool = False, ) -> Structure | None: """Generate structure from part of the cif.""" @@ -924,7 +988,9 @@ def get_num_implicit_hydrogens(sym): # if magCIF, get magnetic symmetry moments and magmoms # else standard CIF, and use empty magmom dict if self.feature_flags["magcif_incommensurate"]: - raise NotImplementedError("Incommensurate structures not currently supported.") + raise NotImplementedError( + "Incommensurate structures not currently supported." + ) if self.feature_flags["magcif"]: self.symmetry_operations = self.get_magsymops(data) magmoms = self.parse_magmoms(data, lattice=lattice) @@ -943,7 +1009,9 @@ def get_matching_coord(coord): coords = np.array(keys) for op in self.symmetry_operations: frac_coord = op.operate(coord) - indices = find_in_coord_list_pbc(coords, frac_coord, atol=self._site_tolerance) + indices = find_in_coord_list_pbc( + coords, frac_coord, atol=self._site_tolerance + ) if len(indices) > 0: return keys[indices[0]] return False @@ -1005,7 +1073,9 @@ def get_matching_coord(coord): coord_to_magmoms[match] = None labels[match] = label sum_occu = [ - sum(c.values()) for c in coord_to_species.values() if set(c.elements) != {Element("O"), Element("H")} + sum(c.values()) + for c in coord_to_species.values() + if set(c.elements) != {Element("O"), Element("H")} ] if any(occu > 1 for occu in sum_occu): msg = ( @@ -1032,7 +1102,9 @@ def get_matching_coord(coord): # property, but this introduces ambiguities for end user # (such as unintended use of `spin` and Species will have # fictitious oxidation state). - raise NotImplementedError("Disordered magnetic structures not currently supported.") + raise NotImplementedError( + "Disordered magnetic structures not currently supported." + ) if coord_to_species.items(): for idx, (comp, group) in enumerate( @@ -1049,7 +1121,9 @@ def get_matching_coord(coord): tmp_coords, magmoms=tmp_magmom, labels=labels, lattice=lattice ) else: - coords, magmoms, new_labels = self._unique_coords(tmp_coords, labels=labels) + coords, magmoms, new_labels = self._unique_coords( + tmp_coords, labels=labels + ) if set(comp.elements) == {Element("O"), Element("H")}: # O with implicit hydrogens @@ -1077,13 +1151,19 @@ def get_matching_coord(coord): all_labels.extend(new_labels) # rescale occupancies if necessary - all_species_noedit = all_species.copy() # save copy before scaling in case of check_occu=False, used below + all_species_noedit = ( + all_species.copy() + ) # save copy before scaling in case of check_occu=False, used below for idx, species in enumerate(all_species): total_occu = sum(species.values()) if 1 < total_occu <= self._occupancy_tolerance: all_species[idx] = species / total_occu - if all_species and len(all_species) == len(all_coords) and len(all_species) == len(all_magmoms): + if ( + all_species + and len(all_species) == len(all_coords) + and len(all_species) == len(all_magmoms) + ): site_properties = {} if any(all_hydrogens): assert len(all_hydrogens) == len(all_coords) @@ -1100,7 +1180,13 @@ def get_matching_coord(coord): else: all_labels = None # type: ignore - struct = Structure(lattice, all_species, all_coords, site_properties=site_properties, labels=all_labels) + struct = Structure( + lattice, + all_species, + all_coords, + site_properties=site_properties, + labels=all_labels, + ) if symmetrized: # Wyckoff labels not currently parsed, note that not all CIFs will contain Wyckoff labels @@ -1116,7 +1202,11 @@ def get_matching_coord(coord): if not check_occu: for idx in range(len(struct)): struct[idx] = PeriodicSite( - all_species_noedit[idx], all_coords[idx], lattice, properties=site_properties, skip_checks=True + all_species_noedit[idx], + all_coords[idx], + lattice, + properties=site_properties, + skip_checks=True, ) if symmetrized or not check_occu: @@ -1181,8 +1271,12 @@ def parse_structures( Returns: list[Structure]: All structures in CIF file. """ - if os.getenv("CI") and datetime.now() > datetime(2024, 3, 1): # March 2024 seems long enough # pragma: no cover - raise RuntimeError("remove the change of default primitive=True to False made on 2023-10-24") + if os.getenv("CI") and datetime.now() > datetime( + 2024, 3, 1 + ): # March 2024 seems long enough # pragma: no cover + raise RuntimeError( + "remove the change of default primitive=True to False made on 2023-10-24" + ) if primitive is None: primitive = False warnings.warn( @@ -1191,8 +1285,12 @@ def parse_structures( "in the CIF file as is. If you want the primitive cell, please set primitive=True explicitly.", UserWarning, ) - if not check_occu: # added in https://github.com/materialsproject/pymatgen/pull/2836 - warnings.warn("Structures with unphysical site occupancies are not compatible with many pymatgen features.") + if ( + not check_occu + ): # added in https://github.com/materialsproject/pymatgen/pull/2836 + warnings.warn( + "Structures with unphysical site occupancies are not compatible with many pymatgen features." + ) if primitive and symmetrized: raise ValueError( "Using both 'primitive' and 'symmetrized' arguments is not currently supported " @@ -1202,7 +1300,9 @@ def parse_structures( structures = [] for idx, dct in enumerate(self._cif.data.values()): try: - struct = self._get_structure(dct, primitive, symmetrized, check_occu=check_occu) + struct = self._get_structure( + dct, primitive, symmetrized, check_occu=check_occu + ) if struct: structures.append(struct) except (KeyError, ValueError) as exc: @@ -1218,7 +1318,9 @@ def parse_structures( # if on_error == "raise" we don't get to here so no need to check if self.warnings and on_error == "warn": - warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings)) + warnings.warn( + "Issues encountered while parsing CIF: " + "\n".join(self.warnings) + ) if len(structures) == 0: raise ValueError("Invalid CIF file with no structures!") @@ -1276,7 +1378,10 @@ def get_bibtex_string(self): # convert to bibtex author format ('and' delimited) if "author" in bibtex_entry: # separate out semicolon authors - if isinstance(bibtex_entry["author"], str) and ";" in bibtex_entry["author"]: + if ( + isinstance(bibtex_entry["author"], str) + and ";" in bibtex_entry["author"] + ): bibtex_entry["author"] = bibtex_entry["author"].split(";") if isinstance(bibtex_entry["author"], list): @@ -1284,8 +1389,14 @@ def get_bibtex_string(self): # convert to bibtex page range format, use empty string if not specified if ("page_first" in bibtex_entry) or ("page_last" in bibtex_entry): - bibtex_entry["pages"] = bibtex_entry.get("page_first", "") + "--" + bibtex_entry.get("page_last", "") - bibtex_entry.pop("page_first", None) # and remove page_first, page_list if present + bibtex_entry["pages"] = ( + bibtex_entry.get("page_first", "") + + "--" + + bibtex_entry.get("page_last", "") + ) + bibtex_entry.pop( + "page_first", None + ) # and remove page_first, page_list if present bibtex_entry.pop("page_last", None) # cite keys are given as cif-reference-idx in order they are found @@ -1319,6 +1430,7 @@ def __init__( significant_figures=8, angle_tolerance=5.0, refine_struct=True, + write_site_properties=False, ): """ Args: @@ -1335,9 +1447,13 @@ def __init__( is not None. refine_struct: Used only if symprec is not None. If True, get_refined_structure is invoked to convert input structure from primitive to conventional. + write_site_properties (bool): Whether to write the `Structure.site_properties` + to the CIF as _atom_site_{property name}. Defaults to False. """ if write_magmoms and symprec: - warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.") + warnings.warn( + "Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection." + ) symprec = None format_str = f"{{:.{significant_figures}f}}" @@ -1346,8 +1462,13 @@ def __init__( loops = [] spacegroup = ("P 1", 1) if symprec is not None: - spg_analyzer = SpacegroupAnalyzer(struct, symprec, angle_tolerance=angle_tolerance) - spacegroup = (spg_analyzer.get_space_group_symbol(), spg_analyzer.get_space_group_number()) + spg_analyzer = SpacegroupAnalyzer( + struct, symprec, angle_tolerance=angle_tolerance + ) + spacegroup = ( + spg_analyzer.get_space_group_symbol(), + spg_analyzer.get_space_group_number(), + ) if refine_struct: # Needs the refined structure when using symprec. This converts @@ -1359,15 +1480,19 @@ def __init__( no_oxi_comp = comp.element_composition block["_symmetry_space_group_name_H-M"] = spacegroup[0] for cell_attr in ["a", "b", "c"]: - block["_cell_length_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) + block["_cell_length_" + cell_attr] = format_str.format( + getattr(lattice, cell_attr) + ) for cell_attr in ["alpha", "beta", "gamma"]: - block["_cell_angle_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) + block["_cell_angle_" + cell_attr] = format_str.format( + getattr(lattice, cell_attr) + ) block["_symmetry_Int_Tables_number"] = spacegroup[1] block["_chemical_formula_structural"] = no_oxi_comp.reduced_formula block["_chemical_formula_sum"] = no_oxi_comp.formula block["_cell_volume"] = format_str.format(lattice.volume) - _reduced_comp, fu = no_oxi_comp.get_reduced_composition_and_factor() + _, fu = no_oxi_comp.get_reduced_composition_and_factor() block["_cell_formula_units_Z"] = str(int(fu)) if symprec is None: @@ -1379,16 +1504,22 @@ def __init__( symm_ops = [] for op in spg_analyzer.get_symmetry_operations(): v = op.translation_vector - symm_ops.append(SymmOp.from_rotation_and_translation(op.rotation_matrix, v)) + symm_ops.append( + SymmOp.from_rotation_and_translation(op.rotation_matrix, v) + ) ops = [op.as_xyz_string() for op in symm_ops] - block["_symmetry_equiv_pos_site_id"] = [f"{i}" for i in range(1, len(ops) + 1)] + block["_symmetry_equiv_pos_site_id"] = [ + f"{i}" for i in range(1, len(ops) + 1) + ] block["_symmetry_equiv_pos_as_xyz"] = ops loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"]) try: - symbol_to_oxinum = {str(el): float(el.oxi_state) for el in sorted(comp.elements)} + symbol_to_oxinum = { + str(el): float(el.oxi_state) for el in sorted(comp.elements) + } block["_atom_type_symbol"] = list(symbol_to_oxinum) block["_atom_type_oxidation_number"] = symbol_to_oxinum.values() loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"]) @@ -1406,6 +1537,7 @@ def __init__( atom_site_moment_crystalaxis_x = [] atom_site_moment_crystalaxis_y = [] atom_site_moment_crystalaxis_z = [] + atom_site_properties = {k: [] for k in struct.site_properties} count = 0 if symprec is None: for site in struct: @@ -1424,25 +1556,48 @@ def __init__( mag = sp.spin else: # Use site label if available for regular sites - site_label = site.label if site.label != site.species_string else site_label + site_label = ( + site.label + if site.label != site.species_string + else site_label + ) mag = 0 atom_site_label.append(site_label) magmom = Magmom(mag) if write_magmoms and abs(magmom) > 0: - moment = Magmom.get_moment_relative_to_crystal_axes(magmom, lattice) + moment = Magmom.get_moment_relative_to_crystal_axes( + magmom, lattice + ) atom_site_moment_label.append(f"{sp.symbol}{count}") - atom_site_moment_crystalaxis_x.append(format_str.format(moment[0])) - atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) - atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) + atom_site_moment_crystalaxis_x.append( + format_str.format(moment[0]) + ) + atom_site_moment_crystalaxis_y.append( + format_str.format(moment[1]) + ) + atom_site_moment_crystalaxis_z.append( + format_str.format(moment[2]) + ) + + if write_site_properties: + for ( + property_key, + property_vals, + ) in struct.site_properties.items(): + atom_site_properties[property_key].append( + property_vals[count] + ) count += 1 else: # The following just presents a deterministic ordering. unique_sites = [ ( - sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[0], + sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[ + 0 + ], len(sites), ) for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites @@ -1463,11 +1618,24 @@ def __init__( atom_site_fract_x.append(format_str.format(site.a)) atom_site_fract_y.append(format_str.format(site.b)) atom_site_fract_z.append(format_str.format(site.c)) - site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" + site_label = ( + site.label + if site.label != site.species_string + else f"{sp.symbol}{count}" + ) atom_site_label.append(site_label) atom_site_occupancy.append(str(occu)) count += 1 + if write_site_properties: + for ( + property_key, + property_vals, + ) in struct.site_properties.items(): + atom_site_properties[property_key].append( + property_vals[count] + ) + block["_atom_site_type_symbol"] = atom_site_type_symbol block["_atom_site_label"] = atom_site_label block["_atom_site_symmetry_multiplicity"] = atom_site_symmetry_multiplicity @@ -1475,17 +1643,21 @@ def __init__( block["_atom_site_fract_y"] = atom_site_fract_y block["_atom_site_fract_z"] = atom_site_fract_z block["_atom_site_occupancy"] = atom_site_occupancy - loops.append( - [ - "_atom_site_type_symbol", - "_atom_site_label", - "_atom_site_symmetry_multiplicity", - "_atom_site_fract_x", - "_atom_site_fract_y", - "_atom_site_fract_z", - "_atom_site_occupancy", - ] - ) + loop_labels = [ + "_atom_site_type_symbol", + "_atom_site_label", + "_atom_site_symmetry_multiplicity", + "_atom_site_fract_x", + "_atom_site_fract_y", + "_atom_site_fract_z", + "_atom_site_occupancy", + ] + if write_site_properties: + for property_key, property_vals in atom_site_properties.items(): + block[f"_atom_site_{property_key}"] = property_vals + loop_labels.append(f"_atom_site_{property_key}") + loops.append(loop_labels) + if write_magmoms: block["_atom_site_moment_label"] = atom_site_moment_label block["_atom_site_moment_crystalaxis_x"] = atom_site_moment_crystalaxis_x @@ -1512,7 +1684,9 @@ def __str__(self): """Returns the CIF as a string.""" return str(self._cf) - def write_file(self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w") -> None: + def write_file( + self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w" + ) -> None: """Write the CIF file.""" with zopen(filename, mode=mode) as file: file.write(str(self)) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 399c4e595d5..c10a9529d85 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -870,6 +870,21 @@ def test_cif_writer_write_file(self): assert len(read_structs) == 2 assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"] + def test_cif_writer_site_properties(self): + struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") + site_props = {"hello": [1.0] * len(struct1), "world": [2.0] * len(struct1)} + site_props["hello"][-1] = -1.0 + struct1.add_site_property("hello", site_props["hello"]) + struct1.add_site_property("world", site_props["world"]) + out_path = f"{self.tmp_path}/test.cif" + CifWriter(struct1, write_site_properties=True).write_file(out_path) + with open(out_path) as f: + lines = f.readlines() + cif_str = "".join(lines) + assert "_atom_site_occupancy\n _atom_site_hello\n _atom_site_world\n" in cif_str + assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0 2.0" in cif_str + assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0 2.0" in cif_str + class TestMagCif(PymatgenTest): def setUp(self): From bbd264d1a086d407e3fee6920332a9b0ef633035 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:13:42 +0000 Subject: [PATCH 02/25] pre-commit auto-fixes --- pymatgen/io/cif.py | 218 +++++++++++---------------------------------- 1 file changed, 50 insertions(+), 168 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 3583b8fa0d6..5d72c9800b7 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -121,11 +121,7 @@ def _format_field(self, v): # add quotes if necessary if v == "": return '""' - if ( - (" " in v or v[0] == "_") - and not (v[0] == "'" and v[-1] == "'") - and not (v[0] == '"' and v[-1] == '"') - ): + if (" " in v or v[0] == "_") and not (v[0] == "'" and v[-1] == "'") and not (v[0] == '"' and v[-1] == '"'): q = '"' if "'" in v else "'" v = q + v + q return v @@ -256,9 +252,7 @@ def from_str(cls, string) -> CifFile: """ dct = {} - for block_str in re.split( - r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL - )[1:]: + for block_str in re.split(r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL)[1:]: # Skip over Cif block that contains powder diffraction data. # Some elements in this block were missing from CIF files in # Springer materials/Pauling file DBs. @@ -407,11 +401,7 @@ def _sanitize_data(self, data): """ # check for implicit hydrogens, warn if any present if "_atom_site_attached_hydrogens" in data.data: - attached_hydrogens = [ - str2float(x) - for x in data.data["_atom_site_attached_hydrogens"] - if str2float(x) != 0 - ] + attached_hydrogens = [str2float(x) for x in data.data["_atom_site_attached_hydrogens"] if str2float(x) != 0] if len(attached_hydrogens) > 0: self.warnings.append( "Structure has implicit hydrogens defined, parsed structure unlikely to be " @@ -446,9 +436,7 @@ def _sanitize_data(self, data): # Below, we split the strings on ' + ' to # check if the length (or number of elements) in the label and # symbol are equal. - if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len( - el_row.split(" + ") - ): + if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len(el_row.split(" + ")): # Dictionary to hold extracted elements and occupancies els_occu = {} @@ -458,23 +446,14 @@ def _sanitize_data(self, data): symbol_str_lst = symbol_str.split(" + ") for elocc_idx, sym in enumerate(symbol_str_lst): # Remove any bracketed items in the string - symbol_str_lst[elocc_idx] = re.sub( - r"\([0-9]*\)", "", sym.strip() - ) + symbol_str_lst[elocc_idx] = re.sub(r"\([0-9]*\)", "", sym.strip()) # Extract element name and its occupancy from the # string, and store it as a # key-value pair in "els_occ". els_occu[ - str( - re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1] - ).replace("", "") - ] = float( - "0" - + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[ - 1 - ] - ) + str(re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1]).replace("", "") + ] = float("0" + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[1]) x = str2float(data["_atom_site_fract_x"][idx]) y = str2float(data["_atom_site_fract_y"][idx]) @@ -482,9 +461,7 @@ def _sanitize_data(self, data): for et, occu in els_occu.items(): # new atom site labels have 'fix' appended - new_atom_site_label.append( - f"{et}_fix{len(new_atom_site_label)}" - ) + new_atom_site_label.append(f"{et}_fix{len(new_atom_site_label)}") new_atom_site_type_symbol.append(et) new_atom_site_occupancy.append(str(occu)) new_fract_x.append(str(x)) @@ -620,9 +597,7 @@ def _unique_coords( ) else: magmom = Magmom(tmp_magmom) - if not in_coord_list_pbc( - coords_out, coord, atol=self._site_tolerance - ): + if not in_coord_list_pbc(coords_out, coord, atol=self._site_tolerance): coords_out.append(coord) magmoms_out.append(magmom) labels_out.append(labels.get(tmp_coord)) @@ -670,15 +645,9 @@ def get_lattice( try: required_args = getargspec(getattr(Lattice, lattice_type)).args - lengths = ( - length - for length in length_strings - if length in required_args - ) + lengths = (length for length in length_strings if length in required_args) angles = (a for a in angle_strings if a in required_args) - return self.get_lattice( - data, lengths, angles, lattice_type=lattice_type - ) + return self.get_lattice(data, lengths, angles, lattice_type=lattice_type) except AttributeError as exc: self.warnings.append(str(exc)) warnings.warn(exc) @@ -822,12 +791,8 @@ def get_magsymops(self, data): (which changes magnetic moments on sites) needs to be returned. """ mag_symm_ops = [] - bns_name = data.data.get( - "_space_group_magn.name_BNS" - ) # get BNS label for MagneticSpaceGroup() - bns_num = data.data.get( - "_space_group_magn.number_BNS" - ) # get BNS number for MagneticSpaceGroup() + bns_name = data.data.get("_space_group_magn.name_BNS") # get BNS label for MagneticSpaceGroup() + bns_num = data.data.get("_space_group_magn.number_BNS") # get BNS number for MagneticSpaceGroup() # check to see if magCIF file explicitly contains magnetic symmetry operations if xyzt := data.data.get("_space_group_symop_magn_operation.xyz"): @@ -845,13 +810,9 @@ def get_magsymops(self, data): for op in mag_symm_ops: for centering_op in centering_symops: new_translation = [ - i - np.floor(i) - for i in op.translation_vector - + centering_op.translation_vector + i - np.floor(i) for i in op.translation_vector + centering_op.translation_vector ] - new_time_reversal = ( - op.time_reversal * centering_op.time_reversal - ) + new_time_reversal = op.time_reversal * centering_op.time_reversal all_ops.append( MagSymmOp.from_rotation_and_translation_and_time_reversal( rotation_matrix=op.rotation_matrix, @@ -889,17 +850,13 @@ def parse_oxi_states(data): """Parse oxidation states from data dictionary.""" try: oxi_states = { - data["_atom_type_symbol"][i]: str2float( - data["_atom_type_oxidation_number"][i] - ) + data["_atom_type_symbol"][i]: str2float(data["_atom_type_oxidation_number"][i]) for i in range(len(data["_atom_type_symbol"])) } # attempt to strip oxidation state from _atom_type_symbol # in case the label does not contain an oxidation state for i, symbol in enumerate(data["_atom_type_symbol"]): - oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float( - data["_atom_type_oxidation_number"][i] - ) + oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float(data["_atom_type_oxidation_number"][i]) except (ValueError, KeyError): oxi_states = None @@ -988,9 +945,7 @@ def get_num_implicit_hydrogens(sym): # if magCIF, get magnetic symmetry moments and magmoms # else standard CIF, and use empty magmom dict if self.feature_flags["magcif_incommensurate"]: - raise NotImplementedError( - "Incommensurate structures not currently supported." - ) + raise NotImplementedError("Incommensurate structures not currently supported.") if self.feature_flags["magcif"]: self.symmetry_operations = self.get_magsymops(data) magmoms = self.parse_magmoms(data, lattice=lattice) @@ -1009,9 +964,7 @@ def get_matching_coord(coord): coords = np.array(keys) for op in self.symmetry_operations: frac_coord = op.operate(coord) - indices = find_in_coord_list_pbc( - coords, frac_coord, atol=self._site_tolerance - ) + indices = find_in_coord_list_pbc(coords, frac_coord, atol=self._site_tolerance) if len(indices) > 0: return keys[indices[0]] return False @@ -1073,9 +1026,7 @@ def get_matching_coord(coord): coord_to_magmoms[match] = None labels[match] = label sum_occu = [ - sum(c.values()) - for c in coord_to_species.values() - if set(c.elements) != {Element("O"), Element("H")} + sum(c.values()) for c in coord_to_species.values() if set(c.elements) != {Element("O"), Element("H")} ] if any(occu > 1 for occu in sum_occu): msg = ( @@ -1102,9 +1053,7 @@ def get_matching_coord(coord): # property, but this introduces ambiguities for end user # (such as unintended use of `spin` and Species will have # fictitious oxidation state). - raise NotImplementedError( - "Disordered magnetic structures not currently supported." - ) + raise NotImplementedError("Disordered magnetic structures not currently supported.") if coord_to_species.items(): for idx, (comp, group) in enumerate( @@ -1121,9 +1070,7 @@ def get_matching_coord(coord): tmp_coords, magmoms=tmp_magmom, labels=labels, lattice=lattice ) else: - coords, magmoms, new_labels = self._unique_coords( - tmp_coords, labels=labels - ) + coords, magmoms, new_labels = self._unique_coords(tmp_coords, labels=labels) if set(comp.elements) == {Element("O"), Element("H")}: # O with implicit hydrogens @@ -1151,19 +1098,13 @@ def get_matching_coord(coord): all_labels.extend(new_labels) # rescale occupancies if necessary - all_species_noedit = ( - all_species.copy() - ) # save copy before scaling in case of check_occu=False, used below + all_species_noedit = all_species.copy() # save copy before scaling in case of check_occu=False, used below for idx, species in enumerate(all_species): total_occu = sum(species.values()) if 1 < total_occu <= self._occupancy_tolerance: all_species[idx] = species / total_occu - if ( - all_species - and len(all_species) == len(all_coords) - and len(all_species) == len(all_magmoms) - ): + if all_species and len(all_species) == len(all_coords) and len(all_species) == len(all_magmoms): site_properties = {} if any(all_hydrogens): assert len(all_hydrogens) == len(all_coords) @@ -1271,12 +1212,8 @@ def parse_structures( Returns: list[Structure]: All structures in CIF file. """ - if os.getenv("CI") and datetime.now() > datetime( - 2024, 3, 1 - ): # March 2024 seems long enough # pragma: no cover - raise RuntimeError( - "remove the change of default primitive=True to False made on 2023-10-24" - ) + if os.getenv("CI") and datetime.now() > datetime(2024, 3, 1): # March 2024 seems long enough # pragma: no cover + raise RuntimeError("remove the change of default primitive=True to False made on 2023-10-24") if primitive is None: primitive = False warnings.warn( @@ -1285,12 +1222,8 @@ def parse_structures( "in the CIF file as is. If you want the primitive cell, please set primitive=True explicitly.", UserWarning, ) - if ( - not check_occu - ): # added in https://github.com/materialsproject/pymatgen/pull/2836 - warnings.warn( - "Structures with unphysical site occupancies are not compatible with many pymatgen features." - ) + if not check_occu: # added in https://github.com/materialsproject/pymatgen/pull/2836 + warnings.warn("Structures with unphysical site occupancies are not compatible with many pymatgen features.") if primitive and symmetrized: raise ValueError( "Using both 'primitive' and 'symmetrized' arguments is not currently supported " @@ -1300,9 +1233,7 @@ def parse_structures( structures = [] for idx, dct in enumerate(self._cif.data.values()): try: - struct = self._get_structure( - dct, primitive, symmetrized, check_occu=check_occu - ) + struct = self._get_structure(dct, primitive, symmetrized, check_occu=check_occu) if struct: structures.append(struct) except (KeyError, ValueError) as exc: @@ -1318,9 +1249,7 @@ def parse_structures( # if on_error == "raise" we don't get to here so no need to check if self.warnings and on_error == "warn": - warnings.warn( - "Issues encountered while parsing CIF: " + "\n".join(self.warnings) - ) + warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings)) if len(structures) == 0: raise ValueError("Invalid CIF file with no structures!") @@ -1378,10 +1307,7 @@ def get_bibtex_string(self): # convert to bibtex author format ('and' delimited) if "author" in bibtex_entry: # separate out semicolon authors - if ( - isinstance(bibtex_entry["author"], str) - and ";" in bibtex_entry["author"] - ): + if isinstance(bibtex_entry["author"], str) and ";" in bibtex_entry["author"]: bibtex_entry["author"] = bibtex_entry["author"].split(";") if isinstance(bibtex_entry["author"], list): @@ -1389,14 +1315,8 @@ def get_bibtex_string(self): # convert to bibtex page range format, use empty string if not specified if ("page_first" in bibtex_entry) or ("page_last" in bibtex_entry): - bibtex_entry["pages"] = ( - bibtex_entry.get("page_first", "") - + "--" - + bibtex_entry.get("page_last", "") - ) - bibtex_entry.pop( - "page_first", None - ) # and remove page_first, page_list if present + bibtex_entry["pages"] = bibtex_entry.get("page_first", "") + "--" + bibtex_entry.get("page_last", "") + bibtex_entry.pop("page_first", None) # and remove page_first, page_list if present bibtex_entry.pop("page_last", None) # cite keys are given as cif-reference-idx in order they are found @@ -1451,9 +1371,7 @@ def __init__( to the CIF as _atom_site_{property name}. Defaults to False. """ if write_magmoms and symprec: - warnings.warn( - "Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection." - ) + warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.") symprec = None format_str = f"{{:.{significant_figures}f}}" @@ -1462,9 +1380,7 @@ def __init__( loops = [] spacegroup = ("P 1", 1) if symprec is not None: - spg_analyzer = SpacegroupAnalyzer( - struct, symprec, angle_tolerance=angle_tolerance - ) + spg_analyzer = SpacegroupAnalyzer(struct, symprec, angle_tolerance=angle_tolerance) spacegroup = ( spg_analyzer.get_space_group_symbol(), spg_analyzer.get_space_group_number(), @@ -1480,13 +1396,9 @@ def __init__( no_oxi_comp = comp.element_composition block["_symmetry_space_group_name_H-M"] = spacegroup[0] for cell_attr in ["a", "b", "c"]: - block["_cell_length_" + cell_attr] = format_str.format( - getattr(lattice, cell_attr) - ) + block["_cell_length_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) for cell_attr in ["alpha", "beta", "gamma"]: - block["_cell_angle_" + cell_attr] = format_str.format( - getattr(lattice, cell_attr) - ) + block["_cell_angle_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) block["_symmetry_Int_Tables_number"] = spacegroup[1] block["_chemical_formula_structural"] = no_oxi_comp.reduced_formula block["_chemical_formula_sum"] = no_oxi_comp.formula @@ -1504,22 +1416,16 @@ def __init__( symm_ops = [] for op in spg_analyzer.get_symmetry_operations(): v = op.translation_vector - symm_ops.append( - SymmOp.from_rotation_and_translation(op.rotation_matrix, v) - ) + symm_ops.append(SymmOp.from_rotation_and_translation(op.rotation_matrix, v)) ops = [op.as_xyz_string() for op in symm_ops] - block["_symmetry_equiv_pos_site_id"] = [ - f"{i}" for i in range(1, len(ops) + 1) - ] + block["_symmetry_equiv_pos_site_id"] = [f"{i}" for i in range(1, len(ops) + 1)] block["_symmetry_equiv_pos_as_xyz"] = ops loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"]) try: - symbol_to_oxinum = { - str(el): float(el.oxi_state) for el in sorted(comp.elements) - } + symbol_to_oxinum = {str(el): float(el.oxi_state) for el in sorted(comp.elements)} block["_atom_type_symbol"] = list(symbol_to_oxinum) block["_atom_type_oxidation_number"] = symbol_to_oxinum.values() loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"]) @@ -1556,48 +1462,32 @@ def __init__( mag = sp.spin else: # Use site label if available for regular sites - site_label = ( - site.label - if site.label != site.species_string - else site_label - ) + site_label = site.label if site.label != site.species_string else site_label mag = 0 atom_site_label.append(site_label) magmom = Magmom(mag) if write_magmoms and abs(magmom) > 0: - moment = Magmom.get_moment_relative_to_crystal_axes( - magmom, lattice - ) + moment = Magmom.get_moment_relative_to_crystal_axes(magmom, lattice) atom_site_moment_label.append(f"{sp.symbol}{count}") - atom_site_moment_crystalaxis_x.append( - format_str.format(moment[0]) - ) - atom_site_moment_crystalaxis_y.append( - format_str.format(moment[1]) - ) - atom_site_moment_crystalaxis_z.append( - format_str.format(moment[2]) - ) + atom_site_moment_crystalaxis_x.append(format_str.format(moment[0])) + atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) + atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) if write_site_properties: for ( property_key, property_vals, ) in struct.site_properties.items(): - atom_site_properties[property_key].append( - property_vals[count] - ) + atom_site_properties[property_key].append(property_vals[count]) count += 1 else: # The following just presents a deterministic ordering. unique_sites = [ ( - sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[ - 0 - ], + sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[0], len(sites), ) for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites @@ -1618,11 +1508,7 @@ def __init__( atom_site_fract_x.append(format_str.format(site.a)) atom_site_fract_y.append(format_str.format(site.b)) atom_site_fract_z.append(format_str.format(site.c)) - site_label = ( - site.label - if site.label != site.species_string - else f"{sp.symbol}{count}" - ) + site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" atom_site_label.append(site_label) atom_site_occupancy.append(str(occu)) count += 1 @@ -1632,9 +1518,7 @@ def __init__( property_key, property_vals, ) in struct.site_properties.items(): - atom_site_properties[property_key].append( - property_vals[count] - ) + atom_site_properties[property_key].append(property_vals[count]) block["_atom_site_type_symbol"] = atom_site_type_symbol block["_atom_site_label"] = atom_site_label @@ -1684,9 +1568,7 @@ def __str__(self): """Returns the CIF as a string.""" return str(self._cf) - def write_file( - self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w" - ) -> None: + def write_file(self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w") -> None: """Write the CIF file.""" with zopen(filename, mode=mode) as file: file.write(str(self)) From a0629896bf0f180ebd187350300dc39111ee3145 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:19:02 -0800 Subject: [PATCH 03/25] fix --- pymatgen/io/cif.py | 235 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 183 insertions(+), 52 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 5d72c9800b7..1d10b202507 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -20,7 +20,16 @@ from monty.io import zopen from monty.serialization import loadfn -from pymatgen.core import Composition, DummySpecies, Element, Lattice, PeriodicSite, Species, Structure, get_el_sp +from pymatgen.core import ( + Composition, + DummySpecies, + Element, + Lattice, + PeriodicSite, + Species, + Structure, + get_el_sp, +) from pymatgen.core.operations import MagSymmOp, SymmOp from pymatgen.electronic_structure.core import Magmom from pymatgen.symmetry.analyzer import SpacegroupAnalyzer, SpacegroupOperations @@ -68,7 +77,11 @@ def __init__(self, data, loops, header): def __eq__(self, other: object) -> bool: if not isinstance(other, CifBlock): return NotImplemented - return self.loops == other.loops and self.data == other.data and self.header == other.header + return ( + self.loops == other.loops + and self.data == other.data + and self.header == other.header + ) def __getitem__(self, key): return self.data[key] @@ -121,7 +134,11 @@ def _format_field(self, v): # add quotes if necessary if v == "": return '""' - if (" " in v or v[0] == "_") and not (v[0] == "'" and v[-1] == "'") and not (v[0] == '"' and v[-1] == '"'): + if ( + (" " in v or v[0] == "_") + and not (v[0] == "'" and v[-1] == "'") + and not (v[0] == '"' and v[-1] == '"') + ): q = '"' if "'" in v else "'" v = q + v + q return v @@ -252,7 +269,9 @@ def from_str(cls, string) -> CifFile: """ dct = {} - for block_str in re.split(r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL)[1:]: + for block_str in re.split( + r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL + )[1:]: # Skip over Cif block that contains powder diffraction data. # Some elements in this block were missing from CIF files in # Springer materials/Pauling file DBs. @@ -401,7 +420,11 @@ def _sanitize_data(self, data): """ # check for implicit hydrogens, warn if any present if "_atom_site_attached_hydrogens" in data.data: - attached_hydrogens = [str2float(x) for x in data.data["_atom_site_attached_hydrogens"] if str2float(x) != 0] + attached_hydrogens = [ + str2float(x) + for x in data.data["_atom_site_attached_hydrogens"] + if str2float(x) != 0 + ] if len(attached_hydrogens) > 0: self.warnings.append( "Structure has implicit hydrogens defined, parsed structure unlikely to be " @@ -436,7 +459,9 @@ def _sanitize_data(self, data): # Below, we split the strings on ' + ' to # check if the length (or number of elements) in the label and # symbol are equal. - if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len(el_row.split(" + ")): + if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len( + el_row.split(" + ") + ): # Dictionary to hold extracted elements and occupancies els_occu = {} @@ -446,14 +471,23 @@ def _sanitize_data(self, data): symbol_str_lst = symbol_str.split(" + ") for elocc_idx, sym in enumerate(symbol_str_lst): # Remove any bracketed items in the string - symbol_str_lst[elocc_idx] = re.sub(r"\([0-9]*\)", "", sym.strip()) + symbol_str_lst[elocc_idx] = re.sub( + r"\([0-9]*\)", "", sym.strip() + ) # Extract element name and its occupancy from the # string, and store it as a # key-value pair in "els_occ". els_occu[ - str(re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1]).replace("", "") - ] = float("0" + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[1]) + str( + re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1] + ).replace("", "") + ] = float( + "0" + + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[ + 1 + ] + ) x = str2float(data["_atom_site_fract_x"][idx]) y = str2float(data["_atom_site_fract_y"][idx]) @@ -461,7 +495,9 @@ def _sanitize_data(self, data): for et, occu in els_occu.items(): # new atom site labels have 'fix' appended - new_atom_site_label.append(f"{et}_fix{len(new_atom_site_label)}") + new_atom_site_label.append( + f"{et}_fix{len(new_atom_site_label)}" + ) new_atom_site_type_symbol.append(et) new_atom_site_occupancy.append(str(occu)) new_fract_x.append(str(x)) @@ -597,7 +633,9 @@ def _unique_coords( ) else: magmom = Magmom(tmp_magmom) - if not in_coord_list_pbc(coords_out, coord, atol=self._site_tolerance): + if not in_coord_list_pbc( + coords_out, coord, atol=self._site_tolerance + ): coords_out.append(coord) magmoms_out.append(magmom) labels_out.append(labels.get(tmp_coord)) @@ -645,9 +683,15 @@ def get_lattice( try: required_args = getargspec(getattr(Lattice, lattice_type)).args - lengths = (length for length in length_strings if length in required_args) + lengths = ( + length + for length in length_strings + if length in required_args + ) angles = (a for a in angle_strings if a in required_args) - return self.get_lattice(data, lengths, angles, lattice_type=lattice_type) + return self.get_lattice( + data, lengths, angles, lattice_type=lattice_type + ) except AttributeError as exc: self.warnings.append(str(exc)) warnings.warn(exc) @@ -791,8 +835,12 @@ def get_magsymops(self, data): (which changes magnetic moments on sites) needs to be returned. """ mag_symm_ops = [] - bns_name = data.data.get("_space_group_magn.name_BNS") # get BNS label for MagneticSpaceGroup() - bns_num = data.data.get("_space_group_magn.number_BNS") # get BNS number for MagneticSpaceGroup() + bns_name = data.data.get( + "_space_group_magn.name_BNS" + ) # get BNS label for MagneticSpaceGroup() + bns_num = data.data.get( + "_space_group_magn.number_BNS" + ) # get BNS number for MagneticSpaceGroup() # check to see if magCIF file explicitly contains magnetic symmetry operations if xyzt := data.data.get("_space_group_symop_magn_operation.xyz"): @@ -810,9 +858,13 @@ def get_magsymops(self, data): for op in mag_symm_ops: for centering_op in centering_symops: new_translation = [ - i - np.floor(i) for i in op.translation_vector + centering_op.translation_vector + i - np.floor(i) + for i in op.translation_vector + + centering_op.translation_vector ] - new_time_reversal = op.time_reversal * centering_op.time_reversal + new_time_reversal = ( + op.time_reversal * centering_op.time_reversal + ) all_ops.append( MagSymmOp.from_rotation_and_translation_and_time_reversal( rotation_matrix=op.rotation_matrix, @@ -850,13 +902,17 @@ def parse_oxi_states(data): """Parse oxidation states from data dictionary.""" try: oxi_states = { - data["_atom_type_symbol"][i]: str2float(data["_atom_type_oxidation_number"][i]) + data["_atom_type_symbol"][i]: str2float( + data["_atom_type_oxidation_number"][i] + ) for i in range(len(data["_atom_type_symbol"])) } # attempt to strip oxidation state from _atom_type_symbol # in case the label does not contain an oxidation state for i, symbol in enumerate(data["_atom_type_symbol"]): - oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float(data["_atom_type_oxidation_number"][i]) + oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float( + data["_atom_type_oxidation_number"][i] + ) except (ValueError, KeyError): oxi_states = None @@ -945,7 +1001,9 @@ def get_num_implicit_hydrogens(sym): # if magCIF, get magnetic symmetry moments and magmoms # else standard CIF, and use empty magmom dict if self.feature_flags["magcif_incommensurate"]: - raise NotImplementedError("Incommensurate structures not currently supported.") + raise NotImplementedError( + "Incommensurate structures not currently supported." + ) if self.feature_flags["magcif"]: self.symmetry_operations = self.get_magsymops(data) magmoms = self.parse_magmoms(data, lattice=lattice) @@ -964,7 +1022,9 @@ def get_matching_coord(coord): coords = np.array(keys) for op in self.symmetry_operations: frac_coord = op.operate(coord) - indices = find_in_coord_list_pbc(coords, frac_coord, atol=self._site_tolerance) + indices = find_in_coord_list_pbc( + coords, frac_coord, atol=self._site_tolerance + ) if len(indices) > 0: return keys[indices[0]] return False @@ -1026,7 +1086,9 @@ def get_matching_coord(coord): coord_to_magmoms[match] = None labels[match] = label sum_occu = [ - sum(c.values()) for c in coord_to_species.values() if set(c.elements) != {Element("O"), Element("H")} + sum(c.values()) + for c in coord_to_species.values() + if set(c.elements) != {Element("O"), Element("H")} ] if any(occu > 1 for occu in sum_occu): msg = ( @@ -1053,7 +1115,9 @@ def get_matching_coord(coord): # property, but this introduces ambiguities for end user # (such as unintended use of `spin` and Species will have # fictitious oxidation state). - raise NotImplementedError("Disordered magnetic structures not currently supported.") + raise NotImplementedError( + "Disordered magnetic structures not currently supported." + ) if coord_to_species.items(): for idx, (comp, group) in enumerate( @@ -1070,7 +1134,9 @@ def get_matching_coord(coord): tmp_coords, magmoms=tmp_magmom, labels=labels, lattice=lattice ) else: - coords, magmoms, new_labels = self._unique_coords(tmp_coords, labels=labels) + coords, magmoms, new_labels = self._unique_coords( + tmp_coords, labels=labels + ) if set(comp.elements) == {Element("O"), Element("H")}: # O with implicit hydrogens @@ -1098,13 +1164,19 @@ def get_matching_coord(coord): all_labels.extend(new_labels) # rescale occupancies if necessary - all_species_noedit = all_species.copy() # save copy before scaling in case of check_occu=False, used below + all_species_noedit = ( + all_species.copy() + ) # save copy before scaling in case of check_occu=False, used below for idx, species in enumerate(all_species): total_occu = sum(species.values()) if 1 < total_occu <= self._occupancy_tolerance: all_species[idx] = species / total_occu - if all_species and len(all_species) == len(all_coords) and len(all_species) == len(all_magmoms): + if ( + all_species + and len(all_species) == len(all_coords) + and len(all_species) == len(all_magmoms) + ): site_properties = {} if any(all_hydrogens): assert len(all_hydrogens) == len(all_coords) @@ -1212,8 +1284,12 @@ def parse_structures( Returns: list[Structure]: All structures in CIF file. """ - if os.getenv("CI") and datetime.now() > datetime(2024, 3, 1): # March 2024 seems long enough # pragma: no cover - raise RuntimeError("remove the change of default primitive=True to False made on 2023-10-24") + if os.getenv("CI") and datetime.now() > datetime( + 2024, 3, 1 + ): # March 2024 seems long enough # pragma: no cover + raise RuntimeError( + "remove the change of default primitive=True to False made on 2023-10-24" + ) if primitive is None: primitive = False warnings.warn( @@ -1222,8 +1298,12 @@ def parse_structures( "in the CIF file as is. If you want the primitive cell, please set primitive=True explicitly.", UserWarning, ) - if not check_occu: # added in https://github.com/materialsproject/pymatgen/pull/2836 - warnings.warn("Structures with unphysical site occupancies are not compatible with many pymatgen features.") + if ( + not check_occu + ): # added in https://github.com/materialsproject/pymatgen/pull/2836 + warnings.warn( + "Structures with unphysical site occupancies are not compatible with many pymatgen features." + ) if primitive and symmetrized: raise ValueError( "Using both 'primitive' and 'symmetrized' arguments is not currently supported " @@ -1233,7 +1313,9 @@ def parse_structures( structures = [] for idx, dct in enumerate(self._cif.data.values()): try: - struct = self._get_structure(dct, primitive, symmetrized, check_occu=check_occu) + struct = self._get_structure( + dct, primitive, symmetrized, check_occu=check_occu + ) if struct: structures.append(struct) except (KeyError, ValueError) as exc: @@ -1249,7 +1331,9 @@ def parse_structures( # if on_error == "raise" we don't get to here so no need to check if self.warnings and on_error == "warn": - warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings)) + warnings.warn( + "Issues encountered while parsing CIF: " + "\n".join(self.warnings) + ) if len(structures) == 0: raise ValueError("Invalid CIF file with no structures!") @@ -1307,7 +1391,10 @@ def get_bibtex_string(self): # convert to bibtex author format ('and' delimited) if "author" in bibtex_entry: # separate out semicolon authors - if isinstance(bibtex_entry["author"], str) and ";" in bibtex_entry["author"]: + if ( + isinstance(bibtex_entry["author"], str) + and ";" in bibtex_entry["author"] + ): bibtex_entry["author"] = bibtex_entry["author"].split(";") if isinstance(bibtex_entry["author"], list): @@ -1315,8 +1402,14 @@ def get_bibtex_string(self): # convert to bibtex page range format, use empty string if not specified if ("page_first" in bibtex_entry) or ("page_last" in bibtex_entry): - bibtex_entry["pages"] = bibtex_entry.get("page_first", "") + "--" + bibtex_entry.get("page_last", "") - bibtex_entry.pop("page_first", None) # and remove page_first, page_list if present + bibtex_entry["pages"] = ( + bibtex_entry.get("page_first", "") + + "--" + + bibtex_entry.get("page_last", "") + ) + bibtex_entry.pop( + "page_first", None + ) # and remove page_first, page_list if present bibtex_entry.pop("page_last", None) # cite keys are given as cif-reference-idx in order they are found @@ -1371,7 +1464,9 @@ def __init__( to the CIF as _atom_site_{property name}. Defaults to False. """ if write_magmoms and symprec: - warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.") + warnings.warn( + "Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection." + ) symprec = None format_str = f"{{:.{significant_figures}f}}" @@ -1380,7 +1475,9 @@ def __init__( loops = [] spacegroup = ("P 1", 1) if symprec is not None: - spg_analyzer = SpacegroupAnalyzer(struct, symprec, angle_tolerance=angle_tolerance) + spg_analyzer = SpacegroupAnalyzer( + struct, symprec, angle_tolerance=angle_tolerance + ) spacegroup = ( spg_analyzer.get_space_group_symbol(), spg_analyzer.get_space_group_number(), @@ -1396,9 +1493,13 @@ def __init__( no_oxi_comp = comp.element_composition block["_symmetry_space_group_name_H-M"] = spacegroup[0] for cell_attr in ["a", "b", "c"]: - block["_cell_length_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) + block["_cell_length_" + cell_attr] = format_str.format( + getattr(lattice, cell_attr) + ) for cell_attr in ["alpha", "beta", "gamma"]: - block["_cell_angle_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) + block["_cell_angle_" + cell_attr] = format_str.format( + getattr(lattice, cell_attr) + ) block["_symmetry_Int_Tables_number"] = spacegroup[1] block["_chemical_formula_structural"] = no_oxi_comp.reduced_formula block["_chemical_formula_sum"] = no_oxi_comp.formula @@ -1416,16 +1517,22 @@ def __init__( symm_ops = [] for op in spg_analyzer.get_symmetry_operations(): v = op.translation_vector - symm_ops.append(SymmOp.from_rotation_and_translation(op.rotation_matrix, v)) + symm_ops.append( + SymmOp.from_rotation_and_translation(op.rotation_matrix, v) + ) ops = [op.as_xyz_string() for op in symm_ops] - block["_symmetry_equiv_pos_site_id"] = [f"{i}" for i in range(1, len(ops) + 1)] + block["_symmetry_equiv_pos_site_id"] = [ + f"{i}" for i in range(1, len(ops) + 1) + ] block["_symmetry_equiv_pos_as_xyz"] = ops loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"]) try: - symbol_to_oxinum = {str(el): float(el.oxi_state) for el in sorted(comp.elements)} + symbol_to_oxinum = { + str(el): float(el.oxi_state) for el in sorted(comp.elements) + } block["_atom_type_symbol"] = list(symbol_to_oxinum) block["_atom_type_oxidation_number"] = symbol_to_oxinum.values() loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"]) @@ -1462,32 +1569,48 @@ def __init__( mag = sp.spin else: # Use site label if available for regular sites - site_label = site.label if site.label != site.species_string else site_label + site_label = ( + site.label + if site.label != site.species_string + else site_label + ) mag = 0 atom_site_label.append(site_label) magmom = Magmom(mag) if write_magmoms and abs(magmom) > 0: - moment = Magmom.get_moment_relative_to_crystal_axes(magmom, lattice) + moment = Magmom.get_moment_relative_to_crystal_axes( + magmom, lattice + ) atom_site_moment_label.append(f"{sp.symbol}{count}") - atom_site_moment_crystalaxis_x.append(format_str.format(moment[0])) - atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) - atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) + atom_site_moment_crystalaxis_x.append( + format_str.format(moment[0]) + ) + atom_site_moment_crystalaxis_y.append( + format_str.format(moment[1]) + ) + atom_site_moment_crystalaxis_z.append( + format_str.format(moment[2]) + ) if write_site_properties: for ( property_key, property_vals, ) in struct.site_properties.items(): - atom_site_properties[property_key].append(property_vals[count]) + atom_site_properties[property_key].append( + property_vals[count] + ) count += 1 else: # The following just presents a deterministic ordering. unique_sites = [ ( - sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[0], + sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[ + 0 + ], len(sites), ) for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites @@ -1508,7 +1631,11 @@ def __init__( atom_site_fract_x.append(format_str.format(site.a)) atom_site_fract_y.append(format_str.format(site.b)) atom_site_fract_z.append(format_str.format(site.c)) - site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" + site_label = ( + site.label + if site.label != site.species_string + else f"{sp.symbol}{count}" + ) atom_site_label.append(site_label) atom_site_occupancy.append(str(occu)) count += 1 @@ -1518,7 +1645,9 @@ def __init__( property_key, property_vals, ) in struct.site_properties.items(): - atom_site_properties[property_key].append(property_vals[count]) + atom_site_properties[property_key].append( + property_vals[count] + ) block["_atom_site_type_symbol"] = atom_site_type_symbol block["_atom_site_label"] = atom_site_label @@ -1568,7 +1697,9 @@ def __str__(self): """Returns the CIF as a string.""" return str(self._cf) - def write_file(self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w") -> None: + def write_file( + self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w" + ) -> None: """Write the CIF file.""" with zopen(filename, mode=mode) as file: file.write(str(self)) From 329fa835b3923ae6faf4b9c5a1ca101739b40d11 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:19:42 +0000 Subject: [PATCH 04/25] pre-commit auto-fixes --- pymatgen/io/cif.py | 235 ++++++++++----------------------------------- 1 file changed, 52 insertions(+), 183 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 1d10b202507..5d72c9800b7 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -20,16 +20,7 @@ from monty.io import zopen from monty.serialization import loadfn -from pymatgen.core import ( - Composition, - DummySpecies, - Element, - Lattice, - PeriodicSite, - Species, - Structure, - get_el_sp, -) +from pymatgen.core import Composition, DummySpecies, Element, Lattice, PeriodicSite, Species, Structure, get_el_sp from pymatgen.core.operations import MagSymmOp, SymmOp from pymatgen.electronic_structure.core import Magmom from pymatgen.symmetry.analyzer import SpacegroupAnalyzer, SpacegroupOperations @@ -77,11 +68,7 @@ def __init__(self, data, loops, header): def __eq__(self, other: object) -> bool: if not isinstance(other, CifBlock): return NotImplemented - return ( - self.loops == other.loops - and self.data == other.data - and self.header == other.header - ) + return self.loops == other.loops and self.data == other.data and self.header == other.header def __getitem__(self, key): return self.data[key] @@ -134,11 +121,7 @@ def _format_field(self, v): # add quotes if necessary if v == "": return '""' - if ( - (" " in v or v[0] == "_") - and not (v[0] == "'" and v[-1] == "'") - and not (v[0] == '"' and v[-1] == '"') - ): + if (" " in v or v[0] == "_") and not (v[0] == "'" and v[-1] == "'") and not (v[0] == '"' and v[-1] == '"'): q = '"' if "'" in v else "'" v = q + v + q return v @@ -269,9 +252,7 @@ def from_str(cls, string) -> CifFile: """ dct = {} - for block_str in re.split( - r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL - )[1:]: + for block_str in re.split(r"^\s*data_", f"x\n{string}", flags=re.MULTILINE | re.DOTALL)[1:]: # Skip over Cif block that contains powder diffraction data. # Some elements in this block were missing from CIF files in # Springer materials/Pauling file DBs. @@ -420,11 +401,7 @@ def _sanitize_data(self, data): """ # check for implicit hydrogens, warn if any present if "_atom_site_attached_hydrogens" in data.data: - attached_hydrogens = [ - str2float(x) - for x in data.data["_atom_site_attached_hydrogens"] - if str2float(x) != 0 - ] + attached_hydrogens = [str2float(x) for x in data.data["_atom_site_attached_hydrogens"] if str2float(x) != 0] if len(attached_hydrogens) > 0: self.warnings.append( "Structure has implicit hydrogens defined, parsed structure unlikely to be " @@ -459,9 +436,7 @@ def _sanitize_data(self, data): # Below, we split the strings on ' + ' to # check if the length (or number of elements) in the label and # symbol are equal. - if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len( - el_row.split(" + ") - ): + if len(data["_atom_site_type_symbol"][idx].split(" + ")) > len(el_row.split(" + ")): # Dictionary to hold extracted elements and occupancies els_occu = {} @@ -471,23 +446,14 @@ def _sanitize_data(self, data): symbol_str_lst = symbol_str.split(" + ") for elocc_idx, sym in enumerate(symbol_str_lst): # Remove any bracketed items in the string - symbol_str_lst[elocc_idx] = re.sub( - r"\([0-9]*\)", "", sym.strip() - ) + symbol_str_lst[elocc_idx] = re.sub(r"\([0-9]*\)", "", sym.strip()) # Extract element name and its occupancy from the # string, and store it as a # key-value pair in "els_occ". els_occu[ - str( - re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1] - ).replace("", "") - ] = float( - "0" - + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[ - 1 - ] - ) + str(re.findall(r"\D+", symbol_str_lst[elocc_idx].strip())[1]).replace("", "") + ] = float("0" + re.findall(r"\.?\d+", symbol_str_lst[elocc_idx].strip())[1]) x = str2float(data["_atom_site_fract_x"][idx]) y = str2float(data["_atom_site_fract_y"][idx]) @@ -495,9 +461,7 @@ def _sanitize_data(self, data): for et, occu in els_occu.items(): # new atom site labels have 'fix' appended - new_atom_site_label.append( - f"{et}_fix{len(new_atom_site_label)}" - ) + new_atom_site_label.append(f"{et}_fix{len(new_atom_site_label)}") new_atom_site_type_symbol.append(et) new_atom_site_occupancy.append(str(occu)) new_fract_x.append(str(x)) @@ -633,9 +597,7 @@ def _unique_coords( ) else: magmom = Magmom(tmp_magmom) - if not in_coord_list_pbc( - coords_out, coord, atol=self._site_tolerance - ): + if not in_coord_list_pbc(coords_out, coord, atol=self._site_tolerance): coords_out.append(coord) magmoms_out.append(magmom) labels_out.append(labels.get(tmp_coord)) @@ -683,15 +645,9 @@ def get_lattice( try: required_args = getargspec(getattr(Lattice, lattice_type)).args - lengths = ( - length - for length in length_strings - if length in required_args - ) + lengths = (length for length in length_strings if length in required_args) angles = (a for a in angle_strings if a in required_args) - return self.get_lattice( - data, lengths, angles, lattice_type=lattice_type - ) + return self.get_lattice(data, lengths, angles, lattice_type=lattice_type) except AttributeError as exc: self.warnings.append(str(exc)) warnings.warn(exc) @@ -835,12 +791,8 @@ def get_magsymops(self, data): (which changes magnetic moments on sites) needs to be returned. """ mag_symm_ops = [] - bns_name = data.data.get( - "_space_group_magn.name_BNS" - ) # get BNS label for MagneticSpaceGroup() - bns_num = data.data.get( - "_space_group_magn.number_BNS" - ) # get BNS number for MagneticSpaceGroup() + bns_name = data.data.get("_space_group_magn.name_BNS") # get BNS label for MagneticSpaceGroup() + bns_num = data.data.get("_space_group_magn.number_BNS") # get BNS number for MagneticSpaceGroup() # check to see if magCIF file explicitly contains magnetic symmetry operations if xyzt := data.data.get("_space_group_symop_magn_operation.xyz"): @@ -858,13 +810,9 @@ def get_magsymops(self, data): for op in mag_symm_ops: for centering_op in centering_symops: new_translation = [ - i - np.floor(i) - for i in op.translation_vector - + centering_op.translation_vector + i - np.floor(i) for i in op.translation_vector + centering_op.translation_vector ] - new_time_reversal = ( - op.time_reversal * centering_op.time_reversal - ) + new_time_reversal = op.time_reversal * centering_op.time_reversal all_ops.append( MagSymmOp.from_rotation_and_translation_and_time_reversal( rotation_matrix=op.rotation_matrix, @@ -902,17 +850,13 @@ def parse_oxi_states(data): """Parse oxidation states from data dictionary.""" try: oxi_states = { - data["_atom_type_symbol"][i]: str2float( - data["_atom_type_oxidation_number"][i] - ) + data["_atom_type_symbol"][i]: str2float(data["_atom_type_oxidation_number"][i]) for i in range(len(data["_atom_type_symbol"])) } # attempt to strip oxidation state from _atom_type_symbol # in case the label does not contain an oxidation state for i, symbol in enumerate(data["_atom_type_symbol"]): - oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float( - data["_atom_type_oxidation_number"][i] - ) + oxi_states[re.sub(r"\d?[\+,\-]?$", "", symbol)] = str2float(data["_atom_type_oxidation_number"][i]) except (ValueError, KeyError): oxi_states = None @@ -1001,9 +945,7 @@ def get_num_implicit_hydrogens(sym): # if magCIF, get magnetic symmetry moments and magmoms # else standard CIF, and use empty magmom dict if self.feature_flags["magcif_incommensurate"]: - raise NotImplementedError( - "Incommensurate structures not currently supported." - ) + raise NotImplementedError("Incommensurate structures not currently supported.") if self.feature_flags["magcif"]: self.symmetry_operations = self.get_magsymops(data) magmoms = self.parse_magmoms(data, lattice=lattice) @@ -1022,9 +964,7 @@ def get_matching_coord(coord): coords = np.array(keys) for op in self.symmetry_operations: frac_coord = op.operate(coord) - indices = find_in_coord_list_pbc( - coords, frac_coord, atol=self._site_tolerance - ) + indices = find_in_coord_list_pbc(coords, frac_coord, atol=self._site_tolerance) if len(indices) > 0: return keys[indices[0]] return False @@ -1086,9 +1026,7 @@ def get_matching_coord(coord): coord_to_magmoms[match] = None labels[match] = label sum_occu = [ - sum(c.values()) - for c in coord_to_species.values() - if set(c.elements) != {Element("O"), Element("H")} + sum(c.values()) for c in coord_to_species.values() if set(c.elements) != {Element("O"), Element("H")} ] if any(occu > 1 for occu in sum_occu): msg = ( @@ -1115,9 +1053,7 @@ def get_matching_coord(coord): # property, but this introduces ambiguities for end user # (such as unintended use of `spin` and Species will have # fictitious oxidation state). - raise NotImplementedError( - "Disordered magnetic structures not currently supported." - ) + raise NotImplementedError("Disordered magnetic structures not currently supported.") if coord_to_species.items(): for idx, (comp, group) in enumerate( @@ -1134,9 +1070,7 @@ def get_matching_coord(coord): tmp_coords, magmoms=tmp_magmom, labels=labels, lattice=lattice ) else: - coords, magmoms, new_labels = self._unique_coords( - tmp_coords, labels=labels - ) + coords, magmoms, new_labels = self._unique_coords(tmp_coords, labels=labels) if set(comp.elements) == {Element("O"), Element("H")}: # O with implicit hydrogens @@ -1164,19 +1098,13 @@ def get_matching_coord(coord): all_labels.extend(new_labels) # rescale occupancies if necessary - all_species_noedit = ( - all_species.copy() - ) # save copy before scaling in case of check_occu=False, used below + all_species_noedit = all_species.copy() # save copy before scaling in case of check_occu=False, used below for idx, species in enumerate(all_species): total_occu = sum(species.values()) if 1 < total_occu <= self._occupancy_tolerance: all_species[idx] = species / total_occu - if ( - all_species - and len(all_species) == len(all_coords) - and len(all_species) == len(all_magmoms) - ): + if all_species and len(all_species) == len(all_coords) and len(all_species) == len(all_magmoms): site_properties = {} if any(all_hydrogens): assert len(all_hydrogens) == len(all_coords) @@ -1284,12 +1212,8 @@ def parse_structures( Returns: list[Structure]: All structures in CIF file. """ - if os.getenv("CI") and datetime.now() > datetime( - 2024, 3, 1 - ): # March 2024 seems long enough # pragma: no cover - raise RuntimeError( - "remove the change of default primitive=True to False made on 2023-10-24" - ) + if os.getenv("CI") and datetime.now() > datetime(2024, 3, 1): # March 2024 seems long enough # pragma: no cover + raise RuntimeError("remove the change of default primitive=True to False made on 2023-10-24") if primitive is None: primitive = False warnings.warn( @@ -1298,12 +1222,8 @@ def parse_structures( "in the CIF file as is. If you want the primitive cell, please set primitive=True explicitly.", UserWarning, ) - if ( - not check_occu - ): # added in https://github.com/materialsproject/pymatgen/pull/2836 - warnings.warn( - "Structures with unphysical site occupancies are not compatible with many pymatgen features." - ) + if not check_occu: # added in https://github.com/materialsproject/pymatgen/pull/2836 + warnings.warn("Structures with unphysical site occupancies are not compatible with many pymatgen features.") if primitive and symmetrized: raise ValueError( "Using both 'primitive' and 'symmetrized' arguments is not currently supported " @@ -1313,9 +1233,7 @@ def parse_structures( structures = [] for idx, dct in enumerate(self._cif.data.values()): try: - struct = self._get_structure( - dct, primitive, symmetrized, check_occu=check_occu - ) + struct = self._get_structure(dct, primitive, symmetrized, check_occu=check_occu) if struct: structures.append(struct) except (KeyError, ValueError) as exc: @@ -1331,9 +1249,7 @@ def parse_structures( # if on_error == "raise" we don't get to here so no need to check if self.warnings and on_error == "warn": - warnings.warn( - "Issues encountered while parsing CIF: " + "\n".join(self.warnings) - ) + warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings)) if len(structures) == 0: raise ValueError("Invalid CIF file with no structures!") @@ -1391,10 +1307,7 @@ def get_bibtex_string(self): # convert to bibtex author format ('and' delimited) if "author" in bibtex_entry: # separate out semicolon authors - if ( - isinstance(bibtex_entry["author"], str) - and ";" in bibtex_entry["author"] - ): + if isinstance(bibtex_entry["author"], str) and ";" in bibtex_entry["author"]: bibtex_entry["author"] = bibtex_entry["author"].split(";") if isinstance(bibtex_entry["author"], list): @@ -1402,14 +1315,8 @@ def get_bibtex_string(self): # convert to bibtex page range format, use empty string if not specified if ("page_first" in bibtex_entry) or ("page_last" in bibtex_entry): - bibtex_entry["pages"] = ( - bibtex_entry.get("page_first", "") - + "--" - + bibtex_entry.get("page_last", "") - ) - bibtex_entry.pop( - "page_first", None - ) # and remove page_first, page_list if present + bibtex_entry["pages"] = bibtex_entry.get("page_first", "") + "--" + bibtex_entry.get("page_last", "") + bibtex_entry.pop("page_first", None) # and remove page_first, page_list if present bibtex_entry.pop("page_last", None) # cite keys are given as cif-reference-idx in order they are found @@ -1464,9 +1371,7 @@ def __init__( to the CIF as _atom_site_{property name}. Defaults to False. """ if write_magmoms and symprec: - warnings.warn( - "Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection." - ) + warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.") symprec = None format_str = f"{{:.{significant_figures}f}}" @@ -1475,9 +1380,7 @@ def __init__( loops = [] spacegroup = ("P 1", 1) if symprec is not None: - spg_analyzer = SpacegroupAnalyzer( - struct, symprec, angle_tolerance=angle_tolerance - ) + spg_analyzer = SpacegroupAnalyzer(struct, symprec, angle_tolerance=angle_tolerance) spacegroup = ( spg_analyzer.get_space_group_symbol(), spg_analyzer.get_space_group_number(), @@ -1493,13 +1396,9 @@ def __init__( no_oxi_comp = comp.element_composition block["_symmetry_space_group_name_H-M"] = spacegroup[0] for cell_attr in ["a", "b", "c"]: - block["_cell_length_" + cell_attr] = format_str.format( - getattr(lattice, cell_attr) - ) + block["_cell_length_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) for cell_attr in ["alpha", "beta", "gamma"]: - block["_cell_angle_" + cell_attr] = format_str.format( - getattr(lattice, cell_attr) - ) + block["_cell_angle_" + cell_attr] = format_str.format(getattr(lattice, cell_attr)) block["_symmetry_Int_Tables_number"] = spacegroup[1] block["_chemical_formula_structural"] = no_oxi_comp.reduced_formula block["_chemical_formula_sum"] = no_oxi_comp.formula @@ -1517,22 +1416,16 @@ def __init__( symm_ops = [] for op in spg_analyzer.get_symmetry_operations(): v = op.translation_vector - symm_ops.append( - SymmOp.from_rotation_and_translation(op.rotation_matrix, v) - ) + symm_ops.append(SymmOp.from_rotation_and_translation(op.rotation_matrix, v)) ops = [op.as_xyz_string() for op in symm_ops] - block["_symmetry_equiv_pos_site_id"] = [ - f"{i}" for i in range(1, len(ops) + 1) - ] + block["_symmetry_equiv_pos_site_id"] = [f"{i}" for i in range(1, len(ops) + 1)] block["_symmetry_equiv_pos_as_xyz"] = ops loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"]) try: - symbol_to_oxinum = { - str(el): float(el.oxi_state) for el in sorted(comp.elements) - } + symbol_to_oxinum = {str(el): float(el.oxi_state) for el in sorted(comp.elements)} block["_atom_type_symbol"] = list(symbol_to_oxinum) block["_atom_type_oxidation_number"] = symbol_to_oxinum.values() loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"]) @@ -1569,48 +1462,32 @@ def __init__( mag = sp.spin else: # Use site label if available for regular sites - site_label = ( - site.label - if site.label != site.species_string - else site_label - ) + site_label = site.label if site.label != site.species_string else site_label mag = 0 atom_site_label.append(site_label) magmom = Magmom(mag) if write_magmoms and abs(magmom) > 0: - moment = Magmom.get_moment_relative_to_crystal_axes( - magmom, lattice - ) + moment = Magmom.get_moment_relative_to_crystal_axes(magmom, lattice) atom_site_moment_label.append(f"{sp.symbol}{count}") - atom_site_moment_crystalaxis_x.append( - format_str.format(moment[0]) - ) - atom_site_moment_crystalaxis_y.append( - format_str.format(moment[1]) - ) - atom_site_moment_crystalaxis_z.append( - format_str.format(moment[2]) - ) + atom_site_moment_crystalaxis_x.append(format_str.format(moment[0])) + atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) + atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) if write_site_properties: for ( property_key, property_vals, ) in struct.site_properties.items(): - atom_site_properties[property_key].append( - property_vals[count] - ) + atom_site_properties[property_key].append(property_vals[count]) count += 1 else: # The following just presents a deterministic ordering. unique_sites = [ ( - sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[ - 0 - ], + sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[0], len(sites), ) for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites @@ -1631,11 +1508,7 @@ def __init__( atom_site_fract_x.append(format_str.format(site.a)) atom_site_fract_y.append(format_str.format(site.b)) atom_site_fract_z.append(format_str.format(site.c)) - site_label = ( - site.label - if site.label != site.species_string - else f"{sp.symbol}{count}" - ) + site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" atom_site_label.append(site_label) atom_site_occupancy.append(str(occu)) count += 1 @@ -1645,9 +1518,7 @@ def __init__( property_key, property_vals, ) in struct.site_properties.items(): - atom_site_properties[property_key].append( - property_vals[count] - ) + atom_site_properties[property_key].append(property_vals[count]) block["_atom_site_type_symbol"] = atom_site_type_symbol block["_atom_site_label"] = atom_site_label @@ -1697,9 +1568,7 @@ def __str__(self): """Returns the CIF as a string.""" return str(self._cf) - def write_file( - self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w" - ) -> None: + def write_file(self, filename: str | Path, mode: Literal["w", "a", "wt", "at"] = "w") -> None: """Write the CIF file.""" with zopen(filename, mode=mode) as file: file.write(str(self)) From 74c60026134f5a471da4451810856a31fd224813 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:23:41 -0800 Subject: [PATCH 05/25] Fix abinit --- pymatgen/io/abinit/abitimer.py | 4 ++-- pymatgen/io/abinit/netcdf.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymatgen/io/abinit/abitimer.py b/pymatgen/io/abinit/abitimer.py index ffcdb43c9d3..301174153f9 100644 --- a/pymatgen/io/abinit/abitimer.py +++ b/pymatgen/io/abinit/abitimer.py @@ -888,11 +888,11 @@ def scatter_hist(self, ax: plt.Axes = None, **kwargs): # axHistx.axis["bottom"].major_ticklabels.set_visible(False) axHistx.set_yticks([0, 50, 100]) for tl in axHistx.get_xticklabels(): - tl.set_visible(False) # noqa: FBT003 + tl.set_visible(False) # axHisty.axis["left"].major_ticklabels.set_visible(False) for tl in axHisty.get_yticklabels(): - tl.set_visible(False) # noqa: FBT003 + tl.set_visible(False) axHisty.set_xticks([0, 50, 100]) # plt.draw() diff --git a/pymatgen/io/abinit/netcdf.py b/pymatgen/io/abinit/netcdf.py index d1936fb5f99..499b6db4045 100644 --- a/pymatgen/io/abinit/netcdf.py +++ b/pymatgen/io/abinit/netcdf.py @@ -91,7 +91,7 @@ def __init__(self, path): # Slicing a ncvar returns a MaskedArrray and this is really annoying # because it can lead to unexpected behavior in e.g. calls to np.matmul! # See also https://github.com/Unidata/netcdf4-python/issues/785 - self.rootgrp.set_auto_mask(False) # noqa: FBT003 + self.rootgrp.set_auto_mask(False) def __enter__(self): """Activated when used in the with statement.""" From b8be501745365171702eac7438e58880d95dc5dd Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:27:53 -0800 Subject: [PATCH 06/25] formatting --- pymatgen/io/cif.py | 54 +++++++++------------------------------------- 1 file changed, 10 insertions(+), 44 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 5d72c9800b7..0e5aad650be 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -323,11 +323,7 @@ def is_magcif() -> bool: """Checks to see if file appears to be a magCIF file (heuristic).""" # Doesn't seem to be a canonical way to test if file is magCIF or # not, so instead check for magnetic symmetry datanames - prefixes = [ - "_space_group_magn", - "_atom_site_moment", - "_space_group_symop_magn", - ] + prefixes = ["_space_group_magn", "_atom_site_moment", "_space_group_symop_magn"] for d in self._cif.data.values(): for k in d.data: for prefix in prefixes: @@ -628,18 +624,12 @@ def get_lattice( """ try: return self.get_lattice_no_exception( - data=data, - angle_strings=angle_strings, - lattice_type=lattice_type, - length_strings=length_strings, + data=data, angle_strings=angle_strings, lattice_type=lattice_type, length_strings=length_strings ) + except KeyError: - # Missing Key search for cell setting - for lattice_label in [ - "_symmetry_cell_setting", - "_space_group_crystal_system", - ]: + for lattice_label in ["_symmetry_cell_setting", "_space_group_crystal_system"]: if data.data.get(lattice_label): lattice_type = data.data.get(lattice_label).lower() try: @@ -658,10 +648,7 @@ def get_lattice( @staticmethod def get_lattice_no_exception( - data, - length_strings=("a", "b", "c"), - angle_strings=("alpha", "beta", "gamma"), - lattice_type=None, + data, length_strings=("a", "b", "c"), angle_strings=("alpha", "beta", "gamma"), lattice_type=None ): """ Take a dictionary of CIF data and returns a pymatgen Lattice object. @@ -741,11 +728,7 @@ def get_symops(self, data): try: cod_data = loadfn( - os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "symmetry", - "symm_ops.json", - ) + os.path.join(os.path.dirname(os.path.dirname(__file__)), "symmetry", "symm_ops.json") ) for d in cod_data: if sg == re.sub(r"\s+", "", d["hermann_mauguin"]): @@ -928,11 +911,7 @@ def _parse_symbol(self, sym): return parsed_sym def _get_structure( - self, - data: dict[str, Any], - primitive: bool, - symmetrized: bool, - check_occu: bool = False, + self, data: dict[str, Any], primitive: bool, symmetrized: bool, check_occu: bool = False ) -> Structure | None: """Generate structure from part of the cif.""" @@ -1121,13 +1100,7 @@ def get_matching_coord(coord): else: all_labels = None # type: ignore - struct = Structure( - lattice, - all_species, - all_coords, - site_properties=site_properties, - labels=all_labels, - ) + struct = Structure(lattice, all_species, all_coords, site_properties=site_properties, labels=all_labels) if symmetrized: # Wyckoff labels not currently parsed, note that not all CIFs will contain Wyckoff labels @@ -1143,11 +1116,7 @@ def get_matching_coord(coord): if not check_occu: for idx in range(len(struct)): struct[idx] = PeriodicSite( - all_species_noedit[idx], - all_coords[idx], - lattice, - properties=site_properties, - skip_checks=True, + all_species_noedit[idx], all_coords[idx], lattice, properties=site_properties, skip_checks=True ) if symmetrized or not check_occu: @@ -1381,10 +1350,7 @@ def __init__( spacegroup = ("P 1", 1) if symprec is not None: spg_analyzer = SpacegroupAnalyzer(struct, symprec, angle_tolerance=angle_tolerance) - spacegroup = ( - spg_analyzer.get_space_group_symbol(), - spg_analyzer.get_space_group_number(), - ) + spacegroup = (spg_analyzer.get_space_group_symbol(), spg_analyzer.get_space_group_number()) if refine_struct: # Needs the refined structure when using symprec. This converts From 5644c8e390d82a561974e0dc315f5eb739afdac3 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:28:23 -0800 Subject: [PATCH 07/25] formatting --- pymatgen/io/cif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 0e5aad650be..e6751ed4fd5 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -627,8 +627,8 @@ def get_lattice( data=data, angle_strings=angle_strings, lattice_type=lattice_type, length_strings=length_strings ) - except KeyError: + # Missing Key search for cell setting for lattice_label in ["_symmetry_cell_setting", "_space_group_crystal_system"]: if data.data.get(lattice_label): lattice_type = data.data.get(lattice_label).lower() From ecbd3a6a85fb9ea02b411eb0568fc8cc79494c3d Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:30:06 -0800 Subject: [PATCH 08/25] type hints --- pymatgen/io/cif.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index e6751ed4fd5..07c4b08d550 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -1313,14 +1313,14 @@ class CifWriter: def __init__( self, - struct, - symprec=None, - write_magmoms=False, - significant_figures=8, - angle_tolerance=5.0, - refine_struct=True, - write_site_properties=False, - ): + struct: Structure, + symprec: float | None = None, + write_magmoms: bool = False, + significant_figures: int = 8, + angle_tolerance: float = 5.0, + refine_struct: bool = True, + write_site_properties: bool =False, + ) -> None: """ Args: struct (Structure): structure to write From 950101f5b40cb7106c1216c5debe66f053b5e753 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:30:41 -0800 Subject: [PATCH 09/25] formatting --- pymatgen/io/cif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 07c4b08d550..572813b56c2 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -1319,7 +1319,7 @@ def __init__( significant_figures: int = 8, angle_tolerance: float = 5.0, refine_struct: bool = True, - write_site_properties: bool =False, + write_site_properties: bool = False, ) -> None: """ Args: From aad6b0064806a7162f3f4be60d0a5423ad089c9d Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:35:40 -0800 Subject: [PATCH 10/25] fix --- pymatgen/io/abinit/abitimer.py | 4 ++-- pymatgen/io/abinit/netcdf.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymatgen/io/abinit/abitimer.py b/pymatgen/io/abinit/abitimer.py index 301174153f9..ffcdb43c9d3 100644 --- a/pymatgen/io/abinit/abitimer.py +++ b/pymatgen/io/abinit/abitimer.py @@ -888,11 +888,11 @@ def scatter_hist(self, ax: plt.Axes = None, **kwargs): # axHistx.axis["bottom"].major_ticklabels.set_visible(False) axHistx.set_yticks([0, 50, 100]) for tl in axHistx.get_xticklabels(): - tl.set_visible(False) + tl.set_visible(False) # noqa: FBT003 # axHisty.axis["left"].major_ticklabels.set_visible(False) for tl in axHisty.get_yticklabels(): - tl.set_visible(False) + tl.set_visible(False) # noqa: FBT003 axHisty.set_xticks([0, 50, 100]) # plt.draw() diff --git a/pymatgen/io/abinit/netcdf.py b/pymatgen/io/abinit/netcdf.py index 499b6db4045..d1936fb5f99 100644 --- a/pymatgen/io/abinit/netcdf.py +++ b/pymatgen/io/abinit/netcdf.py @@ -91,7 +91,7 @@ def __init__(self, path): # Slicing a ncvar returns a MaskedArrray and this is really annoying # because it can lead to unexpected behavior in e.g. calls to np.matmul! # See also https://github.com/Unidata/netcdf4-python/issues/785 - self.rootgrp.set_auto_mask(False) + self.rootgrp.set_auto_mask(False) # noqa: FBT003 def __enter__(self): """Activated when used in the with statement.""" From c7337e0fd0e7a9ec0bca01a9a38381a930e62666 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:38:54 -0800 Subject: [PATCH 11/25] fix --- tests/io/test_cif.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index c10a9529d85..4266fc2caa4 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -854,6 +854,21 @@ def test_no_check_occu(self): structs = parser.parse_structures(primitive=False, check_occu=False)[0] assert structs[0].species.as_dict()["Te"] == 1.5 + def test_cif_writer_site_properties(self): + struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") + site_props = {"hello": [1.0] * len(struct1), "world": [2.0] * len(struct1)} + site_props["hello"][-1] = -1.0 + struct1.add_site_property("hello", site_props["hello"]) + struct1.add_site_property("world", site_props["world"]) + out_path = f"{self.tmp_path}/test.cif" + CifWriter(struct1, write_site_properties=True).write_file(out_path) + with open(out_path) as f: + lines = f.readlines() + cif_str = "".join(lines) + assert "_atom_site_occupancy\n _atom_site_hello\n _atom_site_world\n" in cif_str + assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0 2.0" in cif_str + assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0 2.0" in cif_str + def test_cif_writer_write_file(self): struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") out_path = f"{self.tmp_path}/test.cif" @@ -870,21 +885,6 @@ def test_cif_writer_write_file(self): assert len(read_structs) == 2 assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"] - def test_cif_writer_site_properties(self): - struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") - site_props = {"hello": [1.0] * len(struct1), "world": [2.0] * len(struct1)} - site_props["hello"][-1] = -1.0 - struct1.add_site_property("hello", site_props["hello"]) - struct1.add_site_property("world", site_props["world"]) - out_path = f"{self.tmp_path}/test.cif" - CifWriter(struct1, write_site_properties=True).write_file(out_path) - with open(out_path) as f: - lines = f.readlines() - cif_str = "".join(lines) - assert "_atom_site_occupancy\n _atom_site_hello\n _atom_site_world\n" in cif_str - assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0 2.0" in cif_str - assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0 2.0" in cif_str - class TestMagCif(PymatgenTest): def setUp(self): From 9628d313c7ddf483ac50b9828466230129893e85 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:39:34 -0800 Subject: [PATCH 12/25] fixname --- tests/io/test_cif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 4266fc2caa4..7fb7e4a6cac 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -860,7 +860,7 @@ def test_cif_writer_site_properties(self): site_props["hello"][-1] = -1.0 struct1.add_site_property("hello", site_props["hello"]) struct1.add_site_property("world", site_props["world"]) - out_path = f"{self.tmp_path}/test.cif" + out_path = f"{self.tmp_path}/test_siteprops.cif" CifWriter(struct1, write_site_properties=True).write_file(out_path) with open(out_path) as f: lines = f.readlines() From eac9269b6511e51c63a9940f5e25bd4bb26a84d8 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:41:33 -0800 Subject: [PATCH 13/25] fix --- tests/io/test_cif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 7fb7e4a6cac..4266fc2caa4 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -860,7 +860,7 @@ def test_cif_writer_site_properties(self): site_props["hello"][-1] = -1.0 struct1.add_site_property("hello", site_props["hello"]) struct1.add_site_property("world", site_props["world"]) - out_path = f"{self.tmp_path}/test_siteprops.cif" + out_path = f"{self.tmp_path}/test.cif" CifWriter(struct1, write_site_properties=True).write_file(out_path) with open(out_path) as f: lines = f.readlines() From 9db149bde2f5753b3578750d4a6b1667e59862bc Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:44:38 -0800 Subject: [PATCH 14/25] fix --- tests/io/test_cif.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 4266fc2caa4..c10a9529d85 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -854,21 +854,6 @@ def test_no_check_occu(self): structs = parser.parse_structures(primitive=False, check_occu=False)[0] assert structs[0].species.as_dict()["Te"] == 1.5 - def test_cif_writer_site_properties(self): - struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") - site_props = {"hello": [1.0] * len(struct1), "world": [2.0] * len(struct1)} - site_props["hello"][-1] = -1.0 - struct1.add_site_property("hello", site_props["hello"]) - struct1.add_site_property("world", site_props["world"]) - out_path = f"{self.tmp_path}/test.cif" - CifWriter(struct1, write_site_properties=True).write_file(out_path) - with open(out_path) as f: - lines = f.readlines() - cif_str = "".join(lines) - assert "_atom_site_occupancy\n _atom_site_hello\n _atom_site_world\n" in cif_str - assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0 2.0" in cif_str - assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0 2.0" in cif_str - def test_cif_writer_write_file(self): struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") out_path = f"{self.tmp_path}/test.cif" @@ -885,6 +870,21 @@ def test_cif_writer_write_file(self): assert len(read_structs) == 2 assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"] + def test_cif_writer_site_properties(self): + struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") + site_props = {"hello": [1.0] * len(struct1), "world": [2.0] * len(struct1)} + site_props["hello"][-1] = -1.0 + struct1.add_site_property("hello", site_props["hello"]) + struct1.add_site_property("world", site_props["world"]) + out_path = f"{self.tmp_path}/test.cif" + CifWriter(struct1, write_site_properties=True).write_file(out_path) + with open(out_path) as f: + lines = f.readlines() + cif_str = "".join(lines) + assert "_atom_site_occupancy\n _atom_site_hello\n _atom_site_world\n" in cif_str + assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0 2.0" in cif_str + assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0 2.0" in cif_str + class TestMagCif(PymatgenTest): def setUp(self): From 23157e1699846a7dfda885cbdd8f843943ef5920 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:47:05 -0800 Subject: [PATCH 15/25] fix --- tests/io/test_cif.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index c10a9529d85..95e788007f7 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -879,8 +879,7 @@ def test_cif_writer_site_properties(self): out_path = f"{self.tmp_path}/test.cif" CifWriter(struct1, write_site_properties=True).write_file(out_path) with open(out_path) as f: - lines = f.readlines() - cif_str = "".join(lines) + cif_str = f.read() assert "_atom_site_occupancy\n _atom_site_hello\n _atom_site_world\n" in cif_str assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0 2.0" in cif_str assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0 2.0" in cif_str From 17bd41a49022f9a8e3c3c3d056dd677d19f430f9 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 15:54:30 -0800 Subject: [PATCH 16/25] formatting fix --- pymatgen/io/cif.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 572813b56c2..5c8aa1b6c2f 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -1442,10 +1442,7 @@ def __init__( atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) if write_site_properties: - for ( - property_key, - property_vals, - ) in struct.site_properties.items(): + for property_key, property_vals in struct.site_properties.items(): atom_site_properties[property_key].append(property_vals[count]) count += 1 @@ -1477,15 +1474,13 @@ def __init__( site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" atom_site_label.append(site_label) atom_site_occupancy.append(str(occu)) - count += 1 if write_site_properties: - for ( - property_key, - property_vals, - ) in struct.site_properties.items(): + for property_key, property_vals in struct.site_properties.items(): atom_site_properties[property_key].append(property_vals[count]) + count += 1 + block["_atom_site_type_symbol"] = atom_site_type_symbol block["_atom_site_label"] = atom_site_label block["_atom_site_symmetry_multiplicity"] = atom_site_symmetry_multiplicity From e80ab30f033941f6da632c3b8571ddee4bde42b8 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 16:02:23 -0800 Subject: [PATCH 17/25] fix --- pymatgen/io/cif.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 5c8aa1b6c2f..cfb223c2bc9 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -1442,8 +1442,8 @@ def __init__( atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) if write_site_properties: - for property_key, property_vals in struct.site_properties.items(): - atom_site_properties[property_key].append(property_vals[count]) + for property_key, property_vals in site.properties.items(): + atom_site_properties[property_key].append(property_vals) count += 1 else: @@ -1476,8 +1476,8 @@ def __init__( atom_site_occupancy.append(str(occu)) if write_site_properties: - for property_key, property_vals in struct.site_properties.items(): - atom_site_properties[property_key].append(property_vals[count]) + for property_key, property_vals in site.properties.items(): + atom_site_properties[property_key].append(property_vals) count += 1 From 11f8d3016a99ebcb374572752b0f8608db14e509 Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 16:04:18 -0800 Subject: [PATCH 18/25] fix --- pymatgen/io/cif.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index cfb223c2bc9..fda5c18bd03 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -1441,9 +1441,8 @@ def __init__( atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) - if write_site_properties: - for property_key, property_vals in site.properties.items(): - atom_site_properties[property_key].append(property_vals) + for property_key, property_vals in site.properties.items(): + atom_site_properties[property_key].append(property_vals) count += 1 else: @@ -1475,9 +1474,8 @@ def __init__( atom_site_label.append(site_label) atom_site_occupancy.append(str(occu)) - if write_site_properties: - for property_key, property_vals in site.properties.items(): - atom_site_properties[property_key].append(property_vals) + for property_key, property_vals in site.properties.items(): + atom_site_properties[property_key].append(property_vals) count += 1 From 6fea0cbdea99e0148cbfb1d665fd59c40aaf3afe Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 16:06:23 -0800 Subject: [PATCH 19/25] fix --- tests/io/test_cif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 95e788007f7..f1098b0e4a9 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -876,7 +876,7 @@ def test_cif_writer_site_properties(self): site_props["hello"][-1] = -1.0 struct1.add_site_property("hello", site_props["hello"]) struct1.add_site_property("world", site_props["world"]) - out_path = f"{self.tmp_path}/test.cif" + out_path = f"{self.tmp_path}/test2.cif" CifWriter(struct1, write_site_properties=True).write_file(out_path) with open(out_path) as f: cif_str = f.read() From ffe78ca8b99f18ddacee2a42f884dcd04e82b65d Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 16:35:23 -0800 Subject: [PATCH 20/25] fix --- pymatgen/io/cif.py | 8 +++----- tests/io/test_cif.py | 15 --------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index fda5c18bd03..30612fdd00c 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -1441,8 +1441,9 @@ def __init__( atom_site_moment_crystalaxis_y.append(format_str.format(moment[1])) atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) - for property_key, property_vals in site.properties.items(): - atom_site_properties[property_key].append(property_vals) + if write_site_properties: + for property_key, property_val in site.properties.items(): + atom_site_properties[property_key].append(format_str.format(property_val)) count += 1 else: @@ -1474,9 +1475,6 @@ def __init__( atom_site_label.append(site_label) atom_site_occupancy.append(str(occu)) - for property_key, property_vals in site.properties.items(): - atom_site_properties[property_key].append(property_vals) - count += 1 block["_atom_site_type_symbol"] = atom_site_type_symbol diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index f1098b0e4a9..9c9237b52f4 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -870,21 +870,6 @@ def test_cif_writer_write_file(self): assert len(read_structs) == 2 assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"] - def test_cif_writer_site_properties(self): - struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") - site_props = {"hello": [1.0] * len(struct1), "world": [2.0] * len(struct1)} - site_props["hello"][-1] = -1.0 - struct1.add_site_property("hello", site_props["hello"]) - struct1.add_site_property("world", site_props["world"]) - out_path = f"{self.tmp_path}/test2.cif" - CifWriter(struct1, write_site_properties=True).write_file(out_path) - with open(out_path) as f: - cif_str = f.read() - assert "_atom_site_occupancy\n _atom_site_hello\n _atom_site_world\n" in cif_str - assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0 2.0" in cif_str - assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0 2.0" in cif_str - - class TestMagCif(PymatgenTest): def setUp(self): self.mcif = CifParser(f"{TEST_FILES_DIR}/magnetic.example.NiO.mcif") From 71dba725cf995b184a26a29f17b94edb843ba50f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:36:00 +0000 Subject: [PATCH 21/25] pre-commit auto-fixes --- tests/io/test_cif.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 9c9237b52f4..399c4e595d5 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -870,6 +870,7 @@ def test_cif_writer_write_file(self): assert len(read_structs) == 2 assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"] + class TestMagCif(PymatgenTest): def setUp(self): self.mcif = CifParser(f"{TEST_FILES_DIR}/magnetic.example.NiO.mcif") From b99e316176c864289c7a7feddc02262be1bdce3d Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 16:36:34 -0800 Subject: [PATCH 22/25] fix --- pymatgen/io/cif.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 30612fdd00c..910fc2d65c6 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -1474,7 +1474,6 @@ def __init__( site_label = site.label if site.label != site.species_string else f"{sp.symbol}{count}" atom_site_label.append(site_label) atom_site_occupancy.append(str(occu)) - count += 1 block["_atom_site_type_symbol"] = atom_site_type_symbol From b8555b7bc3637d96761be665b8e5287fe79a0d8d Mon Sep 17 00:00:00 2001 From: Andrew Rosen Date: Sat, 13 Jan 2024 16:37:30 -0800 Subject: [PATCH 23/25] re-add test --- tests/io/test_cif.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 399c4e595d5..a8c7da26c74 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -870,6 +870,18 @@ def test_cif_writer_write_file(self): assert len(read_structs) == 2 assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"] + def test_cif_writer_site_properties(self): + struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") + site_props = {"hello": [1.0] * len(struct1)} + site_props["hello"][-1] = -1.0 + struct1.add_site_property("hello", site_props["hello"]) + out_path = f"{self.tmp_path}/test2.cif" + CifWriter(struct1, write_site_properties=True).write_file(out_path) + with open(out_path) as f: + cif_str = f.read() + assert "_atom_site_occupancy\n _atom_site_hello\n" in cif_str + assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0" in cif_str + assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0" in cif_str class TestMagCif(PymatgenTest): def setUp(self): From 616901a1ac82d3b03e81ada5571476c8fef3aafc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 14 Jan 2024 00:38:15 +0000 Subject: [PATCH 24/25] pre-commit auto-fixes --- tests/io/test_cif.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index a8c7da26c74..ee058b09f01 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -883,6 +883,7 @@ def test_cif_writer_site_properties(self): assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0" in cif_str assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0" in cif_str + class TestMagCif(PymatgenTest): def setUp(self): self.mcif = CifParser(f"{TEST_FILES_DIR}/magnetic.example.NiO.mcif") From 5490f56f160dd9a544e2dc3c9c021278f4b1ab6d Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 14 Jan 2024 16:36:52 +0100 Subject: [PATCH 25/25] fix mypy, fix ruff, tweak test_cif_writer_site_properties --- .pre-commit-config.yaml | 2 +- pymatgen/io/abinit/abitimer.py | 4 ++-- pymatgen/io/abinit/netcdf.py | 2 +- pymatgen/io/cif.py | 28 ++++++++++++++-------------- tests/io/test_cif.py | 14 ++++++-------- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ddc36a1769f..2c46bcadc8f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.11 + rev: v0.1.13 hooks: - id: ruff args: [--fix, --unsafe-fixes] diff --git a/pymatgen/io/abinit/abitimer.py b/pymatgen/io/abinit/abitimer.py index ffcdb43c9d3..301174153f9 100644 --- a/pymatgen/io/abinit/abitimer.py +++ b/pymatgen/io/abinit/abitimer.py @@ -888,11 +888,11 @@ def scatter_hist(self, ax: plt.Axes = None, **kwargs): # axHistx.axis["bottom"].major_ticklabels.set_visible(False) axHistx.set_yticks([0, 50, 100]) for tl in axHistx.get_xticklabels(): - tl.set_visible(False) # noqa: FBT003 + tl.set_visible(False) # axHisty.axis["left"].major_ticklabels.set_visible(False) for tl in axHisty.get_yticklabels(): - tl.set_visible(False) # noqa: FBT003 + tl.set_visible(False) axHisty.set_xticks([0, 50, 100]) # plt.draw() diff --git a/pymatgen/io/abinit/netcdf.py b/pymatgen/io/abinit/netcdf.py index d1936fb5f99..499b6db4045 100644 --- a/pymatgen/io/abinit/netcdf.py +++ b/pymatgen/io/abinit/netcdf.py @@ -91,7 +91,7 @@ def __init__(self, path): # Slicing a ncvar returns a MaskedArrray and this is really annoying # because it can lead to unexpected behavior in e.g. calls to np.matmul! # See also https://github.com/Unidata/netcdf4-python/issues/785 - self.rootgrp.set_auto_mask(False) # noqa: FBT003 + self.rootgrp.set_auto_mask(False) def __enter__(self): """Activated when used in the with statement.""" diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 910fc2d65c6..26a8c809a6f 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -7,7 +7,7 @@ import re import textwrap import warnings -from collections import deque +from collections import defaultdict, deque from datetime import datetime from functools import partial from inspect import getfullargspec as getargspec @@ -1317,7 +1317,7 @@ def __init__( symprec: float | None = None, write_magmoms: bool = False, significant_figures: int = 8, - angle_tolerance: float = 5.0, + angle_tolerance: float = 5, refine_struct: bool = True, write_site_properties: bool = False, ) -> None: @@ -1336,7 +1336,7 @@ def __init__( is not None. refine_struct: Used only if symprec is not None. If True, get_refined_structure is invoked to convert input structure from primitive to conventional. - write_site_properties (bool): Whether to write the `Structure.site_properties` + write_site_properties (bool): Whether to write the Structure.site_properties to the CIF as _atom_site_{property name}. Defaults to False. """ if write_magmoms and symprec: @@ -1345,7 +1345,7 @@ def __init__( format_str = f"{{:.{significant_figures}f}}" - block = {} + block: dict[str, Any] = {} loops = [] spacegroup = ("P 1", 1) if symprec is not None: @@ -1391,12 +1391,12 @@ def __init__( loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"]) try: - symbol_to_oxinum = {str(el): float(el.oxi_state) for el in sorted(comp.elements)} - block["_atom_type_symbol"] = list(symbol_to_oxinum) - block["_atom_type_oxidation_number"] = symbol_to_oxinum.values() + symbol_to_oxi_num = {str(el): float(el.oxi_state or 0) for el in sorted(comp.elements)} + block["_atom_type_symbol"] = list(symbol_to_oxi_num) + block["_atom_type_oxidation_number"] = symbol_to_oxi_num.values() loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"]) except (TypeError, AttributeError): - symbol_to_oxinum = {el.symbol: 0 for el in sorted(comp.elements)} + symbol_to_oxi_num = {el.symbol: 0 for el in sorted(comp.elements)} atom_site_type_symbol = [] atom_site_symmetry_multiplicity = [] @@ -1409,7 +1409,7 @@ def __init__( atom_site_moment_crystalaxis_x = [] atom_site_moment_crystalaxis_y = [] atom_site_moment_crystalaxis_z = [] - atom_site_properties = {k: [] for k in struct.site_properties} + atom_site_properties: dict[str, list] = defaultdict(list) count = 0 if symprec is None: for site in struct: @@ -1442,8 +1442,8 @@ def __init__( atom_site_moment_crystalaxis_z.append(format_str.format(moment[2])) if write_site_properties: - for property_key, property_val in site.properties.items(): - atom_site_properties[property_key].append(format_str.format(property_val)) + for key, val in site.properties.items(): + atom_site_properties[key].append(format_str.format(val)) count += 1 else: @@ -1493,9 +1493,9 @@ def __init__( "_atom_site_occupancy", ] if write_site_properties: - for property_key, property_vals in atom_site_properties.items(): - block[f"_atom_site_{property_key}"] = property_vals - loop_labels.append(f"_atom_site_{property_key}") + for key, vals in atom_site_properties.items(): + block[f"_atom_site_{key}"] = vals + loop_labels += [f"_atom_site_{key}"] loops.append(loop_labels) if write_magmoms: diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index ee058b09f01..5a4c037b2dc 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -871,15 +871,13 @@ def test_cif_writer_write_file(self): assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"] def test_cif_writer_site_properties(self): - struct1 = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") - site_props = {"hello": [1.0] * len(struct1)} - site_props["hello"][-1] = -1.0 - struct1.add_site_property("hello", site_props["hello"]) + struct = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR") + struct.add_site_property(label := "hello", [1.0] * (len(struct) - 1) + [-1.0]) out_path = f"{self.tmp_path}/test2.cif" - CifWriter(struct1, write_site_properties=True).write_file(out_path) - with open(out_path) as f: - cif_str = f.read() - assert "_atom_site_occupancy\n _atom_site_hello\n" in cif_str + CifWriter(struct, write_site_properties=True).write_file(out_path) + with open(out_path) as file: + cif_str = file.read() + assert f"_atom_site_occupancy\n _atom_site_{label}\n" in cif_str assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0" in cif_str assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0" in cif_str