-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path_inline_calcs.py
90 lines (77 loc) · 3.07 KB
/
_inline_calcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
# © 2017-2019, ETH Zurich, Institut für Theoretische Physik
# Author: Dominik Gresch <[email protected]>
"""
Defines helper InlineCalculations for the first-principles workflows.
"""
from past.builtins import basestring # pylint: disable=redefined-builtin,useless-suppression
import numpy as np
from aiida.orm import DataFactory
from aiida.orm.data.parameter import ParameterData
from aiida.orm.calculation.inline import make_inline
@make_inline
def merge_kpoints_inline(mesh_kpoints, band_kpoints):
"""
Merges the kpoints of mesh_kpoints and band_kpoints (in that order), giving weight 1 to the mesh_kpoints, and weight 0 to the band_kpoints.
"""
band_kpoints_array = band_kpoints.get_kpoints()
mesh_kpoints_array = mesh_kpoints.get_kpoints_mesh(print_list=True)
weights = [1.] * len(mesh_kpoints_array) + [0.] * len(band_kpoints_array)
kpoints = DataFactory('array.kpoints')()
kpoints.set_kpoints(
np.vstack([mesh_kpoints_array, band_kpoints_array]), weights=weights
)
return {'kpoints': kpoints}
@make_inline
def flatten_bands_inline(bands):
"""
Flatten the bands such that they have dimension 2.
"""
flattened_bands = bands.clone()
bands_array = bands.get_bands()
flattened_bands.set_bands(bands_array.reshape(bands_array.shape[-2:]))
return {'bands': flattened_bands}
@make_inline
def crop_bands_inline(bands, kpoints):
"""
Crop a BandsData to the given kpoints by removing from the front.
"""
# check consistency of kpoints
kpoints_array = kpoints.get_kpoints()
band_slice = slice(-len(kpoints_array), None)
cropped_bands_kpoints = bands.get_kpoints()[band_slice]
assert np.allclose(cropped_bands_kpoints, kpoints_array)
cropped_bands = DataFactory('array.bands')()
cropped_bands.set_kpointsdata(kpoints)
cropped_bands_array = bands.get_bands()[band_slice]
cropped_bands.set_bands(cropped_bands_array)
return {'bands': cropped_bands}
@make_inline
def reduce_num_wann_inline(wannier_parameters):
"""
Reduces the ``num_wann`` in a Wannier90 input by the number of bands
in its ``exclude_bands`` parameter.
"""
wannier_param_dict = wannier_parameters.get_dict()
if 'exclude_bands' in wannier_param_dict and 'num_bands' in wannier_param_dict:
exclude_bands_val = wannier_param_dict['exclude_bands']
if not isinstance(exclude_bands_val, basestring):
raise ValueError(
"Invalid value for 'exclude_bands': '{}'".
format(exclude_bands_val)
)
num_excluded = 0
for part in exclude_bands_val.split(','):
if '-' in part:
lower, upper = [int(x) for x in part.split('-')]
diff = (upper - lower) + 1
assert diff > 0
num_excluded += diff
else:
num_excluded += 1
wannier_param_dict['num_bands'] = int(
wannier_param_dict['num_bands']
) - num_excluded
return ParameterData(dict=wannier_param_dict)
else:
return wannier_parameters