-
Notifications
You must be signed in to change notification settings - Fork 2
/
translate.py
105 lines (83 loc) · 4.44 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
import logging
import torch
import os
from beaver.data import build_dataset
from beaver.infer import beam_search
from beaver.model import NMTModel
from beaver.utils import parseopt, get_device, calculate_bleu
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
opt = parseopt.parse_translate_args()
device = get_device()
def translate(dataset, fields, model, index):
already_1, hypothesis_1, references_1 = 0, [], []
already_2, hypothesis_2, references_2 = 0, [], []
already_3, hypothesis_3, references_3 = 0, [], []
for batch, flag in dataset:
# predictions = beam_search(opt, model, batch.src, fields, flag)
if flag == 3:
cls_predictions = beam_search(opt, model, batch.src, fields, 3)
hypothesis_3 += [fields["task3_tgt"].decode(p) for p in cls_predictions]
already_3 += len(cls_predictions)
logging.info("Task 3: %7d/%7d" % (already_3, dataset.task3_dataset.num_examples))
'''
if flag == 1:
predictions = beam_search(opt, model, batch.src, fields, 1)
hypothesis_1 += [fields["task1_tgt"].decode(p) for p in predictions]
already_1 += len(predictions)
logging.info("Task 1: %7d/%7d" % (already_1, dataset.task1_dataset.num_examples))
elif flag == 2:
predictions = beam_search(opt, model, batch.src, fields, 2)
hypothesis_2 += [fields["task2_tgt"].decode(p) for p in predictions]
already_2 += len(predictions)
logging.info("Task 2: %7d/%7d" % (already_2, dataset.task2_dataset.num_examples))
# cls
else:
mono_predictions = beam_search(opt, model, batch.src, fields, 3)
hypothesis_3 += [fields["task3_tgt"].decode(p) for p in mono_predictions]
already_3 += len(mono_predictions)
logging.info("Task 3: %7d/%7d" % (already_3, dataset.task3_dataset.num_examples))
origin_1 = sorted(zip(hypothesis_1, dataset.task1_dataset.seed), key=lambda t: t[1])
hypothesis_1 = [h for h, _ in origin_1]
with open(opt.output[0] + "." + str(index), "w", encoding="UTF-8") as out_file:
out_file.write("\n".join(hypothesis_1))
out_file.write("\n")
origin_2 = sorted(zip(hypothesis_2, dataset.task3_dataset.seed), key=lambda t: t[1])
hypothesis_2 = [h for h, _ in origin_2]
with open(opt.output[1] + "." + str(index), "w", encoding="UTF-8") as out_file:
out_file.write("\n".join(hypothesis_2))
out_file.write("\n")
origin_3 = sorted(zip(hypothesis_3, dataset.task3_dataset.seed), key=lambda t: t[1])
hypothesis_3 = [h for h, _ in origin_3]
'''
with open(opt.output[2] + "." + str(index), "w", encoding="UTF-8") as out_file:
out_file.write("\n".join(hypothesis_3))
out_file.write("\n")
logging.info("Translation finished. ")
def main():
logging.info("Build dataset...")
dataset = build_dataset(opt, [opt.input[0], opt.input[1], opt.input[2], opt.input[3], opt.input[4], opt.input[5]], opt.vocab, device, train=False)
fields = dataset.fields
pad_ids = {"src": fields["src"].pad_id,
"task1_tgt": fields["task1_tgt"].pad_id,
"task2_tgt": fields["task2_tgt"].pad_id,
"task3_tgt": fields["task3_tgt"].pad_id}
vocab_sizes = {"src": len(fields["src"].vocab),
"task1_tgt": len(fields["task1_tgt"].vocab),
"task2_tgt": len(fields["task2_tgt"].vocab),
"task3_tgt": len(fields["task3_tgt"].vocab)}
# load checkpoint from model_path
logging.info("decoding range %s - %s." % (opt.start, opt.end))
for index in range(int(opt.start), int(opt.end), 5000):
model_path = opt.model_path + str(index)
logging.info("decoding checkpoint %s" % (model_path))
if os.path.exists(model_path):
logging.info("Load checkpoint from %s." % model_path)
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
logging.info("Build model...")
model = NMTModel.load_model(checkpoint["opt"], pad_ids, vocab_sizes, checkpoint["model"]).to(device).eval()
logging.info("Start translation...")
with torch.set_grad_enabled(False):
translate(dataset, fields, model, index)
if __name__ == '__main__':
main()