forked from tech-srl/code2seq
-
Notifications
You must be signed in to change notification settings - Fork 17
/
interactive_predict.py
60 lines (53 loc) · 2.67 KB
/
interactive_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from common import Common
from java_extractor import JavaExtractor
from cpp_extractor import CppExtractor
SHOW_TOP_CONTEXTS = 10
MAX_PATH_LENGTH = 8
MAX_PATH_WIDTH = 2
EXTRACTION_API = 'https://po3g2dx2qa.execute-api.us-east-1.amazonaws.com/production/extractmethods'
class InteractivePredictor:
exit_keywords = ['exit', 'quit', 'q']
def __init__(self, config, model, language):
self.model = model
self.config = config
if language == 'java':
self.path_extractor = JavaExtractor(config, EXTRACTION_API, self.config.MAX_PATH_LENGTH, max_path_width=2)
elif language == 'cpp':
self.path_extractor = CppExtractor(config)
else:
assert False, 'Unsupported language model'
@staticmethod
def read_file(input_filename):
with open(input_filename, 'r') as file:
return file.readlines()
def predict(self):
input_filename = 'Input.source'
print('Serving')
while True:
print('Modify the file: "' + input_filename + '" and press any key when ready, or "q" / "exit" to exit')
user_input = input()
if user_input.lower() in self.exit_keywords:
print('Exiting...')
return
user_input = ' '.join(self.read_file(input_filename))
try:
predict_lines, pc_info_dict = self.path_extractor.extract_paths(user_input)
except ValueError:
continue
model_results = self.model.predict(predict_lines)
prediction_results = Common.parse_results(model_results, pc_info_dict, topk=SHOW_TOP_CONTEXTS)
for index, method_prediction in prediction_results.items():
print('Original name:\t' + method_prediction.original_name)
if self.config.BEAM_WIDTH == 0:
print('Predicted:\t%s' % [step.prediction for step in method_prediction.predictions])
for timestep, single_timestep_prediction in enumerate(method_prediction.predictions):
print('Attention:')
print('TIMESTEP: %d\t: %s' % (timestep, single_timestep_prediction.prediction))
for attention_obj in single_timestep_prediction.attention_paths:
print('%f\tcontext: %s,%s,%s' % (
attention_obj['score'], attention_obj['token1'], attention_obj['path'],
attention_obj['token2']))
else:
print('Predicted:')
for predicted_seq in method_prediction.predictions:
print('\t%s' % predicted_seq.prediction)