-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add convert_module
to FSDP
#20323
Add convert_module
to FSDP
#20323
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20323 +/- ##
=======================================
Coverage 88% 88%
=======================================
Files 267 267
Lines 23293 23304 +11
=======================================
+ Hits 20396 20407 +11
Misses 2897 2897 |
7a0c355
to
baeb535
Compare
Thank you @tshu-w! As a sanity check, can you verify that the issues in #19721 are resolved? (i.e. memory goes back to what PyTorch uses, and no inconsistency errors are produced - these may be good tests to add btw, or at least a scaled-down version thereof). I'll be happy to run things on my end and dig deeper in parallel. |
I indeed noticed a decrease in VRAM usage (which I will confirm again in the coming week), even when I initialize the LLM in def configure_model(self):
if self.model is not None:
return
self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path)
# suppress the open-end generation warning
self.model.generation_config.pad_token_id = (
self.model.generation_config.pad_token_id
or self.model.generation_config.eos_token_id
)
if self.hparams.peft_config:
peft_config = get_peft_config(self.hparams.peft_config)
self.model = get_peft_model(self.model, peft_config)
if self.tokenizer.chat_template is None:
self.tokenizer.chat_template = (
self.chatml_template
if self.hparams.use_chatml_template
else self.base_template
)
if self.hparams.use_chatml_template:
self.tokenizer.add_tokens(
["<|im_start|>", "<|im_end|>"], special_tokens=True
)
self.model.resize_token_embeddings(len(self.tokenizer))
if self.hparams.ckpt_path:
checkpoint = torch.load(self.hparams.ckpt_path, weights_only=True)
self.load_state_dict(checkpoint["state_dict"]) |
hey @tshu-w did you end up digging further? |
Checking memory gains on my end |
Hey @lantiga I apologize for my slow response, and I regret to inform you that I won't have time to complete this in the near future. I'm glad to see that you're interested in helping with further verification. |
Thanks for the heads up @tshu-w, I'll take it from here |
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Luca Antiga <[email protected]>
* make plugin type check more flexible * Change signature and make the equivalent changes to Fabric connector --------- Co-authored-by: Jianing Yang <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
perfect, thank you @tshu-w ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great find @tshu-w
To verify:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.demos import Transformer, WikiText2
class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.model = None
def configure_model(self):
if self.model is not None:
return
with torch.device("meta"):
self.model = Transformer(
vocab_size=self.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)
def training_step(self, batch):
print("MODEL DTYPE:", self.dtype)
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-4)
def train():
L.seed_everything(42)
dataset = WikiText2()
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=4)
model = LanguageModel(vocab_size=dataset.vocab_size)
trainer = L.Trainer(strategy="fsdp", max_steps=100, precision="bf16-true")
trainer.fit(model, train_dataloader)
trainer.print(torch.cuda.memory_summary())
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
train()
prior to this PR, the first line in training_step
prints
MODEL DTYPE: torch.float32
with this PR
MODEL DTYPE: torch.bfloat16
We probably never noticed because it's not common to train in true precision, I'm still a bit puzzled tbh.
What does this PR do?
Add
convert_module
for FSDP as DeepSpeed.Fixes #19721 (comment)
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20323.org.readthedocs.build/en/20323/