-
Notifications
You must be signed in to change notification settings - Fork 4
/
checkpoint_utils.py
executable file
·803 lines (663 loc) · 25.8 KB
/
checkpoint_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
# Copyright 2024 The GPT-Accelera Team
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import re
import glob
from pathlib import Path
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType,
FullStateDictConfig,
FullOptimStateDictConfig,
)
from models.model import Transformer
from models.reward_model import (
RewardModel,
apply_reward_modeling_head,
)
from models.tp import (
apply_tp,
apply_reward_head_tp,
get_model_parallel_rank,
get_model_parallel_world_size,
get_data_parallel_rank,
get_data_parallel_group,
)
from arguments import Arguments
from training_utils.fsdp_utils import (
fixed_full_optim_state_dict,
fixed_scatter_full_optim_state_dict,
)
def rank0_print(*args):
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
print(*args)
def get_trainable_state_dict(
model: Union[torch.nn.Module, FSDP],
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
use_fsdp: bool = False,
trainable_param_names: Optional[list] = None,
save_only_model: bool = False,
compute_dtype: torch.dtype = torch.bfloat16,
) -> dict:
"""Returns a dictionary containing the trainable state of the model and optimizer.
Args:
model (torch.nn.Module): The model to save.
optimizer (torch.optim.Optimizer): The optimizer to save.
scheduler (torch.optim.lr_scheduler._LRScheduler): The scheduler to save.
Returns:
dict: A dictionary containing the trainable state of the model and optimizer.
"""
scheduler_state_dict = scheduler.state_dict()
if not use_fsdp:
model: torch.nn.Module = model
model_state_dict = model.state_dict()
optimizer_state_dict = optimizer.state_dict()
else:
model: FSDP = model
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
optim_cfg = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
module=model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=cfg,
optim_state_dict_config=optim_cfg,
):
model_state_dict = model.state_dict()
if save_only_model:
# convert model_state_dict to bf16
for key in model_state_dict:
model_state_dict[key] = model_state_dict[key].to(
dtype=compute_dtype
)
optimizer_state_dict = None
else:
optimizer_state_dict = fixed_full_optim_state_dict(
model=model,
optim=optimizer,
rank0_only=True,
group=get_data_parallel_group(),
)
# if get_data_parallel_rank() > 0:
# assert model_state_dict is None or len(model_state_dict) == 0
# assert optimizer_state_dict is None or len(optimizer_state_dict) == 0
# we only keep model parameters that are trainable
if get_data_parallel_rank() == 0:
if use_fsdp:
trainable_param_names = trainable_param_names
else:
trainable_param_names = [
name for name, param in model.named_parameters() if param.requires_grad
]
trainable_model_state_dict = {}
if trainable_param_names is not None:
trainable_param_names = set(trainable_param_names)
for name, param in model_state_dict.items():
if name in trainable_param_names:
trainable_model_state_dict[name] = param
trainable_param_names.remove(name)
if len(trainable_param_names) > 0:
rank0_print(
f"Missing {len(trainable_param_names)} parameters in model state dict:"
)
for name in trainable_param_names:
rank0_print(name)
raise ValueError("Missing parameters in model state dict")
else:
trainable_model_state_dict = model_state_dict
return {
"model": trainable_model_state_dict,
"optimizer": optimizer_state_dict,
"scheduler": scheduler_state_dict,
}
else:
return None
def save_checkpoint(
checkpoint_path: str,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
epoch: int,
global_step: int,
rank: int,
save_only_model: bool,
save_total_limit: int,
use_fsdp: bool = False,
trainable_param_names: Optional[list] = None,
prefix: str = "",
metrics: Optional[dict] = None,
compute_dtype: torch.dtype = torch.bfloat16,
) -> None:
"""Saves a checkpoint.
Args:
checkpoint_path (str): The path to save the checkpoint.
model (torch.nn.Module): The model to save.
optimizer (torch.optim.Optimizer): The optimizer to save.
scheduler (torch.optim.lr_scheduler._LRScheduler): The scheduler to save.
epoch (int): The epoch to save.
global_step (int): The global step to save.
"""
state_dict = get_trainable_state_dict(
model,
optimizer,
scheduler,
use_fsdp,
trainable_param_names,
save_only_model,
compute_dtype,
)
if state_dict is not None:
if save_only_model:
new_state_dict = {}
new_state_dict["model"] = state_dict["model"]
state_dict = new_state_dict
state_dict["epoch"] = epoch or -1
state_dict["global_step"] = global_step
if metrics is not None:
state_dict["metrics"] = metrics
torch.save(state_dict, checkpoint_path)
if dist.is_initialized():
dist.barrier()
if state_dict is None:
return
# write last_checkpoint
save_dir = "/".join(checkpoint_path.split("/")[:-1])
if rank == 0:
with open(f"{save_dir}/{prefix}last_checkpoint", "w") as f:
f.write(checkpoint_path.replace("_rank_0", "_rank*").split("/")[-1])
def extract_step_number(filename):
# Use a regular expression to find the number after 'step_'
match = re.search(r"step_(\d+)", filename)
if match:
# Return the number as an integer
return int(match.group(1))
else:
raise ValueError(f"Invalid checkpoint filename: {filename}")
if save_total_limit is not None:
checkpoints = glob.glob(f"{save_dir}/{prefix}*_rank_{rank}.pt")
checkpoints = sorted(checkpoints, key=extract_step_number)
if len(checkpoints) > save_total_limit:
message = checkpoints[0].replace("_rank_0", "_rank*")
rank0_print(f"Removing {message}")
os.remove(checkpoints[0])
def load_checkpoint(
checkpoint_path: str,
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
use_fsdp: bool = False,
) -> None:
"""Loads a checkpoint.
Args:
checkpoint_path (str): The path to load the checkpoint.
model (torch.nn.Module): The model to load.
optimizer (torch.optim.Optimizer): The optimizer to load.
scheduler (torch.optim.lr_scheduler._LRScheduler): The scheduler to load.
"""
load_rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
pattern = checkpoint_path.replace(f"_rank_{load_rank}", "_rank_*")
ckpt_world_size = len(glob.glob(pattern))
if ckpt_world_size != world_size:
raise ValueError(
f"The model parallel setting of resuming training "
"does not match the current setting: {ckpt_world_size} (train) vs {world_size} (resume)"
)
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
if not use_fsdp:
model.load_state_dict(
state_dict["model"],
strict=False,
)
if optimizer is not None:
optimizer.load_state_dict(state_dict["optimizer"])
else:
model: FSDP = model
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
module=model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=cfg,
):
model.load_state_dict(
state_dict["model"],
strict=False,
)
if get_data_parallel_rank() == 0:
full_osd = state_dict["optimizer"]
else:
full_osd = None
sharded_osd = fixed_scatter_full_optim_state_dict(
full_optim_state_dict=full_osd,
model=model,
optim=optimizer,
group=get_data_parallel_group(),
)
optimizer.load_state_dict(sharded_osd)
if scheduler is not None:
try:
scheduler.load_state_dict(state_dict["scheduler"])
except KeyError:
scheduler.step(state_dict["global_step"])
ret_dict = {
"epoch": state_dict["epoch"],
"global_step": state_dict["global_step"],
}
if "metrics" in state_dict:
ret_dict["metrics"] = state_dict["metrics"]
return ret_dict
def load_inference_checkpoint(
checkpoint_path: str,
model: torch.nn.Module,
) -> None:
"""Loads a checkpoint.
Args:
checkpoint_path (str): The path to load the checkpoint.
model (torch.nn.Module): The model to load.
optimizer (torch.optim.Optimizer): The optimizer to load.
scheduler (torch.optim.lr_scheduler._LRScheduler): The scheduler to load.
"""
if not torch.distributed.is_initialized():
state_dict = torch.load(checkpoint_path)["model"]
else:
load_rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
if f"_rank_{load_rank}" not in checkpoint_path:
raise ValueError(
f"Invalid checkpoint path: {checkpoint_path} for rank {load_rank}"
)
pattern = checkpoint_path.replace(f"_rank_{load_rank}", "_rank_*")
ckpt_world_size = len(glob.glob(pattern))
assert ckpt_world_size > 0, f"No checkpoint found: '{pattern}'"
if ckpt_world_size == world_size:
state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)[
"model"
]
elif ckpt_world_size > world_size:
assert ckpt_world_size % world_size == 0
dup_factor = ckpt_world_size // world_size
all_state_dict = []
for i in range(dup_factor):
ckpt_file_path = pattern.replace("*", str(load_rank * dup_factor + i))
model_state_dict = torch.load(
ckpt_file_path, map_location="cpu", mmap=True
)["model"]
for key in model_state_dict:
model_state_dict[key] = model_state_dict[key].cpu()
all_state_dict.append(model_state_dict)
state_dict = merge_tp_checkpoint(all_state_dict)
else:
assert world_size % ckpt_world_size == 0
split_factor = world_size // ckpt_world_size
split_idx = load_rank % split_factor
ckpt_file_path = pattern.replace("*", str(load_rank // split_factor))
full_state_dict = torch.load(ckpt_file_path, map_location="cpu", mmap=True)[
"model"
]
for key in full_state_dict:
full_state_dict[key] = full_state_dict[key].cpu()
state_dict = split_tp_checkpoint(full_state_dict, split_factor, split_idx)
if list(state_dict.keys())[0].startswith("module."):
assert all([k.startswith("module.") for k in state_dict.keys()])
state_dict = {k[len("module.") :]: v for k, v in state_dict.items()}
# remove size un-matching parameters
model_state_dict = model.state_dict()
for key in list(state_dict.keys()):
if model_state_dict[key].size() != state_dict[key].size():
rank0_print("Warning: removing", key, "from checkpoint.")
del state_dict[key]
model.load_state_dict(
state_dict,
strict=False,
)
return None, None
def merge_tp_checkpoint(
all_state_dict: list,
):
state_dict = {}
all_keys = list(all_state_dict[0].keys())
for key in all_keys:
if "wqkv" in key:
if all_state_dict[0][key].size(1) == 4096:
n_heads, n_local_heads = 32, 32
elif all_state_dict[0][key].size(1) == 5120:
n_heads, n_local_heads = 40, 40
elif all_state_dict[0][key].size(1) == 6656:
n_heads, n_local_heads = 52, 52
elif all_state_dict[0][key].size(1) == 8192:
n_heads, n_local_heads = 64, 8
else:
raise ValueError(
f"Invalid size for {key}: {all_state_dict[0][key].size(1)}"
)
head_dim = all_state_dict[0][key].size(0) // (n_heads + n_local_heads * 2)
weight_splits = [
head_dim * n_heads,
head_dim * n_local_heads,
head_dim * n_local_heads,
]
merged_q, merged_k, merged_v = [], [], []
for x in all_state_dict:
q, k, v = x[key].split(weight_splits, dim=0)
merged_q.append(q)
merged_k.append(k)
merged_v.append(v)
merged_q = torch.cat(merged_q, dim=0)
merged_k = torch.cat(merged_k, dim=0)
merged_v = torch.cat(merged_v, dim=0)
state_dict[key] = torch.cat([merged_q, merged_k, merged_v], dim=0)
elif "wo" in key or "w2" in key:
state_dict[key] = torch.cat([x[key] for x in all_state_dict], dim=1)
elif "output.weight" in key:
state_dict[key] = torch.cat([x[key] for x in all_state_dict], dim=1)
elif "norm" in key or "tok_embeddings" in key:
state_dict[key] = all_state_dict[0][key]
else:
state_dict[key] = torch.cat([x[key] for x in all_state_dict], dim=0)
# free the memory
for x in all_state_dict:
del x[key]
return state_dict
def split_tp_checkpoint(
full_state_dict: dict,
split_factor: int,
split_idx: int,
):
state_dict = {}
all_keys = list(full_state_dict.keys())
for key in all_keys:
if "wqkv" in key:
if full_state_dict[key].size(1) == 4096:
n_heads, n_local_heads = 32, 32
elif full_state_dict[key].size(1) == 5120:
n_heads, n_local_heads = 40, 40
elif full_state_dict[key].size(1) == 6656:
n_heads, n_local_heads = 52, 52
elif full_state_dict[key].size(1) == 8192:
n_heads, n_local_heads = 64, 8
else:
raise ValueError(
f"Invalid size for {key}: {full_state_dict[key].size(1)}"
)
head_dim = full_state_dict[key].size(0) // (n_heads + n_local_heads * 2)
weight_splits = [
head_dim * n_heads,
head_dim * n_local_heads,
head_dim * n_local_heads,
]
q, k, v = full_state_dict[key].split(weight_splits, dim=0)
q = torch.tensor_split(q, split_factor, dim=0)[split_idx]
k = torch.tensor_split(k, split_factor, dim=0)[split_idx]
v = torch.tensor_split(v, split_factor, dim=0)[split_idx]
state_dict[key] = torch.cat([q, k, v], dim=0)
elif "wo" in key or "w2" in key:
state_dict[key] = torch.tensor_split(
full_state_dict[key], split_factor, dim=1
)[split_idx]
elif "output.weight" in key:
state_dict[key] = torch.tensor_split(
full_state_dict[key], split_factor, dim=1
)[split_idx]
elif "norm" in key or "tok_embeddings" in key:
state_dict[key] = full_state_dict[key]
else:
state_dict[key] = torch.tensor_split(
full_state_dict[key], split_factor, dim=0
)[split_idx]
# free the memory
del full_state_dict[key]
return state_dict
def get_checkpoint_path(
checkpoint_dir: str,
epoch: Optional[int],
global_step: int,
tp_rank: int,
prefix: str,
) -> str:
"""Returns the path to save the checkpoint.
Args:
checkpoint_dir (str): The directory to save the checkpoint.
epoch (int): The epoch to save.
global_step (int): The global step to save.
Returns:
str: The path to save the checkpoint.
"""
if epoch is None:
return f"{checkpoint_dir}/{prefix}step_{global_step}_rank_{tp_rank}.pt"
else:
return f"{checkpoint_dir}/{prefix}epoch_{epoch}_step_{global_step}_rank_{tp_rank}.pt"
def get_latest_checkpoint_path(
checkpoint_dir: str, prefix: Optional[str] = None
) -> str:
"""Search and return the path to the latest checkpoint.
Args:
checkpoint_dir (str): The directory to save the checkpoint.
prefix (str): The prefix of the checkpoint.
Returns:
str: The path to the latest checkpoint.
int: The epoch to save.
int: The global step to save.
"""
if prefix is None:
prefix = ""
if not os.path.exists(f"{checkpoint_dir}/{prefix}last_checkpoint"):
return None, 0, 0
with open(f"{checkpoint_dir}/{prefix}last_checkpoint", "r") as f:
last_checkpoint_file = f.read()
last_checkpoint_file = last_checkpoint_file.split("/")[-1]
last_checkpoint_file = os.path.join(checkpoint_dir, last_checkpoint_file)
last_checkpoint_file = last_checkpoint_file.replace("_rank*", "_rank_0")
if dist.is_initialized():
rank = get_model_parallel_rank()
last_checkpoint_file = last_checkpoint_file.replace("_rank_0", f"_rank_{rank}")
epoch = None
if "epoch_" in last_checkpoint_file:
match = re.search(r"epoch_(\d+)", last_checkpoint_file)
if match:
epoch = int(match.group(1))
else:
raise ValueError(f"Invalid checkpoint filename: {last_checkpoint_file}")
if "step_" in last_checkpoint_file:
match = re.search(r"step_(\d+)", last_checkpoint_file)
if match:
global_step = int(match.group(1))
else:
raise ValueError(f"Invalid checkpoint filename: {last_checkpoint_file}")
last_checkpoint_file = last_checkpoint_file.strip()
return last_checkpoint_file, epoch, global_step
def checkpoint_hook(
args: Arguments,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
epoch: int,
global_step: int,
epoch_length: int,
use_fsdp: bool = False,
trainable_param_names: Optional[list] = None,
prefix: str = "",
metrics: Optional[dict] = None,
):
if args.save_strategy == "no":
return
rank = 0
if dist.is_initialized():
rank = get_model_parallel_rank()
# make sure the checkpoint dir exists
os.makedirs(args.save_dir, exist_ok=True)
save_flag = False
if args.save_strategy == "epoch":
if global_step % epoch_length == 0 and global_step != 0:
save_flag = True
elif args.save_strategy == "steps":
if global_step % args.save_steps == 0 and global_step != 0:
save_flag = True
elif metrics is not None:
save_flag = True
else:
raise ValueError(f"Invalid save strategy: {args.save_strategy}")
if save_flag:
checkpoint_path = get_checkpoint_path(
args.save_dir, epoch, global_step, rank, prefix
)
save_checkpoint(
checkpoint_path,
model,
optimizer,
scheduler,
epoch,
global_step,
rank,
args.save_only_model,
args.save_total_limit,
use_fsdp,
trainable_param_names,
prefix,
metrics,
args.compute_dtype,
)
def load_model_from_from_ckpt(
checkpoint_path: Path,
sft_checkpoint_path: Optional[Path],
device: torch.device,
precision: torch.dtype,
use_tp: bool,
requires_grad: bool,
skip_init: bool = False,
sequence_parallel: bool = False,
vocab_parallel: bool = False,
):
with torch.device("meta"):
model = Transformer.from_name(
checkpoint_path.parent.name,
freeze_tok_embeddings=True,
freeze_output=True,
freeze_norm=True,
vocab_parallel=vocab_parallel,
)
if "int8" in str(checkpoint_path):
raise NotImplementedError("int8 quantization cannot be used for finetuning!")
if "int4" in str(checkpoint_path):
raise NotImplementedError("int4 quantization cannot be used for finetuning!")
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
if use_tp:
rank0_print("Applying tensor parallel to model ...")
apply_tp(
model, requires_grad=requires_grad, sequence_parallel=sequence_parallel
)
if not skip_init:
sft_checkpoint_path, _, _ = get_latest_checkpoint_path(sft_checkpoint_path)
if sft_checkpoint_path is not None:
rank0_print(
f"Loading sft model from {sft_checkpoint_path.replace(f'_rank_0', '_rank_*')} ..."
)
load_inference_checkpoint(sft_checkpoint_path, model)
else:
rank0_print(
"Warning: no sft checkpoint found, using base checkpoint."
" (OK for unwrapped policy / resuming training / train from scratch)."
)
model.requires_grad_(requires_grad)
model = model.to(device=device, dtype=precision)
return model.train(requires_grad)
def load_reward_model_from_ckpt(
checkpoint_path: Path,
rm_checkpoint_path: Optional[Path],
device: torch.device,
precision: torch.dtype,
use_tp: bool,
requires_grad: bool,
skip_init: bool = False,
sequence_parallel: bool = False,
vocab_parallel: bool = False,
):
with torch.device("meta"):
model = RewardModel.from_name(
checkpoint_path.parent.name,
freeze_tok_embeddings=True,
freeze_norm=True,
vocab_parallel=vocab_parallel,
)
if "int8" in str(checkpoint_path):
raise NotImplementedError("int8 quantization cannot be used for finetuning!")
if "int4" in str(checkpoint_path):
raise NotImplementedError("int4 quantization cannot be used for finetuning!")
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.backbone_model.load_state_dict(checkpoint, assign=True)
if use_tp:
rank0_print("Applying tensor parallel to model ...")
apply_tp(
model.backbone_model,
requires_grad=requires_grad,
sequence_parallel=sequence_parallel,
)
apply_reward_modeling_head(model.backbone_model, requires_grad=requires_grad)
if use_tp:
rank0_print("Applying tensor parallel to reward head ...")
apply_reward_head_tp(model.backbone_model, requires_grad=requires_grad)
if not skip_init:
rm_checkpoint_path, _, _ = get_latest_checkpoint_path(rm_checkpoint_path)
if rm_checkpoint_path is not None:
rank0_print(
f"Loading reward model from {rm_checkpoint_path.replace(f'_rank_0', '_rank_*')} ..."
)
load_inference_checkpoint(rm_checkpoint_path, model)
rank0_print("Reward head", model.backbone_model.output.weight)
else:
rank0_print("Warning: no rm checkpoint found, using base checkpoint.")
model.requires_grad_(requires_grad)
model = model.to(device=device, dtype=precision)
return model.train(requires_grad)
def load_reward_model_from_sft_ckpt(
checkpoint_path: Path,
sft_checkpoint_path: Optional[Path],
device: torch.device,
precision: torch.dtype,
use_tp: bool,
requires_grad: bool,
reward_head_init_scheme: str = "zeros",
sequence_parallel: bool = False,
vocab_parallel: bool = False,
):
with torch.device("meta"):
model = RewardModel.from_name(
checkpoint_path.parent.name,
freeze_tok_embeddings=True,
freeze_norm=True,
vocab_parallel=vocab_parallel,
)
if "int8" in str(checkpoint_path):
raise NotImplementedError("int8 quantization cannot be used for finetuning!")
if "int4" in str(checkpoint_path):
raise NotImplementedError("int4 quantization cannot be used for finetuning!")
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.backbone_model.load_state_dict(checkpoint, assign=True)
if use_tp:
print("Applying tensor parallel to model ...")
apply_tp(
model.backbone_model,
requires_grad=True,
sequence_parallel=sequence_parallel,
)
sft_checkpoint_path, _, _ = get_latest_checkpoint_path(sft_checkpoint_path)
if sft_checkpoint_path is not None:
print("Loading sft model ...")
load_inference_checkpoint(sft_checkpoint_path, model.backbone_model)
apply_reward_modeling_head(
model.backbone_model,
requires_grad=requires_grad,
init_sceheme=reward_head_init_scheme,
)
if use_tp:
print("Applying tensor parallel to reward head ...")
apply_reward_head_tp(model.backbone_model, requires_grad=requires_grad)
model.requires_grad_(requires_grad)
model = model.to(device=device, dtype=precision)
return model.train(requires_grad)