From 3e565d2c67a69569042a909036bbe471614c63d1 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> Date: Mon, 26 Apr 2021 14:30:33 -0400 Subject: [PATCH] enable multiprocess function to use kwargs (#137) * enable multiprocess function to use kwargs * fix docstring of only public function * need keywoard arbuments in cpus and verbose now --- btk/draw_blends.py | 4 ++-- btk/measure.py | 4 ++-- btk/multiprocess.py | 41 +++++++++++++++++++++++++++++++++++------ 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/btk/draw_blends.py b/btk/draw_blends.py index f1f42c2b2..8c1383ba6 100644 --- a/btk/draw_blends.py +++ b/btk/draw_blends.py @@ -260,8 +260,8 @@ def __next__(self): mini_batch_results = multiprocess( self.render_mini_batch, input_args, - self.cpus, - self.verbose, + cpus=self.cpus, + verbose=self.verbose, ) # join results across mini-batches. diff --git a/btk/measure.py b/btk/measure.py index 91c0d88fa..c90d2faa2 100644 --- a/btk/measure.py +++ b/btk/measure.py @@ -227,8 +227,8 @@ def __next__(self): measure_results = multiprocess( self.run_batch, input_args, - self.cpus, - self.verbose, + cpus=self.cpus, + verbose=self.verbose, ) if self.verbose: print("Measurement performed on batch") diff --git a/btk/multiprocess.py b/btk/multiprocess.py index 9920c4a6a..7c92f3d65 100644 --- a/btk/multiprocess.py +++ b/btk/multiprocess.py @@ -1,20 +1,49 @@ """Tools for multiprocessing in BTK.""" import multiprocessing as mp +from itertools import repeat from itertools import starmap -def multiprocess(func, input_args, cpus, verbose=False): - """Sole Function that implements multiprocessing across mini-batches for BTK.""" +def _apply_args_and_kwargs(fn, args, kwargs): + return fn(*args, **kwargs) + + +def _pool_starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): + args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) + return pool.starmap(_apply_args_and_kwargs, args_for_starmap) + + +def _starmap_with_kwargs(fn, args_iter, kwargs_iter): + args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) + return starmap(_apply_args_and_kwargs, args_for_starmap) + + +def multiprocess(fn, args_iter, kwargs_iter=None, cpus=1, verbose=False): + """Sole function that implements multiprocessing across mini-batches/batches for BTK. + + Args: + fn (function): Function to run in parallel on each positional arguments returned by + `args_iter` and each keyword arguments returned by `kwargs_iter`. + args_iter (iter): Iterator returning positional arguments to be passed in to function for + multiprocessing. This iterator must have a `__len__` method implemented. Each + argument returned by the iterator must be unpackable like: `*args`. + kwargs_iter (iter): Iterator returning keyword arguments to be passed in to + function for multiprocessing. Default value `None` means that no keyword arguments + are passed in. Each element returned by the iterator must be a `dict`. + cpus (int): # of cpus to use for multiprocessing. + verbose (bool): Whether to print information related to multiprocessing + """ + kwargs_iter = repeat({}) if kwargs_iter is None else kwargs_iter if cpus > 1: if verbose: print( - f"Running mini-batch of size {len(input_args)} with multiprocessing with " + f"Running mini-batch of size {len(args_iter)} with multiprocessing with " f"pool {cpus}" ) with mp.Pool(processes=cpus) as pool: - results = pool.starmap(func, input_args) + results = _pool_starmap_with_kwargs(pool, fn, args_iter, kwargs_iter) else: if verbose: - print(f"Running mini-batch of size {len(input_args)} serial {cpus} times") - results = list(starmap(func, input_args)) + print(f"Running mini-batch of size {len(args_iter)} serial {cpus} times") + results = list(_starmap_with_kwargs(fn, args_iter, kwargs_iter)) return results