Skip to content

Commit

Permalink
Commit existing prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabio Grätz committed Apr 3, 2023
1 parent 9658b02 commit ff64a6e
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions plugins/flytekit-torch-elastic/flytekitplugins/torch_elastic/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from dataclasses import dataclass
from typing import Any, Callable

from flytekit import FlyteContextManager, PythonFunctionTask
from flytekit.core.context_manager import ExecutionParameters
from flytekit.extend import TaskPlugins
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from torch.distributed.run import parse_min_max_nnodes


@dataclass
class Elastic(object):
min_replicas: int
max_replicas: int
start_method: str = "spawn"


def mp_helper(fn, kwargs):
print("Using start method spawn")
import dill

fn = dill.loads(fn)

return_val = fn(**kwargs)

return return_val


class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]):
"""
Actual Plugin that transforms the local python code for execution within a spark context
"""

_ELASTIC_TASK_TYPE = "torch-elastic"

def __init__(self, task_config: Elastic, task_function: Callable, **kwargs):
super(PytorchElasticFunctionTask, self).__init__(
task_config=task_config,
task_type=self._ELASTIC_TASK_TYPE,
task_function=task_function,
**kwargs,
)

def execute(self, **kwargs) -> Any:
"""
This method will be invoked to execute the task. If you do decide to override this method you must also
handle dynamic tasks or you will no longer be able to use the task as a dynamic task generator.
"""
min_nodes, max_nodes = 1, 1

"""
All this is mocked.
"""
config = LaunchConfig(
run_id="none",
min_nodes=min_nodes,
max_nodes=max_nodes,
nproc_per_node=4,
rdzv_backend="c10d",
max_restarts=1,
# rdzv_endpoint = "foo"
start_method=self.task_config.start_method,
)

if self.task_config.start_method == "spawn":
"""
If the user wants to use spawn, we use dill to serialize the task function.
We then tell the torch elastic launcher to launch the mp_helper function (which is pickleable)
instead of the task function. This helper function, in the child-process, then deserializes
the task function, again with dill, and executes it.
Note from a few weeks later:
We might be able to pass the string representation of the task which is used by the task
resolver to the helper function. In the child process, the task resolver could retrieve the task.
But we would need a way to not start further child processes from the child process but just execute
the task function. But the idea basically is: pass task name to child process and use task resolver
instead of dill serialisation.
"""
launcher_target_func = mp_helper

import dill

dumped_target_function = dill.dumps(self._task_function)
launcher_args = (dumped_target_function, kwargs)
elif self.task_config.start_method == "fork":
"""
If the user wants to do fork, we don't have to serialize the task function with dill.
However, the torch elastic launcher doesn't support passing kwargs to the target function,
only args. Flyte only works with kwargs.
Thus, we create a closure which already has the task kwargs bound. We tell the torch elastic
launcher to start this function in the child processes.
"""

def fn_partial():
print("Using start method fork")
return self._task_function(**kwargs)

launcher_target_func = fn_partial
launcher_args = ()

else:
raise Exception("Bad start method")

# put the try catch here
# https://github.com/flyteorg/flytekit/blob/47ac6ac2a547fdd2b46654db5163493b4f33dbb2/flytekit/core/python_function_task.py#L164

out = elastic_launch(
config=config,
entrypoint=launcher_target_func,
)(*launcher_args)

return out[0]

# return super().execute(**kwargs)


TaskPlugins.register_pythontask_plugin(Elastic, PytorchElasticFunctionTask)

0 comments on commit ff64a6e

Please sign in to comment.