-
Notifications
You must be signed in to change notification settings - Fork 111
/
modeling.py
219 lines (188 loc) · 10.1 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import torch
from torch.nn.utils.rnn import pad_sequence
from os.path import join
import re
from transformers import AutoTokenizer
from LAMA.p_tuning.models import get_embedding_layer, create_model
from LAMA.data_utils.vocab import get_vocab_by_strategy, token_wrapper
from LAMA.data_utils.dataset import load_file
from LAMA.p_tuning.prompt_encoder import PromptEncoder
class PTuneForLAMA(torch.nn.Module):
def __init__(self, args, device, template):
super().__init__()
self.args = args
self.device = device
# load relation templates
self.relation_templates = dict(
(d['relation'], d['template']) for d in load_file(join(self.args.data_dir, 'relations.jsonl')))
# load tokenizer
tokenizer_src = 'roberta-large' if 'megatron' in self.args.model_name else self.args.model_name
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_src, use_fast=False)
# load pre-trained model
if 'megatron' in self.args.model_name and self.args.use_lm_finetune:
raise RuntimeError("Can not apply args.use_lm_finetune=True on MegatronLM 11B.")
self.model = create_model(self.args)
self.model = self.model.to(self.device)
for param in self.model.parameters():
param.requires_grad = self.args.use_lm_finetune
self.embeddings = get_embedding_layer(self.args, self.model)
# set allowed vocab set
self.vocab = self.tokenizer.get_vocab()
self.allowed_vocab_ids = set(self.vocab[k] for k in get_vocab_by_strategy(self.args, self.tokenizer))
if 'gpt' in self.args.model_name or 'megatron' in self.args.model_name:
template = (template[0], template[1], 0)
self.template = template
# load prompt encoder
self.hidden_size = self.embeddings.embedding_dim
self.tokenizer.add_special_tokens({'additional_special_tokens': [self.args.pseudo_token]})
self.pseudo_token_id = self.tokenizer.get_vocab()[self.args.pseudo_token]
self.pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.unk_token_id
self.spell_length = sum(self.template)
self.prompt_encoder = PromptEncoder(self.template, self.hidden_size, self.tokenizer, self.device, args)
self.prompt_encoder = self.prompt_encoder.to(self.device)
def embed_input(self, queries):
bz = queries.shape[0]
queries_for_embedding = queries.clone()
queries_for_embedding[(queries == self.pseudo_token_id)] = self.tokenizer.unk_token_id
raw_embeds = self.embeddings(queries_for_embedding)
# For using handcraft prompts
if self.args.use_original_template:
return raw_embeds
blocked_indices = (queries == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz
replace_embeds = self.prompt_encoder()
for bidx in range(bz):
for i in range(self.prompt_encoder.spell_length):
raw_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]
return raw_embeds
def get_query(self, x_h, prompt_tokens, x_t=None):
# For using handcraft prompts
if self.args.use_original_template:
if 'gpt' in self.args.model_name or 'megatron' in self.args.model_name:
query = re.sub(r'\[Y\].*', '', self.relation_templates[self.args.relation_id].replace('[X]', x_h))
return self.tokenizer(' ' + query)['input_ids']
else:
query = self.relation_templates[self.args.relation_id].replace('[X]', x_h).replace('[Y]',
self.tokenizer.mask_token)
return self.tokenizer(' ' + query)['input_ids']
# For P-tuning
if 'gpt' not in self.args.model_name and 'megatron' not in self.args.model_name:
# BERT-style model
return [[self.tokenizer.cls_token_id] # [CLS]
+ prompt_tokens * self.template[0]
+ [self.tokenizer.mask_token_id] # head entity
+ prompt_tokens * self.template[1]
+ self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(' ' + x_h)) # [MASK] (tail entity)
+ (prompt_tokens * self.template[2] if self.template[
2] > 0 else self.tokenizer.convert_tokens_to_ids(['.']))
+ [self.tokenizer.sep_token_id]
]
elif 'gpt' in self.args.model_name or 'megatron' in self.args.model_name:
# GPT-style models
return [prompt_tokens * self.template[0]
+ self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(' ' + x_h)) # head entity
+ prompt_tokens * self.template[1]
+ (self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(' ' + x_t)) if x_t is not None else [])
]
else:
raise NotImplementedError("The query template for {} has not been defined.".format(self.args.model_name))
def forward(self, x_hs, x_ts, return_candidates=False):
bz = len(x_hs)
# construct query ids
prompt_tokens = [self.pseudo_token_id]
x_ts = [token_wrapper(self.args, x_t) for x_t in x_ts]
queries = [torch.LongTensor(self.get_query(x_hs[i], prompt_tokens)).squeeze(0) for i in range(bz)]
queries = pad_sequence(queries, True, padding_value=self.pad_token_id).long().to(self.device)
# construct label ids
label_ids = torch.LongTensor(self.tokenizer.convert_tokens_to_ids(x_ts)).reshape(
(bz, -1)).to(self.device)
attention_mask = queries != self.pad_token_id
# get embedded input
inputs_embeds = self.embed_input(queries)
def bert_out():
label_mask = (queries == self.tokenizer.mask_token_id).nonzero().reshape(bz, -1)[:, 1].unsqueeze(
1).to(self.device) # bz * 1
labels = torch.empty_like(queries).fill_(-100).long().to(self.device) # bz * seq_len
labels = labels.scatter_(1, label_mask, label_ids)
output = self.model(inputs_embeds=inputs_embeds.to(self.device),
attention_mask=attention_mask.to(self.device).bool(),
labels=labels.to(self.device))
loss, logits = output.loss, output.logits
pred_ids = torch.argsort(logits, dim=2, descending=True)
hit1 = 0
top10 = []
for i in range(bz):
pred_seq = pred_ids[i, label_mask[i, 0]].tolist()
for pred in pred_seq:
if pred in self.allowed_vocab_ids:
break
if pred == label_ids[i, 0]:
hit1 += 1
if return_candidates:
return loss, hit1, top10
return loss, hit1
def gpt_out():
labels = torch.empty_like(queries).fill_(-100).long().to(self.device) # bz * seq_len
label_mask = (attention_mask.long().sum(dim=1) - 1).unsqueeze(1).to(self.device)
labels = labels.scatter_(1, label_mask, label_ids)
output = self.model(inputs_embeds=inputs_embeds.to(self.device).half(),
attention_mask=attention_mask.to(self.device).half(),
labels=labels.to(self.device))
loss, logits = output.loss, output.logits
pred_ids = torch.argsort(logits, dim=2, descending=True)
hit1 = 0
top10 = []
for i in range(bz):
top10.append([])
pred_seq = pred_ids[i, label_mask[i, 0]].tolist()
for pred in pred_seq:
if pred in self.allowed_vocab_ids:
top10[-1].append(pred)
if len(top10[-1]) >= 10:
break
pred = top10[-1][0]
if pred == label_ids[i, 0]:
hit1 += 1
if return_candidates:
return loss, hit1, top10
return loss, hit1
def megatron_out():
labels = torch.empty_like(queries).fill_(-100).long().to(self.device) # bz * seq_len
label_mask = (attention_mask.long().sum(dim=1) - 1).unsqueeze(1).to(self.device)
labels = labels.scatter_(1, label_mask, label_ids)
if not self.args.use_lm_finetune:
_attention_mask = attention_mask.float().half()
_input_embeds = inputs_embeds.float().half()
else:
_attention_mask = attention_mask.float()
_input_embeds = inputs_embeds.float()
output = self.model.decoder.predict(prev_output_tokens=queries,
inputs_embeds=_input_embeds.to(self.device),
attention_mask=_attention_mask.to(self.device).bool(),
labels=labels.to(self.device))
logits, loss = output
pred_ids = torch.argsort(logits, dim=2, descending=True)
hit1 = 0
top10 = []
for i in range(bz):
top10.append([])
pred_seq = pred_ids[i, label_mask[i, 0]].tolist()
for pred in pred_seq:
if pred in self.allowed_vocab_ids:
top10[-1].append(pred)
if len(top10[-1]) >= 10:
break
pred = top10[-1][0]
if pred == label_ids[i, 0]:
hit1 += 1
if return_candidates:
return loss, hit1, top10
return loss, hit1
if 'bert' in self.args.model_name:
return bert_out()
elif 'gpt' in self.args.model_name:
return gpt_out()
elif 'megatron' in self.args.model_name:
return megatron_out()
else:
raise NotImplementedError()