Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add torchrun plugin #1576

Closed
wants to merge 1 commit into from

Conversation

fg91
Copy link
Member

@fg91 fg91 commented Apr 3, 2023

TL;DR

Work in progress

This plugin allows running torch elastic (torchrun) distributed training with Flyte.

from dataclasses import dataclass

import torch
from dataclasses_json import dataclass_json
from flytekit import dynamic, task, workflow
from flytekitplugins.kfpytorch import PyTorch

from .torch_elastic_task import Elastic


@dataclass_json
@dataclass
class Config:
    lr: float = 1e-5
    bs: int = 64
    name: str = "foo"


@task
def init_model() -> torch.nn.Module:
    model = torch.nn.Linear(11, 22)

    return model


"""
This doesn't start a kubelfow pytorch job yet but a single python task Pod which then
runs a local worker group in sub-processes.
The changes in the flyteidl protobuf definitions, the flytekit python api, and the
flytepropeller (operator) which we need to actually make this distributed on multiple nodes
are easy (see RFC document linked in PR description).
"""
@task(
    task_config=Elastic(
        min_replicas=1,
        max_replicas=1,
        start_method="spawn",
    )
)
def train(config: Config, model: torch.nn.Module) -> tuple[str, Config, torch.nn.Module]:
    import os

    import torch

    local_rank = os.environ["LOCAL_RANK"]

    out_model = torch.nn.Linear(1000, int(local_rank) * 2000 + 1)
    print(f"Training with config {config}")
    config.name = "modified"
    return f"result from local rank {local_rank}", config, out_model


@workflow
def wf(config: Config=Config()) -> tuple[str, Config, torch.nn.Module]:
    model = init_model()
    return train(config=config, model=model)


if __name__ == "__main__":
    print(wf(config=Config()))

Type

  • Bug Fix
  • Feature
  • Plugin

Are all requirements met?

  • Code completed
  • Smoke tested
  • Unit tests added
  • Code documentation added
  • Any pending items have an associated Issue

Complete description

How did you fix the bug, make the feature etc. Link to any design docs etc

Tracking Issue

https://github.com/flyteorg/flyte/issues/

Follow-up issue

NA
OR
https://github.com/flyteorg/flyte/issues/

@codecov
Copy link

codecov bot commented Apr 3, 2023

Codecov Report

Merging #1576 (ff64a6e) into master (9658b02) will not change coverage.
The diff coverage is n/a.

@@           Coverage Diff           @@
##           master    #1576   +/-   ##
=======================================
  Coverage   69.92%   69.92%           
=======================================
  Files         319      319           
  Lines       29525    29525           
  Branches     5317     5317           
=======================================
  Hits        20644    20644           
  Misses       8365     8365           
  Partials      516      516           

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

start_method=self.task_config.start_method,
)

if self.task_config.start_method == "spawn":
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kumare3 what do you think of this workaround?

Unfortunately we can't just pass self._task_function to elastic_launch since it is not pickleable.

@fg91
Copy link
Member Author

fg91 commented Apr 3, 2023

@kumare3 I also opened draft PRs in flyteidl and flyteplugins. Please note that these are super early WIP prototypes because I wanted to first hack everything together to derisk any potential deal breakers such as the limitations imposed by pickle I then stumbled upon (see this comment).


I think this could be part of the existing kfpytorch plugin as an optional install pip install flytekitplugins-kfpytorch[elastic] (since it will have torch as a dependency).

In flytekit, this could look like this:

@task(
    task_config=PyTorch,
        num_workers=2,
        elastic_policy=ElasticPolicy(  # <- optional
            n_proc_per_node = ...,
            ....
        ),
    ),
)
def train():

If done this way, in flyteplugins most of the existing logic could be reused. We would only need a check whether the user configured an elastic policy and if so, add this to the PytorchJob object. I quickly hard-coded this in this draft PR.


One last point: I think it would be amazing if this would be able to start a local process group when running locally. What do you think about this?

@ByronHsu
Copy link
Collaborator

ByronHsu commented Apr 4, 2023

Do you mind creating an issue in flyte mp and reference the issues in all related prs. By that way, we can have a central place to track the feature

@fg91
Copy link
Member Author

fg91 commented Apr 5, 2023

Closed in favor of #1583

@fg91 fg91 closed this Apr 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants