From 0e2ed4dc4ce558539e37b1186af57c0ec66723b9 Mon Sep 17 00:00:00 2001 From: Ofer Koren Date: Wed, 17 Jun 2020 12:26:08 +0300 Subject: [PATCH] concurrency: fix arg handling in concurrent calls in recent feature update --- easypy/concurrency.py | 11 ++++++----- tests/test_concurrency.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/easypy/concurrency.py b/easypy/concurrency.py index d7781618..30ac15e0 100644 --- a/easypy/concurrency.py +++ b/easypy/concurrency.py @@ -974,8 +974,9 @@ def __repr__(self): flags += 'T' return "<%s[%s] '%s'>" % (self.__class__.__name__, self.threadname, flags) - def _logged_func(self, kwargs=None): - kwargs = {**self.kwargs, **(kwargs or {})} + def _logged_func(self, *args, **kwargs): + args = self.args + args + kwargs = {**self.kwargs, **kwargs} stack = ExitStack() self.exc = None self.timer = Timer() @@ -986,7 +987,7 @@ def _logged_func(self, kwargs=None): stack.enter_context(_logger.suppressed()) _logger.debug("%s - starting", self) while True: - self._result = self.func(*self.args, **kwargs) + self._result = self.func(*args, **kwargs) if not self.loop: return if self.wait(self.sleep): @@ -1046,10 +1047,10 @@ def paused(self): @contextmanager def _running(self, *args, **kwargs): - func = lambda *args, **kwargs: self._logged_func(*args, **kwargs) + func = lambda: self._logged_func(*args, **kwargs) if DISABLE_CONCURRENCY: - self._logged_func(*args, **kwargs) + func() yield self return diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 04392334..e3d63526 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -25,9 +25,18 @@ def test_thread_stacks(): def test_call_concurrent(): - func = lambda a, b: a + b + func = lambda a, b: a * 10 + b c = concurrent(func, 1, b=2, threadname='add') - assert c() == 3 + assert c() == 12 + + c = concurrent(func, 1, 2, threadname='add') + assert c() == 12 + + c = concurrent(func, 1, threadname='add') + assert c(2) == 12 + + c = concurrent(func, threadname='add') + assert c(1, 2) == 12 def test_call_concurrent_timeout():