-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
116 lines (107 loc) · 4.94 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
from utils.helpers import read_lines
from gector.gec_model import GecBERTModel
def predict_for_file(input_file, output_file, model, batch_size=32):
test_data = read_lines(input_file)
predictions = []
cnt_corrections = 0
batch = []
for sent in test_data:
batch.append(sent.split())
if len(batch) == batch_size:
#preds, cnt = model.handle_batch(batch)
preds, cnt = model.control_handle_batch(batch)
predictions.extend(preds)
cnt_corrections += cnt
batch = []
if batch:
preds, cnt = model.handle_batch(batch)
predictions.extend(preds)
cnt_corrections += cnt
with open(output_file, 'w') as f:
f.write("\n".join([" ".join(x) for x in predictions]) + '\n')
return cnt_corrections
def main(args):
# get all paths
model = GecBERTModel(vocab_path=args.vocab_path,
model_paths=args.model_path,
max_len=args.max_len, min_len=args.min_len,
iterations=args.iteration_count,
min_error_probability=args.min_error_probability,
lowercase_tokens=args.lowercase_tokens,
model_name=args.transformer_model,
special_tokens_fix=args.special_tokens_fix,
log=False,
confidence=args.additional_confidence,
is_ensemble=args.is_ensemble,
weigths=args.weights)
cnt_corrections = predict_for_file(args.input_file, args.output_file, model,
batch_size=args.batch_size)
# evaluate with m2 or ERRANT
print(f"Produced overall corrections: {cnt_corrections}")
if __name__ == '__main__':
# read parameters
parser = argparse.ArgumentParser()
parser.add_argument('--model_path',
help='Path to the model file.', nargs='+',
required=True)
parser.add_argument('--vocab_path',
help='Path to the model file.',
default='data/output_vocabulary' # to use pretrained models
)
parser.add_argument('--input_file',
help='Path to the evalset file',
required=True)
parser.add_argument('--output_file',
help='Path to the output file',
required=True)
parser.add_argument('--max_len',
type=int,
help='The max sentence length'
'(all longer will be truncated)',
default=50)
parser.add_argument('--min_len',
type=int,
help='The minimum sentence length'
'(all longer will be returned w/o changes)',
default=3)
parser.add_argument('--batch_size',
type=int,
help='The size of hidden unit cell.',
default=128)
parser.add_argument('--lowercase_tokens',
type=int,
help='Whether to lowercase tokens.',
default=0)
parser.add_argument('--transformer_model',
choices=['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert'
'bert-large', 'roberta-large', 'xlnet-large'],
help='Name of the transformer model.',
default='roberta')
parser.add_argument('--iteration_count',
type=int,
help='The number of iterations of the model.',
default=5)
parser.add_argument('--additional_confidence',
type=float,
help='How many probability to add to $KEEP token.',
default=0)
parser.add_argument('--min_error_probability',
type=float,
help='Minimum probability for each action to apply. '
'Also, minimum error probability, as described in the paper.',
default=0.0)
parser.add_argument('--special_tokens_fix',
type=int,
help='Whether to fix problem with [CLS], [SEP] tokens tokenization. '
'For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa.',
default=1)
parser.add_argument('--is_ensemble',
type=int,
help='Whether to do ensembling.',
default=0)
parser.add_argument('--weights',
help='Used to calculate weighted average', nargs='+',
default=None)
args = parser.parse_args()
main(args)