diff --git a/pympipool/flux/executor.py b/pympipool/flux/executor.py index 31a724ed..ba4f3a17 100644 --- a/pympipool/flux/executor.py +++ b/pympipool/flux/executor.py @@ -120,6 +120,7 @@ def __init__( hostname_localhost=False, ): super().__init__() + cloudpickle_register(ind=3) self._process = RaisingThread( target=execute_parallel_tasks, kwargs={ @@ -128,6 +129,7 @@ def __init__( "cores": cores, "interface_class": FluxPythonInterface, "hostname_localhost": hostname_localhost, + "init_function": init_function, # Interface Arguments "threads_per_core": threads_per_core, "gpus_per_core": gpus_per_task, @@ -136,8 +138,6 @@ def __init__( }, ) self._process.start() - self._set_init_function(init_function=init_function) - cloudpickle_register(ind=3) class FluxPythonInterface(BaseInterface): diff --git a/pympipool/mpi/executor.py b/pympipool/mpi/executor.py index eefa2431..ad4ef916 100644 --- a/pympipool/mpi/executor.py +++ b/pympipool/mpi/executor.py @@ -107,6 +107,7 @@ def __init__( hostname_localhost=False, ): super().__init__() + cloudpickle_register(ind=3) self._process = RaisingThread( target=execute_parallel_tasks, kwargs={ @@ -114,6 +115,7 @@ def __init__( "future_queue": self._future_queue, "cores": cores, "interface_class": MpiExecInterface, + "init_function": init_function, # Interface Arguments "cwd": cwd, "oversubscribe": oversubscribe, @@ -121,5 +123,3 @@ def __init__( }, ) self._process.start() - self._set_init_function(init_function=init_function) - cloudpickle_register(ind=3) diff --git a/pympipool/shared/executorbase.py b/pympipool/shared/executorbase.py index b0bc3cd7..c429f616 100644 --- a/pympipool/shared/executorbase.py +++ b/pympipool/shared/executorbase.py @@ -67,12 +67,6 @@ def __del__(self): except (AttributeError, RuntimeError): pass - def _set_init_function(self, init_function): - if init_function is not None: - self._future_queue.put( - {"init": True, "fn": init_function, "args": (), "kwargs": {}} - ) - def cancel_items_in_queue(que): """ @@ -119,6 +113,7 @@ def execute_parallel_tasks( cores, interface_class, hostname_localhost=False, + init_function=None, **kwargs, ): """ @@ -143,10 +138,15 @@ def execute_parallel_tasks( hostname_localhost=hostname_localhost, ), future_queue=future_queue, + init_function=init_function, ) -def execute_parallel_tasks_loop(interface, future_queue): +def execute_parallel_tasks_loop(interface, future_queue, init_function=None): + if init_function is not None: + interface.send_dict( + input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}} + ) while True: task_dict = future_queue.get() if "shutdown" in task_dict.keys() and task_dict["shutdown"]: @@ -166,9 +166,6 @@ def execute_parallel_tasks_loop(interface, future_queue): raise thread_exception else: future_queue.task_done() - elif "fn" in task_dict.keys() and "init" in task_dict.keys(): - interface.send_dict(input_dict=task_dict) - future_queue.task_done() def executor_broker( diff --git a/pympipool/slurm/executor.py b/pympipool/slurm/executor.py index 299ab5ee..db11b204 100644 --- a/pympipool/slurm/executor.py +++ b/pympipool/slurm/executor.py @@ -116,6 +116,7 @@ def __init__( hostname_localhost=False, ): super().__init__() + cloudpickle_register(ind=3) self._process = RaisingThread( target=execute_parallel_tasks, kwargs={ @@ -123,6 +124,7 @@ def __init__( "future_queue": self._future_queue, "cores": cores, "interface_class": SrunInterface, + "init_function": init_function, # Interface Arguments "threads_per_core": threads_per_core, "gpus_per_core": gpus_per_task, @@ -132,5 +134,3 @@ def __init__( }, ) self._process.start() - self._set_init_function(init_function=init_function) - cloudpickle_register(ind=3) diff --git a/tests/test_worker_memory.py b/tests/test_worker_memory.py index 7230e0f5..909c496a 100644 --- a/tests/test_worker_memory.py +++ b/tests/test_worker_memory.py @@ -35,7 +35,6 @@ def test_call_funct(self): def test_execute_task(self): f = Future() q = Queue() - q.put({"init": True, "fn": set_global, "args": (), "kwargs": {}}) q.put({"fn": get_global, "args": (), "kwargs": {}, "future": f}) q.put({"shutdown": True, "wait": True}) cloudpickle_register(ind=1) @@ -45,6 +44,7 @@ def test_execute_task(self): oversubscribe=False, interface_class=MpiExecInterface, hostname_localhost=True, + init_function=set_global, ) self.assertEqual(f.result(), np.array([5])) q.join()