Skip to content

Commit

Permalink
fightihg with belief state
Browse files Browse the repository at this point in the history
Signed-off-by: nvidia <[email protected]>
  • Loading branch information
tkornuta-nvidia committed Jun 4, 2020
1 parent f822ec7 commit 7b9a899
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
)

# 2. Forward pass throught Dialog Policy Manager module (Rule-Based, queries a "simple DB" to get required data).
system_acts, belief_state = dialog_pipeline.modules[dialog_pipeline.steps[4]].forward(
belief_state, system_acts = dialog_pipeline.modules[dialog_pipeline.steps[4]].forward(
belief_state=belief_state, request_state=request_state
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,7 @@ class TemplateNLGMultiWOZNM(NonTrainableNM):
"""Generate a natural language utterance conditioned on the dialog act.
"""

@property
@add_port_docs()
def input_ports(self):
"""Returns definitions of module input ports.
system_acts (list): list of system actions action produced by dialog policy module
"""
return {"system_acts": NeuralType(axes=tuple('ANY'), elements_type=VoidType())}

@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
system_uttr (str): generated system's response
"""
return {"system_uttr": NeuralType(axes=(AxisType(kind=AxisKind.Time)), elements_type=StringType())}

def __init__(self, mode="auto_manual"):
def __init__(self, mode="auto_manual", name=None):
"""
Initializes the object
Args:
Expand All @@ -94,6 +78,9 @@ def __init__(self, mode="auto_manual"):
- `auto_manual`: use auto templates first. When fails, use manual templates.
both template are dict, *_template[dialog_act][slot] is a list of templates.
"""
# Call base class constructor.
NonTrainableNM.__init__(self, name=name)

self.mode = mode
template_dir = os.path.dirname(os.path.abspath(__file__))

Expand All @@ -104,6 +91,22 @@ def read_json(filename):
self.auto_system_template = read_json(os.path.join(template_dir, 'auto_system_template_nlg.json'))
self.manual_system_template = read_json(os.path.join(template_dir, 'manual_system_template_nlg.json'))

@property
@add_port_docs()
def input_ports(self):
"""Returns definitions of module input ports.
system_acts (list): list of system actions action produced by dialog policy module
"""
return {"system_acts": NeuralType(axes=tuple('ANY'), elements_type=VoidType())}

@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
system_uttr (str): generated system's response
"""
return {"system_uttr": NeuralType(axes=(AxisType(kind=AxisKind.Time)), elements_type=StringType())}

def forward(self, system_acts):
"""
Generates system response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from nemo.backends.pytorch.nm import NonTrainableNM
from nemo.collections.nlp.data.datasets.multiwoz_dataset.dbquery import Database
from nemo.collections.nlp.data.datasets.multiwoz_dataset.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
from nemo.core import AxisKind, AxisType, NeuralType, StringType, VoidType
from nemo.core.neural_types import *
from nemo.utils.decorators import add_port_docs

__all__ = ['RuleBasedMultiwozBotNM']
Expand Down Expand Up @@ -93,14 +93,42 @@ class RuleBasedMultiwozBotNM(NonTrainableNM):
Rule-based bot. Implemented for Multiwoz dataset.
"""

def __init__(self, data_dir: str, name: str = None):
"""
Initializes the object
Args:
data_dir (str): path to data directory
name: name of the modules (DEFAULT: none)
"""
# Call base class constructor.
NonTrainableNM.__init__(self, name=name)
# Set init values of attributes.
self.last_state = {}
self.db = Database(data_dir)
self.last_request_state = {}
self.last_belief_state = {}
self.recommend_flag = -1
self.choice = ""

@property
@add_port_docs()
def input_ports(self):
"""Returns definitions of module input ports.
"""
return {
'belief_state': NeuralType(axes=(AxisType(kind=AxisKind.Time, is_list=True)), elements_type=StringType()),
'request_state': NeuralType(axes=(AxisType(kind=AxisKind.Time)), elements_type=StringType()),
'belief_state': NeuralType(
axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(
kind=AxisKind.MultiWOZDomain, is_list=True
), # always 7 domains - but cannot set size with is_list!
],
elements_type=Length(),
),
'request_state': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)],
elements_type=StringType(),
),
}

@property
Expand All @@ -111,32 +139,30 @@ def output_ports(self):
belief_state (dict): dialogue state with slot-slot_values pairs for all domains
"""
return {
'system_acts': NeuralType(axes=tuple('ANY'), elements_type=VoidType()),
'belief_state': NeuralType(axes=(AxisType(kind=AxisKind.Time, is_list=True)), elements_type=StringType()),
'belief_state': NeuralType(
axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(
kind=AxisKind.MultiWOZDomain, is_list=True
), # always 7 domains - but cannot set size with is_list!
],
elements_type=Length(),
),
'system_acts': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)],
elements_type=StringType(),
),
}

def __init__(self, data_dir):
"""
Initializes the object
Args:
data_dir (str): path to data directory
"""
self.last_state = {}
self.db = Database(data_dir)
self.last_request_state = {}
self.last_belief_state = {}
self.recommend_flag = -1
self.choice = ""

def forward(self, belief_state, request_state):
"""
Generated System Act and add it to the belief state
Args:
belief_state (dict): dialogue state with slot-slot_values pairs for all domains
request_state (dict): requested slots dict
Returns:
system_acts (list): DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...}
belief_state (dict): updated belief state
system_acts (list): DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...}
"""

if self.recommend_flag != -1:
Expand Down Expand Up @@ -198,7 +224,7 @@ def forward(self, belief_state, request_state):
logging.debug("DPM output: %s", system_acts)
logging.debug("Belief State after DPM: %s", belief_state)
logging.debug("Request State after DPM: %s", request_state)
return system_acts, belief_state
return belief_state, system_acts

def _update_greeting(self, user_act, DA):
""" General request / inform. """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def input_ports(self):
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(kind=AxisKind.MultiWOZDomain, is_list=True), # 7 domains
],
elements_type=MultiWOZSlotValue(),
elements_type=Length(),
),
'user_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()),
}
Expand All @@ -66,11 +66,14 @@ def output_ports(self):
"""
return {
'belief_state': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True)],
elements_type=StringType(),
axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(kind=AxisKind.MultiWOZDomain, is_list=True), # 7 domains
],
elements_type=Length(),
),
'request_state': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True)],
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)],
elements_type=StringType(),
),
}
Expand Down
4 changes: 0 additions & 4 deletions nemo/core/neural_types/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,6 @@ class Length(IntType):
"""Type representing an element storing a "length" (e.g. length of a list)."""


class Length(IntType):
"""Type representing an element storing a "length" (e.g. length of a list)."""


class SlotValue(ElementType):
"""Element type representing slot-value pair."""

Expand Down
2 changes: 1 addition & 1 deletion nemo/core/neural_types/neural_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def compare_and_raise_error(self, parent_type_name, port_name, second_object):
and type_comatibility != NeuralTypeComparisonResult.GREATER
):
raise NeuralPortNmTensorMismatchError(
parent_type_name, port_name, str(self), str(second_object), type_comatibility
parent_type_name, port_name, str(self), str(second_object.ntype), type_comatibility
)

@staticmethod
Expand Down

0 comments on commit 7b9a899

Please sign in to comment.