Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minLoRA #1

Merged
merged 1 commit into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions config/finetune_shakespeare.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import time
from functools import partial

import torch
from minlora import LoRAParametrization

out_dir = 'out-shakespeare'
eval_interval = 5
Expand All @@ -9,6 +13,8 @@

dataset = 'shakespeare'
init_from = 'gpt2-xl' # this is the largest GPT-2 model
init_from = 'gpt2-large' # use a smaller for faster training
# xl doesn't fit on 24GB GPU, but with LORA it does

# only save checkpoints if the validation loss improves
always_save_checkpoint = False
Expand All @@ -23,3 +29,18 @@
# finetune at constant LR
learning_rate = 3e-5
decay_lr = False


use_lora = True
learning_rate = 1e-3 # use a higher LR for LoRA
lora_dropout_p = 0.0
rank=4
lora_alpha = 64
lora_config = {
torch.nn.Embedding: {
"weight": partial(LoRAParametrization.from_embedding, rank=rank, lora_alpha=lora_alpha),
},
torch.nn.Linear: {
"weight": partial(LoRAParametrization.from_linear, rank=rank, lora_alpha=lora_alpha),
},
}
12 changes: 12 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import tiktoken
from model import GPTConfig, GPT
import minlora

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
Expand Down Expand Up @@ -38,12 +39,23 @@
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
if use_lora:
minlora.add_lora(model, lora_config)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
if use_lora:
# the full state dict includes the LoRA state dict
# so actually we don't need to load it separately again
model.load_state_dict(checkpoint['lora'], strict=False)
print('Loaded LoRA state dict')
# sanity check
#model.apply(minlora.apply_to_lora(lambda m: print((m.lora_A.sum(), m.lora_B.sum()))))
# merge for zero-overhead inference
minlora.merge_lora(model)
elif init_from.startswith('gpt2'):
# init from a given GPT-2 model
model = GPT.from_pretrained(init_from, dict(dropout=0.0))
Expand Down
39 changes: 38 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time
import math
import pickle
import inspect
from contextlib import nullcontext

import numpy as np
Expand All @@ -29,6 +30,7 @@

from model import GPTConfig, GPT

import minlora
# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
Expand Down Expand Up @@ -180,13 +182,46 @@ def get_batch(split):
if block_size < model.config.block_size:
model.crop_block_size(block_size)
model_args['block_size'] = block_size # so that the checkpoint will have the right value
if use_lora:
minlora.add_lora(model, lora_config=lora_config)
minlora.tie_weights(linear=model.lm_head, embedding=model.transformer.wte)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, can you explain why you tie weights here?

model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
def configure_optimizers_lora(self, weight_decay, learning_rate, betas, device_type):
# we apply weight decay to all lora params
optim_groups = [
# note: .get_lora_params() returns a generator
# we need to wrap it in a list so we can consume it twice
{"params": list(minlora.get_lora_params(self)) , "weight_decay": weight_decay},
# you can also add biases for fine-tuning,
# but I want to make sure lora alone works
# {"params": minlora.get_bias_params(self), "weight_decay": 0.0}, # bias params don't get weight decay
]

def parameter_count(optim_groups):
n = sum(p.numel() for group in optim_groups for p in group["params"])
if n < 1e6:
return f"{n/1e3:.1f}k"
else:
return f"{n/1e6:.1f}M"

print(f"optimizing {parameter_count(optim_groups)} parameters")

# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
use_fused = (device_type == "cuda") and ("fused" in inspect.signature(torch.optim.AdamW).parameters)
print(f"using fused AdamW: {use_fused}")
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)

return optimizer
if use_lora:
optimizer = configure_optimizers_lora(model, weight_decay, learning_rate, (beta1, beta2), device_type)
else:
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])

Expand Down Expand Up @@ -271,6 +306,8 @@ def get_lr(it):
'best_val_loss': best_val_loss,
'config': config,
}
if use_lora:
checkpoint['lora'] = minlora.get_lora_state_dict(raw_model)
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
if iter_num == 0 and eval_only:
Expand Down