diff --git a/pandas/util/testing.py b/pandas/util/testing.py index ba869efbc5837..ebd1f7d2c17f8 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -1,9 +1,6 @@ from __future__ import division # pylint: disable-msg=W0402 -# flake8: noqa - -import random import re import string import sys @@ -25,10 +22,10 @@ import numpy as np import pandas as pd -from pandas.core.common import (is_sequence, array_equivalent, is_list_like, is_number, - is_datetimelike_v_numeric, is_datetimelike_v_object, - is_number, pprint_thing, take_1d, - needs_i8_conversion) +from pandas.core.common import (is_sequence, array_equivalent, + is_list_like, is_datetimelike_v_numeric, + is_datetimelike_v_object, is_number, + pprint_thing, take_1d, needs_i8_conversion) import pandas.compat as compat import pandas.lib as lib @@ -50,21 +47,24 @@ K = 4 _RAISE_NETWORK_ERROR_DEFAULT = False + # set testing_mode def set_testing_mode(): # set the testing mode filters - testing_mode = os.environ.get('PANDAS_TESTING_MODE','None') + testing_mode = os.environ.get('PANDAS_TESTING_MODE', 'None') if 'deprecate' in testing_mode: warnings.simplefilter('always', DeprecationWarning) + def reset_testing_mode(): # reset the testing mode filters - testing_mode = os.environ.get('PANDAS_TESTING_MODE','None') + testing_mode = os.environ.get('PANDAS_TESTING_MODE', 'None') if 'deprecate' in testing_mode: warnings.simplefilter('ignore', DeprecationWarning) set_testing_mode() + class TestCase(unittest.TestCase): @classmethod @@ -88,19 +88,24 @@ def round_trip_pickle(self, obj, path=None): # https://docs.python.org/3/library/unittest.html#deprecated-aliases def assertEquals(self, *args, **kwargs): - return deprecate('assertEquals', self.assertEqual)(*args, **kwargs) + return deprecate('assertEquals', + self.assertEqual)(*args, **kwargs) def assertNotEquals(self, *args, **kwargs): - return deprecate('assertNotEquals', self.assertNotEqual)(*args, **kwargs) + return deprecate('assertNotEquals', + self.assertNotEqual)(*args, **kwargs) def assert_(self, *args, **kwargs): - return deprecate('assert_', self.assertTrue)(*args, **kwargs) + return deprecate('assert_', + self.assertTrue)(*args, **kwargs) def assertAlmostEquals(self, *args, **kwargs): - return deprecate('assertAlmostEquals', self.assertAlmostEqual)(*args, **kwargs) + return deprecate('assertAlmostEquals', + self.assertAlmostEqual)(*args, **kwargs) def assertNotAlmostEquals(self, *args, **kwargs): - return deprecate('assertNotAlmostEquals', self.assertNotAlmostEqual)(*args, **kwargs) + return deprecate('assertNotAlmostEquals', + self.assertNotAlmostEqual)(*args, **kwargs) def assert_almost_equal(left, right, check_exact=False, **kwargs): @@ -121,6 +126,7 @@ def assert_almost_equal(left, right, check_exact=False, **kwargs): assert_dict_equal = _testing.assert_dict_equal + def randbool(size=(), p=0.5): return rand(*size) <= p @@ -168,7 +174,7 @@ def randu(nchars): See `randu_array` if you want to create an array of random unicode strings. """ - return ''.join(choice(RANDU_CHARS, nchars)) + return ''.join(np.random.choice(RANDU_CHARS, nchars)) def close(fignum=None): @@ -186,6 +192,7 @@ def _skip_if_32bit(): if is_platform_32bit(): raise nose.SkipTest("skipping for 32 bit") + def mplskip(cls): """Skip a TestCase instance if matplotlib isn't installed""" @@ -201,13 +208,15 @@ def setUpClass(cls): cls.setUpClass = setUpClass return cls + def _skip_if_no_mpl(): try: - import matplotlib + import matplotlib # noqa except ImportError: import nose raise nose.SkipTest("matplotlib not installed") + def _skip_if_mpl_1_5(): import matplotlib v = matplotlib.__version__ @@ -215,18 +224,20 @@ def _skip_if_mpl_1_5(): import nose raise nose.SkipTest("matplotlib 1.5") + def _skip_if_no_scipy(): try: - import scipy.stats + import scipy.stats # noqa except ImportError: import nose raise nose.SkipTest("no scipy.stats module") try: - import scipy.interpolate + import scipy.interpolate # noqa except ImportError: import nose raise nose.SkipTest('scipy.interpolate missing') + def _skip_if_scipy_0_17(): import scipy v = scipy.__version__ @@ -234,6 +245,7 @@ def _skip_if_scipy_0_17(): import nose raise nose.SkipTest("scipy 0.17") + def _skip_if_no_xarray(): try: import xarray @@ -246,9 +258,10 @@ def _skip_if_no_xarray(): import nose raise nose.SkipTest("xarray not version is too low: {0}".format(v)) + def _skip_if_no_pytz(): try: - import pytz + import pytz # noqa except ImportError: import nose raise nose.SkipTest("pytz not installed") @@ -256,7 +269,7 @@ def _skip_if_no_pytz(): def _skip_if_no_dateutil(): try: - import dateutil + import dateutil # noqa except ImportError: import nose raise nose.SkipTest("dateutil not installed") @@ -267,14 +280,16 @@ def _skip_if_windows_python_3(): import nose raise nose.SkipTest("not used on python 3/win32") + def _skip_if_windows(): if is_platform_windows(): import nose raise nose.SkipTest("Running on Windows") + def _skip_if_no_pathlib(): try: - from pathlib import Path + from pathlib import Path # noqa except ImportError: import nose raise nose.SkipTest("pathlib not available") @@ -282,7 +297,7 @@ def _skip_if_no_pathlib(): def _skip_if_no_localpath(): try: - from py.path import local as LocalPath + from py.path import local as LocalPath # noqa except ImportError: import nose raise nose.SkipTest("py.path not installed") @@ -294,7 +309,7 @@ def _incompat_bottleneck_version(method): as we don't match the nansum/nanprod behavior for all-nan ops, see GH9422 """ - if method not in ['sum','prod']: + if method not in ['sum', 'prod']: return False try: import bottleneck as bn @@ -302,6 +317,7 @@ def _incompat_bottleneck_version(method): except ImportError: return False + def skip_if_no_ne(engine='numexpr'): import nose _USE_NUMEXPR = pd.computation.expressions._USE_NUMEXPR @@ -354,7 +370,7 @@ def check_output(*popenargs, **kwargs): """ if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE,stderr=subprocess.PIPE, + process = subprocess.Popen(stdout=subprocess.PIPE, stderr=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() @@ -410,14 +426,15 @@ def get_locales(prefix=None, normalize=True, return None try: - # raw_locales is "\n" seperated list of locales + # raw_locales is "\n" separated list of locales # it may contain non-decodable parts, so split # extract what we can and then rejoin. raw_locales = raw_locales.split(b'\n') out_locales = [] for x in raw_locales: if compat.PY3: - out_locales.append(str(x, encoding=pd.options.display.encoding)) + out_locales.append(str( + x, encoding=pd.options.display.encoding)) else: out_locales.append(str(x)) @@ -512,7 +529,7 @@ def _valid_locales(locales, normalize): return list(filter(_can_set_locale, map(normalizer, locales))) -#------------------------------------------------------------------------------ +# ----------------------------------------------------------------------------- # Console debugging tools def debug(f, *args, **kwargs): @@ -540,7 +557,7 @@ def set_trace(): from pdb import Pdb as OldPdb OldPdb().set_trace(sys._getframe().f_back) -#------------------------------------------------------------------------------ +# ----------------------------------------------------------------------------- # contextmanager to ensure the file cleanup @@ -584,13 +601,14 @@ def ensure_clean(filename=None, return_filelike=False): os.close(fd) except Exception as e: print("Couldn't close file descriptor: %d (file: %s)" % - (fd, filename)) + (fd, filename)) try: if os.path.exists(filename): os.remove(filename) except Exception as e: print("Exception on removing file: %s" % e) + def get_data_path(f=''): """Return the path of a data file, these are relative to the current test directory. @@ -600,7 +618,7 @@ def get_data_path(f=''): base_dir = os.path.abspath(os.path.dirname(filename)) return os.path.join(base_dir, 'data', f) -#------------------------------------------------------------------------------ +# ----------------------------------------------------------------------------- # Comparators @@ -611,8 +629,10 @@ def equalContents(arr1, arr2): def assert_equal(a, b, msg=""): - """asserts that a equals b, like nose's assert_equal, but allows custom message to start. - Passes a and b to format string as well. So you can use '{0}' and '{1}' to display a and b. + """asserts that a equals b, like nose's assert_equal, + but allows custom message to start. Passes a and b to + format string as well. So you can use '{0}' and '{1}' + to display a and b. Examples -------- @@ -700,8 +720,8 @@ def _get_ilevel_values(index, level): # length comparison if len(left) != len(right): raise_assert_detail(obj, '{0} length are different'.format(obj), - '{0}, {1}'.format(len(left), left), - '{0}, {1}'.format(len(right), right)) + '{0}, {1}'.format(len(left), left), + '{0}, {1}'.format(len(right), right)) # MultiIndex special comparison for little-friendly error messages if left.nlevels > 1: @@ -720,8 +740,10 @@ def _get_ilevel_values(index, level): if check_exact: if not left.equals(right): - diff = np.sum((left.values != right.values).astype(int)) * 100.0 / len(left) - msg = '{0} values are different ({1} %)'.format(obj, np.round(diff, 5)) + diff = np.sum((left.values != right.values) + .astype(int)) * 100.0 / len(left) + msg = '{0} values are different ({1} %)'\ + .format(obj, np.round(diff, 5)) raise_assert_detail(obj, msg, left, right) else: _testing.assert_almost_equal(left.values, right.values, @@ -769,7 +791,7 @@ def assert_attr_equal(attr, left, right, obj='Attributes'): return True else: raise_assert_detail(obj, 'Attribute "{0}" are different'.format(attr), - left_attr, right_attr) + left_attr, right_attr) def assert_is_valid_plot_return_object(objs): @@ -790,6 +812,7 @@ def assert_is_valid_plot_return_object(objs): def isiterable(obj): return hasattr(obj, '__iter__') + def is_sorted(seq): if isinstance(seq, (Index, Series)): seq = seq.values @@ -839,8 +862,10 @@ def assertIsInstance(obj, cls, msg=''): "%sExpected object to be of type %r, found %r instead" % ( msg, cls, type(obj))) + def assert_isinstance(obj, class_type_or_tuple, msg=''): - return deprecate('assert_isinstance', assertIsInstance)(obj, class_type_or_tuple, msg=msg) + return deprecate('assert_isinstance', assertIsInstance)( + obj, class_type_or_tuple, msg=msg) def assertNotIsInstance(obj, cls, msg=''): @@ -907,8 +932,8 @@ def assert_numpy_array_equal(left, right, right = np.array(right) if left.shape != right.shape: - raise_assert_detail(obj, '{0} shapes are different'.format(obj), - left.shape, right.shape) + raise_assert_detail(obj, '{0} shapes are different' + .format(obj), left.shape, right.shape) diff = 0 for l, r in zip(left, right): @@ -917,7 +942,8 @@ def assert_numpy_array_equal(left, right, diff += 1 diff = diff * 100.0 / left.size - msg = '{0} values are different ({1} %)'.format(obj, np.round(diff, 5)) + msg = '{0} values are different ({1} %)'\ + .format(obj, np.round(diff, 5)) raise_assert_detail(obj, msg, left, right) elif is_list_like(left): msg = "First object is iterable, second isn't" @@ -982,7 +1008,8 @@ def assert_series_equal(left, right, check_dtype=True, # index comparison assert_index_equal(left.index, right.index, exact=check_index_type, check_names=check_names, - check_less_precise=check_less_precise, check_exact=check_exact, + check_less_precise=check_less_precise, + check_exact=check_exact, obj='{0}.index'.format(obj)) if check_dtype: @@ -998,7 +1025,7 @@ def assert_series_equal(left, right, check_dtype=True, if (is_datetimelike_v_numeric(left, right) or is_datetimelike_v_object(left, right) or needs_i8_conversion(left) or - needs_i8_conversion(right)): + needs_i8_conversion(right)): # datetimelike may have different objects (e.g. datetime.datetime # vs Timestamp) but will compare equal @@ -1093,13 +1120,15 @@ def assert_frame_equal(left, right, check_dtype=True, # index comparison assert_index_equal(left.index, right.index, exact=check_index_type, check_names=check_names, - check_less_precise=check_less_precise, check_exact=check_exact, + check_less_precise=check_less_precise, + check_exact=check_exact, obj='{0}.index'.format(obj)) # column comparison assert_index_equal(left.columns, right.columns, exact=check_column_type, check_names=check_names, - check_less_precise=check_less_precise, check_exact=check_exact, + check_less_precise=check_less_precise, + check_exact=check_exact, obj='{0}.columns'.format(obj)) # compare by blocks @@ -1118,14 +1147,13 @@ def assert_frame_equal(left, right, check_dtype=True, assert col in right lcol = left.iloc[:, i] rcol = right.iloc[:, i] - assert_series_equal(lcol, rcol, - check_dtype=check_dtype, - check_index_type=check_index_type, - check_less_precise=check_less_precise, - check_exact=check_exact, - check_names=check_names, - check_datetimelike_compat=check_datetimelike_compat, - obj='DataFrame.iloc[:, {0}]'.format(i)) + assert_series_equal( + lcol, rcol, check_dtype=check_dtype, + check_index_type=check_index_type, + check_less_precise=check_less_precise, + check_exact=check_exact, check_names=check_names, + check_datetimelike_compat=check_datetimelike_compat, + obj='DataFrame.iloc[:, {0}]'.format(i)) def assert_panelnd_equal(left, right, @@ -1165,15 +1193,19 @@ def assert_contains_all(iterable, dic): def assert_copy(iter1, iter2, **eql_kwargs): """ - iter1, iter2: iterables that produce elements comparable with assert_almost_equal + iter1, iter2: iterables that produce elements + comparable with assert_almost_equal - Checks that the elements are equal, but not the same object. (Does not - check that items in sequences are also not the same object) + Checks that the elements are equal, but not + the same object. (Does not check that items + in sequences are also not the same object) """ for elem1, elem2 in zip(iter1, iter2): assert_almost_equal(elem1, elem2, **eql_kwargs) - assert elem1 is not elem2, "Expected object %r and object %r to be different objects, were same." % ( - type(elem1), type(elem2)) + assert elem1 is not elem2, ("Expected object %r and " + "object %r to be different " + "objects, were same." + % (type(elem1), type(elem2))) def getCols(k): @@ -1307,6 +1339,7 @@ def getTimeSeriesData(nper=None, freq='B'): def getPeriodData(nper=None): return dict((c, makePeriodSeries(nper)) for c in getCols(K)) + # make frame def makeTimeDataFrame(nper=None, freq='B'): data = getTimeSeriesData(nper, freq) @@ -1330,9 +1363,11 @@ def getMixedTypeDict(): return index, data + def makeMixedDataFrame(): return DataFrame(getMixedTypeDict()[1]) + def makePeriodFrame(nper=None): data = getPeriodData(nper) return DataFrame(data) @@ -1362,9 +1397,9 @@ def makeCustomIndex(nentries, nlevels, prefix='#', names=False, ndupe_l=None, nentries - number of entries in index nlevels - number of levels (> 1 produces multindex) prefix - a string prefix for labels - names - (Optional), bool or list of strings. if True will use default names, - if false will use no names, if a list is given, the name of each level - in the index will be taken from the list. + names - (Optional), bool or list of strings. if True will use default + names, if false will use no names, if a list is given, the name of + each level in the index will be taken from the list. ndupe_l - (Optional), list of ints, the number of rows for which the label will repeated at the corresponding level, you can specify just the first few, the rest will use the default ndupe_l of 1. @@ -1382,8 +1417,8 @@ def makeCustomIndex(nentries, nlevels, prefix='#', names=False, ndupe_l=None, if ndupe_l is None: ndupe_l = [1] * nlevels assert (is_sequence(ndupe_l) and len(ndupe_l) <= nlevels) - assert (names is None or names is False - or names is True or len(names) is nlevels) + assert (names is None or names is False or + names is True or len(names) is nlevels) assert idx_type is None or \ (idx_type in ('i', 'f', 's', 'u', 'dt', 'p', 'td') and nlevels == 1) @@ -1399,8 +1434,9 @@ def makeCustomIndex(nentries, nlevels, prefix='#', names=False, ndupe_l=None, names = [names] # specific 1D index type requested? - idx_func = dict(i=makeIntIndex, f=makeFloatIndex, s=makeStringIndex, - u=makeUnicodeIndex, dt=makeDateIndex, td=makeTimedeltaIndex, + idx_func = dict(i=makeIntIndex, f=makeFloatIndex, + s=makeStringIndex, u=makeUnicodeIndex, + dt=makeDateIndex, td=makeTimedeltaIndex, p=makePeriodIndex).get(idx_type) if idx_func: idx = idx_func(nentries) @@ -1457,14 +1493,15 @@ def makeCustomDataframe(nrows, ncols, c_idx_names=True, r_idx_names=True, c_idx_nlevels ==1. c_idx_nlevels - number of levels in columns index. > 1 will yield MultiIndex r_idx_nlevels - number of levels in rows index. > 1 will yield MultiIndex - data_gen_f - a function f(row,col) which return the data value at that position, - the default generator used yields values of the form "RxCy" based on position. + data_gen_f - a function f(row,col) which return the data value + at that position, the default generator used yields values of the form + "RxCy" based on position. c_ndupe_l, r_ndupe_l - list of integers, determines the number - of duplicates for each label at a given level of the corresponding index. - The default `None` value produces a multiplicity of 1 across - all levels, i.e. a unique index. Will accept a partial list of - length N < idx_nlevels, for just the first N levels. If ndupe - doesn't divide nrows/ncol, the last label might have lower multiplicity. + of duplicates for each label at a given level of the corresponding + index. The default `None` value produces a multiplicity of 1 across + all levels, i.e. a unique index. Will accept a partial list of length + N < idx_nlevels, for just the first N levels. If ndupe doesn't divide + nrows/ncol, the last label might have lower multiplicity. dtype - passed to the DataFrame constructor as is, in case you wish to have more control in conjuncion with a custom `data_gen_f` r_idx_type, c_idx_type - "i"/"f"/"s"/"u"/"dt"/"td". @@ -1484,8 +1521,9 @@ def makeCustomDataframe(nrows, ncols, c_idx_names=True, r_idx_names=True, # make the data a random int between 1 and 100 >> mkdf(5,3,data_gen_f=lambda r,c:randint(1,100)) - # 2-level multiindex on rows with each label duplicated twice on first level, - # default names on both axis, single index on both axis + # 2-level multiindex on rows with each label duplicated + # twice on first level, default names on both axis, single + # index on both axis >> a=makeCustomDataframe(5,3,r_idx_nlevels=2,r_ndupe_l=[2]) # DatetimeIndex on row, index with unicode labels on columns @@ -1505,9 +1543,11 @@ def makeCustomDataframe(nrows, ncols, c_idx_names=True, r_idx_names=True, assert c_idx_nlevels > 0 assert r_idx_nlevels > 0 assert r_idx_type is None or \ - (r_idx_type in ('i', 'f', 's', 'u', 'dt', 'p', 'td') and r_idx_nlevels == 1) + (r_idx_type in ('i', 'f', 's', + 'u', 'dt', 'p', 'td') and r_idx_nlevels == 1) assert c_idx_type is None or \ - (c_idx_type in ('i', 'f', 's', 'u', 'dt', 'p', 'td') and c_idx_nlevels == 1) + (c_idx_type in ('i', 'f', 's', + 'u', 'dt', 'p', 'td') and c_idx_nlevels == 1) columns = makeCustomIndex(ncols, nlevels=c_idx_nlevels, prefix='C', names=c_idx_names, ndupe_l=c_ndupe_l, @@ -1599,12 +1639,14 @@ def add_nans(panel): dm[col][:i + j] = np.NaN return panel + def add_nans_panel4d(panel4d): for l, label in enumerate(panel4d.labels): panel = panel4d[label] add_nans(panel) return panel4d + class TestSubDict(dict): def __init__(self, *args, **kwargs): @@ -1676,6 +1718,7 @@ def skip_if_no_package(*args, **kwargs): exc_failed_check=SkipTest, *args, **kwargs) + def skip_if_no_package_deco(pkg_name, version=None, app='pandas'): from nose import SkipTest @@ -1683,11 +1726,12 @@ def deco(func): @wraps(func) def wrapper(*args, **kwargs): package_check(pkg_name, version=version, app=app, - exc_failed_import=SkipTest, exc_failed_check=SkipTest) + exc_failed_import=SkipTest, + exc_failed_check=SkipTest) return func(*args, **kwargs) return wrapper return deco - # +# # Additional tags decorators for nose # @@ -1738,13 +1782,13 @@ def dec(f): # or this e.errno/e.reason.errno _network_errno_vals = ( - 101, # Network is unreachable - 111, # Connection refused - 110, # Connection timed out - 104, # Connection reset Error - 54, # Connection reset by peer - 60, # urllib.error.URLError: [Errno 60] Connection timed out - ) + 101, # Network is unreachable + 111, # Connection refused + 110, # Connection timed out + 104, # Connection reset Error + 54, # Connection reset by peer + 60, # urllib.error.URLError: [Errno 60] Connection timed out +) # Both of the above shouldn't mask real issues such as 404's # or refused connections (changed DNS). @@ -1755,7 +1799,8 @@ def dec(f): _network_error_classes = (IOError, httplib.HTTPException) if sys.version_info >= (3, 3): - _network_error_classes += (TimeoutError,) + _network_error_classes += (TimeoutError,) # noqa + def can_connect(url, error_classes=_network_error_classes): """Try to connect to the given url. True if succeeds, False if IOError @@ -1805,8 +1850,8 @@ def network(t, url="http://www.google.com", t : callable The test requiring network connectivity. url : path - The url to test via ``pandas.io.common.urlopen`` to check for connectivity. - Defaults to 'http://www.google.com'. + The url to test via ``pandas.io.common.urlopen`` to check + for connectivity. Defaults to 'http://www.google.com'. raise_on_error : bool If True, never catches errors. check_before_test : bool @@ -1897,8 +1942,8 @@ def wrapper(*args, **kwargs): e_str = str(e) if any([m.lower() in e_str.lower() for m in _skip_on_messages]): - raise SkipTest("Skipping test because exception message is known" - " and error %s" % e) + raise SkipTest("Skipping test because exception " + "message is known and error %s" % e) if not isinstance(e, error_classes): raise @@ -2006,16 +2051,19 @@ def assertRaises(_exception, _callable=None, *args, **kwargs): else: return manager + def assertRaisesRegexp(_exception, _regexp, _callable=None, *args, **kwargs): - """ Port of assertRaisesRegexp from unittest in Python 2.7 - used in with statement. + """ Port of assertRaisesRegexp from unittest in + Python 2.7 - used in with statement. Explanation from standard library: - Like assertRaises() but also tests that regexp matches on the string - representation of the raised exception. regexp may be a regular expression - object or a string containing a regular expression suitable for use by - re.search(). + Like assertRaises() but also tests that regexp matches on the + string representation of the raised exception. regexp may be a + regular expression object or a string containing a regular + expression suitable for use by re.search(). - You can pass either a regular expression or a compiled regular expression object. + You can pass either a regular expression + or a compiled regular expression object. >>> assertRaisesRegexp(ValueError, 'invalid literal for.*XYZ', ... int, 'XYZ') >>> import re @@ -2052,7 +2100,10 @@ def assertRaisesRegexp(_exception, _regexp, _callable=None, *args, **kwargs): class _AssertRaisesContextmanager(object): - """handles the behind the scenes work for assertRaises and assertRaisesRegexp""" + """ + Handles the behind the scenes work + for assertRaises and assertRaisesRegexp + """ def __init__(self, exception, regexp=None, *args, **kwargs): self.exception = exception if regexp is not None and not hasattr(regexp, "search"): @@ -2084,6 +2135,7 @@ def handle_success(self, exc_type, exc_value, traceback): raise_with_traceback(e, traceback) return True + @contextmanager def assert_produces_warning(expected_warning=Warning, filter_level="always", clear=None, check_stacklevel=True): @@ -2119,7 +2171,7 @@ def assert_produces_warning(expected_warning=Warning, filter_level="always", # if they have happened before # to guarantee that we will catch them if not is_list_like(clear): - clear = [ clear ] + clear = [clear] for m in clear: try: m.__warningregistry__.clear() @@ -2137,11 +2189,13 @@ def assert_produces_warning(expected_warning=Warning, filter_level="always", saw_warning = True if check_stacklevel and issubclass(actual_warning.category, - (FutureWarning, DeprecationWarning)): + (FutureWarning, + DeprecationWarning)): from inspect import getframeinfo, stack caller = getframeinfo(stack()[2][0]) - msg = ("Warning not set with correct stacklevel. File were warning" - " is raised: {0} != {1}. Warning message: {2}".format( + msg = ("Warning not set with correct stacklevel. " + "File where warning is raised: {0} != {1}. " + "Warning message: {2}".format( actual_warning.filename, caller.filename, actual_warning.message)) assert actual_warning.filename == caller.filename, msg @@ -2214,12 +2268,15 @@ def test_parallel(num_threads=2, kwargs_list=None): num_threads : int, optional The number of times the function is run in parallel. kwargs_list : list of dicts, optional - The list of kwargs to update original function kwargs on different threads. + The list of kwargs to update original + function kwargs on different threads. Notes ----- This decorator does not pass the return value of the decorated function. - Original from scikit-image: https://github.com/scikit-image/scikit-image/pull/1519 + Original from scikit-image: + + https://github.com/scikit-image/scikit-image/pull/1519 """