-
Notifications
You must be signed in to change notification settings - Fork 301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Add torchrun plugin #1576
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
117 changes: 117 additions & 0 deletions
117
plugins/flytekit-torch-elastic/flytekitplugins/torch_elastic/task.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Callable | ||
|
||
from flytekit import FlyteContextManager, PythonFunctionTask | ||
from flytekit.core.context_manager import ExecutionParameters | ||
from flytekit.extend import TaskPlugins | ||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch | ||
from torch.distributed.run import parse_min_max_nnodes | ||
|
||
|
||
@dataclass | ||
class Elastic(object): | ||
min_replicas: int | ||
max_replicas: int | ||
start_method: str = "spawn" | ||
|
||
|
||
def mp_helper(fn, kwargs): | ||
print("Using start method spawn") | ||
import dill | ||
|
||
fn = dill.loads(fn) | ||
|
||
return_val = fn(**kwargs) | ||
|
||
return return_val | ||
|
||
|
||
class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]): | ||
""" | ||
Actual Plugin that transforms the local python code for execution within a spark context | ||
""" | ||
|
||
_ELASTIC_TASK_TYPE = "torch-elastic" | ||
|
||
def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): | ||
super(PytorchElasticFunctionTask, self).__init__( | ||
task_config=task_config, | ||
task_type=self._ELASTIC_TASK_TYPE, | ||
task_function=task_function, | ||
**kwargs, | ||
) | ||
|
||
def execute(self, **kwargs) -> Any: | ||
""" | ||
This method will be invoked to execute the task. If you do decide to override this method you must also | ||
handle dynamic tasks or you will no longer be able to use the task as a dynamic task generator. | ||
""" | ||
min_nodes, max_nodes = 1, 1 | ||
|
||
""" | ||
All this is mocked. | ||
""" | ||
config = LaunchConfig( | ||
run_id="none", | ||
min_nodes=min_nodes, | ||
max_nodes=max_nodes, | ||
nproc_per_node=4, | ||
rdzv_backend="c10d", | ||
max_restarts=1, | ||
# rdzv_endpoint = "foo" | ||
start_method=self.task_config.start_method, | ||
) | ||
|
||
if self.task_config.start_method == "spawn": | ||
""" | ||
If the user wants to use spawn, we use dill to serialize the task function. | ||
We then tell the torch elastic launcher to launch the mp_helper function (which is pickleable) | ||
instead of the task function. This helper function, in the child-process, then deserializes | ||
the task function, again with dill, and executes it. | ||
|
||
Note from a few weeks later: | ||
We might be able to pass the string representation of the task which is used by the task | ||
resolver to the helper function. In the child process, the task resolver could retrieve the task. | ||
But we would need a way to not start further child processes from the child process but just execute | ||
the task function. But the idea basically is: pass task name to child process and use task resolver | ||
instead of dill serialisation. | ||
""" | ||
launcher_target_func = mp_helper | ||
|
||
import dill | ||
|
||
dumped_target_function = dill.dumps(self._task_function) | ||
launcher_args = (dumped_target_function, kwargs) | ||
elif self.task_config.start_method == "fork": | ||
""" | ||
If the user wants to do fork, we don't have to serialize the task function with dill. | ||
However, the torch elastic launcher doesn't support passing kwargs to the target function, | ||
only args. Flyte only works with kwargs. | ||
Thus, we create a closure which already has the task kwargs bound. We tell the torch elastic | ||
launcher to start this function in the child processes. | ||
""" | ||
|
||
def fn_partial(): | ||
print("Using start method fork") | ||
return self._task_function(**kwargs) | ||
|
||
launcher_target_func = fn_partial | ||
launcher_args = () | ||
|
||
else: | ||
raise Exception("Bad start method") | ||
|
||
# put the try catch here | ||
# https://github.com/flyteorg/flytekit/blob/47ac6ac2a547fdd2b46654db5163493b4f33dbb2/flytekit/core/python_function_task.py#L164 | ||
|
||
out = elastic_launch( | ||
config=config, | ||
entrypoint=launcher_target_func, | ||
)(*launcher_args) | ||
|
||
return out[0] | ||
|
||
# return super().execute(**kwargs) | ||
|
||
|
||
TaskPlugins.register_pythontask_plugin(Elastic, PytorchElasticFunctionTask) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kumare3 what do you think of this workaround?
Unfortunately we can't just pass
self._task_function
toelastic_launch
since it is not pickleable.