-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_my_tokenizer.py
124 lines (92 loc) · 3.35 KB
/
train_my_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import os
from itertools import chain
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordLevelTrainer
from process_data.utils import CURRENT_DATA_BASE, ORIGINAL_DATA_BASE, read_file
BASE_PATH = "/home/ming/malware/inst2vec_bert/bert/"
def parse_args():
parser = argparse.ArgumentParser(
description="Train a word level tokenizer for ASM_BERT"
)
parser.add_argument(
"--vocab_size",
type=int,
default=2000,
help="The size of vocabulary used to train the tokenizer.",
)
parser.add_argument(
"--padding_length",
type=int,
default=32,
help="The length will be padded to by the tokenizer.",
)
args = parser.parse_args()
return args
def train_tokenizer(args, dataset):
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(
vocab_size=args.vocab_size,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
)
# def batch_iterator(batch_size=1000):
# for i in range(0, len(dataset), batch_size):
# yield dataset[i : i + batch_size]["text"]
# tokenizer.train_from_iterator(
# batch_iterator(), trainer=trainer, length=len(dataset)
# )
tokenizer.train_from_iterator(dataset, trainer)
return tokenizer
def save_tokenizer(tokenizer, tokenizer_file):
tokenizer.save(tokenizer_file)
def load_tokenizer(tokenizer_file):
if not os.path.exists(tokenizer_file):
print("{} doesn't exist, will be retrained...".format(tokenizer_file))
return None
print("The tokenizer has already been trained.")
return Tokenizer.from_file(tokenizer_file)
def post_process(tokenizer):
tokenizer.post_processor = TemplateProcessing(
single="[CLS] $A [SEP]",
pair="[CLS] $A [SEP] $B:1 [SEP]:1",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
],
)
return tokenizer
def tokenizer_encode(tokenizer, data):
return tokenizer.encode_batch(data)
def main(tokenizer_file=""):
args = parse_args()
tokenizer = load_tokenizer(tokenizer_file)
if tokenizer is not None:
return
# json_files = [
# os.path.join(CURRENT_DATA_BASE, "inst.1.{}.json".format(i)) for i in range(128)
# ]
# dataset = load_dataset("json", data_files=json_files, field="data")
text_files = [
os.path.join(ORIGINAL_DATA_BASE, "inst.{}.{}.txt.clean".format(i, group))
for group in ["pos", "neg"] for i in range(10)
]
dataset = []
for f in text_files:
dataset += read_file(f)
dataset = [tuple(sent[:-1].split("\t")) for sent in dataset]
print("Trainging tokenizer...")
tokenizer = train_tokenizer(args, chain.from_iterable(dataset))
tokenizer = post_process(tokenizer)
tokenizer.enable_padding(
pad_id=tokenizer.token_to_id("[PAD]"),
pad_token="[PAD]",
length=args.padding_length,
)
save_tokenizer(tokenizer, tokenizer_file)
if __name__ == "__main__":
main(os.path.join(CURRENT_DATA_BASE, "tokenizer-inst.all.json"))