Skip to content

Commit

Permalink
Merge pull request #234 from SMTG-Bham/fix-phonon-bs
Browse files Browse the repository at this point in the history
Fix sumo-phonon-bandplot
  • Loading branch information
utf authored Jan 10, 2024
2 parents 0d7361f + bd9c52f commit b7fe5af
Show file tree
Hide file tree
Showing 9 changed files with 8 additions and 19 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
exclude: ^tests/data/
repos:
- repo: https://github.com/myint/autoflake
rev: v1.4
rev: v2.2.1
hooks:
- id: autoflake
args: [--in-place, --remove-all-unused-imports, --remove-unused-variable, --ignore-init-module-imports]
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
rev: 7.0.0
hooks:
- id: flake8
args: [--max-line-length=125, "--extend-ignore=E203,W503,E402,F401"]
language_version: python3
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
Expand Down
2 changes: 0 additions & 2 deletions sumo/electronic_structure/bandstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def get_projections(bs, selection, normalise=None):

# store the projections for all elements and orbitals in a useable format
for spin, element, orbital in it.product(spins, elements, all_orbitals):

# convert data to [nb][nk]
el_orb_proj = [
[all_proj[spin][nb][nk][element][orbital] for nk in range(nkpts)]
Expand All @@ -164,7 +163,6 @@ def get_projections(bs, selection, normalise=None):
# now go through the selected orbitals and extract what's needed
spec_proj = []
for spec in selection:

if isinstance(spec, str):
# spec is just an element type, therefore sum all orbitals
element = spec
Expand Down
2 changes: 0 additions & 2 deletions sumo/electronic_structure/effective_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def get_fitting_data(bs, spin, band_id, kpoint_id, num_sample_points=3):
# check to see if there are enough points to sample from first
# check in the forward direction
if kpoint_id + num_sample_points <= branch_data["end_index"]:

# calculate sampling limits
start_id = kpoint_id
end_id = kpoint_id + num_sample_points + 1
Expand Down Expand Up @@ -90,7 +89,6 @@ def get_fitting_data(bs, spin, band_id, kpoint_id, num_sample_points=3):

# check in the backward direction
if kpoint_id - num_sample_points >= branch_data["start_index"]:

# calculate sampling limits
start_id = kpoint_id - num_sample_points
end_id = kpoint_id + 1
Expand Down
1 change: 0 additions & 1 deletion sumo/io/castep.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def to_file(self, filename):

@classmethod
def from_file(cls, filename):

with zopen(filename, "rt") as f:
lines = [line.strip() for line in f]

Expand Down
1 change: 0 additions & 1 deletion sumo/io/questaal.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ def _get_structure_from_lattice(self):
def to_file(self, filename):
"""Write QuestaalInit object to init file"""
with open(filename, "w") as f:

f.write("LATTICE\n")
for key, value in self.lattice.items():
if key == "PLAT":
Expand Down
2 changes: 0 additions & 2 deletions sumo/plotting/bs_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def get_projected_plot(

# nd is branch index
for spin, nd in it.product(spins, range(nbranches)):

# mask data to reduce plotting load
bands = np.array(data["energy"][str(spin)][nd])
mask = np.where(
Expand Down Expand Up @@ -597,7 +596,6 @@ def get_projected_plot(
weights[weights < 0] = 0

if mode == "rgb":

# colours aren't used now but needed later for legend
colours = [color1, color2, color3]

Expand Down
5 changes: 2 additions & 3 deletions sumo/plotting/phonon_bs_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _plot_lines(data, ax, color=None, alpha=1, zorder=1):

# nd is branch index, nb is band index, nk is kpoint index
for nd, nb in itertools.product(
range(len(data["distances"])), range(self._nb_bands)
range(len(data["distances"])), range(self.n_bands)
):
f = freqs[nd][nb]

Expand All @@ -176,7 +176,7 @@ def _plot_lines(data, ax, color=None, alpha=1, zorder=1):
# raise Exception(bs.qpoints)
json_plotter = PhononBSPlotter(bs)
json_data = json_plotter.bs_plot_data()
if json_plotter._nb_bands != self._nb_bands:
if json_plotter.n_bands != self.n_bands:
raise Exception(
f"Number of bands in {bs_json} does not match main plot"
)
Expand Down Expand Up @@ -246,7 +246,6 @@ def _makeplot(
if dos is not None:
self._plot_phonon_dos(dos, ax=fig.axes[1], color=color, dashline=dashline)
else:

# keep correct aspect ratio; match axis to canvas
x0, x1 = ax.get_xlim()
y0, y1 = ax.get_ylim()
Expand Down
2 changes: 1 addition & 1 deletion sumo/symmetry/seekpath_kpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def kpath_from_seekpath(cls, seekpath, point_coords):
# convert from seekpath format e.g. [(l1, l2), (l2, l3), (l4, l5)]
# to our preferred representation [[l1, l2, l3], [l4, l5]]
path = [[seekpath[0][0]]]
for (k1, k2) in seekpath:
for k1, k2 in seekpath:
if path[-1] and path[-1][-1] == k1:
path[-1].append(k2)
else:
Expand Down
2 changes: 0 additions & 2 deletions tests/tests_plotting/test_band_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_sanitise_label(self):
("@X", None),
("@HEX", None),
):

self.assertEqual(SBSPlotter._sanitise_label(label_in), label_out)

def test_sanitise_label_group(self):
Expand All @@ -45,5 +44,4 @@ def test_sanitise_label_group(self):
(r"X@$\mid$@Y", r"X"),
(r"@X@$\mid$@Y", None),
):

self.assertEqual(SBSPlotter._sanitise_label_group(label_in), label_out)

0 comments on commit b7fe5af

Please sign in to comment.