Skip to content

Commit

Permalink
Use spawn context instead of default context for multiprocessing to f…
Browse files Browse the repository at this point in the history
…ix issue on Linux [release]
  • Loading branch information
AjayP13 committed Dec 29, 2023
1 parent 71cfa80 commit 926f6fa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/datadreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from collections import UserDict, defaultdict
from multiprocessing import Process
from multiprocessing.context import SpawnProcess
from threading import Lock
from typing import TYPE_CHECKING, Any, Callable, cast

Expand Down Expand Up @@ -139,7 +139,7 @@ def _register_child_thread(parent_thread_id: tuple[int, int]):
DataDreamer.ctx.step_stack[get_thread_id()] = []

@staticmethod
def _add_process(process: Process):
def _add_process(process: SpawnProcess):
DataDreamer.ctx.background_processes.append(process)

@staticmethod
Expand Down
17 changes: 9 additions & 8 deletions src/utils/background_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from collections import UserDict, namedtuple
from functools import partial
from logging import Logger, StreamHandler
from multiprocessing import Process, Queue, get_start_method, set_start_method
from multiprocessing import get_context
from multiprocessing.context import SpawnProcess
from threading import Thread, get_ident
from time import sleep, time
from typing import Any, Callable, Generator
Expand All @@ -30,7 +31,7 @@ def get_parent_process_context() -> dict[str, Any]:
from .. import DataDreamer

# Remove non-picklable (and un-needed attrs from ctx)
ctx = UserDict()
ctx: Any = UserDict()
ctx.__dict__.update(DataDreamer.ctx.__dict__.copy())
if hasattr(ctx, "background_processes"):
del ctx.background_processes
Expand Down Expand Up @@ -146,11 +147,12 @@ def run_in_background_thread(func: Callable, *args, **kwargs) -> Thread:
return t


def run_in_background_process(func: Callable, *args, **kwargs) -> tuple[Process, Any]:
pipe: Any = Queue(1)
orig_start_method = get_start_method()
set_start_method("spawn", force=True)
p = Process(
def run_in_background_process(
func: Callable, *args, **kwargs
) -> tuple[SpawnProcess, Any]:
spawn_context = get_context(method="spawn")
pipe: Any = spawn_context.Queue(1)
p = spawn_context.Process(
target=partial(
_process_func_wrapper,
get_parent_process_context(),
Expand All @@ -161,7 +163,6 @@ def run_in_background_process(func: Callable, *args, **kwargs) -> tuple[Process,
)
p.daemon = False
p.start()
set_start_method(orig_start_method, force=True)
return p, pipe


Expand Down

0 comments on commit 926f6fa

Please sign in to comment.