-
Notifications
You must be signed in to change notification settings - Fork 767
/
Copy pathgpt2_tinystories.py
123 lines (100 loc) · 2.83 KB
/
gpt2_tinystories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Fine tuning GPT-2 on the TinyStories dataset.
To train (tested on 8x A100 40GB):
$ python gpt2_tinystories.py
Final checkpoint is saved in ./gpt2-tinystories-final.
Use `gpt2_val_tinystories.ipyn` to evaluate the model.
Rowel Atienza
2024
References:
1) GPT2 - https://huggingface.co/openai-community/gpt2
2) TinyStories - https://huggingface.co/datasets/roneneldan/TinyStories
"""
import torch
from datasets import load_dataset
from transformers import (
GPT2LMHeadModel,
GPT2TokenizerFast,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
import math
# Load dataset
dataset = load_dataset("roneneldan/TinyStories")
# Initialize tokenizer and model
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Add padding token to tokenizer
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
"""Tokenize dataset examples."""
return tokenizer(
examples["text"],
truncation=True,
max_length=512,
padding="max_length"
)
# Tokenize datasets
tokenized_train = dataset["train"].map(
tokenize_function,
batched=True,
remove_columns=dataset["train"].column_names
)
tokenized_test = dataset["validation"].map(
tokenize_function,
batched=True,
remove_columns=dataset["validation"].column_names
)
# Create data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
def compute_metrics(eval_pred):
"""Compute perplexity metric."""
logits, labels = eval_pred
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(
torch.from_numpy(shift_logits.reshape(-1, shift_logits.shape[-1])),
torch.from_numpy(shift_labels.reshape(-1))
)
try:
perplexity = math.exp(loss.item())
except OverflowError:
perplexity = float("inf")
return {"perplexity": perplexity}
# Define training arguments
training_args = TrainingArguments(
output_dir="./gpt2-tinystories",
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8,
#evaluation_strategy="steps",
#eval_steps=500,
save_steps=1000,
warmup_steps=500,
learning_rate=5e-5,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=100,
#load_best_model_at_end=True,
#metric_for_best_model="perplexity",
#greater_is_better=False
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
#eval_dataset=tokenized_test,
data_collator=data_collator,
#compute_metrics=compute_metrics
)
# Train model
trainer.train()
# Save model
trainer.save_model("./gpt2-tinystories-final")