Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Policy Manager and NLG modules for MultiWOZ #691

Merged
merged 39 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5639eeb
pm+nlg for multiwoz init
ekmb May 29, 2020
a3385fb
pipeline is working, init clean up
ekmb May 29, 2020
329402a
headers added
ekmb May 29, 2020
e2460c6
fixed invalid .json file, added db files to multiwoz preprocessing
ekmb May 29, 2020
d2aaeda
code clean up
ekmb Jun 1, 2020
90b885a
lgtm fixes
ekmb Jun 1, 2020
4538da5
docs for TRADE update, jenkins for ruled_based example
ekmb Jun 1, 2020
aeb46a3
jenkins fix
ekmb Jun 1, 2020
d899791
ports refactor wip
ekmb Jun 2, 2020
5ff2d2c
ports refactor wip
ekmb Jun 2, 2020
97c9dce
wip works
ekmb Jun 2, 2020
7c943a0
neural types refactored
ekmb Jun 2, 2020
12256c9
remove unused
ekmb Jun 2, 2020
7964c7e
lgtm fixes
ekmb Jun 2, 2020
7c791ad
typo
ekmb Jun 2, 2020
73120bc
state dict splited
ekmb Jun 3, 2020
307b942
merge
ekmb Jun 3, 2020
b752fec
lgtm fixes
ekmb Jun 3, 2020
315bd79
fixing the process script, moved multiwoz_mapping.pair to multiwoz, e…
tkornuta-nvidia Jun 3, 2020
800b771
formatting fix
tkornuta-nvidia Jun 3, 2020
fa33d33
reformatted the code, ready for definition of NG by connecting the mo…
tkornuta-nvidia Jun 4, 2020
d913c65
work in progress-ess, not working, internet issues
tkornuta-nvidia Jun 4, 2020
6f3980b
UtteranceEncoder neural types wip
tkornuta-nvidia Jun 4, 2020
0fc5372
utterance encoder neural types
tkornuta-nvidia Jun 4, 2020
0afa1bc
updating trade outputs
tkornuta-nvidia Jun 4, 2020
f822ec7
updating trade outputs
tkornuta-nvidia Jun 4, 2020
7b9a899
fightihg with belief state
tkornuta-nvidia Jun 4, 2020
a314765
Cannot make second named tuple work
tkornuta-nvidia Jun 4, 2020
7de4074
reorganized files, whole pipeline handshaking works
tkornuta-nvidia Jun 4, 2020
82c274c
reorganized files, whole pipeline handshaking works
tkornuta-nvidia Jun 4, 2020
8a8b7d9
polish
tkornuta-nvidia Jun 4, 2020
f1c3855
Fix of my dummy error
tkornuta-nvidia Jun 4, 2020
2f5b4a7
new examples
tkornuta-nvidia Jun 4, 2020
db3d410
style fix
ekmb Jun 4, 2020
f9f1d45
fixed TRADE training
ekmb Jun 4, 2020
9d97337
Added module responsible for sys uttr dialog history update
tkornuta-nvidia Jun 5, 2020
3bd914e
LGTM fix
tkornuta-nvidia Jun 5, 2020
28b84b3
moved dialog specific axesc andctypes to nlp/neural_types.py, refacto…
tkornuta-nvidia Jun 5, 2020
c802bdc
style fix
tkornuta-nvidia Jun 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ To release a new version, please update the changelog as followed:
- Added documentation for 8 kHz model ([PR #632](https://github.com/NVIDIA/NeMo/pull/632)) - @jbalam-nv
- The Neural Graph is a high-level abstract concept empowering the users to build graphs consisting of many, interconnected Neural Modules. A user in his/her application can build any number of graphs, potentially spanning over the same modules. The import/export options combined with the lightweight API make Neural Graphs a perfect tool for rapid prototyping and experimentation. ([PR #413](https://github.com/NVIDIA/NeMo/pull/413)) - @tkornuta-nvidia
- Created the NeMo CV collection, added MNIST and CIFAR10 thin datalayers, implemented/ported several general usage trainable and non-trainable modules, added several new ElementTypes ([PR #654](https://github.com/NVIDIA/NeMo/pull/654)) - @tkornuta-nvidia
- Added SGD dataset and SGD model baseline ([PR #612](https://github.com/NVIDIA/NeMo/pull/612)) - @ekmb
- Policy Manager and Natural Language Generation Modules for MultiWOZ added ([PR #691](https://github.com/NVIDIA/NeMo/pull/691)) - @ekmb


### Changed
Expand Down
7 changes: 7 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ pipeline {
sh 'rm -rf examples/nlp/glue_benchmark/glue_output'
}
}
stage ('TRADE-Rule-based-DPM/NLG') {
steps {
sh 'cd examples/nlp/dialogue_state_tracking && python rule_based_policy_multiwoz.py --data_dir /home/TestData/nlp/multiwoz2.1/pm_nlg \
--encoder_ckpt /home/TestData/nlp/multiwoz2.1/pm_nlg/ckpts/EncoderRNN-EPOCH-10.pt \
--decoder_ckpt /home/TestData/nlp/multiwoz2.1/pm_nlg/ckpts/TRADEGenerator-EPOCH-10.pt'
}
}
}
}

Expand Down
27 changes: 27 additions & 0 deletions docs/sources/source/nlp/dialogue_state_tracking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,33 @@ You may find the checkpoints for the trained models on MultiWOZ 2.0 and MultiWOZ
predict special values for like **don't care** or **none** for the slots. The `process_multiwoz.py`_ script extracts the additional labels from the dataset and `dialogue_state_tracking_trade.py`_ script reports the **Gating Accuracy** as well.


Complete dialogue Pipeline with TRADE for MultiWOZ
--------------------------------------------------

The pre-trained TRADE model, as was mentioned above, is a Dialogue State Tracker (DST) responsible for extracting correct slot-slot_value pairs from the dialogue history. Once the system has this information, \
the next module called Dialogue Policy Manager (DPM) comes into play.
This module determines what actions the system should take, given the dialogue state passed from the DST module. For example, the DPM can request additional information from the user or \
inform the user about possible ways/options to fill out the user's original intent.
With the output of the DPM, the final dialogue module called Natural Language Generation (NLG) generates the system's response to the user's utterance.

NeMo provides Rule-based DPM and Rule-based NLG modules (source: `ConvLab-2: An Open-Source Toolkit for Building, Evaluating, and Diagnosing Dialogue Systems <https://github.com/thu-coai/ConvLab-2>`_) \
to complete the dialogue pipeline based on the TRADE model and MultiWOZ dataset.

To evaluate TRADE's model output and its role in the complete dialogue pipeline, use ``examples/nlp/dialogue_state_tracking\rule_based_policy_multiwoz.py``.
Before running this script, make sure to download the pre-trained TRADE model checkpoint following the steps above.

.. code-block:: bash

cd examples/nlp/dialogue_state_tracking
python rule_based_policy_multiwoz.py \
--data_dir <path to the data> \
--encoder_ckpt <path to checkpoint folder>\EncoderRNN.pt \
--decoder_ckpt <path to checkpoint folder>\TRADEGenerator.pt \
--mode example \

Use ``--mode interactive`` to chat with the system and ``--hide_output`` - to hide the intermediate output of the dialogue modules


References
----------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,17 @@
This script can be used to process and import MultiWOZ 2.0 and 2.1 datasets.
You may find more information on how to use this example in NeMo's documentation:
https://nvidia.github.io/NeMo/nlp/dialogue_state_tracking_trade.html

This file contains code artifacts adapted from the original implementation:
https://github.com/thu-coai/ConvLab-2
"""

import argparse
import json
import os
import re
import shutil
from os.path import exists
from os.path import exists, expanduser
from shutil import copyfile

from nemo.collections.nlp.data.datasets.datasets_utils import if_exist

Expand Down Expand Up @@ -378,7 +381,7 @@ def divideData(data, infold, outfold):
the data for three different sets"""

os.makedirs(outfold, exist_ok=True)
shutil.copyfile(f'{infold}/ontology.json', f'{outfold}/ontology.json')
copyfile(f'{infold}/ontology.json', f'{outfold}/ontology.json')

testListFile = []
fin = open(f'{infold}/testListFile.json', 'r')
Expand Down Expand Up @@ -436,6 +439,40 @@ def divideData(data, infold, outfold):
train_dials.append(dialogue)
count_train += 1

value_dict = json.load(open(f'{outfold}/ontology.json'))
new_ontology = {}
# k = {'taxi-arrive by':list_of_values} -> {'taxi':{'arriveby': list_ofslot_values}}
for k, v in value_dict.items():
domain, slot = k.split('-')
slot = slot.replace(' ', '')
if domain in new_ontology:
new_ontology[domain][slot] = v
else:
new_ontology[domain] = {slot: v}

with open(f'{outfold}/value_dict.json', 'w') as f:
json.dump(new_ontology, f, indent=4)

# save all data base *db.json file in a db folder
db_fold = os.path.join(outfold, 'db')
os.makedirs(db_fold, exist_ok=True)

for f in os.listdir(infold):
if '_db.json' in f:
copyfile(f'{infold}/{f}', f'{db_fold}/{f}')
# taxi_db.json file is missing a comma in the MultiWOZ2.1 dataset
# check if it's so and fix
if f == 'taxi_db.json':
try:
with open(os.path.join(db_fold, f)) as f_:
_ = json.load(f_)
except json.decoder.JSONDecodeError:
taxi_db_text = open(os.path.join(db_fold, f)).readlines()
taxi_db_text[2] = taxi_db_text[2].rstrip() + ',\n'
taxi_db_text = '{' + ''.join(taxi_db_text)[1:-2] + '}\n'
taxi_db_text = taxi_db_text.replace("'", '"')
open(os.path.join(db_fold, f), 'w').write(taxi_db_text)

# save all dialogues
with open(f'{outfold}/dev_dials.json', 'w') as f:
json.dump(val_dials, f, indent=4)
Expand All @@ -446,7 +483,9 @@ def divideData(data, infold, outfold):
with open(f'{outfold}/train_dials.json', 'w') as f:
json.dump(train_dials, f, indent=4)

print(f"Saving done. Generated dialogs: {count_train} train, {count_val} val, {count_test} test.")
print(
f"Processing done and saved in `{outfold}`. Generated dialogs: {count_train} train, {count_val} val, {count_test} test."
)


if __name__ == "__main__":
Expand All @@ -456,14 +495,22 @@ def divideData(data, infold, outfold):
"--source_data_dir", required=True, type=str, help='The path to the folder containing the MultiWOZ data files.'
)
parser.add_argument("--target_data_dir", default='multiwoz2.1/', type=str)
parser.add_argument("--overwrite_files", action="store_true", help="Whether to overwrite preprocessed file")
args = parser.parse_args()

if not exists(args.source_data_dir):
raise FileNotFoundError(f"{args.source_data_dir} does not exist.")
# Get the absolute path.
abs_source_data_dir = expanduser(args.source_data_dir)
abs_target_data_dir = expanduser(args.target_data_dir)

if not exists(abs_source_data_dir):
raise FileNotFoundError(f"{abs_source_data_dir} does not exist.")

# Check if the files exist
if if_exist(args.target_data_dir, ['ontology.json', 'dev_dials.json', 'test_dials.json', 'train_dials.json']):
print(f'Data is already processed and stored at {args.source_data_dir}, skipping pre-processing.')
if (
if_exist(abs_target_data_dir, ['ontology.json', 'dev_dials.json', 'test_dials.json', 'train_dials.json', 'db'])
and not args.overwrite_files
):
print(f'Data is already processed and stored at {abs_target_data_dir}, skipping pre-processing.')
exit(0)

fin = open('multiwoz_mapping.pair', 'r')
Expand All @@ -474,6 +521,6 @@ def divideData(data, infold, outfold):

print('Creating dialogues...')
# Process MultiWOZ dataset
delex_data = createData(args.source_data_dir)
delex_data = createData(abs_source_data_dir)
# Divide data
divideData(delex_data, args.source_data_dir, args.target_data_dir)
divideData(delex_data, abs_source_data_dir, abs_target_data_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@

import nemo.core as nemo_core
from nemo import logging
from nemo.backends.pytorch.common import EncoderRNN
from nemo.backends.pytorch.common.losses import CrossEntropyLossNM, LossAggregatorNM
from nemo.collections.nlp.callbacks.state_tracking_trade_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.multiwoz_dataset import MultiWOZDataDesc
from nemo.collections.nlp.nm.data_layers import MultiWOZDataLayer
from nemo.collections.nlp.nm.losses import MaskedLogLoss
from nemo.collections.nlp.nm.trainables import TRADEGenerator
from nemo.collections.nlp.nm.trainables import EncoderRNN, TRADEGenerator
from nemo.utils.lr_policies import get_lr_policy

parser = argparse.ArgumentParser(description='Dialogue state tracking with TRADE model on MultiWOZ dataset')
Expand Down Expand Up @@ -155,8 +154,8 @@ def create_pipeline(num_samples, batch_size, num_gpus, input_dropout, data_prefi
point_outputs, gate_outputs = decoder(
encoder_hidden=hidden,
encoder_outputs=outputs,
input_lens=input_data.src_lens,
src_ids=input_data.src_ids,
dialog_lens=input_data.src_lens,
dialog_ids=input_data.src_ids,
targets=input_data.tgt_ids,
)

Expand Down
Loading