-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrainer.py
120 lines (93 loc) · 4.53 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from dataset import Dataset
from evaluator import Evaluator
from model import DpoWrapper, TutorialLLM
class Trainer():
"""
Trainer for the model.
This module provides methods to pretrain, finetune, and align the model.
"""
def __init__(self, model: TutorialLLM, dataset: Dataset, evaluator: Evaluator, device: str) -> None:
"""
Initialize the trainer with the model, dataset, evaluator, and device.
Args:
model: The model to be trained.
dataset: The dataset to provide training data.
evaluator: The evaluator to evaluate the model performance.
device: The device to run the model on ('cpu' or 'cuda').
"""
self.model = model
self.dataset = dataset
self.evaluator = evaluator
self.device = device
def pretrain(self, iterations: int) -> None:
"""
Pretrain the model for a certain number of iterations.
For each iteration, a batch of pretrain data is used to train the model.
Args:
iterations: The number of iterations to pretrain the model.
"""
# Reset the evaluator to clear the loss history
self.evaluator.reset()
# Initialize an optimizer with learning rate 1e-3
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3)
for i in range(iterations):
# Get a batch of pretrain data
inputs, labels = self.dataset.get_batch_pretrain('train')
# Forward pass and calculate the loss
_, loss = self.model(inputs, labels)
# Evaluate the model performance
self.evaluator.evaluate_pretrain(self.model, i, loss.item())
# Backward pass and update the model
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
print('Save the pretrained model...')
torch.save(self.model, 'model_pretrain.pth')
def finetune(self, epochs) -> None:
"""
Finetune the model for a certain number of epochs.
For each epoch, a batch of finetune data is used to train the model.
Args:
epochs: The number of epochs to finetune the model.
"""
# Initialize an optimizer with learning rate 1e-3
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3)
for epoch in range(epochs):
# Reset the evaluator to clear the loss history for each epoch
self.evaluator.reset()
for i, (inputs, labels) in enumerate(self.dataset.get_batch_generator_finetune('train')):
# Forward pass and calculate the loss
_, loss = self.model(inputs, labels)
# Evaluate the model performance
self.evaluator.evaluate_finetune(self.model, epoch, i, loss.item())
# Backward pass and update the model
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
print('Save the finetuned model...')
torch.save(self.model, 'model_finetune.pth')
def align(self, epochs) -> None:
"""
Align the model with our preference for a certain number of epochs.
For each epoch, a batch of alignment data is used to train the model.
Args:
epochs: The number of epochs to align the model with our preference.
"""
# The alignment needs a reference model for DPO, we use a DpoWrapper to manage the 2 models
dpo_wrapper = DpoWrapper(self.model)
# Initialize an optimizer with learning rate 1e-5
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
for epoch in range(epochs):
# Reset the evaluator to clear the loss history for each epoch
self.evaluator.reset()
for i, (positive_inputs, positive_labels, negative_inputs, negative_labels) in enumerate(self.dataset.get_batch_generator_alignment('train')):
loss, reward_margin = dpo_wrapper.forward(positive_inputs, positive_labels, negative_inputs, negative_labels)
# Evaluate the model every evaluation_interval iterations
self.evaluator.evaluate_alignment(dpo_wrapper, epoch, i, loss.item(), reward_margin.item())
# Backward pass and update the model
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
print('Save the aligned model...')
torch.save(self.model, 'model_aligned.pth')