Skip to content

Commit

Permalink
[multiprocessing] - Improve launch subprocesses (fairinternal/xformer…
Browse files Browse the repository at this point in the history
…s#960)

* testing ProcessPoolExecutor singleton pattern

* rebasing branch 'improve_launch_subprocesses' on '804f6300'

* better pytorch memory cleaning

* added tests mix issue

* one single dtype during tests

* added get_global_pool_allocator according to dtype and world_size

* removed pytest session cleanup&fix linters&use correct context enter/exit pattern&removed executor initializer&removed lru_cache

* restored all tests

* removed the context's manager submit+exiting if statements/mechanisms and dictionary presence checking of the executor's instance&hack the process termination

---------

Co-authored-by: Valeriu Lacatusu <[email protected]>

__original_commit__ = fairinternal/xformers@cbafbdb
  • Loading branch information
lvaleriu authored and xFormers Bot committed Jan 26, 2024
1 parent 342de87 commit 94582e8
Showing 1 changed file with 110 additions and 14 deletions.
124 changes: 110 additions & 14 deletions tests/multiprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import concurrent.futures
import concurrent
import gc
import multiprocessing
import signal
import tempfile
from typing import List
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import Dict, List, Tuple

import torch


class SafeMpContext:
class SafeMpContext(multiprocessing.context.BaseContext):
def __init__(self) -> None:
self.mp_context = multiprocessing.get_context("spawn")
self.processes: List[multiprocessing.context.SpawnProcess] = []
Expand All @@ -27,10 +28,17 @@ def kill_all_processes(self):
for p in self.processes:
p.terminate()
p.join(1)
if p.exitcode is None:

# (https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Process.exitcode)
# Even though the python documentation seems to say that after joining the exitcode should
# become set, this is not what we have observed in practice. We therefore loop until it
# becomes set.
while p.exitcode is None:
p.kill()
p.join()

assert p.exitcode is not None, f"{p} is still alive"

def log_bad_exit_codes(self):
for rank, p in enumerate(self.processes):
if p.exitcode == 0:
Expand All @@ -57,9 +65,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.log_bad_exit_codes()


def _launch_subprocesses_fn_wrapper(
init_method: str, rank: int, world_size: int, user_fn, args, kwargs
):
def init_process_group(init_method: str, rank: int, world_size: int):
torch._C._set_print_stack_traces_on_fatal_signal(True)

if torch.cuda.device_count() >= world_size:
Expand All @@ -68,23 +74,113 @@ def _launch_subprocesses_fn_wrapper(
else:
# Use Gloo instead of NCCL so that we can run on a single GPU
backend = "gloo"

torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method,
)
return user_fn(*args, **kwargs)


def _launch_subprocesses_fn_wrapper(
init_method: str, rank: int, world_size: int, user_fn, args, kwargs
):
# Check if the process group is already initialized
if not torch.distributed.is_initialized():
init_process_group(init_method, rank, world_size)
try:
return user_fn(*args, **kwargs)
finally:
# should free all memory used by PyTorch in the subprocesses
gc.collect()
torch.cuda.empty_cache()


# Global dictionary to keep track of executors and temporary files
EXECUTORS_AND_FILES: Dict[
int, Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor]
] = {}


def get_global_pool_allocator(
world_size: int,
) -> Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor]:
global EXECUTORS_AND_FILES

if world_size not in EXECUTORS_AND_FILES:
rdv = NamedTemporaryFile(mode="w+b", buffering=-1, delete=False)
mp_context = SafeMpContext()

executor = concurrent.futures.ProcessPoolExecutor(
max_workers=world_size, mp_context=mp_context
)

# Add the executor and temporary file to the global list
EXECUTORS_AND_FILES[world_size] = (rdv, executor)
else:
rdv, executor = EXECUTORS_AND_FILES[world_size]

return rdv, executor


class ProcessPoolExecutorManager:
def __init__(self, world_size: int):
self.world_size = world_size

def __enter__(self):
# when you start a subprocess you want to free memory used by PyTorch in the main process,
# so the subprocess can have memory
gc.collect()
torch.cuda.empty_cache()

self.rdv, self.executor = get_global_pool_allocator(self.world_size)
return self

def submit(self, fn, *args, **kwargs):
return self.executor.submit(fn, *args, **kwargs)

def __exit__(self, exc_type, exc_val, exc_tb):
# One of the subprocesses jobs has failed
if exc_val:
# We want to avoid killing the processes while the executor was thinking that they were
# still up and healthy (as this may have unintended consequences, such as the executor
# restarting the processes, or reporting spurious errors).
# Set the internal state of the executor and call cancel() on each issued task that is
# not executing
self.executor.shutdown(wait=False, cancel_futures=True)

# Kill all remaining subprocesses
mp_context = self.executor._mp_context
mp_context.kill_all_processes()
mp_context.log_bad_exit_codes()

# We want to wait for all the futures to complete, so we need to shutdown twice
self.executor.shutdown(wait=True)

# Close the temporary file
self.rdv.close()

# Remove the executor from the global list.
# This will recreate it next time a test is requiring this world_size
assert self.world_size in EXECUTORS_AND_FILES
del EXECUTORS_AND_FILES[self.world_size]

print(
f"Shutdown and remove the executor after subprocesses error. Executors cnt: {len(EXECUTORS_AND_FILES)}"
)


def launch_subprocesses(world_size: int, fn, *args, **kwargs):
with SafeMpContext() as mp_context, concurrent.futures.ProcessPoolExecutor(
max_workers=world_size, mp_context=mp_context
) as e, tempfile.NamedTemporaryFile(mode="w+b", buffering=-1, delete=True) as rdv:
# This custom manager allows each test execution to enter/exit the following context.
# When entering the context, it creates/reuses a new/existing ProcessPoolExecutor with the given world size.
# The context also allows to detect an exception upon exit, in which case it will kill all spawned processes,
# delete the manager, recreate the manager upon following request and respawn processes.
with ProcessPoolExecutorManager(world_size) as manager:
futures = [
e.submit(
manager.submit(
_launch_subprocesses_fn_wrapper,
init_method=f"file://{rdv.name}",
init_method=f"file://{manager.rdv.name}",
rank=rank,
world_size=world_size,
user_fn=fn,
Expand Down

0 comments on commit 94582e8

Please sign in to comment.