From 43e00c03d8e858f2114b924b8f2bdd10366e732f Mon Sep 17 00:00:00 2001 From: Alvin Noe Ladines Date: Mon, 22 Jan 2024 03:58:11 +0100 Subject: [PATCH] Fix matid import --- systemnormalizer/normalizer.py | 260 +++++++++++++++++++++------------ 1 file changed, 170 insertions(+), 90 deletions(-) diff --git a/systemnormalizer/normalizer.py b/systemnormalizer/normalizer.py index ffeb73d..0159ca9 100644 --- a/systemnormalizer/normalizer.py +++ b/systemnormalizer/normalizer.py @@ -23,7 +23,14 @@ import json import re from matid import SymmetryAnalyzer, Classifier # pylint: disable=import-error -from matid.classifications import Class0D, Atom, Class1D, Material2D, Surface, Class3D # pylint: disable=import-error +from matid.classification.classifications import ( + Class0D, + Atom, + Class1D, + Material2D, + Surface, + Class3D, +) # pylint: disable=import-error from nomad import atomutils from nomad.atomutils import Formula @@ -34,27 +41,31 @@ # use a regular expression to check atom labels; expression is build from list of # all labels sorted desc to find Br and not B when searching for Br. -atom_label_re = re.compile('|'.join( - sorted(ase.data.chemical_symbols, key=lambda x: len(x), reverse=True))) +atom_label_re = re.compile( + "|".join(sorted(ase.data.chemical_symbols, key=lambda x: len(x), reverse=True)) +) def normalized_atom_labels(atom_labels): - ''' + """ Normalizes the given atom labels: they either are labels right away, or contain additional numbers (to distinguish same species but different labels, see meta-info), or we replace them with ase placeholder atom for unknown elements 'X'. - ''' + """ return [ ase.data.chemical_symbols[0] if match is None else match.group(0) - for match in [re.search(atom_label_re, atom_label) for atom_label in atom_labels]] + for match in [ + re.search(atom_label_re, atom_label) for atom_label in atom_labels + ] + ] def formula_normalizer(atoms): - ''' + """ Reads the chemical symbols in ase.atoms and returns a normalized formula. Formula normalization is on the basis of atom counting, e.g., Tc -> Tc100, SZn -> S50Zn50, Co2Nb -> Co67Nb33 - ''' + """ # atoms_counter = atoms.symbols.formula.count() # dictionary atoms_total = sum(atoms_counter.values()) @@ -65,14 +76,15 @@ def formula_normalizer(atoms): atoms_normed.append(key + norm) # atoms_normed.sort() - return ''.join(atoms_normed) + return "".join(atoms_normed) class SystemNormalizer(SystemBasedNormalizer): - ''' + """ This normalizer performs all system (atoms, cells, etc.) related normalizations of the legacy NOMAD-coe *stats* normalizer. - ''' + """ + @staticmethod def atom_label_to_num(atom_label): # Take first three characters and make first letter capitalized. @@ -86,22 +98,24 @@ def atom_label_to_num(atom_label): return 0 def normalize_system(self, system, is_representative) -> bool: - ''' + """ The 'main' method of this :class:`SystemBasedNormalizer`. Normalizes the section with the given `index`. Normalizes geometry, classifies, system_type, and runs symmetry analysis. Returns: True, iff the normalization was successful - ''' + """ if self.section_run is None: - self.logger.error('section_run is not present.') + self.logger.error("section_run is not present.") return False - atoms_cls = system.m_def.all_sub_sections['atoms'].sub_section.section_cls + atoms_cls = system.m_def.all_sub_sections["atoms"].sub_section.section_cls if system.atoms is None: system.atoms = atoms_cls() - def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any = None) -> Any: + def get_value( + quantity_def, default: Any = None, numpy: bool = True, source: Any = None + ) -> Any: try: source = system if source is None else source value = source.m_get(quantity_def) @@ -125,15 +139,17 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any atom_species = get_value(atoms_cls.species, numpy=False, source=system.atoms) if atom_labels is None and atom_species is None: - self.logger.warn('system has neither atom species nor labels') + self.logger.warn("system has neither atom species nor labels") return False # If there are no atom labels we create them from atom species data. if atom_labels is None: try: - atom_labels = list(ase.data.chemical_symbols[species] for species in atom_species) + atom_labels = list( + ase.data.chemical_symbols[species] for species in atom_species + ) except IndexError: - self.logger.error('system has atom species that are out of range') + self.logger.error("system has atom species that are out of range") return False system.atoms.labels = atom_labels @@ -143,12 +159,17 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any atoms = ase.Atoms(symbols=atom_labels) chemical_symbols = list(atoms.get_chemical_symbols()) if atom_labels != chemical_symbols: - self.logger.error('atom labels are ambiguous', atom_labels=atom_labels[:10]) + self.logger.error( + "atom labels are ambiguous", atom_labels=atom_labels[:10] + ) atom_labels = chemical_symbols except Exception as e: self.logger.error( - 'cannot build ase atoms from atom labels', - atom_labels=atom_labels[:10], exc_info=e, error=str(e)) + "cannot build ase atoms from atom labels", + atom_labels=atom_labels[:10], + exc_info=e, + error=str(e), + ) raise e if atom_species is None: @@ -159,8 +180,10 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any atom_species = [atom_species] if atom_species != atoms.get_atomic_numbers().tolist(): self.logger.warning( - 'atom species do not match labels', - atom_labels=atom_labels[:10], atom_species=atom_species[:10]) + "atom species do not match labels", + atom_labels=atom_labels[:10], + atom_species=atom_species[:10], + ) atom_species = atoms.get_atomic_numbers().tolist() system.atoms.species = atom_species @@ -168,66 +191,82 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any pbc = get_value(atoms_cls.periodic, numpy=False, source=system.atoms) if pbc is None: pbc = [False, False, False] - self.logger.warning('missing configuration_periodic_dimensions') + self.logger.warning("missing configuration_periodic_dimensions") system.atoms.periodic = pbc try: atoms.set_pbc(pbc) except Exception as e: self.logger.error( - 'cannot use pbc with ase atoms', exc_info=e, pbc=pbc, error=str(e)) + "cannot use pbc with ase atoms", exc_info=e, pbc=pbc, error=str(e) + ) return False # formulas try: formula = Formula(atoms.get_chemical_formula()) - system.chemical_composition = atoms.get_chemical_formula(mode='all') - system.chemical_composition_reduced = formula.format('reduced') + system.chemical_composition = atoms.get_chemical_formula(mode="all") + system.chemical_composition_reduced = formula.format("reduced") system.chemical_composition_hill = formula.format("hill") except ValueError as e: - self.logger.error('could not extract chemical formula', exc_info=e, error=str(e)) + self.logger.error( + "could not extract chemical formula", exc_info=e, error=str(e) + ) # positions atom_positions = get_value(atoms_cls.positions, numpy=True, source=system.atoms) if atom_positions is None or len(atom_positions) == 0: - self.logger.warning('no atom positions, skip further system analysis') + self.logger.warning("no atom positions, skip further system analysis") return False if len(atom_positions) != len(atoms): self.logger.error( - 'len of atom position does not match number of atoms', - n_atom_positions=len(atom_positions), n_atoms=len(atoms)) + "len of atom position does not match number of atoms", + n_atom_positions=len(atom_positions), + n_atoms=len(atoms), + ) return False try: atoms.set_positions(1e10 * atom_positions.magnitude) except Exception as e: self.logger.error( - 'cannot use positions with ase atoms', exc_info=e, error=str(e)) + "cannot use positions with ase atoms", exc_info=e, error=str(e) + ) return False # lattice vectors - lattice_vectors = get_value(atoms_cls.lattice_vectors, numpy=True, source=system.atoms) + lattice_vectors = get_value( + atoms_cls.lattice_vectors, numpy=True, source=system.atoms + ) if lattice_vectors is None: if any(pbc): - self.logger.error('no lattice vectors but periodicity', pbc=pbc) + self.logger.error("no lattice vectors but periodicity", pbc=pbc) else: try: atoms.set_cell(1e10 * lattice_vectors.magnitude) except Exception as e: self.logger.error( - 'cannot use lattice_vectors with ase atoms', exc_info=e, error=str(e)) + "cannot use lattice_vectors with ase atoms", + exc_info=e, + error=str(e), + ) return False # reciprocal lattice vectors lattice_vectors_reciprocal = get_value( - atoms_cls.lattice_vectors_reciprocal, numpy=True, source=system.atoms) + atoms_cls.lattice_vectors_reciprocal, numpy=True, source=system.atoms + ) if lattice_vectors_reciprocal is None and lattice_vectors is not None: - system.atoms.lattice_vectors_reciprocal = 2 * np.pi * atomutils.reciprocal_cell(lattice_vectors.magnitude) # there is also a get_reciprocal_cell method in ase + system.atoms.lattice_vectors_reciprocal = ( + 2 * np.pi * atomutils.reciprocal_cell(lattice_vectors.magnitude) + ) # there is also a get_reciprocal_cell method in ase # configuration configuration = [ - atom_labels, atoms.positions.tolist(), + atom_labels, + atoms.positions.tolist(), atoms.cell.tolist() if atoms.cell is not None else None, - atoms.pbc.tolist()] - configuration_id = utils.hash(json.dumps(configuration).encode('utf-8')) + atoms.pbc.tolist(), + ] + configuration_id = utils.hash(json.dumps(configuration).encode("utf-8")) system.configuration_raw_gid = configuration_id if is_representative: @@ -237,50 +276,62 @@ def get_value(quantity_def, default: Any = None, numpy: bool = True, source: Any # System type analysis if atom_positions is not None: with utils.timer( - self.logger, 'system classification executed', - system_size=len(atoms)): + self.logger, + "system classification executed", + system_size=len(atoms), + ): self.system_type_analysis(atoms) # Symmetry analysis - if atom_positions is not None and (lattice_vectors is not None or not any(pbc)) and system.type == "bulk": + if ( + atom_positions is not None + and (lattice_vectors is not None or not any(pbc)) + and system.type == "bulk" + ): with utils.timer( - self.logger, 'symmetry analysis executed', - system_size=len(atoms)): + self.logger, "symmetry analysis executed", system_size=len(atoms) + ): self.symmetry_analysis(system, atoms) return True def system_type_analysis(self, atoms: ase.Atoms) -> None: - ''' + """ Determine the system type with MatID. Write the system type to the entry_archive. Args: atoms: The structure to analyse - ''' + """ system_type = config.services.unavailable_value classification = None if len(atoms) <= config.normalize.system_classification_with_clusters_threshold: try: - classifier = Classifier(radii="covalent", cluster_threshold=config.normalize.cluster_threshold) + classifier = Classifier( + radii="covalent", + cluster_threshold=config.normalize.cluster_threshold, + ) cls = classifier.classify(atoms) except Exception as e: self.logger.error( - 'matid project system classification failed', exc_info=e, error=str(e)) + "matid project system classification failed", + exc_info=e, + error=str(e), + ) else: classification = type(cls) if classification == Class3D: - system_type = 'bulk' + system_type = "bulk" elif classification == Atom: - system_type = 'atom' + system_type = "atom" elif classification == Class0D: - system_type = 'molecule / cluster' + system_type = "molecule / cluster" elif classification == Class1D: - system_type = '1D' + system_type = "1D" elif classification == Surface: - system_type = 'surface' + system_type = "surface" elif classification == Material2D: - system_type = '2D' + system_type = "2D" else: self.logger.info("system type analysis not run due to large system size") idx = self.section_run.m_cache["representative_system_idx"] @@ -289,7 +340,7 @@ def system_type_analysis(self, atoms: ase.Atoms) -> None: self.section_run.system[-1].type = system_type def symmetry_analysis(self, system, atoms: ase.Atoms) -> None: - '''Analyze the symmetry of the material being simulated. Only performed + """Analyze the symmetry of the material being simulated. Only performed for bulk materials. We feed in the parsed values in section_system to the the symmetry @@ -301,10 +352,12 @@ def symmetry_analysis(self, system, atoms: ase.Atoms) -> None: Returns: None: The method should write symmetry variables to the entry_archive which is member of this class. - ''' + """ # Try to use MatID's symmetry analyzer to analyze the ASE object. try: - symm = SymmetryAnalyzer(atoms, symmetry_tol=config.normalize.symmetry_tolerance) + symm = SymmetryAnalyzer( + atoms, symmetry_tol=config.normalize.symmetry_tolerance + ) space_group_number = symm.get_space_group_number() @@ -337,20 +390,24 @@ def symmetry_analysis(self, system, atoms: ase.Atoms) -> None: transform = symm._get_spglib_transformation_matrix() origin_shift = symm._get_spglib_origin_shift() except ValueError as e: - self.logger.debug('symmetry analysis is not available', details=str(e)) + self.logger.debug("symmetry analysis is not available", details=str(e)) return except Exception as e: - self.logger.error('matid symmetry analysis fails with exception', exc_info=e) + self.logger.error( + "matid symmetry analysis fails with exception", exc_info=e + ) return # Write data extracted from MatID's symmetry analysis to the # representative section_system. # symmetry_cls = system.m_def.all_sub_sections['symmetry'].sub_section.section_cls - sec_symmetry = system.m_def.all_sub_sections['symmetry'].sub_section.section_cls() + sec_symmetry = system.m_def.all_sub_sections[ + "symmetry" + ].sub_section.section_cls() system.symmetry.append(sec_symmetry) sec_symmetry.m_cache["symmetry_analyzer"] = symm - sec_symmetry.symmetry_method = 'MatID (spg)' + sec_symmetry.symmetry_method = "MatID (spg)" sec_symmetry.space_group_number = space_group_number sec_symmetry.hall_number = hall_number sec_symmetry.hall_symbol = hall_symbol @@ -361,7 +418,7 @@ def symmetry_analysis(self, system, atoms: ase.Atoms) -> None: sec_symmetry.origin_shift = origin_shift sec_symmetry.transformation_matrix = transform - atoms_cls = system.m_def.all_sub_sections['atoms'].sub_section.section_cls + atoms_cls = system.m_def.all_sub_sections["atoms"].sub_section.section_cls sec_std = atoms_cls() sec_symmetry.system_std.append(sec_std) sec_std.lattice_vectors = conv_cell * ureg.angstrom @@ -392,19 +449,23 @@ def springer_classification(self, atoms, space_group_number): idx = self.section_run.m_cache["representative_system_idx"] for material in springer_data.values(): - sec_springer_mat = self.section_run.system[idx].m_def.all_sub_sections['springer_material'].sub_section.section_cls() + sec_springer_mat = ( + self.section_run.system[idx] + .m_def.all_sub_sections["springer_material"] + .sub_section.section_cls() + ) self.section_run.system[idx].springer_material.append(sec_springer_mat) - sec_springer_mat.id = material['spr_id'] - sec_springer_mat.alphabetical_formula = material['spr_aformula'] - sec_springer_mat.url = material['spr_url'] + sec_springer_mat.id = material["spr_id"] + sec_springer_mat.alphabetical_formula = material["spr_aformula"] + sec_springer_mat.url = material["spr_url"] - compound_classes = material['spr_compound'] + compound_classes = material["spr_compound"] if compound_classes is None: compound_classes = [] sec_springer_mat.compound_class = compound_classes - classifications = material['spr_classification'] + classifications = material["spr_classification"] if classifications is None: classifications = [] sec_springer_mat.classification = classifications @@ -413,26 +474,33 @@ def springer_classification(self, atoms, space_group_number): # found is the same for all springer_id's springer_data_keys = list(springer_data.keys()) if len(springer_data_keys) != 0: - class_0 = springer_data[springer_data_keys[0]]['spr_classification'] - comp_0 = springer_data[springer_data_keys[0]]['spr_compound'] + class_0 = springer_data[springer_data_keys[0]]["spr_classification"] + comp_0 = springer_data[springer_data_keys[0]]["spr_compound"] # compare 'class_0' and 'comp_0' against the rest for ii in range(1, len(springer_data_keys)): - class_test = (class_0 == springer_data[springer_data_keys[ii]]['spr_classification']) - comp_test = (comp_0 == springer_data[springer_data_keys[ii]]['spr_compound']) + class_test = ( + class_0 + == springer_data[springer_data_keys[ii]]["spr_classification"] + ) + comp_test = ( + comp_0 == springer_data[springer_data_keys[ii]]["spr_compound"] + ) if (class_test or comp_test) is False: - self.logger.info('Mismatch in Springer classification or compounds') + self.logger.info("Mismatch in Springer classification or compounds") - def prototypes(self, system, atom_species: NDArray, wyckoffs: NDArray, spg_number: int) -> None: - '''Tries to match the material to an entry in the AFLOW prototype data. + def prototypes( + self, system, atom_species: NDArray, wyckoffs: NDArray, spg_number: int + ) -> None: + """Tries to match the material to an entry in the AFLOW prototype data. If a match is found, a section_prototype is added to section_system. Args: atomic_numbers: Array of atomic numbers. wyckoff_letters: Array of Wyckoff letters as strings. spg_number: Space group number. - ''' + """ norm_wyckoff = atomutils.get_normalized_wyckoff(atom_species, wyckoffs) protoDict = atomutils.search_aflow_prototype(spg_number, norm_wyckoff) if protoDict is not None: @@ -441,13 +509,17 @@ def prototypes(self, system, atom_species: NDArray, wyckoffs: NDArray, spg_numbe aflow_prototype_notes = protoDict["Notes"] aflow_prototype_name = protoDict["Prototype"] aflow_strukturbericht_designation = protoDict["Strukturbericht Designation"] - prototype_label = '%d-%s-%s' % ( + prototype_label = "%d-%s-%s" % ( spg_number, aflow_prototype_name, - protoDict.get("Pearsons Symbol", "-") + protoDict.get("Pearsons Symbol", "-"), ) idx = self.section_run.m_cache["representative_system_idx"] - sec_prototype = self.section_run.system[idx].m_def.all_sub_sections['prototype'].sub_section.section_cls() + sec_prototype = ( + self.section_run.system[idx] + .m_def.all_sub_sections["prototype"] + .sub_section.section_cls() + ) self.section_run.system[idx].prototype.append(sec_prototype) sec_prototype.label = prototype_label sec_prototype.aflow_id = aflow_prototype_id @@ -456,11 +528,15 @@ def prototypes(self, system, atom_species: NDArray, wyckoffs: NDArray, spg_numbe sec_prototype.m_cache["prototype_notes"] = aflow_prototype_notes sec_prototype.m_cache["prototype_name"] = aflow_prototype_name if aflow_strukturbericht_designation != "None": - sec_prototype.m_cache["strukturbericht_designation"] = aflow_strukturbericht_designation + sec_prototype.m_cache[ + "strukturbericht_designation" + ] = aflow_strukturbericht_designation -def query_springer_data(normalized_formula: str, space_group_number: int) -> Dict[str, Any]: - ''' Queries a msgpack database for springer-related quantities. ''' +def query_springer_data( + normalized_formula: str, space_group_number: int +) -> Dict[str, Any]: + """Queries a msgpack database for springer-related quantities.""" try: from nomad import archive except ModuleNotFoundError: @@ -470,15 +546,19 @@ def query_springer_data(normalized_formula: str, space_group_number: int) -> Dic if config.normalize.springer_db_path is None: return {} - entries = archive.query_archive(config.normalize.springer_db_path, {str(space_group_number): {normalized_formula: '*'}}) + entries = archive.query_archive( + config.normalize.springer_db_path, + {str(space_group_number): {normalized_formula: "*"}}, + ) db_dict = {} entries = entries.get(str(space_group_number), {}).get(normalized_formula, {}) for sp_id, entry in entries.items(): db_dict[sp_id] = { - 'spr_id': sp_id, - 'spr_aformula': entry['aformula'], - 'spr_url': entry['url'], - 'spr_compound': entry['compound'], - 'spr_classification': entry['classification']} + "spr_id": sp_id, + "spr_aformula": entry["aformula"], + "spr_url": entry["url"], + "spr_compound": entry["compound"], + "spr_classification": entry["classification"], + } return db_dict