From 7b9a89900808310c8eb4216a28aed2b2e4d987ba Mon Sep 17 00:00:00 2001 From: nvidia Date: Thu, 4 Jun 2020 11:51:03 -0700 Subject: [PATCH] fightihg with belief state Signed-off-by: nvidia --- .../rule_based_policy_multiwoz.py | 2 +- .../dialogue_state_tracking/nlg_multiwoz.py | 37 ++++++----- .../rule_based_multiwoz_bot.py | 66 +++++++++++++------ .../trade_state_update_nm.py | 11 ++-- nemo/core/neural_types/elements.py | 4 -- nemo/core/neural_types/neural_type.py | 2 +- 6 files changed, 75 insertions(+), 47 deletions(-) diff --git a/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py b/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py index b23d3dd2f923..4ba0d7b4f3f1 100644 --- a/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py +++ b/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py @@ -101,7 +101,7 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state) ) # 2. Forward pass throught Dialog Policy Manager module (Rule-Based, queries a "simple DB" to get required data). - system_acts, belief_state = dialog_pipeline.modules[dialog_pipeline.steps[4]].forward( + belief_state, system_acts = dialog_pipeline.modules[dialog_pipeline.steps[4]].forward( belief_state=belief_state, request_state=request_state ) diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/nlg_multiwoz.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/nlg_multiwoz.py index 2c3c077daf39..b4b53196e7e0 100644 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/nlg_multiwoz.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/nlg_multiwoz.py @@ -68,23 +68,7 @@ class TemplateNLGMultiWOZNM(NonTrainableNM): """Generate a natural language utterance conditioned on the dialog act. """ - @property - @add_port_docs() - def input_ports(self): - """Returns definitions of module input ports. - system_acts (list): list of system actions action produced by dialog policy module - """ - return {"system_acts": NeuralType(axes=tuple('ANY'), elements_type=VoidType())} - - @property - @add_port_docs() - def output_ports(self): - """Returns definitions of module output ports. - system_uttr (str): generated system's response - """ - return {"system_uttr": NeuralType(axes=(AxisType(kind=AxisKind.Time)), elements_type=StringType())} - - def __init__(self, mode="auto_manual"): + def __init__(self, mode="auto_manual", name=None): """ Initializes the object Args: @@ -94,6 +78,9 @@ def __init__(self, mode="auto_manual"): - `auto_manual`: use auto templates first. When fails, use manual templates. both template are dict, *_template[dialog_act][slot] is a list of templates. """ + # Call base class constructor. + NonTrainableNM.__init__(self, name=name) + self.mode = mode template_dir = os.path.dirname(os.path.abspath(__file__)) @@ -104,6 +91,22 @@ def read_json(filename): self.auto_system_template = read_json(os.path.join(template_dir, 'auto_system_template_nlg.json')) self.manual_system_template = read_json(os.path.join(template_dir, 'manual_system_template_nlg.json')) + @property + @add_port_docs() + def input_ports(self): + """Returns definitions of module input ports. + system_acts (list): list of system actions action produced by dialog policy module + """ + return {"system_acts": NeuralType(axes=tuple('ANY'), elements_type=VoidType())} + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + system_uttr (str): generated system's response + """ + return {"system_uttr": NeuralType(axes=(AxisType(kind=AxisKind.Time)), elements_type=StringType())} + def forward(self, system_acts): """ Generates system response diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_multiwoz_bot.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_multiwoz_bot.py index 83ec4c8e6943..5b3a34020193 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_multiwoz_bot.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_multiwoz_bot.py @@ -32,7 +32,7 @@ from nemo.backends.pytorch.nm import NonTrainableNM from nemo.collections.nlp.data.datasets.multiwoz_dataset.dbquery import Database from nemo.collections.nlp.data.datasets.multiwoz_dataset.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA -from nemo.core import AxisKind, AxisType, NeuralType, StringType, VoidType +from nemo.core.neural_types import * from nemo.utils.decorators import add_port_docs __all__ = ['RuleBasedMultiwozBotNM'] @@ -93,14 +93,42 @@ class RuleBasedMultiwozBotNM(NonTrainableNM): Rule-based bot. Implemented for Multiwoz dataset. """ + def __init__(self, data_dir: str, name: str = None): + """ + Initializes the object + Args: + data_dir (str): path to data directory + name: name of the modules (DEFAULT: none) + """ + # Call base class constructor. + NonTrainableNM.__init__(self, name=name) + # Set init values of attributes. + self.last_state = {} + self.db = Database(data_dir) + self.last_request_state = {} + self.last_belief_state = {} + self.recommend_flag = -1 + self.choice = "" + @property @add_port_docs() def input_ports(self): """Returns definitions of module input ports. """ return { - 'belief_state': NeuralType(axes=(AxisType(kind=AxisKind.Time, is_list=True)), elements_type=StringType()), - 'request_state': NeuralType(axes=(AxisType(kind=AxisKind.Time)), elements_type=StringType()), + 'belief_state': NeuralType( + axes=[ + AxisType(kind=AxisKind.Batch, is_list=True), + AxisType( + kind=AxisKind.MultiWOZDomain, is_list=True + ), # always 7 domains - but cannot set size with is_list! + ], + elements_type=Length(), + ), + 'request_state': NeuralType( + axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)], + elements_type=StringType(), + ), } @property @@ -111,23 +139,21 @@ def output_ports(self): belief_state (dict): dialogue state with slot-slot_values pairs for all domains """ return { - 'system_acts': NeuralType(axes=tuple('ANY'), elements_type=VoidType()), - 'belief_state': NeuralType(axes=(AxisType(kind=AxisKind.Time, is_list=True)), elements_type=StringType()), + 'belief_state': NeuralType( + axes=[ + AxisType(kind=AxisKind.Batch, is_list=True), + AxisType( + kind=AxisKind.MultiWOZDomain, is_list=True + ), # always 7 domains - but cannot set size with is_list! + ], + elements_type=Length(), + ), + 'system_acts': NeuralType( + axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)], + elements_type=StringType(), + ), } - def __init__(self, data_dir): - """ - Initializes the object - Args: - data_dir (str): path to data directory - """ - self.last_state = {} - self.db = Database(data_dir) - self.last_request_state = {} - self.last_belief_state = {} - self.recommend_flag = -1 - self.choice = "" - def forward(self, belief_state, request_state): """ Generated System Act and add it to the belief state @@ -135,8 +161,8 @@ def forward(self, belief_state, request_state): belief_state (dict): dialogue state with slot-slot_values pairs for all domains request_state (dict): requested slots dict Returns: - system_acts (list): DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...} belief_state (dict): updated belief state + system_acts (list): DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...} """ if self.recommend_flag != -1: @@ -198,7 +224,7 @@ def forward(self, belief_state, request_state): logging.debug("DPM output: %s", system_acts) logging.debug("Belief State after DPM: %s", belief_state) logging.debug("Request State after DPM: %s", request_state) - return system_acts, belief_state + return belief_state, system_acts def _update_greeting(self, user_act, DA): """ General request / inform. """ diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py index a780f607a5b1..07ed0fd13e40 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py @@ -54,7 +54,7 @@ def input_ports(self): AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.MultiWOZDomain, is_list=True), # 7 domains ], - elements_type=MultiWOZSlotValue(), + elements_type=Length(), ), 'user_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()), } @@ -66,11 +66,14 @@ def output_ports(self): """ return { 'belief_state': NeuralType( - axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True)], - elements_type=StringType(), + axes=[ + AxisType(kind=AxisKind.Batch, is_list=True), + AxisType(kind=AxisKind.MultiWOZDomain, is_list=True), # 7 domains + ], + elements_type=Length(), ), 'request_state': NeuralType( - axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True)], + axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)], elements_type=StringType(), ), } diff --git a/nemo/core/neural_types/elements.py b/nemo/core/neural_types/elements.py index 4ab261b67a16..909044ecd00c 100644 --- a/nemo/core/neural_types/elements.py +++ b/nemo/core/neural_types/elements.py @@ -280,10 +280,6 @@ class Length(IntType): """Type representing an element storing a "length" (e.g. length of a list).""" -class Length(IntType): - """Type representing an element storing a "length" (e.g. length of a list).""" - - class SlotValue(ElementType): """Element type representing slot-value pair.""" diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index 29634b2d2595..03c5273715d7 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -121,7 +121,7 @@ def compare_and_raise_error(self, parent_type_name, port_name, second_object): and type_comatibility != NeuralTypeComparisonResult.GREATER ): raise NeuralPortNmTensorMismatchError( - parent_type_name, port_name, str(self), str(second_object), type_comatibility + parent_type_name, port_name, str(self), str(second_object.ntype), type_comatibility ) @staticmethod