-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
137 lines (112 loc) · 5.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright 2024 Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn as nn
import transformers
from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed, get_open_port
import os
import hydra
import torch.multiprocessing as mp
from omegaconf import OmegaConf, DictConfig
import trainers
import wandb
import json
import socket
from typing import Optional, Set
import resource
OmegaConf.register_new_resolver("get_local_run_dir", lambda exp_name, local_dirs: get_local_run_dir(exp_name, local_dirs))
def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module, reference_model: Optional[nn.Module] = None):
"""Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer)."""
if 'FSDP' in config.trainer:
init_distributed(rank, world_size, port=config.fsdp_port)
if config.debug:
wandb.init = lambda *args, **kwargs: None
wandb.log = lambda *args, **kwargs: None
if rank == 0 and config.wandb.enabled:
os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs)
wandb.init(
entity=config.wandb.entity,
project=config.wandb.project,
config=OmegaConf.to_container(config),
dir=get_local_dir(config.local_dirs),
name=config.exp_name,
)
TrainerClass = getattr(trainers, config.trainer)
print(f'Creating trainer on process {rank} with world size {world_size}')
trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
trainer.train()
trainer.save()
@hydra.main(version_base=None, config_path="config", config_name="config")
def main(config: DictConfig):
"""Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es)."""
# Resolve hydra references, e.g. so we don't re-compute the run directory
OmegaConf.resolve(config)
missing_keys: Set[str] = OmegaConf.missing_keys(config)
if missing_keys:
raise ValueError(f"Got missing keys in config:\n{missing_keys}")
if config.eval_every % config.batch_size != 0:
print('WARNING: eval_every must be divisible by batch_size')
print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size)
config.eval_every = config.eval_every - config.eval_every % config.batch_size
if 'FSDP' in config.trainer and config.fsdp_port is None:
free_port = get_open_port()
print('no FSDP port specified; using open port for FSDP:', free_port)
config.fsdp_port = free_port
print(OmegaConf.to_yaml(config))
config_path = os.path.join(config.local_run_dir, 'config.yaml')
with open(config_path, 'w') as f:
OmegaConf.save(config, f)
print('=' * 80)
print(f'Writing to {socket.gethostname()}:{config.local_run_dir}')
print('=' * 80)
os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs)
print('building policy')
model_kwargs = {'device_map': 'balanced'} if config.trainer == 'BasicTrainer' else {}
policy_dtype = getattr(torch, config.model.policy_dtype)
policy = transformers.AutoModelForCausalLM.from_pretrained(
config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=policy_dtype, **model_kwargs)
disable_dropout(policy)
if config.loss.name in {'dpo', 'ipo', 'ct'}:
print('building reference model')
reference_model_dtype = getattr(torch, config.model.reference_dtype)
reference_model = transformers.AutoModelForCausalLM.from_pretrained(
config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=reference_model_dtype, **model_kwargs)
disable_dropout(reference_model)
else:
reference_model = None
if config.model.archive is not None:
state_dict = torch.load(config.model.archive, map_location='cpu')
print(state_dict.keys())
# step, metrics = state_dict['step_idx'], state_dict['metrics']
# print(f'loading pre-trained weights at step {step} from {config.model.archive} with metrics {json.dumps(metrics, indent=2)}')
policy.load_state_dict(state_dict)
# policy.load_state_dict(state_dict['state'])
if config.loss.name in {'dpo', 'ipo'}:
reference_model.load_state_dict(state_dict)
# reference_model.load_state_dict(state_dict['state'])
print('loaded pre-trained weights')
if 'FSDP' in config.trainer:
world_size = torch.cuda.device_count()
# import ipdb; ipdb.set_trace()
print('starting', world_size, 'processes for FSDP training')
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
print(f'setting RLIMIT_NOFILE soft limit to {hard} from {soft}')
mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
else:
print('starting single-process worker')
worker_main(0, 1, config, policy, reference_model)
if __name__ == '__main__':
main()