-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_lora.py
348 lines (293 loc) · 12.2 KB
/
train_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import os
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from tqdm import tqdm, trange
from functools import partial
import jax
if jax.default_backend() == 'gpu':
os.environ['XLA_FLAGS'] = (
'--xla_gpu_triton_gemm_any=false '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_async_all_gather=true '
'--xla_gpu_enable_async_reduce_scatter=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
'--xla_gpu_collective_permute_decomposer_threshold=1024 '
'--xla_gpu_all_reduce_combine_threshold_bytes=51200 '
'--xla_gpu_simplify_all_fp_conversions=true '
)
import jax.numpy as jnp
import optax
import flax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as PS
import datasets
from transformers import AutoTokenizer
from simple_parsing import ArgumentParser
from simple_parsing.helpers import list_field
import magix
import magix.models
import magix.lora
from magix import (
get_chckpoint_manager,
load_model_hub,
)
def apply_chat_template(turns: Iterable[Dict[str, str]], eos_token: str = None):
ROLE_DICT = {
'user': '<|user|>',
'assistant': '<|assistant|>',
'system': '<|system|>',
}
def _format(turn):
role, content = turn['role'], turn['content']
return f"{ROLE_DICT[role]}\n{content}{eos_token}"
return '\n'.join(_format(turn) for turn in turns)
class TrainDataset:
def __init__(
self,
train_data,
tokenizer,
field_name: str = 'text',
max_len: int = 1024,
use_chat_template: bool = False,
):
self.data = train_data
self.tokenizer = tokenizer
self.field_name = field_name
self.max_len = max_len
self.use_chat_template = use_chat_template
def __len__(self):
return len(self.data)
def get_batch(self, indices):
batch = self.data[indices]
batch = batch[self.field_name]
if self.use_chat_template:
batch = [apply_chat_template(turns, eos_token=self.tokenizer.eos_token) for turns in batch]
tokenized = self.tokenizer(
batch, max_length=self.max_len+1, padding='max_length',
truncation=True, return_tensors='np',
)
return dict(tokenized)
class Batches:
def __init__(
self,
rng: jax.random.PRNGKey,
dataset: TrainDataset,
batch_size: int,
shuffle: bool = False
):
steps_per_epoch = len(dataset) // batch_size
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
else:
batch_idx = jnp.arange(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
self.dataset = dataset
self.batch_idx = batch_idx
def __call__(self, step):
idx = self.batch_idx[step]
batch = self.dataset.get_batch(idx)
return batch
def decay_mask_fn(params):
flat_params = flax.traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and 'layernorm' not in path[-2]) for path in flat_params}
return flax.traverse_util.unflatten_dict(flat_mask)
@dataclass
class TrainArgs:
train_file: str = None
train_data_config: str = None
train_data_field: str = 'text'
split: str = 'train'
use_chat_template: bool = False
checkpoint_dir: str = None
max_length: int = 1024
num_epochs: int = 1
batch_size: int = 16
num_target_passages: int = 16
query_num_chunks: int = 4
passage_num_chunks: int = 8
learning_rate: float = 2e-6
weight_decay: float = 0.0001
adam_beta1: float = 0.9
adam_beta2: float = 0.999
max_grad_norm: float = 1.0
save_steps: int = 200
seed: int = 42
lora_alpha: float = 32.0
lora_rank: int = 8
@dataclass
class ModelArgs:
model_type: str = 'llama'
model_name: str = None
tokenizer_name: str = None
model_cache_dir: str = None
mesh_shape: List[int] = list_field(-1, 1)
bf16_model_weights: bool = False
def main():
parser = ArgumentParser()
parser.add_arguments(TrainArgs, dest="train_args")
parser.add_arguments(ModelArgs, dest="model_args")
args = parser.parse_args()
train_args: TrainArgs = args.train_args
model_args: ModelArgs = args.model_args
# logger with date and time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
logger = logging.getLogger(__name__)
# dataset setup
if train_args.train_file.endswith('.jsonl'):
train_data = datasets.load_dataset('json', data_files=train_args.train_file)['train']
else:
train_data = datasets.load_dataset(
train_args.train_file,
train_args.train_data_config
)[train_args.split]
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name,
add_eos_token=not train_args.use_chat_template,
use_fast=True, padding_side='right', legacy=False)
tokenizer.pad_token = tokenizer.eos_token
train_dataset = TrainDataset(train_data, tokenizer, train_args.train_data_field, train_args.max_length, train_args.use_chat_template)
# optimizer setup
total_train_steps = len(train_dataset) // train_args.batch_size * train_args.num_epochs
lr_schedule = optax.warmup_cosine_decay_schedule(
0, train_args.learning_rate, int(total_train_steps*0.1), int(total_train_steps*0.9))
optimizer = optax.adamw(
lr_schedule,
mask=decay_mask_fn,
b1=train_args.adam_beta1,
b2=train_args.adam_beta2,
weight_decay=train_args.weight_decay,
)
optimizer = optax.chain(
optax.clip_by_global_norm(train_args.max_grad_norm),
optimizer
)
optimizer = optax.apply_if_finite(optimizer, 10)
lora = magix.lora.Lora(
alpha=train_args.lora_alpha,
rules={
'layers/.*/kernel': train_args.lora_rank,
}
)
# initalize model parameters and optimizer state
mesh = magix.create_device_mesh(model_args.mesh_shape)
checkpoint_manager = get_chckpoint_manager(train_args.checkpoint_dir, train_args.save_steps, items=['lora', 'optimizer'])
is_new_train = checkpoint_manager.latest_step() is None
_model_cls = magix.models.CAUSAL_LM_MODEL_MAPPING.get(model_args.model_type, None)
if _model_cls is None:
raise NotImplementedError(f"Model type {model_args.model_type} is not implemented")
sharding_config = _model_cls.partition_rules
logger.info("Loading model from hub")
if model_args.model_cache_dir and os.path.exists(model_args.model_cache_dir):
model, params = magix.checkpoint_utils.load_model_local(
_model_cls,
model_args.model_cache_dir,
sharding_config,
mesh,
model_name=model_args.model_name,
)
else:
model, params = load_model_hub(_model_cls, model_args.model_name, sharding_config, mesh, half=model_args.bf16_model_weights)
# magix.checkpoint_utils.save_model_local(params, model_args.model_cache_dir)
rng = jax.random.key(train_args.seed)
dropout_rng, data_rng, lora_rng = jax.random.split(rng, 3)
def create_lora_and_opt_states(rng, params):
lora_state = lora.init_params(rng, params)
opt_state = optimizer.init(lora_state)
return lora_state, opt_state
lora_state_shapes, opt_shapes = jax.eval_shape(create_lora_and_opt_states, lora_rng, params)
lora_sharding = magix.lora.create_lora_sharding(sharding_config, mesh, lora_state_shapes)
opt_sharding = magix.lora.create_lora_sharding(sharding_config, mesh, opt_shapes)
if is_new_train:
lora_state = jax.jit(lora.init_params, out_shardings=lora_sharding) (lora_rng, params)
opt_state = jax.jit(optimizer.init, out_shardings=opt_sharding)(lora_state)
else:
loaded = magix.checkpoint_utils.load_by_sharding(
checkpoint_manager,
items=['lora', 'optimizer'],
dummies=[lora_state_shapes, opt_shapes],
shardings=[lora_sharding, opt_sharding]
)
lora_state, opt_state = loaded['lora'], loaded['optimizer']
def train_step(params, lora_state, opt_state, batch, dropout_rng):
def compute_loss(params, lora_state, batch, dropout_rng):
params = lora.apply(params, lora_state)
input_ids = batch['input_ids']
attention_mask = jnp.logical_and(batch['attention_mask'][:,:-1], batch['attention_mask'][:,1:]).astype('bool')
logits = model(
input_ids=input_ids[:,:-1], attention_mask=attention_mask,
params=params, train=True, dropout_rng=dropout_rng)[0]
target_ids = input_ids[:,1:]
loss = optax.softmax_cross_entropy_with_integer_labels(logits, target_ids)
loss = loss * attention_mask / attention_mask.sum()
loss = loss.sum()
return loss
loss, grads = jax.value_and_grad(compute_loss, argnums=1) (params, lora_state, batch, dropout_rng)
metrics = {"loss": loss}
updates, new_opt_state = optimizer.update(grads, opt_state, lora_state)
new_lora_state = optax.apply_updates(lora_state, updates)
return new_lora_state, new_opt_state, metrics
p_train_step = jax.jit(
train_step,
donate_argnums=(1,2,3),
out_shardings=(
magix.item_sharding(lora_state),
magix.item_sharding(opt_state),
None
)
)
p_train_step = partial(p_train_step, params) # safeguard params in a closure
# train loop
lastest_step = checkpoint_manager.latest_step()
if lastest_step is None:
lastest_step = -1
train_metrics = []
def combine_metrics(list_of_dicts):
return {key: jnp.array([d[key] for d in list_of_dicts]) for key in list_of_dicts[0]}
epochs = tqdm(range(train_args.num_epochs), desc=f"Epoch ... (1/{train_args.num_epochs})", position=0)
logger.info("Starting training loop...")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", train_args.num_epochs)
logger.info(" Instantaneous batch size = %d", train_args.batch_size)
with mesh:
for epoch in epochs:
# Create sampling rng
input_rng = jax.random.fold_in(data_rng, epoch)
batch_loader = Batches(
input_rng, train_dataset, train_args.batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_args.batch_size
# train
for step in trange(steps_per_epoch):
cur_step = epoch * (len(train_dataset) // train_args.batch_size) + step
if lastest_step >= cur_step:
continue
elif lastest_step == cur_step:
logger.info('Resuming training from step %d', cur_step)
batch = batch_loader(step)
dropout_rngs = jax.random.fold_in(dropout_rng, cur_step)
lora_state, opt_state, metrics = p_train_step(lora_state, opt_state, batch, dropout_rngs)
is_last_step = (cur_step + 1) == total_train_steps
checkpoint_manager.save(
cur_step,
items={'lora': lora_state, 'optimizer': opt_state},
force=is_last_step
)
train_metrics.append(metrics)
if cur_step % 100 == 0 and cur_step > 0:
print(
f"Step... ({cur_step} | Loss: {combine_metrics(train_metrics)['loss'].mean()}, Learning Rate: {lr_schedule(cur_step)})",
flush=True,
)
train_metrics = []
epochs.write(
f"Epoch... ({epoch + 1}/{train_args.num_epochs})"
)
checkpoint_manager.wait_until_finished()
if __name__ == '__main__':
main()