-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
31 lines (24 loc) · 1.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
from eval import evaluate_data
from utils import word_to_onehot, word_to_indexes
def train(model, data, dev, epochs=3, learning_rate=0.001, verbose = True):
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
model.train()
if verbose: print("epoch:", epoch)
for idx in range(len(data["reference"])):
hidden = model.init_hidden()
example = word_to_onehot(data["stripped"][idx])
reference = word_to_indexes(data["reference"][idx])
loss = 0
optimizer.zero_grad()
for i in range(example.size()[0]):
output, hidden = model(example[i], hidden)
loss += loss_fn(output, reference[i].unsqueeze(dim=0))
loss.backward()
optimizer.step()
if verbose:
results, _ = evaluate_data(dev["stripped"], dev["reference"], model = model)
print("dev:", results)