Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support generator for LocalJob & ThreadJob & DaskJob #6

Merged
merged 4 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ from executor.engine import Engine, ProcessJob
def add(a, b):
return a + b

async def stream():
for i in range(5):
await asyncio.sleep(0.5)
yield i

with Engine() as engine:
# job1 and job2 will be executed in parallel
job1 = ProcessJob(add, args=(1, 2))
Expand All @@ -86,6 +91,13 @@ with Engine() as engine:
engine.submit(job1, job2, job3)
engine.wait_job(job3) # wait for job3 done
print(job3.result()) # 10

# generator
job4 = ProcessJob(stream)
# do not do engine.wait because the generator job's future is done only when StopIteration
await engine.submit_async(job4)
async for x in job3.result():
print(x)
```

Async mode example:
Expand Down
14 changes: 12 additions & 2 deletions executor/engine/job/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dask.distributed import Client, LocalCluster

from .base import Job
from .utils import GeneratorWrapper
from ..utils import PortManager


Expand Down Expand Up @@ -56,12 +57,21 @@ def release_resource(self) -> bool:
async def run_function(self):
"""Run job with Dask."""
client = self.engine.dask_client
func = functools.partial(self.func, **self.kwargs)
fut = client.submit(func, *self.args)
func = functools.partial(self.func, *self.args, **self.kwargs)
fut = client.submit(func)
self._executor = fut
result = await fut
return result

async def run_generator(self):
"""Run job as a generator."""
client = self.engine.dask_client
func = functools.partial(self.func, *self.args, **self.kwargs)
fut = client.submit(func)
self._executor = client.get_executor(pure=False)
result = GeneratorWrapper(self, fut)
return result

async def cancel(self):
"""Cancel job."""""
if self.status == "running":
Expand Down
5 changes: 5 additions & 0 deletions executor/engine/job/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ async def run_function(self):
"""Run job in local thread."""
res = self.func(*self.args, **self.kwargs)
return res

async def run_generator(self):
"""Run job as a generator."""
res = self.func(*self.args, **self.kwargs)
return res
13 changes: 11 additions & 2 deletions executor/engine/job/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from concurrent.futures import ThreadPoolExecutor

from .base import Job
from .utils import _gen_initializer, GeneratorWrapper


class ThreadJob(Job):
Expand Down Expand Up @@ -42,13 +43,21 @@ def release_resource(self) -> bool:

async def run_function(self):
"""Run job in thread pool."""
func = functools.partial(self.func, *self.args, **self.kwargs)
self._executor = ThreadPoolExecutor(1)
loop = asyncio.get_running_loop()
func = functools.partial(self.func, **self.kwargs)
fut = loop.run_in_executor(self._executor, func, *self.args)
fut = loop.run_in_executor(self._executor, func)
result = await fut
return result

async def run_generator(self):
"""Run job as a generator."""
func = functools.partial(self.func, *self.args, **self.kwargs)
self._executor = ThreadPoolExecutor(
1, initializer=_gen_initializer, initargs=(func,))
result = GeneratorWrapper(self)
return result

async def cancel(self):
"""Cancel job."""
if self.status == "running":
Expand Down
37 changes: 25 additions & 12 deletions executor/engine/job/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import typing as T
import asyncio
from datetime import datetime
from concurrent.futures import Future
import threading

from ..utils import CheckAttrRange, ExecutorError


if T.TYPE_CHECKING:
from .base import Job


JobStatusType = T.Literal['pending', 'running', 'failed', 'done', 'cancelled']
valid_job_statuses: T.List[JobStatusType] = [
'pending', 'running', 'failed', 'done', 'cancelled']
Expand Down Expand Up @@ -38,36 +39,48 @@ def __init__(self, job: "Job", valid_status: T.List[JobStatusType]):


_T = T.TypeVar("_T")
_thread_locals = threading.local()


def _gen_initializer(gen_func, args=tuple(), kwargs={}): # pragma: no cover
global _generator
_generator = gen_func(*args, **kwargs)
global _thread_locals
if "_thread_locals" not in globals():
# avoid conflict for ThreadJob
_thread_locals = threading.local()
_thread_locals._generator = gen_func(*args, **kwargs)


def _gen_next(): # pragma: no cover
global _generator
return next(_generator)
def _gen_next(fut=None): # pragma: no cover
global _thread_locals
if fut is None:
return next(_thread_locals._generator)
else:
return next(fut)


def _gen_anext(): # pragma: no cover
global _generator
return asyncio.run(_generator.__anext__())
def _gen_anext(fut=None): # pragma: no cover
global _thread_locals
if fut is None:
return asyncio.run(_thread_locals._generator.__anext__())
else:
return asyncio.run(fut.__anext__())


class GeneratorWrapper(T.Generic[_T]):
"""
wrap a generator in executor pool
"""
def __init__(self, job: "Job"):

def __init__(self, job: "Job", fut: T.Optional[Future] = None):
self._job = job
self._fut = fut

def __iter__(self):
return self

def __next__(self) -> _T:
try:
return self._job._executor.submit(_gen_next).result()
return self._job._executor.submit(_gen_next, self._fut).result()
except Exception as e:
engine = self._job.engine
if engine is None:
Expand All @@ -87,7 +100,7 @@ def __aiter__(self):

async def __anext__(self) -> _T:
try:
fut = self._job._executor.submit(_gen_anext)
fut = self._job._executor.submit(_gen_anext, self._fut)
res = await asyncio.wrap_future(fut)
return res
except Exception as e:
Expand Down
Loading