-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
138 lines (123 loc) · 7.35 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-
"""
@Auth: Xhw
@Description: CHIP/CBLUE 医学实体关系抽取,数据来源 https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414
"""
import json
import numpy as np
import torch
from torch.utils.data import Dataset
def load_name(filename):
#{"text": "产后抑郁症@区分产后抑郁症与轻度情绪失调(产后忧郁或“婴儿忧郁”)是重要的,因为轻度情绪失调不需要治疗。", "spo_list": [{"Combined": false, "predicate": "鉴别诊断", "subject": "产后抑郁症", "subject_type": "疾病", "object": {"@value": "轻度情绪失调"}, "object_type": {"@value": "疾病"}}]}
D = []
with open(filename, 'r', encoding='utf-8') as f:
for line in f:
line = json.loads(line)
D.append({
"text":line["text"],
"spo_list":[(spo["subject"], spo["predicate"], spo["object"]["@value"], spo["subject_type"], spo["object_type"]["@value"])
for spo in line["spo_list"]]
})
return D
def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'):
"""Numpy函数,将序列padding到同一长度
"""
if length is None:
length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0)
elif not hasattr(length, '__getitem__'):
length = [length]
slices = [np.s_[:length[i]] for i in range(seq_dims)]
slices = tuple(slices) if len(slices) > 1 else slices[0]
pad_width = [(0, 0) for _ in np.shape(inputs[0])]
outputs = []
for x in inputs:
x = x[slices]
for i in range(seq_dims):
if mode == 'post':
pad_width[i] = (0, length[i] - np.shape(x)[i])
elif mode == 'pre':
pad_width[i] = (length[i] - np.shape(x)[i], 0)
else:
raise ValueError('"mode" argument must be "post" or "pre".')
x = np.pad(x, pad_width, 'constant', constant_values=value)
outputs.append(x)
return np.array(outputs)
class data_generator(Dataset):
def __init__(self, data, tokenizer, max_len, schema):
self.data = data
self.tokenizer = tokenizer
self.max_len = max_len
self.schema = schema #spo
def __len__(self):
return len(self.data)
def encoder(self, item):
text = item["text"]
encoder_text = self.tokenizer(text, return_offsets_mapping=True, max_length=self.max_len,add_special_tokens=True)
mapping=encoder_text["offset_mapping"]
text2tokens=self.tokenizer.tokenize(text,add_special_tokens=True, max_length=self.max_len)
input_ids = encoder_text["input_ids"]
token_type_ids = encoder_text["token_type_ids"] #RoBERTa不需要NSP任务
attention_mask = encoder_text["attention_mask"]
spoes = set()
for s, s_pos, p, o, o_pos in item['spo_list']:
p_id = self.schema[p]
ent2token = self.tokenizer.tokenize(s, add_special_tokens=False)
token_start_indexs = [i for i, v in enumerate(text2tokens) if v == ent2token[0] if i<=511]#先定位到切词后的位置
token_end_indexs = [i for i, v in enumerate(text2tokens) if v == ent2token[-1] if i<=511]
#根据光标找到且此后的位置,token2char_span_mapping[x][-1]-1 减1是因为原始的char_span是闭区间,而token2char_span是开区间
s_start_index_ = list(filter(lambda x: mapping[x][0] == s_pos[0], token_start_indexs))#切词后位置对应下标
s_end_index_ = list(filter(lambda x: mapping[x][-1] - 1 == s_pos[1]-1, token_end_indexs))
ent2token = self.tokenizer.tokenize(o, add_special_tokens=False)
token_start_indexs = [i for i, v in enumerate(text2tokens) if v == ent2token[0] if i<=511]
token_end_indexs = [i for i, v in enumerate(text2tokens) if v == ent2token[-1] if i<=511]
#根据光标找到且此后的位置,token2char_span_mapping[x][-1]-1 减1是因为原始的char_span是闭区间,而token2char_span是开区间
o_start_index_ = list(filter(lambda x: mapping[x][0] == o_pos[0], token_start_indexs))
o_end_index_ = list(filter(lambda x: mapping[x][-1] - 1 == o_pos[1]-1, token_end_indexs))
if s_start_index_==[] or s_end_index_==[] or o_start_index_==[] or o_end_index_==[]:
continue
if s_start_index_[0] >= self.max_len or s_end_index_[0] >= self.max_len or o_start_index_[0] >= self.max_len or o_end_index_[0] >= self.max_len:
continue
if s_start_index_ <= s_end_index_ and o_start_index_ <= o_end_index_:
spoes.add((s_start_index_[0], s_end_index_[0], p_id, o_start_index_[0], o_end_index_[0]))
entity_labels = [set() for i in range(2)]
head_labels = [set() for i in range(len(self.schema))]
tail_labels = [set() for i in range(len(self.schema))]
for sh, st, p, oh, ot in spoes:
entity_labels[0].add((sh, st)) #实体提取:2个类型,头实体or尾实体
entity_labels[1].add((oh, ot))
head_labels[p].add((sh, oh)) #类似TP-Linker
tail_labels[p].add((st, ot))
for label in entity_labels+head_labels+tail_labels:
if not label:
label.add((0,0))
# 例如entity = [{(1,3)}, {(4,5), (7,9)}]
# entity[0]即{(1,3)}代表头实体首尾, entity[1]即{(4,5),{7,9}}代表尾实体首尾
# 需要标签对齐为 [[[1,3][0,0]] , [[4,5][7,9]]]
entity_labels = sequence_padding([list(l) for l in entity_labels])
head_labels = sequence_padding([list(l) for l in head_labels])
tail_labels = sequence_padding([list(l) for l in tail_labels])
return text, entity_labels, head_labels, tail_labels, input_ids, attention_mask, token_type_ids
def __getitem__(self, idx):
item = self.data[idx]
return self.encoder(item)
@staticmethod
def collate(examples):
batch_token_ids, batch_mask_ids, batch_token_type_ids = [], [], []
batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []
text_list = []
for item in examples:
text, entity_labels, head_labels, tail_labels, input_ids, attention_mask, token_type_ids = item
batch_entity_labels.append(entity_labels)
batch_head_labels.append(head_labels)
batch_tail_labels.append(tail_labels)
batch_token_ids.append(input_ids)
batch_mask_ids.append(attention_mask)
batch_token_type_ids.append(token_type_ids)
text_list.append(text)
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids)).long()
batch_mask_ids = torch.tensor(sequence_padding(batch_mask_ids)).float()
batch_token_type_ids = torch.tensor(sequence_padding(batch_token_type_ids)).long()#RoBERTa 不需要NSP
batch_entity_labels = torch.tensor(sequence_padding(batch_entity_labels, seq_dims=2)).long()
batch_head_labels = torch.tensor(sequence_padding(batch_head_labels, seq_dims=2)).long()
batch_tail_labels = torch.tensor(sequence_padding(batch_tail_labels, seq_dims=2)).long()\
return text_list, batch_token_ids, batch_mask_ids, batch_token_type_ids, batch_entity_labels, batch_head_labels, batch_tail_labels