-
Notifications
You must be signed in to change notification settings - Fork 770
/
Copy path24_the-trainer-api.srt
445 lines (356 loc) · 9.91 KB
/
24_the-trainer-api.srt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
1
00:00:00,304 --> 00:00:01,285
(空气呼啸)
(air whooshing)
2
00:00:01,285 --> 00:00:02,345
(空气爆裂声)
(air popping)
3
00:00:02,345 --> 00:00:05,698
(空气呼啸)
(air whooshing)
4
00:00:05,698 --> 00:00:06,548
- Trainer API。
- The Trainer API.
5
00:00:08,070 --> 00:00:10,040
Transformers Library 提供了一个 Trainer API
The Transformers Library provides a Trainer API
6
00:00:10,040 --> 00:00:13,320
让你能够轻松的微调 Transformer 模型
that allows you to easily fine-tune transformer models
7
00:00:13,320 --> 00:00:14,193
在你自己的数据集上。
on your own dataset.
8
00:00:15,150 --> 00:00:17,250
Trainer 类接受你的数据集,
The Trainer class take your datasets,
9
00:00:17,250 --> 00:00:19,900
你的模型以及训练超参数
your model as well as the training hyperparameters
10
00:00:20,820 --> 00:00:23,310
并且可以在任何类型的设置上进行训练,
and can perform the training on any kind of setup,
11
00:00:23,310 --> 00:00:26,654
(CPU、GPU、多个 GPU、TPUs)
(CPU, GPU, multiple GPUs, TPUs)
12
00:00:26,654 --> 00:00:28,680
也可以在任意数据集上进行推理
can also compute the predictions
13
00:00:28,680 --> 00:00:31,710
如果你提供了指标
on any dataset and if you provided metrics
14
00:00:31,710 --> 00:00:33,813
就可以在任意数据集上评估你的模型。
evaluate your model on any dataset.
15
00:00:34,950 --> 00:00:36,930
Trainer 也可以负责最后的数据处理
It can also handle final data processing
16
00:00:36,930 --> 00:00:38,670
比如动态填充,
such as dynamic padding,
17
00:00:38,670 --> 00:00:40,377
只要你提供 tokenizer
as long as you provide the tokenizer
18
00:00:40,377 --> 00:00:42,693
或给定 data collator。
or a given data collator.
19
00:00:43,572 --> 00:00:45,900
我们将在 MRPC 数据集上尝试这个 API,
We will try this API on the MRPC dataset,
20
00:00:45,900 --> 00:00:48,492
因为该数据集相对较小且易于预处理。
since it's relatively small and easy to preprocess.
21
00:00:48,492 --> 00:00:49,325
正如我们在 Datasets 概述视频中看到的那样,
As we saw in the Datasets overview video,
22
00:00:49,325 --> 00:00:54,325
我们可以像这样对其进行预处理。
here is how we can preprocess it.
23
00:00:54,511 --> 00:00:57,030
我们在预处理过程中不进行填充,
We do not apply padding during the preprocessing,
24
00:00:57,030 --> 00:00:58,590
因为我们将进行动态填充
as we will use dynamic padding
25
00:00:58,590 --> 00:01:00,083
使用我们的 DataCollatorWithPadding。
with our DataCollatorWithPadding.
26
00:01:01,170 --> 00:01:02,790
请注意,我们不执行一些最终的数据处理步骤
Note that we don't do the final steps
27
00:01:02,790 --> 00:01:04,830
像是重命名列/删除列
of renaming/removing columns
28
00:01:04,830 --> 00:01:06,873
或将格式设置为 torch 张量。
or set the format to torch tensors.
29
00:01:07,710 --> 00:01:10,560
Trainer 会自动为我们做这一切
The Trainer will do all of this automatically for us
30
00:01:10,560 --> 00:01:12,633
通过分析模型签名。
by analyzing the model signature.
31
00:01:14,054 --> 00:01:16,650
实例化 Trainer 前的最后一步是
The last step before creating the Trainer are
32
00:01:16,650 --> 00:01:17,940
定义模型
to define a model
33
00:01:17,940 --> 00:01:20,250
和一些训练超参数。
and some training hyperparameters.
34
00:01:20,250 --> 00:01:22,653
我们在 model API 视频中学会了如何定义模型。
We saw to do the first in the model API video.
35
00:01:23,734 --> 00:01:26,790
对于第二点,我们使用 TrainingArguments 类。
For the second we use the TrainingArguments class.
36
00:01:26,790 --> 00:01:28,710
Trainer 只需要一个文件夹的路径
It only takes a path to a folder
37
00:01:28,710 --> 00:01:30,900
用以保存结果和检查点,
where results and checkpoint will be saved,
38
00:01:30,900 --> 00:01:33,060
但你也可以自定义你的 Trainer 会使用的
but you can also customize all the hyperparameters
39
00:01:33,060 --> 00:01:34,470
所有超参数
your Trainer will use,
40
00:01:34,470 --> 00:01:37,270
比如学习率,训练几个 epoch 等等。
learning rate, number of training epochs etc.
41
00:01:38,190 --> 00:01:39,660
接下来实例化一个 Trainer 并开始训练
It's then very easy to create a Trainer
42
00:01:39,660 --> 00:01:41,400
就非常简单了。
and launch a training.
43
00:01:41,400 --> 00:01:43,170
这会显示一个进度条
This should display a progress bar
44
00:01:43,170 --> 00:01:45,900
几分钟后(如果你在 GPU 上运行)
and after a few minutes (if you're running on a GPU)
45
00:01:45,900 --> 00:01:48,000
你会完成训练。
you should have the training finished.
46
00:01:48,000 --> 00:01:50,790
然而,结果将是相当虎头蛇尾,
The result will be rather anticlimactic however,
47
00:01:50,790 --> 00:01:52,710
因为你只会得到训练损失
as you will only get a training loss
48
00:01:52,710 --> 00:01:54,300
这并没有真正告诉你
which doesn't really tell you anything
49
00:01:54,300 --> 00:01:56,820
你的模型表现如何。
about how well your model is performing.
50
00:01:56,820 --> 00:01:58,977
这是因为我们没有指定任何指标
This is because we didn't specify any metric
51
00:01:58,977 --> 00:02:00,273
用于评估。
for the evaluation.
52
00:02:01,200 --> 00:02:02,160
要获得这些指标,
To get those metrics,
53
00:02:02,160 --> 00:02:03,810
我们将首先使用预测方法
we will first gather the predictions
54
00:02:03,810 --> 00:02:06,513
在整个评估集上收集预测结果。
on the whole evaluation set using the predict method.
55
00:02:07,440 --> 00:02:10,020
它返回一个包含三个字段的命名元组,
It returns a namedtuple with three fields,
56
00:02:10,020 --> 00:02:12,990
Prediction (其中包含模型的预测),
Prediction(which contains the model predictions),
57
00:02:12,990 --> 00:02:15,030
Label_IDs (其中包含标签
Label_IDs(which contains the labels
58
00:02:15,030 --> 00:02:16,800
如果你的数据集有的话)
if your dataset had them)
59
00:02:16,800 --> 00:02:18,570
和指标(在本示例中是空的)
and metrics (which is empty here).
60
00:02:18,570 --> 00:02:20,520
我们正在努力做到这一点。
We're trying to do that.
61
00:02:20,520 --> 00:02:22,470
预测结果是模型
The predictions are the logits of the models
62
00:02:22,470 --> 00:02:24,143
对于数据集中的所有句子所输出的 logits。
for all the sentences in the dataset.
63
00:02:24,143 --> 00:02:27,513
所以是一个形状为 408 x 2 的 NumPy 数组。
So a NumPy array of shape 408 by 2.
64
00:02:28,500 --> 00:02:30,270
为了将它们与我们的标签相匹配,
To match them with our labels,
65
00:02:30,270 --> 00:02:31,590
我们需要取最大的 logit
we need to take the maximum logit
66
00:02:31,590 --> 00:02:32,850
对于每个预测
for each prediction
67
00:02:32,850 --> 00:02:35,820
(知道两个类别中的哪一个类是所预测的结果)
(to know which of the two classes was predicted.)
68
00:02:35,820 --> 00:02:37,683
我们使用 argmax 函数来做到这一点。
We do this with the argmax function.
69
00:02:38,640 --> 00:02:41,550
然后我们可以使用 Datasets library 中的指标。
Then we can use a metric from the Datasets library.
70
00:02:41,550 --> 00:02:43,500
它可以像数据集一样被轻松地加载
It can be loaded as easily as a dataset
71
00:02:43,500 --> 00:02:45,360
使用 load_metric 函数
with the load_metric function
72
00:02:45,360 --> 00:02:49,500
并且返回用于该数据集的评估指标。
and it returns the evaluation metric used for the dataset.
73
00:02:49,500 --> 00:02:51,600
我们可以看到我们的模型确实学到了一些东西
We can see our model did learn something
74
00:02:51,600 --> 00:02:54,363
因为它有 85.7% 的准确率。
as it is 85.7% accurate.
75
00:02:55,440 --> 00:02:57,870
为了在训练期间监控评估指标,
To monitor the evaluation metrics during training,
76
00:02:57,870 --> 00:02:59,829
我们需要定义一个 compute_metrics 函数
we need to define a compute_metrics function
77
00:02:59,829 --> 00:03:02,670
和以前一样的步骤。
that does the same step as before.
78
00:03:02,670 --> 00:03:04,728
它接收一个带有预测和标签的命名元组
It takes a namedtuple with predictions and labels
79
00:03:04,728 --> 00:03:06,327
并且返回一个字典
and must return a dictionary
80
00:03:06,327 --> 00:03:08,427
包含我们想要跟踪的指标。
with the metrics we want to keep track of.
81
00:03:09,360 --> 00:03:11,490
通过将评估策略设置为 epoch
By passing the epoch evaluation strategy
82
00:03:11,490 --> 00:03:13,080
对于我们的 TrainingArguments,
to our TrainingArguments,
83
00:03:13,080 --> 00:03:14,490
我们告诉 Trainer 去进行评估
we tell the Trainer to evaluate
84
00:03:14,490 --> 00:03:15,903
在每个 epoch 结束的时候。
at the end of every epoch.
85
00:03:17,280 --> 00:03:18,587
在 notebook 中启动训练
Launching a training inside a notebook
86
00:03:18,587 --> 00:03:20,640
会显示一个进度条
will then display a progress bar
87
00:03:20,640 --> 00:03:23,643
并在你运行完每个 epoch 时将数据填到你看到的这个表格。
and complete the table you see here as you pass every epoch.
88
00:03:25,400 --> 00:03:28,249
(空气呼啸)
(air whooshing)
89
00:03:28,249 --> 00:03:29,974
(空气渐弱)
(air decrescendos)