From 8325156bdb3fddd3e4186f3c72ed6be703ac4e22 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 4 Sep 2023 16:10:18 +0100 Subject: [PATCH] Castep parser: a little refactoring Using some features of Python 3.8+ and newer Numpy I tried not to go too crazy, there are more places where assignment expressions (i.e. walrus) could be used to save a line but overall flow and clarity don't really benefit. --- euphonic/readers/castep.py | 67 +++++++++++++------------------------- 1 file changed, 23 insertions(+), 44 deletions(-) diff --git a/euphonic/readers/castep.py b/euphonic/readers/castep.py index 387ac64ce..4690d2127 100644 --- a/euphonic/readers/castep.py +++ b/euphonic/readers/castep.py @@ -53,15 +53,11 @@ def read_phonon_dos_data( weights = np.empty((0,)) freqs = np.empty((0, n_branches)) mode_grads = np.empty((0, n_branches)) - while True: - frequency_block = _read_frequency_block( + while (frequency_block := _read_frequency_block( f, n_branches, extra_columns=[0], - terminator=' END GRADIENTS\n') - if frequency_block is None: - # We've reached 'END GRADIENTS' line - break + terminator=' END GRADIENTS\n')) is not None: qmode_grad = frequency_block.extra @@ -73,17 +69,11 @@ def read_phonon_dos_data( if 'BEGIN DOS' not in line: raise RuntimeError( f'Expected "BEGIN DOS" in {filename}, got {line}') - # max_rows arg not available until Numpy 1.16.0 - try: - dos_data = np.loadtxt(f, max_rows=n_bins) - except TypeError: - data = f.readlines() - dos_data = np.array([[float(elem) for elem in line.split()] - for line in data[:n_bins]]) + + dos_data = np.loadtxt(f, max_rows=n_bins) data_dict: Dict[str, Any] = {} - data_dict['crystal'] = {} - cry_dict = data_dict['crystal'] + cry_dict = data_dict['crystal'] = {} cry_dict['n_atoms'] = n_atoms cry_dict['cell_vectors'] = (cell_vectors*ureg('angstrom').to( cell_vectors_unit)).magnitude @@ -195,11 +185,8 @@ def read_phonon_data( repeated_qpt_ids = defaultdict(set) loto_split_indices = set() - while True: - frequency_block = _read_frequency_block(f, n_branches) - if frequency_block is None: - # Reached empty line, this should be end of file - break + while (frequency_block := _read_frequency_block(f, n_branches) + ) is not None: qpt_id = frequency_block.qpt_id @@ -243,9 +230,8 @@ def read_phonon_data( # Multiple qpts with same CASTEP q-pt index: correct weights for qpt_id, indices in repeated_qpt_ids.items(): - intersection = indices & loto_split_indices if (prefer_non_loto - and intersection + and (intersection := indices & loto_split_indices) and len(indices) > len(intersection)): # Repeated q-point has both split and un-split variations; # set weights of split points to zero @@ -396,11 +382,9 @@ def _read_frequency_block( qpt = np.array([float(x) for x in floats[:3]]) qweight = float(floats[3]) - direction: Optional[np.ndarray] + direction: Optional[np.ndarray] = None if len(floats) >= 6: direction = floats[4:7] - else: - direction = None freq_lines = [f.readline().split() for i in range(n_branches)] @@ -458,11 +442,9 @@ def read_interpolation_data( with open(filename, 'rb') as f: int_type = '>i4' float_type = '>f8' - header = '' first_cell_read = True - while header.strip() != b'END': - header = _read_entry(f) - if header.strip() == b'BEGIN_UNIT_CELL': + while (header := _read_entry(f).strip()) != b'END': + if header == b'BEGIN_UNIT_CELL': # CASTEP writes the cell twice: the first is the # geometry optimised cell, the second is the original # cell. We only want the geometry optimised cell. @@ -470,7 +452,7 @@ def read_interpolation_data( (n_atoms, cell_vectors, atom_r, atom_mass, atom_type) = _read_cell(f, int_type, float_type) first_cell_read = False - elif header.strip() == b'FORCE_CON': + elif header == b'FORCE_CON': sc_matrix = np.transpose(np.reshape( _read_entry(f, int_type), (3, 3))) n_cells_in_sc = int(np.rint(np.absolute( @@ -483,16 +465,15 @@ def read_interpolation_data( cell_origins = np.reshape( _read_entry(f, int_type), (n_cells_in_sc, 3)) _ = _read_entry(f, int_type) # FC row not used - elif header.strip() == b'BORN_CHGS': + elif header == b'BORN_CHGS': born = np.reshape( _read_entry(f, float_type), (n_atoms, 3, 3)) - elif header.strip() == b'DIELECTRIC': + elif header == b'DIELECTRIC': dielectric = np.transpose(np.reshape( _read_entry(f, float_type), (3, 3))) data_dict: Dict[str, Any] = {} - data_dict['crystal'] = {} - cry_dict = data_dict['crystal'] + cry_dict = data_dict['crystal'] = {} cry_dict['n_atoms'] = n_atoms cry_dict['cell_vectors'] = cell_vectors*ureg( 'bohr').to(cell_vectors_unit).magnitude @@ -563,29 +544,27 @@ def _read_cell(file_obj: BinaryIO, int_type: str, float_type: str Shape (n_atoms,) string ndarray. The chemical symbols of each atom in the unit cell """ - header = '' - while header.strip() != b'END_UNIT_CELL': - header = _read_entry(file_obj) - if header.strip() == b'CELL%NUM_IONS': + while (header := _read_entry(file_obj).strip()) != b'END_UNIT_CELL': + if header == b'CELL%NUM_IONS': n_atoms = _read_entry(file_obj, int_type) - elif header.strip() == b'CELL%REAL_LATTICE': + elif header == b'CELL%REAL_LATTICE': cell_vectors = np.transpose(np.reshape( _read_entry(file_obj, float_type), (3, 3))) - elif header.strip() == b'CELL%NUM_SPECIES': + elif header == b'CELL%NUM_SPECIES': n_species = _read_entry(file_obj, int_type) - elif header.strip() == b'CELL%NUM_IONS_IN_SPECIES': + elif header == b'CELL%NUM_IONS_IN_SPECIES': n_atoms_in_species = _read_entry(file_obj, int_type) if n_species == 1: n_atoms_in_species = np.array([n_atoms_in_species]) - elif header.strip() == b'CELL%IONIC_POSITIONS': + elif header == b'CELL%IONIC_POSITIONS': max_atoms_in_species = max(n_atoms_in_species) atom_r_tmp = np.reshape(_read_entry(file_obj, float_type), (n_species, max_atoms_in_species, 3)) - elif header.strip() == b'CELL%SPECIES_MASS': + elif header == b'CELL%SPECIES_MASS': atom_mass_tmp = _read_entry(file_obj, float_type) if n_species == 1: atom_mass_tmp = np.array([atom_mass_tmp]) - elif header.strip() == b'CELL%SPECIES_SYMBOL': + elif header == b'CELL%SPECIES_SYMBOL': # Need to decode binary string for Python 3 compatibility if n_species == 1: atom_type_tmp = [_read_entry(file_obj, 'S8')