From 813c3ad8eb79ee0110d009ca0d6af488678c3d8c Mon Sep 17 00:00:00 2001 From: Stuart Archibald Date: Tue, 1 Sep 2020 16:01:56 +0100 Subject: [PATCH] Workaround #5973 As title. --- numba/tests/test_dispatcher.py | 60 +++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/numba/tests/test_dispatcher.py b/numba/tests/test_dispatcher.py index 30a8e081485..65aff170cfc 100644 --- a/numba/tests/test_dispatcher.py +++ b/numba/tests/test_dispatcher.py @@ -28,11 +28,15 @@ from numba.core.dispatcher import Dispatcher from numba.tests.support import (skip_parfors_unsupported, needs_lapack, SerialMixin) - +from numba.testing.main import _TIMEOUT as _RUNNER_TIMEOUT import llvmlite.binding as ll import unittest from numba.parfors import parfor + +_TEST_TIMEOUT = _RUNNER_TIMEOUT - 60. + + try: import jinja2 except ImportError: @@ -1970,45 +1974,57 @@ def assign(out, x): self.assertEqual(ct_bad, 1) +@njit +def add_y1(x, y=1): + return x + y + + +@njit +def add_ynone(x, y=None): + return x + (1 if y else 2) + + +@njit +def mult(x, y): + return x * y + + +@njit +def add_func(x, func=mult): + return x + func(x, x) + + +def _checker(f1, arg): + assert f1(arg) == f1.py_func(arg) + + class TestMultiprocessingDefaultParameters(SerialMixin, unittest.TestCase): def run_fc_multiproc(self, fc): try: ctx = multiprocessing.get_context('spawn') except AttributeError: ctx = multiprocessing - with ctx.Pool(1) as p: - self.assertEqual(p.map(fc, [1, 2, 3]), list(map(fc, [1, 2, 3]))) + + for a in [1, 2, 3]: + p = ctx.Process(target=_checker, args=(fc, a,)) + p.start() + p.join(_TEST_TIMEOUT) + self.assertEqual(p.exitcode, 0) def test_int_def_param(self): """ Tests issue #4888""" - @njit - def add(x, y=1): - return x + y - - self.run_fc_multiproc(add) + self.run_fc_multiproc(add_y1) def test_none_def_param(self): """ Tests None as a default parameter""" - @njit - def add(x, y=None): - return x + (1 if y else 2) - - self.run_fc_multiproc(add) + self.run_fc_multiproc(add_func) def test_function_def_param(self): """ Tests a function as a default parameter""" - @njit - def mult(x, y): - return x * y - - @njit - def add(x, func=mult): - return x + func(x, x) - - self.run_fc_multiproc(add) + self.run_fc_multiproc(add_func) if __name__ == '__main__':