Skip to content

Commit

Permalink
Workaround numba#5973
Browse files Browse the repository at this point in the history
As title.
  • Loading branch information
stuartarchibald committed Sep 1, 2020
1 parent 900a2d8 commit 813c3ad
Showing 1 changed file with 38 additions and 22 deletions.
60 changes: 38 additions & 22 deletions numba/tests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
from numba.core.dispatcher import Dispatcher
from numba.tests.support import (skip_parfors_unsupported, needs_lapack,
SerialMixin)

from numba.testing.main import _TIMEOUT as _RUNNER_TIMEOUT
import llvmlite.binding as ll
import unittest
from numba.parfors import parfor


_TEST_TIMEOUT = _RUNNER_TIMEOUT - 60.


try:
import jinja2
except ImportError:
Expand Down Expand Up @@ -1970,45 +1974,57 @@ def assign(out, x):
self.assertEqual(ct_bad, 1)


@njit
def add_y1(x, y=1):
return x + y


@njit
def add_ynone(x, y=None):
return x + (1 if y else 2)


@njit
def mult(x, y):
return x * y


@njit
def add_func(x, func=mult):
return x + func(x, x)


def _checker(f1, arg):
assert f1(arg) == f1.py_func(arg)


class TestMultiprocessingDefaultParameters(SerialMixin, unittest.TestCase):
def run_fc_multiproc(self, fc):
try:
ctx = multiprocessing.get_context('spawn')
except AttributeError:
ctx = multiprocessing
with ctx.Pool(1) as p:
self.assertEqual(p.map(fc, [1, 2, 3]), list(map(fc, [1, 2, 3])))

for a in [1, 2, 3]:
p = ctx.Process(target=_checker, args=(fc, a,))
p.start()
p.join(_TEST_TIMEOUT)
self.assertEqual(p.exitcode, 0)

def test_int_def_param(self):
""" Tests issue #4888"""

@njit
def add(x, y=1):
return x + y

self.run_fc_multiproc(add)
self.run_fc_multiproc(add_y1)

def test_none_def_param(self):
""" Tests None as a default parameter"""

@njit
def add(x, y=None):
return x + (1 if y else 2)

self.run_fc_multiproc(add)
self.run_fc_multiproc(add_func)

def test_function_def_param(self):
""" Tests a function as a default parameter"""

@njit
def mult(x, y):
return x * y

@njit
def add(x, func=mult):
return x + func(x, x)

self.run_fc_multiproc(add)
self.run_fc_multiproc(add_func)


if __name__ == '__main__':
Expand Down

0 comments on commit 813c3ad

Please sign in to comment.