-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
Copy pathcommon.py
297 lines (252 loc) · 11 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# -*- coding: utf-8 -*-
# Copyright 2017, IBM.
#
# This source code is licensed under the Apache License, Version 2.0 found in
# the LICENSE.txt file in the root directory of this source tree.
"""Shared functionality and helpers for the unit tests."""
from enum import Enum
import functools
import inspect
import logging
import os
import unittest
from unittest.util import safe_repr
from qiskit import __path__ as qiskit_path
from qiskit.wrapper.defaultqiskitprovider import DefaultQISKitProvider
class Path(Enum):
"""Helper with paths commonly used during the tests."""
# Main SDK path: qiskit/
SDK = qiskit_path[0]
# test.python path: qiskit/test/python/
TEST = os.path.dirname(__file__)
# Examples path: examples/
EXAMPLES = os.path.join(SDK, '../examples')
# Schemas path: qiskit/schemas
SCHEMAS = os.path.join(SDK, 'schemas')
class QiskitTestCase(unittest.TestCase):
"""Helper class that contains common functionality."""
@classmethod
def setUpClass(cls):
cls.moduleName = os.path.splitext(inspect.getfile(cls))[0]
cls.log = logging.getLogger(cls.__name__)
# Set logging to file and stdout if the LOG_LEVEL environment variable
# is set.
if os.getenv('LOG_LEVEL'):
# Set up formatter.
log_fmt = ('{}.%(funcName)s:%(levelname)s:%(asctime)s:'
' %(message)s'.format(cls.__name__))
formatter = logging.Formatter(log_fmt)
# Set up the file handler.
log_file_name = '%s.log' % cls.moduleName
file_handler = logging.FileHandler(log_file_name)
file_handler.setFormatter(formatter)
cls.log.addHandler(file_handler)
# Set the logging level from the environment variable, defaulting
# to INFO if it is not a valid level.
level = logging._nameToLevel.get(os.getenv('LOG_LEVEL'),
logging.INFO)
cls.log.setLevel(level)
def tearDown(self):
# Reset the default provider, as in practice it acts as a singleton
# due to importing the wrapper from qiskit.
from qiskit.wrapper import _wrapper
_wrapper._DEFAULT_PROVIDER = DefaultQISKitProvider()
@staticmethod
def _get_resource_path(filename, path=Path.TEST):
""" Get the absolute path to a resource.
Args:
filename (string): filename or relative path to the resource.
path (Path): path used as relative to the filename.
Returns:
str: the absolute path to the resource.
"""
return os.path.normpath(os.path.join(path.value, filename))
def assertNoLogs(self, logger=None, level=None):
"""
Context manager to test that no message is sent to the specified
logger and level (the opposite of TestCase.assertLogs()).
"""
# pylint: disable=invalid-name
return _AssertNoLogsContext(self, logger, level)
def assertDictAlmostEqual(self, dict1, dict2, delta=None, msg=None,
places=None, default_value=0):
"""
Assert two dictionaries with numeric values are almost equal.
Fail if the two dictionaries are unequal as determined by
comparing that the difference between values with the same key are
not greater than delta (default 1e-8), or that difference rounded
to the given number of decimal places is not zero. If a key in one
dictionary is not in the other the default_value keyword argument
will be used for the missing value (default 0). If the two objects
compare equal then they will automatically compare almost equal.
Args:
dict1 (dict): a dictionary.
dict2 (dict): a dictionary.
delta (number): threshold for comparison (defaults to 1e-8).
msg (str): return a custom message on failure.
places (int): number of decimal places for comparison.
default_value (number): default value for missing keys.
Raises:
TypeError: raises TestCase failureException if the test fails.
"""
# pylint: disable=invalid-name
if dict1 == dict2:
# Shortcut
return
if delta is not None and places is not None:
raise TypeError("specify delta or places not both")
if places is not None:
success = True
standard_msg = ''
# check value for keys in target
keys1 = set(dict1.keys())
for key in keys1:
val1 = dict1.get(key, default_value)
val2 = dict2.get(key, default_value)
if round(abs(val1 - val2), places) != 0:
success = False
standard_msg += '(%s: %s != %s), ' % (safe_repr(key),
safe_repr(val1),
safe_repr(val2))
# check values for keys in counts, not in target
keys2 = set(dict2.keys()) - keys1
for key in keys2:
val1 = dict1.get(key, default_value)
val2 = dict2.get(key, default_value)
if round(abs(val1 - val2), places) != 0:
success = False
standard_msg += '(%s: %s != %s), ' % (safe_repr(key),
safe_repr(val1),
safe_repr(val2))
if success is True:
return
standard_msg = standard_msg[:-2] + ' within %s places' % places
else:
if delta is None:
delta = 1e-8 # default delta value
success = True
standard_msg = ''
# check value for keys in target
keys1 = set(dict1.keys())
for key in keys1:
val1 = dict1.get(key, default_value)
val2 = dict2.get(key, default_value)
if abs(val1 - val2) > delta:
success = False
standard_msg += '(%s: %s != %s), ' % (safe_repr(key),
safe_repr(val1),
safe_repr(val2))
# check values for keys in counts, not in target
keys2 = set(dict2.keys()) - keys1
for key in keys2:
val1 = dict1.get(key, default_value)
val2 = dict2.get(key, default_value)
if abs(val1 - val2) > delta:
success = False
standard_msg += '(%s: %s != %s), ' % (safe_repr(key),
safe_repr(val1),
safe_repr(val2))
if success is True:
return
standard_msg = standard_msg[:-2] + ' within %s delta' % delta
msg = self._formatMessage(msg, standard_msg)
raise self.failureException(msg)
class _AssertNoLogsContext(unittest.case._AssertLogsContext):
"""A context manager used to implement TestCase.assertNoLogs()."""
# pylint: disable=inconsistent-return-statements
def __exit__(self, exc_type, exc_value, tb):
"""
This is a modified version of TestCase._AssertLogsContext.__exit__(...)
"""
self.logger.handlers = self.old_handlers
self.logger.propagate = self.old_propagate
self.logger.setLevel(self.old_level)
if exc_type is not None:
# let unexpected exceptions pass through
return False
if self.watcher.records:
msg = 'logs of level {} or higher triggered on {}:\n'.format(
logging.getLevelName(self.level), self.logger.name)
for record in self.watcher.records:
msg += 'logger %s %s:%i: %s\n' % (record.name, record.pathname,
record.lineno,
record.getMessage())
self._raiseFailure(msg)
def slow_test(func):
"""
Decorator that signals that the test takes minutes to run.
Args:
func (callable): test function to be decorated.
Returns:
callable: the decorated function.
"""
@functools.wraps(func)
def _(*args, **kwargs):
if SKIP_SLOW_TESTS:
raise unittest.SkipTest('Skipping slow tests')
return func(*args, **kwargs)
return _
def requires_qe_access(func):
"""
Decorator that signals that the test uses the online API:
* determines if the test should be skipped by checking environment
variables.
* if the test is not skipped, it reads `QE_TOKEN` and `QE_URL` from
`Qconfig.py` or from environment variables.
* if the test is not skipped, it appends `QE_TOKEN` and `QE_URL` as
arguments to the test function.
Args:
func (callable): test function to be decorated.
Returns:
callable: the decorated function.
"""
@functools.wraps(func)
def _(*args, **kwargs):
# pylint: disable=invalid-name
if SKIP_ONLINE_TESTS:
raise unittest.SkipTest('Skipping online tests')
# Try to read the variables from Qconfig.
try:
import Qconfig
QE_TOKEN = Qconfig.APItoken
QE_URL = Qconfig.config['url']
QE_HUB = Qconfig.config.get('hub')
QE_GROUP = Qconfig.config.get('group')
QE_PROJECT = Qconfig.config.get('project')
except ImportError:
# Try to read them from environment variables (ie. Travis).
QE_TOKEN = os.getenv('QE_TOKEN')
QE_URL = os.getenv('QE_URL')
QE_HUB = os.getenv('QE_HUB')
QE_GROUP = os.getenv('QE_GROUP')
QE_PROJECT = os.getenv('QE_PROJECT')
if not QE_TOKEN or not QE_URL:
raise Exception(
'Could not locate a valid "Qconfig.py" file nor read the QE '
'values from the environment')
kwargs['QE_TOKEN'] = QE_TOKEN
kwargs['QE_URL'] = QE_URL
kwargs['hub'] = QE_HUB
kwargs['group'] = QE_GROUP
kwargs['project'] = QE_PROJECT
return func(*args, **kwargs)
return _
def _is_ci_fork_pull_request():
"""
Check if the tests are being run in a CI environment and if it is a pull
request.
Returns:
bool: True if the tests are executed inside a CI tool, and the changes
are not against the "master" branch.
"""
if os.getenv('TRAVIS'):
# Using Travis CI.
if os.getenv('TRAVIS_PULL_REQUEST_BRANCH'):
return True
elif os.getenv('APPVEYOR'):
# Using AppVeyor CI.
if os.getenv('APPVEYOR_PULL_REQUEST_NUMBER'):
return True
return False
SKIP_ONLINE_TESTS = os.getenv('SKIP_ONLINE_TESTS', _is_ci_fork_pull_request())
SKIP_SLOW_TESTS = os.getenv('SKIP_SLOW_TESTS', True) not in ['false', 'False', '-1']