Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convert_module to FSDP #20323

Merged
merged 12 commits into from
Dec 10, 2024
Merged

Add convert_module to FSDP #20323

merged 12 commits into from
Dec 10, 2024

Conversation

tshu-w
Copy link
Contributor

@tshu-w tshu-w commented Oct 6, 2024

What does this PR do?

Add convert_module for FSDP as DeepSpeed.

Fixes #19721 (comment)

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--20323.org.readthedocs.build/en/20323/

@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels Oct 6, 2024
Copy link

codecov bot commented Oct 6, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 88%. Comparing base (030f36b) to head (9cbe473).
Report is 2 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #20323   +/-   ##
=======================================
  Coverage      88%      88%           
=======================================
  Files         267      267           
  Lines       23293    23304   +11     
=======================================
+ Hits        20396    20407   +11     
  Misses       2897     2897           

@tshu-w tshu-w force-pushed the FSDP branch 2 times, most recently from 7a0c355 to baeb535 Compare October 7, 2024 08:31
@lantiga
Copy link
Collaborator

lantiga commented Oct 7, 2024

Thank you @tshu-w!
Looks good in general, FSDP relies on contexts, but this may not cleanly apply when recomputations are involved.

As a sanity check, can you verify that the issues in #19721 are resolved? (i.e. memory goes back to what PyTorch uses, and no inconsistency errors are produced - these may be good tests to add btw, or at least a scaled-down version thereof).

I'll be happy to run things on my end and dig deeper in parallel.

@tshu-w
Copy link
Contributor Author

tshu-w commented Oct 7, 2024

I indeed noticed a decrease in VRAM usage (which I will confirm again in the coming week), even when I initialize the LLM in def configure_model as follows. However, I cannot guarantee that this PR resolves the original issue, as the author has manually set the LLM torch_dtype to torch.bfloat16. Nevertheless, I believe this might be able to solve part of the problem.

def configure_model(self):
    if self.model is not None:
        return

    self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path)
    # suppress the open-end generation warning
    self.model.generation_config.pad_token_id = (
        self.model.generation_config.pad_token_id
        or self.model.generation_config.eos_token_id
    )

    if self.hparams.peft_config:
        peft_config = get_peft_config(self.hparams.peft_config)
        self.model = get_peft_model(self.model, peft_config)

    if self.tokenizer.chat_template is None:
        self.tokenizer.chat_template = (
            self.chatml_template
            if self.hparams.use_chatml_template
            else self.base_template
        )
        if self.hparams.use_chatml_template:
            self.tokenizer.add_tokens(
                ["<|im_start|>", "<|im_end|>"], special_tokens=True
            )
            self.model.resize_token_embeddings(len(self.tokenizer))

    if self.hparams.ckpt_path:
        checkpoint = torch.load(self.hparams.ckpt_path, weights_only=True)
        self.load_state_dict(checkpoint["state_dict"])

@lantiga
Copy link
Collaborator

lantiga commented Nov 12, 2024

hey @tshu-w did you end up digging further?

@lantiga lantiga added the waiting on author Waiting on user action, correction, or update label Nov 12, 2024
@lantiga
Copy link
Collaborator

lantiga commented Nov 26, 2024

Checking memory gains on my end

@tshu-w
Copy link
Contributor Author

tshu-w commented Nov 29, 2024

Hey @lantiga I apologize for my slow response, and I regret to inform you that I won't have time to complete this in the near future. I'm glad to see that you're interested in helping with further verification.

@lantiga
Copy link
Collaborator

lantiga commented Dec 3, 2024

Thanks for the heads up @tshu-w, I'll take it from here
Can you allow me to write to your fork so I can just push changes?

@tshu-w
Copy link
Contributor Author

tshu-w commented Dec 3, 2024

Can you allow me to write to your fork so I can just push changes?

This PR is editable by maintainers

Screenshot 2024-12-03 at 21 26 52

and I've invited you as a collaborator in my fork, are there any other steps I need to take?

jedyang97 and others added 3 commits December 4, 2024 16:20
* make plugin type check more flexible

* Change signature and make the equivalent changes to Fabric connector

---------

Co-authored-by: Jianing Yang <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
@lantiga
Copy link
Collaborator

lantiga commented Dec 4, 2024

perfect, thank you @tshu-w !

@github-actions github-actions bot added the ci Continuous Integration label Dec 10, 2024
@github-actions github-actions bot added the docs Documentation related label Dec 10, 2024
Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find @tshu-w

To verify:

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.demos import Transformer, WikiText2


class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.model = None

    def configure_model(self):
        if self.model is not None:
            return

        with torch.device("meta"):
            self.model = Transformer(
                vocab_size=self.vocab_size,
                nlayers=16,
                nhid=4096,
                ninp=1024,
                nhead=32,
            )

    def training_step(self, batch):
        print("MODEL DTYPE:", self.dtype)
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)



def train():
    L.seed_everything(42)

    dataset = WikiText2()
    train_dataloader = DataLoader(dataset, num_workers=8, batch_size=4)

    model = LanguageModel(vocab_size=dataset.vocab_size)

    trainer = L.Trainer(strategy="fsdp", max_steps=100, precision="bf16-true")
    trainer.fit(model, train_dataloader)

    trainer.print(torch.cuda.memory_summary())


if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")

    train()

prior to this PR, the first line in training_step prints

MODEL DTYPE: torch.float32

with this PR

MODEL DTYPE: torch.bfloat16

We probably never noticed because it's not common to train in true precision, I'm still a bit puzzled tbh.

@lantiga lantiga merged commit 601c060 into Lightning-AI:master Dec 10, 2024
102 of 103 checks passed
@lantiga lantiga deleted the FSDP branch December 10, 2024 23:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci Continuous Integration docs Documentation related fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PyTorch Lightning FSDP takes more memory than PyTorch FSDP
3 participants