Skip to content

Commit

Permalink
UtteranceEncoder neural types wip
Browse files Browse the repository at this point in the history
Signed-off-by: nvidia <[email protected]>
  • Loading branch information
tkornuta-nvidia committed Jun 4, 2020
1 parent d913c65 commit 6f3980b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,9 @@ def divideData(data, infold, outfold):
with open(f'{outfold}/train_dials.json', 'w') as f:
json.dump(train_dials, f, indent=4)

print(f"Saving done. Generated dialogs: {count_train} train, {count_val} val, {count_test} test.")
print(
f"Processing done and saved in `{outfold}`. Generated dialogs: {count_train} train, {count_val} val, {count_test} test."
)


if __name__ == "__main__":
Expand All @@ -508,7 +510,7 @@ def divideData(data, infold, outfold):
if_exist(abs_target_data_dir, ['ontology.json', 'dev_dials.json', 'test_dials.json', 'train_dials.json', 'db'])
and not args.overwrite_files
):
print(f'Data is already processed and stored at {abs_source_data_dir}, skipping pre-processing.')
print(f'Data is already processed and stored at {abs_target_data_dir}, skipping pre-processing.')
exit(0)

fin = open('multiwoz_mapping.pair', 'r')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

# Examples: two "separate" dialogs (one single-turn, one multiple-turn).
examples = [
#["I want to find a moderate hotel with internet and parking in the east"],
# ["I want to find a moderate hotel with internet and parking in the east"],
[
"Is there a train from Ely to Cambridge on Tuesday ?",
"I need to arrive by 11 am .",
Expand Down Expand Up @@ -202,8 +202,8 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
)

# Set evaluation mode - for trainable modules.
#trade_encoder.eval()
#trade_decoder.eval()
# trade_encoder.eval()
# trade_decoder.eval()

# "Execute" the graph - depending on the mode.
if args.mode == 'interactive':
Expand All @@ -222,7 +222,7 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
logging.info("============ Starting a new dialogue ============")
else:
# Pass the "user uterance" as inputs to the dialog pipeline.
system_uttr, belief_state, dial_history = forward_(
system_uttr, belief_state, dial_history = forward(
dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def output_ports(self):
"""Returns definitions of module output ports.
"""
return {
'belief_state': NeuralType(axes=tuple(AxisType(kind=AxisKind.Time, is_list=True)), elements_type=StringType()),
'belief_state': NeuralType(
axes=tuple(AxisType(kind=AxisKind.Time, is_list=True)), elements_type=StringType()
),
'request_state': NeuralType(axes=tuple(AxisType(kind=AxisKind.Time)), elements_type=StringType()),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from nemo.backends.pytorch.nm import NonTrainableNM
from nemo.core.neural_types import AxisKind, AxisType, ChannelType, LengthsType, NeuralType, StringType
from nemo.core.neural_types import AgentUtterance, AxisKind, AxisType, ChannelType, LengthsType, NeuralType, StringType
from nemo.utils import logging
from nemo.utils.decorators import add_port_docs

Expand All @@ -46,26 +46,17 @@ def input_ports(self):
"""
return {
'dial_history': NeuralType(
axes=(
AxisType(kind=AxisKind.Batch),
AxisType(kind=AxisKind.Channel, size=3),
AxisType(kind=AxisKind.Height, size=224),
AxisType(kind=AxisKind.Width, size=224),
),
axes=(AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True),),
elements_type=AgentUtterance(),
),
'user_uttr': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True)],
elements_type=StringType(),
),
'user_uttr': NeuralType(axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(kind=AxisKind.Time, is_list=True)
],
elements_type=StringType()),
),
'sys_uttr': NeuralType(
axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(kind=AxisKind.Time)
],
elements_type=StringType()
),
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time)],
elements_type=StringType(),
),
}

@property
Expand All @@ -79,11 +70,10 @@ def output_ports(self):
return {
'src_ids': NeuralType(('B', 'T'), elements_type=ChannelType()),
'src_lens': NeuralType(tuple('B'), elements_type=LengthsType()),
'dial_history': NeuralType(axes=(
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(kind=AxisKind.Time, is_list=True),
),
elements_type=StringType()),
'dial_history': NeuralType(
axes=(AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True),),
elements_type=StringType(),
),
}

def __init__(self, data_desc):
Expand Down
5 changes: 4 additions & 1 deletion nemo/core/neural_types/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
'NormalizedImageValue',
'StringLabel',
'StringType',
'AgentUtterance',
]

import abc
Expand Down Expand Up @@ -250,7 +251,9 @@ class StringType(ElementType):

class AgentUtterance(ElementType):
"Element type representing utterance returned by an agent (user or system) participating in a dialog."

def __str__(self):
return "Utterance returned by an agent (user or system) participating in a dialog."

def fields(self):
return ("Agent", "Utterance")
return ("Agent", "Utterance")

0 comments on commit 6f3980b

Please sign in to comment.