Skip to content

Commit

Permalink
remove unused
Browse files Browse the repository at this point in the history
Signed-off-by: Evelina Bakhturina <[email protected]>
  • Loading branch information
ekmb committed Jun 2, 2020
1 parent 7c943a0 commit 12256c9
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 444 deletions.
19 changes: 10 additions & 9 deletions examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,18 @@ def init_session():
return '', '', default_state()


def get_system_responce(user_uttr, system_uttr, dialog_history, state):
def get_system_response(user_uttr, system_uttr, dialog_history, state):
"""
Returns system reply by passing user utterance through TRADE Dialogue State Tracker, then the output of the TRADE model to the
Rule-base Dialogue Policy Magager and the output of the Policy Manager to the Rule-based Natural language generation module
Returns system reply by passing system and user utterances (dialogue history) through the TRADE Dialogue State Tracker,
then the output of the TRADE model goes to the Rule-base Dialogue Policy Magager
and the output of the Policy Manager goes to the Rule-based Natural language generation module
Args:
user_uttr(str): User utterance
system_uttr(str): Previous system utterance
dialog_history(str): Diaglogue history contains all previous system and user utterances
user_uttr (str): User utterance
system_uttr (str): Previous system utterance
dialog_history (str): Diaglogue history contains all previous system and user utterances
state (dict): dialogue state
Returns:
system_utterance(str): system response
system_utterance (str): system response
state (dict): updated dialogue state
"""
src_ids, src_lens = utterance_encoder.forward(state=state, user_uttr=user_uttr, sys_uttr=system_uttr)
Expand Down Expand Up @@ -177,12 +178,12 @@ def get_system_responce(user_uttr, system_uttr, dialog_history, state):
system_uttr, dialog_history, state = init_session()
logging.info("============ Starting a new dialogue ============")
else:
get_system_responce(user_uttr, system_uttr, dialog_history, state)
get_system_response(user_uttr, system_uttr, dialog_history, state)

elif args.mode == 'example':
for example in examples:
logging.info("============ Starting a new dialogue ============")
system_uttr, dialog_history, state = init_session()
for user_uttr in example:
logging.info("User utterance: %s", user_uttr)
system_uttr, state = get_system_responce(user_uttr, system_uttr, dialog_history, state)
system_uttr, state = get_system_response(user_uttr, system_uttr, dialog_history, state)
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@ def forward(self, state):
for user_act_ in user_acts:
del DA[user_act_]

# print("Sys action: ", DA)

if DA == {}:
DA = {'general-greet': [['none', 'none']]}
tuples = []
Expand Down Expand Up @@ -284,12 +282,6 @@ def _update_DA(self, user_act, user_action, state, DA):
kb_result = self.db.query(domain.lower(), constraints)
self.kb_result[domain] = deepcopy(kb_result)

# print("\tConstraint: " + "{}".format(constraints))
# print("\tCandidate Count: " + "{}".format(len(kb_result)))
# if len(kb_result) > 0:
# print("Candidate: " + "{}".format(kb_result[0]))

# print(state['user_action'])
# Respond to user's request
if intent_type == 'Request':
if self.recommend_flag > 1:
Expand All @@ -310,7 +302,6 @@ def _update_DA(self, user_act, user_action, state, DA):

else:
# There's no result matching user's constraint
# if len(state['kb_results_dict']) == 0:
if len(kb_result) == 0:
if (domain + "-NoOffer") not in DA:
DA[domain + "-NoOffer"] = []
Expand Down Expand Up @@ -468,8 +459,6 @@ def _update_train(self, user_act, user_action, state, DA):
kb_result = self.db.query('train', constraints)
self.kb_result['Train'] = deepcopy(kb_result)

# print(constraints)
# print(len(kb_result))
if user_act == 'Train-Request':
del DA['Train-Request']
if 'Train-Inform' not in DA:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

'''
This file contains code artifacts adapted from the original implementation:
https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py
https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/dst/trade/multiwoz/trade.py
https://github.com/thu-coai/ConvLab-2
'''
import copy
import re
Expand All @@ -33,76 +34,6 @@
__all__ = ['TradeStateUpdateNM']


# class UtteranceEncoderNM(NonTrainableNM):
# """
# Encodes dialogue history (system and user utterances) into a Multiwoz dataset format
# Args:
# data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots,
# and associated vocabulary
# """

# @property
# @add_port_docs()
# def input_ports(self):
# """Returns definitions of module input ports.
# state (dict): dialogue state dictionary - see nemo.collections.nlp.data.datasets.multiwoz_dataset.state
# for the format
# user_uttr (str): user utterance
# sys_uttr (str): system utterace
# """
# return {
# "state": NeuralType(axes=tuple('ANY'), element_type=VoidType()),
# "user_uttr": NeuralType(axes=tuple('ANY'), element_type=VoidType()),
# "sys_uttr": NeuralType(axes=tuple('ANY'), element_type=VoidType())
# }


# @property
# @add_port_docs()
# def output_ports(self):
# """Returns definitions of module output ports.
# src_ids (int): token ids for dialogue history
# src_lens (int): length of the tokenized dialogue history
# """
# return {
# 'src_ids': NeuralType(('B', 'T'), element_type=ChannelType()),
# 'src_lens': NeuralType(tuple('B'), elemenet_type=LengthsType()),
# }

# def __init__(self, data_desc):
# """
# Init
# Args:
# data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots,
# and associated vocabulary
# """
# super().__init__()
# self.data_desc = data_desc

# def forward(self, state, user_uttr, sys_uttr):
# """
# Returns dialogue utterances in the format accepted by the TRADE Dialogue state tracking model
# Args:
# state (dict): state dictionary - see nemo.collections.nlp.data.datasets.multiwoz_dataset.state
# for the format
# user_uttr (str): user utterance
# sys_uttr (str): system utterace
# Returns:
# src_ids (int): token ids for dialogue history
# src_lens (int): length of the tokenized dialogue history
# """
# state["history"].append(["sys", sys_uttr])
# state["history"].append(["user", user_uttr])
# state["user_action"] = user_uttr
# logging.debug("Dialogue state: %s", state)

# context = ' ; '.join([item[1].strip().lower() for item in state['history']]).strip() + ' ;'
# context_ids = self.data_desc.vocab.tokens2ids(context.split())
# src_ids = torch.tensor(context_ids).unsqueeze(0).to(self._device)
# src_lens = torch.tensor(len(context_ids)).unsqueeze(0).to(self._device)
# return src_ids, src_lens


class TradeStateUpdateNM(NonTrainableNM):
"""
Takes the predictions of the TRADE Dialogue state tracking model,
Expand Down Expand Up @@ -288,9 +219,7 @@ def normalize_value(self, value_set, domain, slot, value):
if slot not in value_set[domain]:
logging.warning('slot {} no in domain {}'.format(slot, domain))
return value
# raise Exception(
# 'slot <{}> not found in db_values[{}]'.format(
# slot, domain))

value_list = value_set[domain][slot]
# exact match or containing match
v = self._match_or_contain(value, value_list)
Expand Down
Loading

0 comments on commit 12256c9

Please sign in to comment.