Skip to content

Commit

Permalink
Fix bandstats issue
Browse files Browse the repository at this point in the history
  • Loading branch information
utf committed Dec 13, 2024
1 parent 9ec6740 commit 5c0d625
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 19 deletions.
2 changes: 1 addition & 1 deletion sumo/cli/bandplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def bandplot(
if code == "vasp":
for vr_file in filenames:
vr = BSVasprun(vr_file, parse_projected_eigen=parse_projected)
bs = vr.get_band_structure(line_mode=True)
bs = vr.get_band_structure(line_mode=True, efermi="smart")
bandstructures.append(bs)
bs = get_reconstructed_band_structure(bandstructures)
elif code == "castep":
Expand Down
37 changes: 28 additions & 9 deletions sumo/cli/bandstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,26 @@ def bandstats(
bandstructures = []
for vr_file in filenames:
vr = BSVasprun(vr_file, parse_projected_eigen=False)
bs = vr.get_band_structure(line_mode=True)
bs = vr.get_band_structure(line_mode=True, efermi="smart")
bandstructures.append(bs)
bs = get_reconstructed_band_structure(bandstructures, force_kpath_branches=False)
bs, kpt_mapping = get_reconstructed_band_structure(
bandstructures, force_kpath_branches=True, return_forced_branch_kpt_map=True
)

if bs.is_metal():
logging.error("ERROR: System is metallic!")
sys.exit()

_log_band_gap_information(bs)
_log_band_gap_information(bs, kpt_mapping=kpt_mapping)

vbm_data = bs.get_vbm()
cbm_data = bs.get_cbm()

logging.info("\nValence band maximum:")
_log_band_edge_information(bs, vbm_data)
_log_band_edge_information(bs, vbm_data, kpt_mapping=kpt_mapping)

logging.info("\nConduction band minimum:")
_log_band_edge_information(bs, cbm_data)
_log_band_edge_information(bs, cbm_data, kpt_mapping=kpt_mapping)

if parabolic:
logging.info("\nUsing parabolic fitting of the band edges")
Expand Down Expand Up @@ -179,11 +181,14 @@ def bandstats(
return {"hole_data": hole_data, "electron_data": elec_data}


def _log_band_gap_information(bs):
def _log_band_gap_information(bs, kpt_mapping=None):
"""Log data about the direct and indirect band gaps.
Args:
bs (:obj:`~pymatgen.electronic_structure.bandstructure.BandStructureSymmLine`):
kpt_mapping (:obj:`dict`, optional): A mapping of k-point indicies from the
band structure with forced branches to the original band structure.
"""
bg_data = bs.get_band_gap()
if not bg_data["direct"]:
Expand All @@ -199,6 +204,7 @@ def _log_band_gap_information(bs):
direct_kpoint = bs.kpoints[direct_kindex].frac_coords
direct_kpoint = kpt_str.format(k=direct_kpoint)
eq_kpoints = bs.get_equivalent_kpoints(direct_kindex)
eq_kpoints = _map_kpoints(eq_kpoints, kpt_mapping)
k_indices = ", ".join(map(str, eq_kpoints))

# add 1 to band indices to be consistent with VASP band numbers.
Expand All @@ -215,7 +221,9 @@ def _log_band_gap_information(bs):

direct_kindex = direct_data[Spin.up]["kpoint_index"]
direct_kpoint = kpt_str.format(k=bs.kpoints[direct_kindex].frac_coords)
k_indices = ", ".join(map(str, bs.get_equivalent_kpoints(direct_kindex)))
eq_kpoints = bs.get_equivalent_kpoints(direct_kindex)
eq_kpoints = _map_kpoints(eq_kpoints, kpt_mapping)
k_indices = ", ".join(map(str, eq_kpoints))
b_indices = ", ".join(
[str(i + 1) for i in direct_data[Spin.up]["band_indices"]]
)
Expand All @@ -225,14 +233,16 @@ def _log_band_gap_information(bs):
logging.info(f" Band indices: {b_indices}")


def _log_band_edge_information(bs, edge_data):
def _log_band_edge_information(bs, edge_data, kpt_mapping=None):
"""Log data about the valence band maximum or conduction band minimum.
Args:
bs (:obj:`~pymatgen.electronic_structure.bandstructure.BandStructureSymmLine`):
The band structure.
edge_data (dict): The :obj:`dict` from ``bs.get_vbm()`` or
``bs.get_cbm()``
kpt_mapping (:obj:`dict`, optional): A mapping of k-point indicies from the
band structure with forced branches to the original band structure.
"""
if bs.is_spin_polarized:
spins = edge_data["band_index"].keys()
Expand All @@ -247,7 +257,9 @@ def _log_band_edge_information(bs, edge_data):

kpoint = edge_data["kpoint"]
kpoint_str = kpt_str.format(k=kpoint.frac_coords)
k_indices = ", ".join(map(str, edge_data["kpoint_index"]))
k_indices = ", ".join(
map(str, _map_kpoints(edge_data["kpoint_index"], kpt_mapping))
)
k_degen = bs.get_kpoint_degeneracy(kpoint=kpoint.frac_coords)

if kpoint.label:
Expand Down Expand Up @@ -311,6 +323,13 @@ def _log_effective_mass_data(data, is_spin_polarized, mass_type="m_e"):
logging.info(f" {mass_type}: {eff_mass:.3f} | {band_str} | {kpoint_str}")


def _map_kpoints(kpt_idxs, kpt_mapping):
"""Map k-point indices to the original band structure."""
if not kpt_mapping:
return kpt_idxs
return sorted(set([kpt_mapping.get(k, k) for k in kpt_idxs]))


def _get_parser():
parser = argparse.ArgumentParser(
description="""
Expand Down
37 changes: 29 additions & 8 deletions sumo/electronic_structure/bandstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def get_projections(bs, selection, normalise=None):
return spec_proj


def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=True):
def get_reconstructed_band_structure(
list_bs, efermi=None, force_kpath_branches=True, return_forced_branch_kpt_map=False
):
"""Combine a list of band structures into a single band structure.
This is typically very useful when you split non self consistent
Expand All @@ -210,12 +212,17 @@ def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=
across all band structures is used.
force_kpath_branches (bool): Force a linemode band structure to contain
branches by adding repeated high-symmetry k-points in the path.
return_forced_branch_kpt_map (bool): If True, return a mapping of the
the new k-points to the original k-points.
Returns:
:obj:`pymatgen.electronic_structure.bandstructure.BandStructure` or \
:obj:`pymatgen.electronic_structure.bandstructureBandStructureSymmLine`:
A band structure object. The type depends on the type of the band
structures in ``list_bs``.
If return_forced_branch_kpt_map is True, then a tuple is returned
containing the band structure and the mapping from the new k-points
to the original k-points.
"""
if efermi is None:
efermi = sum(b.efermi for b in list_bs) / len(list_bs)
Expand Down Expand Up @@ -244,13 +251,17 @@ def get_reconstructed_band_structure(list_bs, efermi=None, force_kpath_branches=
structure=list_bs[0].structure,
projections=projections,
)
if force_kpath_branches:
return force_branches(bs)
else:
return bs
branch_bs, mapping = force_branches(bs, return_mapping=True)
if force_kpath_branches and return_forced_branch_kpt_map:
return branch_bs, mapping
elif force_kpath_branches:
return branch_bs
elif return_forced_branch_kpt_map:
return bs, mapping
return bs


def force_branches(bandstructure):
def force_branches(bandstructure, return_mapping=False):
"""Force a linemode band structure to contain branches.
Branches give a specific portion of the path from one high-symmetry point
Expand All @@ -262,9 +273,14 @@ def force_branches(bandstructure):
Args:
bandstructure: A band structure object.
return_mapping: If True, return a mapping of the new k-points (with branches)
to the original k-points.
Returns:
A band structure with brnaches.
A band structure with branches.
If return_forced_branch_kpt_map is True, then a tuple is returned
containing the band structure and the mapping from the new k-points
to the original k-points.
"""
kpoints = np.array([k.frac_coords for k in bandstructure.kpoints])
labels_dict = {k: v.frac_coords for k, v in bandstructure.labels_dict.items()}
Expand All @@ -275,6 +291,7 @@ def force_branches(bandstructure):
# already.
dup_ids = []
high_sym_kpoints = tuple(map(tuple, labels_dict.values()))
mapping = {}
for i, k in enumerate(kpoints):
dup_ids.append(i)
if (
Expand All @@ -287,6 +304,7 @@ def force_branches(bandstructure):
)
):
dup_ids.append(i)
mapping[len(dup_ids) - 1] = i

kpoints = kpoints[dup_ids]

Expand All @@ -297,7 +315,7 @@ def force_branches(bandstructure):
if len(bandstructure.projections) != 0:
projections[spin] = bandstructure.projections[spin][:, dup_ids]

return type(bandstructure)(
bs = type(bandstructure)(
kpoints,
eigenvals,
bandstructure.lattice_rec,
Expand All @@ -306,6 +324,9 @@ def force_branches(bandstructure):
structure=bandstructure.structure,
projections=projections,
)
if return_mapping:
return bs, mapping
return bs


def string_to_spin(spin_string):
Expand Down
2 changes: 1 addition & 1 deletion sumo/electronic_structure/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def load_dos(
else:
vr = vasprun

band = vr.get_band_structure()
band = vr.get_band_structure(efermi="smart")
dos = vr.complete_dos

dos, band = _scissor_dos(dos, band, scissor)
Expand Down

0 comments on commit 5c0d625

Please sign in to comment.