-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpNet.py
169 lines (147 loc) · 6.84 KB
/
gpNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# -*- coding: utf-8 -*-
"""
GlobalPointer参考: https://github.com/gaohongkui/GlobalPointer_pytorch/blob/main/models/GlobalPointer.py
稀疏多标签交叉熵损失参考: bert4keras源码
"""
import torch
import torch.nn as nn
import numpy as np
def sparse_multilabel_categorical_crossentropy(y_true=None, y_pred=None, mask_zero=False):
'''
稀疏多标签交叉熵损失的torch实现
'''
shape = y_pred.shape
y_true = y_true[..., 0] * shape[2] + y_true[..., 1]
y_pred = y_pred.reshape(shape[0], -1, np.prod(shape[2:]))
zeros = torch.zeros_like(y_pred[...,:1])
y_pred = torch.cat([y_pred, zeros], dim=-1)
if mask_zero:
infs = zeros + 1e12
y_pred = torch.cat([infs, y_pred[..., 1:]], dim=-1)
y_pos_2 = torch.gather(y_pred, index=y_true, dim=-1)
y_pos_1 = torch.cat([y_pos_2, zeros], dim=-1)
if mask_zero:
y_pred = torch.cat([-infs, y_pred[..., 1:]], dim=-1)
y_pos_2 = torch.gather(y_pred, index=y_true, dim=-1)
pos_loss = torch.logsumexp(-y_pos_1, dim=-1)
all_loss = torch.logsumexp(y_pred, dim=-1)
aux_loss = torch.logsumexp(y_pos_2, dim=-1) - all_loss
aux_loss = torch.clip(1 - torch.exp(aux_loss), 1e-10, 1)
neg_loss = all_loss + torch.log(aux_loss)
loss = torch.mean(torch.sum(pos_loss + neg_loss))
return loss
class RawGlobalPointer(nn.Module):
def __init__(self, hiddensize, ent_type_size, inner_dim, RoPE=True, tril_mask=True):
'''
:param encoder: BERT
:param ent_type_size: 实体数目
:param inner_dim: 64
'''
super().__init__()
self.ent_type_size = ent_type_size
self.inner_dim = inner_dim
self.hidden_size = hiddensize
self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2)
self.RoPE = RoPE
self.trail_mask = tril_mask
def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim):
position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)
indices = torch.arange(0, output_dim // 2, dtype=torch.float)
indices = torch.pow(10000, -2 * indices / output_dim)
embeddings = position_ids * indices
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim))
embeddings = embeddings.to(self.device)
return embeddings
def forward(self, context_outputs, attention_mask):
self.device = attention_mask.device
last_hidden_state = context_outputs[0]
batch_size = last_hidden_state.size()[0]
seq_len = last_hidden_state.size()[1]
outputs = self.dense(last_hidden_state)
outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)
outputs = torch.stack(outputs, dim=-2)
qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:]
if self.RoPE:
# pos_emb:(batch_size, seq_len, inner_dim)
pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim)
cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1)
sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1)
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1)
qw2 = qw2.reshape(qw.shape)
qw = qw * cos_pos + qw2 * sin_pos
kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1)
kw2 = kw2.reshape(kw.shape)
kw = kw * cos_pos + kw2 * sin_pos
# logits:(batch_size, ent_type_size, seq_len, seq_len)
logits = torch.einsum('bmhd,bnhd->bhmn', qw, kw)
# padding mask
pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len)
logits = logits * pad_mask - (1 - pad_mask) * 1e12
# 排除下三角
if self.trail_mask:
mask = torch.tril(torch.ones_like(logits), -1)
logits = logits - mask * 1e12
return logits / self.inner_dim ** 0.5
class Biaffine(nn.Module):
def __init__(self, n_in, n_out=1, bias_x=True, bias_y=True):
super(Biaffine, self).__init__()
self.n_in = n_in
self.n_out = n_out
self.bias_x = bias_x
self.bias_y = bias_y
weight = torch.zeros((n_out, n_in + int(bias_x), n_in + int(bias_y)))
nn.init.xavier_normal_(weight)
self.weight = nn.Parameter(weight, requires_grad=True)
def forward(self, x, y):
if self.bias_x:
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
if self.bias_y:
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
# [batch_size, n_out, seq_len, seq_len]
s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y)
return s
class MLP(nn.Module):
def __init__(self, n_in, n_out, dropout=0):
super().__init__()
self.linear = nn.Linear(n_in, n_out)
self.activation = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.dropout(x)
x = self.linear(x)
x = self.activation(x)
return x
class CoPredictor(nn.Module):
def __init__(self, cls_num, hid_size, biaffine_size, channels, ffnn_hid_size, tril_mask=False,
dropout=0):
super().__init__()
self.mlp1 = MLP(n_in=hid_size, n_out=biaffine_size, dropout=dropout)
self.mlp2 = MLP(n_in=hid_size, n_out=biaffine_size, dropout=dropout)
self.biaffine = Biaffine(n_in=biaffine_size, n_out=cls_num, bias_x=True, bias_y=True)
self.dropout = nn.Dropout(dropout)
self.cls_num = cls_num
self.tril_mask = tril_mask
def forward(self, x, attention_mask=None):
inputs = x[0]
batch_size = inputs.size()[0]
seq_len = inputs.size()[1]
h = self.dropout(self.mlp1(inputs))
t = self.dropout(self.mlp2(inputs))
# [batch, cls_num, seq_len, seq_len]
logits = self.biaffine(h, t)
# padding mask
pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.cls_num, seq_len, seq_len)
logits = logits * pad_mask - (1 - pad_mask) * 1e12
# 排除padding
if attention_mask is not None: # huggingface's attention_mask
attn_mask = (
1 - attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
)
logits = logits - attn_mask * 1e12
if self.tril_mask:
# 排除下三角
mask = torch.tril(torch.ones_like(logits), diagonal=-1)
logits = logits - mask * 1e12
return logits