From b12e9ec3cff020983e3dde9b16f5ccc4fd0f4963 Mon Sep 17 00:00:00 2001 From: Jake Lishman Date: Wed, 15 May 2024 17:56:11 +0100 Subject: [PATCH] Fix unnecessary serialisation of `PassManager` in serial contexts (#12410) * Fix unnecessary serialisation of `PassManager` in serial contexts This exposes the interal decision in `parallel_map` of whether to actually run in serial or not. If not, there's no need for `PassManager` to side-car its `dill` serialisation onto the side of the IPC (we use `dill` because we need to pickle lambdas), which can be an unfortunately huge cost for certain IBM pulse-enabled backends. * Remove new function from public API This makes the patch series safe for backport to 1.1. --- qiskit/passmanager/passmanager.py | 22 +++++------ qiskit/utils/__init__.py | 5 ++- qiskit/utils/parallel.py | 39 ++++++++++++------- .../parallel-check-8186a8f074774a1f.yaml | 5 +++ 4 files changed, 43 insertions(+), 28 deletions(-) create mode 100644 releasenotes/notes/parallel-check-8186a8f074774a1f.yaml diff --git a/qiskit/passmanager/passmanager.py b/qiskit/passmanager/passmanager.py index ba416dfb063b..8d3a4e9aa693 100644 --- a/qiskit/passmanager/passmanager.py +++ b/qiskit/passmanager/passmanager.py @@ -21,7 +21,7 @@ import dill -from qiskit.utils.parallel import parallel_map +from qiskit.utils.parallel import parallel_map, should_run_in_parallel from .base_tasks import Task, PassManagerIR from .exceptions import PassManagerError from .flow_controllers import FlowControllerLinear @@ -225,16 +225,16 @@ def callback_func(**kwargs): in_programs = [in_programs] is_list = False - if len(in_programs) == 1: - out_program = _run_workflow( - program=in_programs[0], - pass_manager=self, - callback=callback, - **kwargs, - ) - if is_list: - return [out_program] - return out_program + # If we're not going to run in parallel, we want to avoid spending time `dill` serialising + # ourselves, since that can be quite expensive. + if len(in_programs) == 1 or not should_run_in_parallel(num_processes): + out = [ + _run_workflow(program=program, pass_manager=self, callback=callback, **kwargs) + for program in in_programs + ] + if len(in_programs) == 1 and not is_list: + return out[0] + return out del callback del kwargs diff --git a/qiskit/utils/__init__.py b/qiskit/utils/__init__.py index f5256f6f11ec..30935437ebf2 100644 --- a/qiskit/utils/__init__.py +++ b/qiskit/utils/__init__.py @@ -44,7 +44,7 @@ .. autofunction:: local_hardware_info .. autofunction:: is_main_process -A helper function for calling a custom function with python +A helper function for calling a custom function with Python :class:`~concurrent.futures.ProcessPoolExecutor`. Tasks can be executed in parallel using this function. .. autofunction:: parallel_map @@ -70,7 +70,7 @@ from . import optionals -from .parallel import parallel_map +from .parallel import parallel_map, should_run_in_parallel __all__ = [ "LazyDependencyManager", @@ -85,4 +85,5 @@ "is_main_process", "apply_prefix", "parallel_map", + "should_run_in_parallel", ] diff --git a/qiskit/utils/parallel.py b/qiskit/utils/parallel.py index d46036a478f9..f87eeb815967 100644 --- a/qiskit/utils/parallel.py +++ b/qiskit/utils/parallel.py @@ -48,6 +48,8 @@ from the multiprocessing library. """ +from __future__ import annotations + import os from concurrent.futures import ProcessPoolExecutor import sys @@ -101,6 +103,21 @@ def _task_wrapper(param): return task(value, *task_args, **task_kwargs) +def should_run_in_parallel(num_processes: int | None = None) -> bool: + """Return whether the current parallelisation configuration suggests that we should run things + like :func:`parallel_map` in parallel (``True``) or degrade to serial (``False``). + + Args: + num_processes: the number of processes requested for use (if given). + """ + num_processes = CPU_COUNT if num_processes is None else num_processes + return ( + num_processes > 1 + and os.getenv("QISKIT_IN_PARALLEL", "FALSE") == "FALSE" + and CONFIG.get("parallel_enabled", PARALLEL_DEFAULT) + ) + + def parallel_map( # pylint: disable=dangerous-default-value task, values, task_args=(), task_kwargs={}, num_processes=CPU_COUNT ): @@ -110,21 +127,20 @@ def parallel_map( # pylint: disable=dangerous-default-value result = [task(value, *task_args, **task_kwargs) for value in values] - On Windows this function defaults to a serial implementation to avoid the - overhead from spawning processes in Windows. + This will parallelise the results if the number of ``values`` is greater than one, and the + current system configuration permits parallelization. Args: task (func): Function that is to be called for each value in ``values``. - values (array_like): List or array of values for which the ``task`` - function is to be evaluated. + values (array_like): List or array of values for which the ``task`` function is to be + evaluated. task_args (list): Optional additional arguments to the ``task`` function. task_kwargs (dict): Optional additional keyword argument to the ``task`` function. num_processes (int): Number of processes to spawn. Returns: - result: The result list contains the value of - ``task(value, *task_args, **task_kwargs)`` for - each value in ``values``. + result: The result list contains the value of ``task(value, *task_args, **task_kwargs)`` for + each value in ``values``. Raises: QiskitError: If user interrupts via keyboard. @@ -147,12 +163,7 @@ def func(_): if len(values) == 1: return [task(values[0], *task_args, **task_kwargs)] - # Run in parallel if not Win and not in parallel already - if ( - num_processes > 1 - and os.getenv("QISKIT_IN_PARALLEL") == "FALSE" - and CONFIG.get("parallel_enabled", PARALLEL_DEFAULT) - ): + if should_run_in_parallel(num_processes): os.environ["QISKIT_IN_PARALLEL"] = "TRUE" try: results = [] @@ -173,8 +184,6 @@ def func(_): os.environ["QISKIT_IN_PARALLEL"] = "FALSE" return results - # Cannot do parallel on Windows , if another parallel_map is running in parallel, - # or len(values) == 1. results = [] for _, value in enumerate(values): result = task(value, *task_args, **task_kwargs) diff --git a/releasenotes/notes/parallel-check-8186a8f074774a1f.yaml b/releasenotes/notes/parallel-check-8186a8f074774a1f.yaml new file mode 100644 index 000000000000..d3266b2aa5f2 --- /dev/null +++ b/releasenotes/notes/parallel-check-8186a8f074774a1f.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + :meth:`.PassManager.run` will no longer waste time serializing itself when given multiple inputs + if it is only going to work in serial.