-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
61 lines (53 loc) · 2.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
class LabelSmoothingLoss(nn.Module):
"""
With label smoothing,
KL-divergence between q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
"""
def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=1, device='cuda'):
assert 0.0 < label_smoothing <= 1.0
self.ignore_index = ignore_index
super(LabelSmoothingLoss, self).__init__()
self.smoothing_value = label_smoothing / (tgt_vocab_size - 2)
self.device = device
self.confidence = 1.0 - label_smoothing
def forward(self, output, target):
"""
output (FloatTensor): batch_size x n_classes
target (LongTensor): batch_size
"""
one_hot = torch.full((output.shape[1],), self.smoothing_value).to(self.device)
one_hot[self.ignore_index] = 0
model_prob = one_hot.repeat(target.size(0), 1)
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
return F.kl_div(output, model_prob, reduction='sum')
def accuracy(out, target, ignored_index, unk_index):
out_ = np.array(out[target != ignored_index])
target_ = np.array(target[target != ignored_index])
out_[out_ == unk_index] = -1
return accuracy_score(out_, target_)
class DotDict(dict):
"""A dictionary that supports dot notation
as well as dictionary access notation
usage: d = DotDict() or d = DotDict({'val1':'first'})
set attributes: d.val2 = 'second' or d['val2'] = 'second'
get attributes: d.val2 or d['val2']
"""
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __init__(self, dct):
for key, value in dct.items():
if hasattr(value, 'keys'):
value = DotDict(value)
self[key] = value
def adjust_learning_rate(optimizer, lr):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
for param_group in optimizer.param_groups:
param_group['lr'] = lr