Skip to content

Commit

Permalink
Merge pull request #266 from pyiron/init_function
Browse files Browse the repository at this point in the history
Call init function during interface creation
  • Loading branch information
jan-janssen authored Feb 18, 2024
2 parents d9edd2d + f5dfa73 commit 36090f6
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 17 deletions.
4 changes: 2 additions & 2 deletions pympipool/flux/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
hostname_localhost=False,
):
super().__init__()
cloudpickle_register(ind=3)
self._process = RaisingThread(
target=execute_parallel_tasks,
kwargs={
Expand All @@ -128,6 +129,7 @@ def __init__(
"cores": cores,
"interface_class": FluxPythonInterface,
"hostname_localhost": hostname_localhost,
"init_function": init_function,
# Interface Arguments
"threads_per_core": threads_per_core,
"gpus_per_core": gpus_per_task,
Expand All @@ -136,8 +138,6 @@ def __init__(
},
)
self._process.start()
self._set_init_function(init_function=init_function)
cloudpickle_register(ind=3)


class FluxPythonInterface(BaseInterface):
Expand Down
4 changes: 2 additions & 2 deletions pympipool/mpi/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,19 @@ def __init__(
hostname_localhost=False,
):
super().__init__()
cloudpickle_register(ind=3)
self._process = RaisingThread(
target=execute_parallel_tasks,
kwargs={
# Executor Arguments
"future_queue": self._future_queue,
"cores": cores,
"interface_class": MpiExecInterface,
"init_function": init_function,
# Interface Arguments
"cwd": cwd,
"oversubscribe": oversubscribe,
"hostname_localhost": hostname_localhost,
},
)
self._process.start()
self._set_init_function(init_function=init_function)
cloudpickle_register(ind=3)
17 changes: 7 additions & 10 deletions pympipool/shared/executorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ def __del__(self):
except (AttributeError, RuntimeError):
pass

def _set_init_function(self, init_function):
if init_function is not None:
self._future_queue.put(
{"init": True, "fn": init_function, "args": (), "kwargs": {}}
)


def cancel_items_in_queue(que):
"""
Expand Down Expand Up @@ -119,6 +113,7 @@ def execute_parallel_tasks(
cores,
interface_class,
hostname_localhost=False,
init_function=None,
**kwargs,
):
"""
Expand All @@ -143,10 +138,15 @@ def execute_parallel_tasks(
hostname_localhost=hostname_localhost,
),
future_queue=future_queue,
init_function=init_function,
)


def execute_parallel_tasks_loop(interface, future_queue):
def execute_parallel_tasks_loop(interface, future_queue, init_function=None):
if init_function is not None:
interface.send_dict(
input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}}
)
while True:
task_dict = future_queue.get()
if "shutdown" in task_dict.keys() and task_dict["shutdown"]:
Expand All @@ -166,9 +166,6 @@ def execute_parallel_tasks_loop(interface, future_queue):
raise thread_exception
else:
future_queue.task_done()
elif "fn" in task_dict.keys() and "init" in task_dict.keys():
interface.send_dict(input_dict=task_dict)
future_queue.task_done()


def executor_broker(
Expand Down
4 changes: 2 additions & 2 deletions pympipool/slurm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,15 @@ def __init__(
hostname_localhost=False,
):
super().__init__()
cloudpickle_register(ind=3)
self._process = RaisingThread(
target=execute_parallel_tasks,
kwargs={
# Executor Arguments
"future_queue": self._future_queue,
"cores": cores,
"interface_class": SrunInterface,
"init_function": init_function,
# Interface Arguments
"threads_per_core": threads_per_core,
"gpus_per_core": gpus_per_task,
Expand All @@ -132,5 +134,3 @@ def __init__(
},
)
self._process.start()
self._set_init_function(init_function=init_function)
cloudpickle_register(ind=3)
2 changes: 1 addition & 1 deletion tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_call_funct(self):
def test_execute_task(self):
f = Future()
q = Queue()
q.put({"init": True, "fn": set_global, "args": (), "kwargs": {}})
q.put({"fn": get_global, "args": (), "kwargs": {}, "future": f})
q.put({"shutdown": True, "wait": True})
cloudpickle_register(ind=1)
Expand All @@ -45,6 +44,7 @@ def test_execute_task(self):
oversubscribe=False,
interface_class=MpiExecInterface,
hostname_localhost=True,
init_function=set_global,
)
self.assertEqual(f.result(), np.array([5]))
q.join()

0 comments on commit 36090f6

Please sign in to comment.