-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
200 lines (160 loc) · 8.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import argparse
from typing import Dict
import torch
import torch.nn as nn
from torch.nn.utils.clip_grad import clip_grad_norm_
from flambe.dataset import TabularDataset
from flambe.field import TextField, LabelField
from tensorboardX import SummaryWriter
from sampler import BaseSampler, EpisodicSampler
from model import PrototypicalTextClassifier
def train(args):
"""Run Training """
global_step = 0
best_metric = None
best_model: Dict[str, torch.Tensor] = dict()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
writer = SummaryWriter(log_dir=args.output_dir)
# We use flambe to do the data preprocessing
# More info at https://flambe.ai
print("Performing preprocessing (possibly download embeddings).")
embeddings = args.embeddings if args.use_pretrained_embeddings else None
text_field = TextField(lower=args.lowercase, embeddings=embeddings, embeddings_format='gensim')
label_field = LabelField()
transforms = {'text': text_field, 'label': label_field}
dataset = TabularDataset.from_path(args.train_path,
args.val_path,
sep=',' if args.file_type == 'csv' else '\t',
transform=transforms)
# Create samplers
train_sampler = EpisodicSampler(dataset.train,
n_support=args.n_support,
n_query=args.n_query,
n_episodes=args.n_episodes,
n_classes=args.n_classes)
# The train_eval_sampler is used to computer prototypes over the full dataset
train_eval_sampler = BaseSampler(dataset.train, batch_size=args.eval_batch_size)
val_sampler = BaseSampler(dataset.val, batch_size=args.eval_batch_size)
if args.device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
device = args.device
# Build model, criterion and optimizers
model = PrototypicalTextClassifier(vocab_size=dataset.text.vocab_size,
distance=args.distance,
embedding_dim=args.embedding_dim,
pretrained_embeddings=dataset.text.embedding_matrix,
rnn_type='sru',
n_layers=args.n_layers,
hidden_dim=args.hidden_dim,
freeze_pretrained_embeddings=True)
loss_fn = nn.CrossEntropyLoss()
parameters = (p for p in model.parameters() if p.requires_grad)
optimizer = torch.optim.Adam(parameters, lr=args.learning_rate)
print("Beginning training.")
for epoch in range(args.num_epochs):
######################
# TRAIN #
######################
print(f'Epoch: {epoch}')
model.train()
with torch.enable_grad():
for batch in train_sampler:
# Zero the gradients and clear the accumulated loss
optimizer.zero_grad()
# Move to device
batch = tuple(t.to(device) for t in batch)
query, query_label, support, support_label = batch
# Compute loss
pred = model(query, support, support_label)
loss = loss_fn(pred, query_label)
loss.backward()
# Clip gradients if necessary
if args.max_grad_norm is not None:
clip_grad_norm_(model.parameters(), args.max_grad_norm)
writer.add_scalar('Training/Loss', loss.item(), global_step)
# Optimize
optimizer.step()
global_step += 1
# Zero the gradients when exiting a train step
optimizer.zero_grad()
#########################
# EVALUATE #
#########################
model.eval()
with torch.no_grad():
# First compute prototypes over the training data
encodings, labels = [], []
for text, label in train_eval_sampler:
padding_mask = (text != model.padding_idx).byte()
text_embeddings = model.embedding_dropout(model.embedding(text))
text_encoding = model.encoder(text_embeddings, padding_mask=padding_mask)
labels.append(label.cpu())
encodings.append(text_encoding.cpu())
# Compute prototypes
encodings = torch.cat(encodings, dim=0)
labels = torch.cat(labels, dim=0)
prototypes = model.compute_prototypes(encodings, labels).to(device)
_preds, _targets = [], []
for batch in val_sampler:
# Move to device
source, target = tuple(t.to(device) for t in batch)
pred = model(source, prototypes=prototypes)
_preds.append(pred.cpu())
_targets.append(target.cpu())
preds = torch.cat(_preds, dim=0)
targets = torch.cat(_targets, dim=0)
val_loss = loss_fn(preds, targets).item()
val_metric = (pred.argmax(dim=1) == target).float().mean().item()
# Update best model
if best_metric is None or val_metric > best_metric:
best_metric = val_metric
best_model_state = model.state_dict()
for k, t in best_model_state.items():
best_model_state[k] = t.cpu().detach()
best_model = best_model_state
# Log metrics
print(f'Validation loss: {val_loss}')
print(f'Validation accuracy: {val_metric}')
writer.add_scalar('Validation/Loss', val_loss, epoch)
writer.add_scalar('Validation/Accuracy', val_metric, epoch)
# Save the best model
print("Finisehd training.")
torch.save(best_model, os.path.join(args.output_dir, 'model.pt'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Required
parser.add_argument('--train_path', type=str, required=True,
help="Path to training data. Should be a CSV with \
the text in the first column and label in the second")
parser.add_argument('--val_path', type=str, required=True,
help="Path to validation data. Should be a CSV with \
the text in the first column and label in the second")
parser.add_argument('--output_dir', type=str, required=True,
help="Path to output directory")
# Optional
parser.add_argument('--file_type', type=str, default='csv', choices=['csv', 'tsv'],
help="Handle CSV or TSV inputs.")
parser.add_argument('--distance', type=str, choices=['euclidean', 'hyperbolic'],
default='euclidean', help="Distance metric to use.")
parser.add_argument('--device', type=str, default=None, help="Device to use.")
parser.add_argument('--use_pretrained_embeddings', type=bool, default=True,
help="Whether to use pretrained embeddings")
parser.add_argument('--lowercase', type=bool, default=False, help="Whether to lowercase the text")
parser.add_argument('--embeddings', type=str, default='glove-wiki-gigaword-300',
help="Gensim embeddings to use.")
parser.add_argument('--n_layers', type=int, default=2, help="Number of layers in the RNN.")
parser.add_argument('--hidden_dim', type=int, default=128, help="Hidden dimension of the RNN.")
parser.add_argument('--embedding_dim', type=int, default=128, help="Dimension of the token embeddings.")
parser.add_argument('--n_support', type=int, default=1, help="Number of support points per class.")
parser.add_argument('--n_query', type=int, default=64, help="Total number of query points (not per class)")
parser.add_argument('--n_classes', type=int, default=None, help="Number of classes per episode")
parser.add_argument('--n_episodes', type=int, default=100, help="Number of episodes per 'epoch'")
parser.add_argument('--num_epochs', type=int, default=100, help="Number of training and evaluation steps.")
parser.add_argument('--eval_batch_size', type=int, default=128, help="Batch size used during evaluation.")
parser.add_argument('--learning_rate', type=float, default=0.001, help="The learning rate.")
parser.add_argument('--max_grad_norm', type=float, default=None, help="Maximum grad norm to clip at.")
args = parser.parse_args()
train(args)