-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathrasa_service.py
24 lines (17 loc) · 897 Bytes
/
rasa_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from rasa_nlu.model import Interpreter
from text_classification_benchmarks.api_services.api_service import ApiService
def create_import_file(train_df, classes, output_path='.'):
grouped = train_df.groupby(['label'])
with open('{}/nlu.md'.format(output_path), 'w') as f:
for label, indices in grouped.groups.items():
f.write('## intent:{}\n'.format(classes[label]))
for utterance in train_df.utterance.loc[indices].values:
f.write('- {}\n'.format(utterance))
f.write('\n')
class RasaService(ApiService):
def __init__(self, model_path, classes, max_api_calls=None, verbose=False):
super().__init__(classes, max_api_calls, verbose)
self.interpreter = Interpreter.load(model_path)
def predict(self, utterance):
result = self.interpreter.parse(utterance)
return result['intent']['name']