Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sumo-phonon-bandplot #234

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
exclude: ^tests/data/
repos:
- repo: https://github.com/myint/autoflake
rev: v1.4
rev: v2.2.1
hooks:
- id: autoflake
args: [--in-place, --remove-all-unused-imports, --remove-unused-variable, --ignore-init-module-imports]
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
rev: 7.0.0
hooks:
- id: flake8
args: [--max-line-length=125, "--extend-ignore=E203,W503,E402,F401"]
language_version: python3
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
Expand Down
2 changes: 0 additions & 2 deletions sumo/electronic_structure/bandstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def get_projections(bs, selection, normalise=None):

# store the projections for all elements and orbitals in a useable format
for spin, element, orbital in it.product(spins, elements, all_orbitals):

# convert data to [nb][nk]
el_orb_proj = [
[all_proj[spin][nb][nk][element][orbital] for nk in range(nkpts)]
Expand All @@ -164,7 +163,6 @@ def get_projections(bs, selection, normalise=None):
# now go through the selected orbitals and extract what's needed
spec_proj = []
for spec in selection:

if isinstance(spec, str):
# spec is just an element type, therefore sum all orbitals
element = spec
Expand Down
2 changes: 0 additions & 2 deletions sumo/electronic_structure/effective_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def get_fitting_data(bs, spin, band_id, kpoint_id, num_sample_points=3):
# check to see if there are enough points to sample from first
# check in the forward direction
if kpoint_id + num_sample_points <= branch_data["end_index"]:

# calculate sampling limits
start_id = kpoint_id
end_id = kpoint_id + num_sample_points + 1
Expand Down Expand Up @@ -90,7 +89,6 @@ def get_fitting_data(bs, spin, band_id, kpoint_id, num_sample_points=3):

# check in the backward direction
if kpoint_id - num_sample_points >= branch_data["start_index"]:

# calculate sampling limits
start_id = kpoint_id - num_sample_points
end_id = kpoint_id + 1
Expand Down
1 change: 0 additions & 1 deletion sumo/io/castep.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def to_file(self, filename):

@classmethod
def from_file(cls, filename):

with zopen(filename, "rt") as f:
lines = [line.strip() for line in f]

Expand Down
1 change: 0 additions & 1 deletion sumo/io/questaal.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ def _get_structure_from_lattice(self):
def to_file(self, filename):
"""Write QuestaalInit object to init file"""
with open(filename, "w") as f:

f.write("LATTICE\n")
for key, value in self.lattice.items():
if key == "PLAT":
Expand Down
2 changes: 0 additions & 2 deletions sumo/plotting/bs_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def get_projected_plot(

# nd is branch index
for spin, nd in it.product(spins, range(nbranches)):

# mask data to reduce plotting load
bands = np.array(data["energy"][str(spin)][nd])
mask = np.where(
Expand Down Expand Up @@ -597,7 +596,6 @@ def get_projected_plot(
weights[weights < 0] = 0

if mode == "rgb":

# colours aren't used now but needed later for legend
colours = [color1, color2, color3]

Expand Down
5 changes: 2 additions & 3 deletions sumo/plotting/phonon_bs_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _plot_lines(data, ax, color=None, alpha=1, zorder=1):

# nd is branch index, nb is band index, nk is kpoint index
for nd, nb in itertools.product(
range(len(data["distances"])), range(self._nb_bands)
range(len(data["distances"])), range(self.n_bands)
):
f = freqs[nd][nb]

Expand All @@ -176,7 +176,7 @@ def _plot_lines(data, ax, color=None, alpha=1, zorder=1):
# raise Exception(bs.qpoints)
json_plotter = PhononBSPlotter(bs)
json_data = json_plotter.bs_plot_data()
if json_plotter._nb_bands != self._nb_bands:
if json_plotter.n_bands != self.n_bands:
raise Exception(
f"Number of bands in {bs_json} does not match main plot"
)
Expand Down Expand Up @@ -246,7 +246,6 @@ def _makeplot(
if dos is not None:
self._plot_phonon_dos(dos, ax=fig.axes[1], color=color, dashline=dashline)
else:

# keep correct aspect ratio; match axis to canvas
x0, x1 = ax.get_xlim()
y0, y1 = ax.get_ylim()
Expand Down
2 changes: 1 addition & 1 deletion sumo/symmetry/seekpath_kpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def kpath_from_seekpath(cls, seekpath, point_coords):
# convert from seekpath format e.g. [(l1, l2), (l2, l3), (l4, l5)]
# to our preferred representation [[l1, l2, l3], [l4, l5]]
path = [[seekpath[0][0]]]
for (k1, k2) in seekpath:
for k1, k2 in seekpath:
if path[-1] and path[-1][-1] == k1:
path[-1].append(k2)
else:
Expand Down
2 changes: 0 additions & 2 deletions tests/tests_plotting/test_band_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_sanitise_label(self):
("@X", None),
("@HEX", None),
):

self.assertEqual(SBSPlotter._sanitise_label(label_in), label_out)

def test_sanitise_label_group(self):
Expand All @@ -45,5 +44,4 @@ def test_sanitise_label_group(self):
(r"X@$\mid$@Y", r"X"),
(r"@X@$\mid$@Y", None),
):

self.assertEqual(SBSPlotter._sanitise_label_group(label_in), label_out)