-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
129 lines (99 loc) · 4.69 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel, XLMRobertaModel
class MyModel(nn.Module):
def __init__(self, model_name, bert_path, num_class, lang_id_task=False, lang_class=4, requires_grad=False):
super(MyModel, self).__init__()
self.model_name = model_name
self.bert_path = bert_path
self.num_class = num_class
self.lang_id_task = lang_id_task
self.lang_class = lang_class
self.requires_grad = requires_grad
if self.model_name == "xlm-bert":
self.bert = BertModel.from_pretrained(self.bert_path)
elif self.model_name == "xlm-roberta":
self.bert = XLMRobertaModel.from_pretrained(self.bert_path)
else:
raise NotImplementedError
for name, params in self.bert.named_parameters():
if "emb" in name:
params.requires_grad = True
else:
params.requires_grad = self.requires_grad
self.dropout=nn.Dropout(p=0.2)
self.fc = nn.Linear(in_features=self.bert.config.hidden_size, out_features=self.num_class)
self.avg_pool_layer=nn.AdaptiveAvgPool1d(output_size=1)
self.lang_fc = nn.Linear(in_features=self.bert.config.hidden_size, out_features=self.lang_class)
def forward(self, input_ids, attention_mask, token_type_ids=None, **kwargs):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_embedding=self.avg_pool_layer(bert_output[0].permute(0,2,1)).squeeze(dim=-1)
#cls_embedding = bert_output[0][:, 0, :].squeeze(dim=1)
cls_embedding=self.dropout(cls_embedding)
x = self.fc(cls_embedding)
if self.lang_id_task is True:
lang_logits = self.lang_fc(cls_embedding)
else:
lang_logits = None
return x, lang_logits
class MyModel_origin(nn.Module):
def __init__(self, model_name, bert_path, num_class, lang_id_task=False, lang_class=4, requires_grad=False):
super(MyModel_origin, self).__init__()
self.model_name = model_name
self.bert_path = bert_path
self.num_class = num_class
self.lang_id_task = lang_id_task
self.lang_class = lang_class
self.requires_grad = requires_grad
if self.model_name == "xlm-bert":
self.bert = BertModel.from_pretrained(self.bert_path)
elif self.model_name == "xlm-roberta":
self.bert = XLMRobertaModel.from_pretrained(self.bert_path)
else:
raise NotImplementedError
for name, params in self.bert.named_parameters():
if "emb" in name:
params.requires_grad = True
else:
params.requires_grad = self.requires_grad
self.dropout=nn.Dropout(p=0.2)
self.fc = nn.Linear(in_features=self.bert.config.hidden_size, out_features=self.num_class)
#self.avg_pool_layer=nn.AdaptiveAvgPool1d(output_size=1)
self.lang_fc = nn.Linear(in_features=self.bert.config.hidden_size, out_features=self.lang_class)
def forward(self, input_ids, attention_mask, token_type_ids=None, **kwargs):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
#cls_embedding=self.avg_pool_layer(bert_output[0].permute(0,2,1)).squeeze(dim=-1)
cls_embedding = bert_output[0][:, 0, :].squeeze(dim=1)
cls_embedding=self.dropout(cls_embedding)
x = self.fc(cls_embedding)
if self.lang_id_task is True:
lang_logits = self.lang_fc(cls_embedding)
else:
lang_logits = None
return x, lang_logits
class FGM(object):
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emd_name="emb"):
for name, params in self.model.named_parameters():
if params.requires_grad is True and emd_name in name:
self.backup[name] = params.data.clone()
norm = torch.norm(params.grad)
if norm != 0:
r_at = epsilon * params.grad / norm
params.grad.add_(r_at)
def restore(self, emd_name="emb"):
for name, params in self.model.named_parameters():
if params.requires_grad is True and emd_name in name:
assert name in self.backup
params.data = self.backup[name]
self.backup = {}
if __name__ == "__main__":
bert_path = "/Users/codewithzichao/Desktop/competitions/EACL2021/bert-base-multilingual-cased/"
num_class = 2
my_model = MyModel(bert_path, num_class)
for name, params in my_model.named_parameters():
if "emb" in name:
print(name)