-
Notifications
You must be signed in to change notification settings - Fork 554
/
Copy pathconftest.py
228 lines (195 loc) · 8.87 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from typing import List, Optional
import common # TODO(zongheng): for some reason isort places it here.
import pytest
# Usage: use
# @pytest.mark.slow
# to mark a test as slow and to skip by default.
# https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option
# By default, only run generic tests and cloud-specific tests for AWS and Azure,
# due to the cloud credit limit for the development account.
#
# A "generic test" tests a generic functionality (e.g., autostop) that
# should work on any cloud we support. The cloud used for such a test
# is controlled by `--generic-cloud` (typically you do not need to set it).
#
# To only run tests for a specific cloud (as well as generic tests), use
# --aws, --gcp, --azure, or --lambda.
#
# To only run tests for managed jobs (without generic tests), use
# --managed-jobs.
all_clouds_in_smoke_tests = [
'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci', 'do',
'kubernetes', 'vsphere', 'cudo', 'fluidstack', 'paperspace', 'runpod'
]
default_clouds_to_run = ['aws', 'azure']
# Translate cloud name to pytest keyword. We need this because
# @pytest.mark.lambda is not allowed, so we use @pytest.mark.lambda_cloud
# instead.
cloud_to_pytest_keyword = {
'aws': 'aws',
'gcp': 'gcp',
'azure': 'azure',
'lambda': 'lambda_cloud',
'cloudflare': 'cloudflare',
'ibm': 'ibm',
'scp': 'scp',
'oci': 'oci',
'kubernetes': 'kubernetes',
'vsphere': 'vsphere',
'fluidstack': 'fluidstack',
'cudo': 'cudo',
'paperspace': 'paperspace',
'do': 'do',
'runpod': 'runpod'
}
def pytest_addoption(parser):
# tests marked as `slow` will be skipped by default, use --runslow to run
parser.addoption('--runslow',
action='store_true',
default=False,
help='run slow tests.')
for cloud in all_clouds_in_smoke_tests:
parser.addoption(f'--{cloud}',
action='store_true',
default=False,
help=f'Only run {cloud.upper()} tests.')
parser.addoption('--managed-jobs',
action='store_true',
default=False,
help='Only run tests for managed jobs.')
parser.addoption('--serve',
action='store_true',
default=False,
help='Only run tests for sky serve.')
parser.addoption('--tpu',
action='store_true',
default=False,
help='Only run tests for TPU.')
parser.addoption(
'--generic-cloud',
type=str,
choices=all_clouds_in_smoke_tests,
help='Cloud to use for generic tests. If the generic cloud is '
'not within the clouds to be run, it will be reset to the first '
'cloud in the list of the clouds to be run.')
parser.addoption('--terminate-on-failure',
dest='terminate_on_failure',
action='store_true',
default=True,
help='Terminate test VMs on failure.')
parser.addoption('--no-terminate-on-failure',
dest='terminate_on_failure',
action='store_false',
help='Do not terminate test VMs on failure.')
def pytest_configure(config):
config.addinivalue_line('markers', 'slow: mark test as slow to run')
for cloud in all_clouds_in_smoke_tests:
cloud_keyword = cloud_to_pytest_keyword[cloud]
config.addinivalue_line(
'markers', f'{cloud_keyword}: mark test as {cloud} specific')
pytest.terminate_on_failure = config.getoption('--terminate-on-failure')
def _get_cloud_to_run(config) -> List[str]:
cloud_to_run = []
for cloud in all_clouds_in_smoke_tests:
if config.getoption(f'--{cloud}'):
if cloud == 'cloudflare':
cloud_to_run.append(default_clouds_to_run[0])
else:
cloud_to_run.append(cloud)
generic_cloud_option = config.getoption('--generic-cloud')
if generic_cloud_option is not None and generic_cloud_option not in cloud_to_run:
cloud_to_run.append(generic_cloud_option)
if len(cloud_to_run) == 0:
cloud_to_run = default_clouds_to_run
return cloud_to_run
def pytest_collection_modifyitems(config, items):
skip_marks = {}
skip_marks['slow'] = pytest.mark.skip(reason='need --runslow option to run')
skip_marks['managed_jobs'] = pytest.mark.skip(
reason='skipped, because --managed-jobs option is set')
skip_marks['serve'] = pytest.mark.skip(
reason='skipped, because --serve option is set')
skip_marks['tpu'] = pytest.mark.skip(
reason='skipped, because --tpu option is set')
for cloud in all_clouds_in_smoke_tests:
skip_marks[cloud] = pytest.mark.skip(
reason=f'tests for {cloud} is skipped, try setting --{cloud}')
cloud_to_run = _get_cloud_to_run(config)
generic_cloud = _generic_cloud(config)
generic_cloud_keyword = cloud_to_pytest_keyword[generic_cloud]
for item in items:
if 'slow' in item.keywords and not config.getoption('--runslow'):
item.add_marker(skip_marks['slow'])
if _is_generic_test(
item) and f'no_{generic_cloud_keyword}' in item.keywords:
item.add_marker(skip_marks[generic_cloud])
for cloud in all_clouds_in_smoke_tests:
cloud_keyword = cloud_to_pytest_keyword[cloud]
if (cloud_keyword in item.keywords and cloud not in cloud_to_run):
# Need to check both conditions as the first default cloud is
# added to cloud_to_run when tested for cloudflare
if config.getoption('--cloudflare') and cloud == 'cloudflare':
continue
item.add_marker(skip_marks[cloud])
if (not 'managed_jobs'
in item.keywords) and config.getoption('--managed-jobs'):
item.add_marker(skip_marks['managed_jobs'])
if (not 'tpu' in item.keywords) and config.getoption('--tpu'):
item.add_marker(skip_marks['tpu'])
if (not 'serve' in item.keywords) and config.getoption('--serve'):
item.add_marker(skip_marks['serve'])
# Check if tests need to be run serially for Kubernetes and Lambda Cloud
# We run Lambda Cloud tests serially because Lambda Cloud rate limits its
# launch API to one launch every 10 seconds.
# We run Kubernetes tests serially because the Kubernetes cluster may have
# limited resources (e.g., just 8 cpus).
serial_mark = pytest.mark.xdist_group(
name=f'serial_{generic_cloud_keyword}')
# Handle generic tests
if generic_cloud in ['lambda']:
for item in items:
if (_is_generic_test(item) and
f'no_{generic_cloud_keyword}' not in item.keywords):
item.add_marker(serial_mark)
# Adding the serial mark does not update the item.nodeid,
# but item.nodeid is important for pytest.xdist_group, e.g.
# https://github.com/pytest-dev/pytest-xdist/blob/master/src/xdist/scheduler/loadgroup.py
# This is a hack to update item.nodeid
item._nodeid = f'{item.nodeid}@serial_{generic_cloud_keyword}'
# Handle generic cloud specific tests
for item in items:
if generic_cloud in ['lambda', 'kubernetes']:
if generic_cloud_keyword in item.keywords:
item.add_marker(serial_mark)
item._nodeid = f'{item.nodeid}@serial_{generic_cloud_keyword}' # See comment on item.nodeid above
if config.option.collectonly:
for item in items:
full_name = item.nodeid
marks = [mark.name for mark in item.iter_markers()]
print(f"Collected {full_name} with marks: {marks}")
def _is_generic_test(item) -> bool:
for cloud in all_clouds_in_smoke_tests:
if cloud_to_pytest_keyword[cloud] in item.keywords:
return False
return True
def _generic_cloud(config) -> str:
generic_cloud_option = config.getoption('--generic-cloud')
if generic_cloud_option is not None:
return generic_cloud_option
return _get_cloud_to_run(config)[0]
@pytest.fixture
def generic_cloud(request) -> str:
return _generic_cloud(request.config)
@pytest.fixture
def enable_all_clouds(monkeypatch: pytest.MonkeyPatch) -> None:
common.enable_all_clouds_in_monkeypatch(monkeypatch)
@pytest.fixture
def aws_config_region(monkeypatch: pytest.MonkeyPatch) -> str:
from sky import skypilot_config
region = 'us-east-2'
if skypilot_config.loaded():
ssh_proxy_command = skypilot_config.get_nested(
('aws', 'ssh_proxy_command'), None)
if isinstance(ssh_proxy_command, dict) and ssh_proxy_command:
region = list(ssh_proxy_command.keys())[0]
return region