diff --git a/CHANGELOG.md b/CHANGELOG.md index fa50c6a13405..1fbf6d6ac532 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,9 @@ To release a new version, please update the changelog as followed: - Updated nemo's use of the logging library. from nemo import logging is now the reccomended way of using the nemo logger. neural_factory.logger and all other instances of logger are now deprecated and planned for removal in the next version. Please see PR 267 for complete change information. ([PR #267](https://github.com/NVIDIA/NeMo/pull/267), [PR #283](https://github.com/NVIDIA/NeMo/pull/283), [PR #305](https://github.com/NVIDIA/NeMo/pull/305), [PR #311](https://github.com/NVIDIA/NeMo/pull/311)) - @blisc +- Added TRADE (dialogue state tracking model) on MultiWOZ dataset +([PR #322](https://github.com/NVIDIA/NeMo/pull/322)) - @chiphuyen, @VahidooX + ### Dependencies Update - Added dependency on `wrapt` (the new version of the `deprecated` warning) - @tkornuta-nvidia, @DEKHTIARJonathan diff --git a/examples/nlp/dialogue_state_tracking_trade.py b/examples/nlp/dialogue_state_tracking_trade.py new file mode 100644 index 000000000000..996e0195d721 --- /dev/null +++ b/examples/nlp/dialogue_state_tracking_trade.py @@ -0,0 +1,226 @@ +# ============================================================================= +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" An implementation of the paper "Transferable Multi-Domain State Generator +for Task-Oriented Dialogue Systems" (Wu et al., 2019 - ACL 2019) +Adopted from: https://github.com/jasonwu0731/trade-dst +""" + +import argparse +import math +import os + +import numpy as np + +import nemo.collections.nlp as nemo_nlp +import nemo.core as nemo_core +from nemo import logging +from nemo.backends.pytorch.common import EncoderRNN +from nemo.collections.nlp.callbacks.state_tracking_trade_callback import eval_epochs_done_callback, eval_iter_callback +from nemo.collections.nlp.data.datasets.state_tracking_trade_dataset import MultiWOZDataDesc +from nemo.utils.lr_policies import get_lr_policy + +parser = argparse.ArgumentParser(description='Dialog state tracking with TRADE model on MultiWOZ dataset') +parser.add_argument("--local_rank", default=None, type=int) +parser.add_argument("--batch_size", default=16, type=int) +parser.add_argument("--eval_batch_size", default=16, type=int) +parser.add_argument("--num_gpus", default=1, type=int) +parser.add_argument("--num_epochs", default=10, type=int) +parser.add_argument("--lr_warmup_proportion", default=0.0, type=float) +parser.add_argument("--lr", default=0.001, type=float) +parser.add_argument("--lr_policy", default=None, type=str) +parser.add_argument("--min_lr", default=1e-4, type=float) +parser.add_argument("--weight_decay", default=0.0, type=float) +parser.add_argument("--emb_dim", default=400, type=int) +parser.add_argument("--hid_dim", default=400, type=int) +parser.add_argument("--n_layers", default=1, type=int) +parser.add_argument("--dropout", default=0.2, type=float) +parser.add_argument("--input_dropout", default=0.2, type=float) +parser.add_argument("--data_dir", default='data/statetracking/multiwoz2.1', type=str) +parser.add_argument("--train_file_prefix", default='train', type=str) +parser.add_argument("--eval_file_prefix", default='test', type=str) +parser.add_argument("--work_dir", default='outputs', type=str) +parser.add_argument("--save_epoch_freq", default=-1, type=int) +parser.add_argument("--save_step_freq", default=-1, type=int) +parser.add_argument("--optimizer_kind", default="adam", type=str) +parser.add_argument("--amp_opt_level", default="O0", type=str, choices=["O0", "O1", "O2"]) +parser.add_argument("--shuffle_data", action='store_true') +parser.add_argument("--num_train_samples", default=-1, type=int) +parser.add_argument("--num_eval_samples", default=-1, type=int) +parser.add_argument("--grad_norm_clip", type=float, default=10, help="gradient clipping") +parser.add_argument("--teacher_forcing", default=0.5, type=float) +args = parser.parse_args() + +# List of the domains to be considered +domains = {"attraction": 0, "restaurant": 1, "taxi": 2, "train": 3, "hotel": 4} + +if not os.path.exists(args.data_dir): + raise ValueError(f'Data not found at {args.data_dir}') + +work_dir = f'{args.work_dir}/DST_TRADE' + +data_desc = MultiWOZDataDesc(args.data_dir, domains) + +nf = nemo_core.NeuralModuleFactory( + backend=nemo_core.Backend.PyTorch, + local_rank=args.local_rank, + optimization_level=args.amp_opt_level, + log_dir=work_dir, + create_tb_writer=True, + files_to_copy=[__file__], + add_time_to_log_dir=True, +) + +vocab_size = len(data_desc.vocab) +encoder = EncoderRNN(vocab_size, args.emb_dim, args.hid_dim, args.dropout, args.n_layers) + +decoder = nemo_nlp.nm.trainables.TRADEGenerator( + data_desc.vocab, + encoder.embedding, + args.hid_dim, + args.dropout, + data_desc.slots, + len(data_desc.gating_dict), + teacher_forcing=args.teacher_forcing, +) + +gate_loss_fn = nemo_nlp.nm.losses.CrossEntropyLoss3D(num_classes=len(data_desc.gating_dict)) +ptr_loss_fn = nemo_nlp.nm.losses.TRADEMaskedCrossEntropy() +total_loss_fn = nemo_nlp.nm.losses.LossAggregatorNM(num_inputs=2) + + +def create_pipeline(num_samples, batch_size, num_gpus, input_dropout, data_prefix, is_training): + logging.info(f"Loading {data_prefix} data...") + shuffle = args.shuffle_data if is_training else False + + data_layer = nemo_nlp.nm.data_layers.MultiWOZDataLayer( + args.data_dir, + data_desc.domains, + all_domains=data_desc.all_domains, + vocab=data_desc.vocab, + slots=data_desc.slots, + gating_dict=data_desc.gating_dict, + num_samples=num_samples, + shuffle=shuffle, + num_workers=0, + batch_size=batch_size, + mode=data_prefix, + is_training=is_training, + input_dropout=input_dropout, + ) + + src_ids, src_lens, tgt_ids, tgt_lens, gate_labels, turn_domain = data_layer() + + data_size = len(data_layer) + logging.info(f'The length of data layer is {data_size}') + + if data_size < batch_size: + logging.warning("Batch_size is larger than the dataset size") + logging.warning("Reducing batch_size to dataset size") + batch_size = data_size + + steps_per_epoch = math.ceil(data_size / (batch_size * num_gpus)) + logging.info(f"Steps_per_epoch = {steps_per_epoch}") + + outputs, hidden = encoder(inputs=src_ids, input_lens=src_lens) + + point_outputs, gate_outputs = decoder( + encoder_hidden=hidden, encoder_outputs=outputs, input_lens=src_lens, src_ids=src_ids, targets=tgt_ids + ) + + gate_loss = gate_loss_fn(logits=gate_outputs, labels=gate_labels) + ptr_loss = ptr_loss_fn(logits=point_outputs, targets=tgt_ids, loss_mask=tgt_lens) + total_loss = total_loss_fn(loss_1=gate_loss, loss_2=ptr_loss) + + if is_training: + tensors_to_evaluate = [total_loss, gate_loss, ptr_loss] + else: + tensors_to_evaluate = [total_loss, point_outputs, gate_outputs, gate_labels, turn_domain, tgt_ids, tgt_lens] + + return tensors_to_evaluate, total_loss, ptr_loss, gate_loss, steps_per_epoch, data_layer + + +( + tensors_train, + total_loss_train, + ptr_loss_train, + gate_loss_train, + steps_per_epoch_train, + data_layer_train, +) = create_pipeline( + args.num_train_samples, + batch_size=args.batch_size, + num_gpus=args.num_gpus, + input_dropout=args.input_dropout, + data_prefix=args.train_file_prefix, + is_training=True, +) + +tensors_eval, total_loss_eval, ptr_loss_eval, gate_loss_eval, steps_per_epoch_eval, data_layer_eval = create_pipeline( + args.num_eval_samples, + batch_size=args.eval_batch_size, + num_gpus=args.num_gpus, + input_dropout=0.0, + data_prefix=args.eval_file_prefix, + is_training=False, +) + +# Create callbacks for train and eval modes +train_callback = nemo_core.SimpleLossLoggerCallback( + tensors=[total_loss_train, gate_loss_train, ptr_loss_train], + print_func=lambda x: logging.info( + f'Loss:{str(np.round(x[0].item(), 3))}, ' + f'Gate Loss:{str(np.round(x[1].item(), 3))}, ' + f'Pointer Loss:{str(np.round(x[2].item(), 3))}' + ), + tb_writer=nf.tb_writer, + get_tb_values=lambda x: [["loss", x[0]], ["gate_loss", x[1]], ["pointer_loss", x[2]]], + step_freq=steps_per_epoch_train, +) + +eval_callback = nemo_core.EvaluatorCallback( + eval_tensors=tensors_eval, + user_iter_callback=lambda x, y: eval_iter_callback(x, y, data_desc), + user_epochs_done_callback=lambda x: eval_epochs_done_callback(x, data_desc), + tb_writer=nf.tb_writer, + eval_step=steps_per_epoch_train, +) + +ckpt_callback = nemo_core.CheckpointCallback( + folder=nf.checkpoint_dir, epoch_freq=args.save_epoch_freq, step_freq=args.save_step_freq +) + +if args.lr_policy is not None: + total_steps = args.num_epochs * steps_per_epoch_train + lr_policy_fn = get_lr_policy( + args.lr_policy, total_steps=total_steps, warmup_ratio=args.lr_warmup_proportion, min_lr=args.min_lr + ) +else: + lr_policy_fn = None + +grad_norm_clip = args.grad_norm_clip if args.grad_norm_clip > 0 else None +nf.train( + tensors_to_optimize=[total_loss_train], + callbacks=[eval_callback, train_callback, ckpt_callback], + lr_policy=lr_policy_fn, + optimizer=args.optimizer_kind, + optimization_params={ + "num_epochs": args.num_epochs, + "lr": args.lr, + "grad_norm_clip": grad_norm_clip, + "weight_decay": args.weight_decay, + }, +) diff --git a/examples/nlp/scripts/multiwoz/process_multiwoz.py b/examples/nlp/scripts/multiwoz/process_multiwoz.py new file mode 100644 index 000000000000..bcdeec21bc0b --- /dev/null +++ b/examples/nlp/scripts/multiwoz/process_multiwoz.py @@ -0,0 +1,400 @@ +7 #!/usr/bin/python + +# ============================================================================= +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# ============================================================================= +# Copyright 2019 Salesforce Research and Paweł Budzianowski. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom +# the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# ============================================================================= + +""" +Dataset: http://dialogue.mi.eng.cam.ac.uk/index.php/corpus/ + +Code based on: +https://github.com/jasonwu0731/trade-dst +https://github.com/budzianowski/multiwoz +""" + +import argparse +import json +import os +import re +import shutil + +from nemo.collections.nlp.data.datasets.datasets_utils import if_exist + +parser = argparse.ArgumentParser(description='Process MultiWOZ dataset') +parser.add_argument("--data_dir", default='../../data/statetracking/MULTIWOZ2.1', type=str) +parser.add_argument("--out_dir", default='../../data/statetracking/multiwoz', type=str) +args = parser.parse_args() + +if not os.path.exists(args.data_dir): + raise FileNotFoundError(f"{args.data_dir} doesn't exist.") + +DOMAINS = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] +PHONE_NUM_TMPL = '\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})' +POSTCODE_TMPL = ( + '([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?' + '[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})' +) + +REPLACEMENTS = {} +with open('replacements.txt', 'r') as f: + for line in f: + word1, word2 = line.strip().split('\t') + REPLACEMENTS[word1] = word2 +REPLACEMENTS['-'] = ' ' +REPLACEMENTS[';'] = ',' +REPLACEMENTS['/'] = ' and ' + +DONT_CARES = set(['dont care', 'dontcare', "don't care", "do not care"]) + + +def is_ascii(text): + return all(ord(c) < 128 for c in text) + + +def normalize(text): + text = text.lower().strip() + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + text = re.sub('[\"\<>@\(\)]', '', text) # remove brackets + text = re.sub(u"(\u2018|\u2019)", "'", text) # weird unicode bug + # add space around punctuations + text = re.sub('(\D)([?.,!])', r'\1 \2 ', text) + + clean_tokens = [] + + for token in text.split(): + token = token.strip() + if not token: + continue + if token in REPLACEMENTS: + clean_tokens.append(REPLACEMENTS[token]) + else: + clean_tokens.append(token) + + text = ' '.join(clean_tokens) # remove extra spaces + text = re.sub('(\d) (\d)', r'\1\2', text) # concatenate numbers + + return text + + +def get_goal(idx, log, goals, last_goal): + if idx == 1: # first system's response + active_goals = get_summary_belief_state(log[idx]["metadata"], True) + return active_goals[0] if len(active_goals) != 0 else goals[0] + else: + new_goals = get_new_goal(log[idx - 2]["metadata"], log[idx]["metadata"]) + return last_goal if not new_goals else new_goals[0] + + +def get_summary_belief_state(bstate, get_goal=False): + """Based on the mturk annotations we form multi-domain belief state + TODO: Figure out why this script has hotel-name but jason's script doesn't + (see val_dialogs.json) + """ + summary_bstate, summary_bvalue, active_domain = [], [], [] + for domain in DOMAINS: + domain_active = False + booking = [] + + for slot in sorted(bstate[domain]['book'].keys()): + if slot == 'booked': + booking.append(int(len(bstate[domain]['book']['booked']) != 0)) + else: + if bstate[domain]['book'][slot]: + booking.append(1) + curr_bvalue = [f"{domain}-book {slot.strip().lower()}", normalize(bstate[domain]['book'][slot])] + summary_bvalue.append(curr_bvalue) + else: + booking.append(0) + if domain == 'train': + if 'people' not in bstate[domain]['book']: + booking.append(0) + if 'ticket' not in bstate[domain]['book']: # TODO: possibly elif + booking.append(0) + summary_bstate += booking + + for slot in bstate[domain]['semi']: + slot_enc = [0, 0, 0] # not mentioned, dontcare, filled + if bstate[domain]['semi'][slot] == 'not mentioned': + slot_enc[0] = 1 + elif bstate[domain]['semi'][slot] in DONT_CARES: + slot_enc[1] = 1 + summary_bvalue.append([f"{domain}-{slot.strip().lower()}", "dontcare"]) + elif bstate[domain]['semi'][slot]: + curr_bvalue = [f"{domain}-{slot.strip().lower()}", normalize(bstate[domain]['semi'][slot])] + summary_bvalue.append(curr_bvalue) + if sum(slot_enc) > 0: + domain_active = True + summary_bstate += slot_enc + + if domain_active: # quasi domain-tracker + summary_bstate += [1] + active_domain.append(domain) + else: + summary_bstate += [0] + + assert len(summary_bstate) == 94 + if get_goal: + return active_domain + return summary_bstate, summary_bvalue + + +def get_new_goal(prev_turn, curr_turn): + """ If multiple domains are updated between turns, + return all of them + """ + new_goals = [] + # Sometimes, metadata is an empty dictionary, bug? + if not prev_turn or not curr_turn: + return new_goals + + for domain in prev_turn: + if curr_turn[domain] != prev_turn[domain]: + new_goals.append(domain) + return new_goals + + +def get_dialog_act(curr_dialog_acts, act_idx): + """Given system dialogue acts fix automatic delexicalization.""" + acts = [] + if not act_idx in curr_dialog_acts: + return acts + + turn = curr_dialog_acts[act_idx] + + if isinstance(turn, dict): # it's annotated: + for key in turn: + key_acts = turn[key] + key = key.strip().lower() + if key.endswith('request'): + for act in key_acts: + acts.append(act[0].lower()) + elif key.endswith('inform'): + for act in key_acts: + acts.append([act[0].lower(), normalize(act[1])]) + return acts + + +def fix_delex(curr_dialog_acts, act_idx, text): + """Given system dialogue acts fix automatic delexicalization.""" + if not act_idx in curr_dialog_acts: + return text + + turn = curr_dialog_acts[act_idx] + + if isinstance(turn, dict): # it's annotated: + for key in turn: + if 'Attraction' in key: + if 'restaurant_' in text: + text = text.replace("restaurant", "attraction") + if 'hotel_' in text: + text = text.replace("hotel", "attraction") + if 'Hotel' in key: + if 'attraction_' in text: + text = text.replace("attraction", "hotel") + if 'restaurant_' in text: + text = text.replace("restaurant", "hotel") + if 'Restaurant' in key: + if 'attraction_' in text: + text = text.replace("attraction", "restaurant") + if 'hotel_' in text: + text = text.replace("hotel", "restaurant") + + return text + + +def create_data(data_dir): + data = json.load(open(f'{data_dir}/data.json', 'r')) + dialog_acts = json.load(open(f'{data_dir}/dialogue_acts.json', 'r')) + + delex_data = {} + + for dialog_id in data: + dialog = data[dialog_id] + curr_dialog_acts = dialog_acts[dialog_id.strip('.json')] + goals = [key for key in dialog['goal'].keys() if key in DOMAINS and dialog['goal'][key]] + + last_goal, act_idx = '', 1 + for idx, turn in enumerate(dialog['log']): + dialog['log'][idx]['text'] = normalize(turn['text']) + + if idx % 2 == 1: # system's turn + cur_goal = get_goal(idx, dialog['log'], goals, last_goal) + last_goal = cur_goal + + dialog['log'][idx - 1]['domain'] = cur_goal # human's domain + dialog['log'][idx]['dialogue_acts'] = get_dialog_act(curr_dialog_acts, str(act_idx)) + act_idx += 1 + + dialog['log'][idx]['text'] = fix_delex(curr_dialog_acts, str(act_idx), dialog['log'][idx]['text']) + + delex_data[dialog_id] = dialog + return delex_data + + +def analyze_dialogue(dialog, max_length): + """Cleaning procedure for all kinds of errors in text and annotation.""" + if len(dialog['log']) % 2 == 1: + print('Odd number of turns. Wrong dialogue.') + return None + + clean_dialog = {} + clean_dialog['goal'] = dialog['goal'] # for now we just copy the goal + usr_turns, sys_turns = [], [] + + for idx in range(len(dialog['log'])): + text = dialog['log'][idx]['text'] + if len(text.split()) > max_length or not is_ascii(text): + return None # sequence corrupted. discard + + if idx % 2 == 0: # usr turn + usr_turns.append(dialog['log'][idx]) + else: # sys turn + belief_summary, belief_value_summary = get_summary_belief_state(dialog['log'][idx]['metadata']) + + dialog['log'][idx]['belief_summary'] = str(belief_summary) + dialog['log'][idx]['belief_value_summary'] = belief_value_summary + sys_turns.append(dialog['log'][idx]) + + clean_dialog['usr_log'] = usr_turns + clean_dialog['sys_log'] = sys_turns + + return clean_dialog + + +def get_dialog(dialog, max_length=50): + """Extract a dialogue from the file""" + dialog = analyze_dialogue(dialog, max_length) + if dialog is None: + return None + + dialogs = [] + for idx in range(len(dialog['usr_log'])): + dialogs.append( + { + 'usr': dialog['usr_log'][idx]['text'], + 'sys': dialog['sys_log'][idx]['text'], + 'sys_a': dialog['sys_log'][idx]['dialogue_acts'], + 'domain': dialog['usr_log'][idx]['domain'], + 'bvs': dialog['sys_log'][idx]['belief_value_summary'], + } + ) + + return dialogs + + +def partition_data(data, infold, outfold): + """Partition the data into train, valid, and test sets + based on the list of val and test specified in the dataset. + """ + if if_exist( + outfold, ['trainListFile.json', 'val_dialogs.json', 'test_dialogs.json', 'train_dialogs.json', 'ontology.json'] + ): + print(f'Data is already processed and stored at {outfold}') + return + os.makedirs(outfold, exist_ok=True) + shutil.copyfile(f'{infold}/ontology.json', f'{outfold}/ontology.json') + + with open(f'{infold}/testListFile.json', 'r') as fin: + test_files = [line.strip() for line in fin.readlines()] + + with open(f'{infold}/valListFile.json', 'r') as fin: + val_files = [line.strip() for line in fin.readlines()] + + train_list_files = open(f'{outfold}/trainListFile.json', 'w') + + train_dialogs, val_dialogs, test_dialogs = [], [], [] + count_train, count_val, count_test = 0, 0, 0 + + for dialog_id in data: + dialog = data[dialog_id] + domains = [key for key in dialog['goal'].keys() if key in DOMAINS and dialog['goal'][key]] + + dial = get_dialog(dialog) + if dial: + dialogue = {} + dialogue['dialog_idx'] = dialog_id + dialogue['domains'] = list(set(domains)) + last_bs = [] + dialogue['dialog'] = [] + + for idx, turn in enumerate(dial): + turn_dl = { + 'sys_transcript': dial[idx - 1]['sys'] if idx > 0 else "", + 'turn_idx': idx, + 'transcript': turn['usr'], + 'sys_acts': dial[idx - 1]['sys_a'] if idx > 0 else [], + 'domain': turn['domain'], + } + turn_dl['belief_state'] = [{"slots": [s], "act": "inform"} for s in turn['bvs']] + turn_dl['turn_label'] = [bs["slots"][0] for bs in turn_dl['belief_state'] if bs not in last_bs] + last_bs = turn_dl['belief_state'] + dialogue['dialog'].append(turn_dl) + + if dialog_id in test_files: + test_dialogs.append(dialogue) + count_test += 1 + elif dialog_id in val_files: + val_dialogs.append(dialogue) + count_val += 1 + else: + train_list_files.write(dialog_id + '\n') + train_dialogs.append(dialogue) + count_train += 1 + + print(f"Dialogs: {count_train} train, {count_val} val, {count_test} test.") + + # save all dialogues + with open(f'{outfold}/val_dialogs.json', 'w') as fout: + json.dump(val_dialogs, fout, indent=4) + + with open(f'{outfold}/test_dialogs.json', 'w') as fout: + json.dump(test_dialogs, fout, indent=4) + + with open(f'{outfold}/train_dialogs.json', 'w') as fout: + json.dump(train_dialogs, fout, indent=4) + + train_list_files.close() + + +def process_woz(): + delex_data = create_data(args.data_dir) + partition_data(delex_data, args.data_dir, args.out_dir) + + +process_woz() diff --git a/examples/nlp/scripts/multiwoz/replacements.txt b/examples/nlp/scripts/multiwoz/replacements.txt new file mode 100644 index 000000000000..34df41d01e93 --- /dev/null +++ b/examples/nlp/scripts/multiwoz/replacements.txt @@ -0,0 +1,83 @@ +it's it is +don't do not +doesn't does not +didn't did not +you'd you would +you're you are +you'll you will +i'm i am +they're they are +that's that is +what's what is +couldn't could not +i've i have +we've we have +can't cannot +i'd i would +i'd i would +aren't are not +isn't is not +wasn't was not +weren't were not +won't will not +there's there is +there're there are +. . . +restaurants restaurant -s +hotels hotel -s +laptops laptop -s +cheaper cheap -er +dinners dinner -s +lunches lunch -s +breakfasts breakfast -s +expensively expensive -ly +moderately moderate -ly +cheaply cheap -ly +prices price -s +places place -s +venues venue -s +ranges range -s +meals meal -s +locations location -s +areas area -s +policies policy -s +children child -s +kids kid -s +kidfriendly kid friendly +cards card -s +upmarket expensive +inpricey cheap +inches inch -s +uses use -s +dimensions dimension -s +driverange drive range +includes include -s +computers computer -s +machines machine -s +families family -s +ratings rating -s +constraints constraint -s +pricerange price range +batteryrating battery rating +requirements requirement -s +drives drive -s +specifications specification -s +weightrange weight range +harddrive hard drive +batterylife battery life +businesses business -s +hours hour -s +one 1 +two 2 +three 3 +four 4 +five 5 +six 6 +seven 7 +eight 8 +nine 9 +ten 10 +eleven 11 +twelve 12 +anywhere any where +good bye goodbye diff --git a/nemo/backends/pytorch/common/rnn.py b/nemo/backends/pytorch/common/rnn.py index 4b8e994223eb..4a112c18c6eb 100644 --- a/nemo/backends/pytorch/common/rnn.py +++ b/nemo/backends/pytorch/common/rnn.py @@ -1,4 +1,18 @@ -__all__ = ['DecoderRNN'] +# ============================================================================= +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= import random @@ -11,6 +25,8 @@ from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag from nemo.utils.misc import pad_to +__all__ = ['DecoderRNN', 'EncoderRNN'] + class DecoderRNN(TrainableNM): """Simple RNN-based decoder with attention. @@ -203,3 +219,95 @@ def forward_cl(self, targets, encoder_outputs=None): attention_weights = None return log_probs, attention_weights + + +class EncoderRNN(TrainableNM): + """ Simple RNN-based encoder using GRU cells """ + + @property + def input_ports(self): + """Returns definitions of module input ports. + + targets: + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + encoder_outputs: + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + 2: AxisType(ChannelTag) + """ + return { + 'inputs': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + 'input_lens': NeuralType({0: AxisType(BatchTag),}, optional=True), + } + + @property + def output_ports(self): + """Returns definitions of module output ports. + + log_probs: + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + 2: AxisType(ChannelTag) + + attention_weights: + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + 2: AxisType(TimeTag) + """ + return { + 'outputs': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag)}), + 'hidden': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag)}), + } + + def __init__( + self, input_dim, emb_dim, hid_dim, dropout, n_layers=1, pad_idx=1, embedding_to_load=None, sum_hidden=True + ): + super().__init__() + self.dropout = nn.Dropout(dropout) + self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx) + if embedding_to_load is not None: + self.embedding.weight.data.copy_(embedding_to_load) + else: + self.embedding.weight.data.normal_(0, 0.1) + self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, batch_first=True, dropout=dropout, bidirectional=True) + self.sum_hidden = sum_hidden + self.to(self._device) + + def forward(self, inputs, input_lens=None): + embedded = self.embedding(inputs) + embedded = self.dropout(embedded) + if input_lens is not None: + embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lens, batch_first=True) + + outputs, hidden = self.rnn(embedded) + # outputs of shape (seq_len, batch, num_directions * hidden_size) + # hidden of shape (num_layers * num_directions, batch, hidden_size) + if input_lens is not None: + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + else: + outputs = outputs.transpose(0, 1) + # outputs of shape: (batch, seq_len, num_directions * hidden_size) + + batch_size = hidden.size()[1] + + # separate final hidden states by layer and direction + hidden = hidden.view(self.rnn.num_layers, 2 if self.rnn.bidirectional else 1, batch_size, self.rnn.hidden_size) + hidden = hidden.transpose(2, 0).transpose(1, 2) + # hidden shape: batch x num_layer x num_directions x hidden_size + if self.sum_hidden and self.rnn.bidirectional: + hidden = hidden[:, :, 0, :] + hidden[:, :, 1, :] + outputs = outputs[:, :, : self.rnn.hidden_size] + outputs[:, :, self.rnn.hidden_size :] + else: + hidden = hidden.reshape(batch_size, self.rnn.num_layers, -1) + # hidden is now of shape (batch, num_layer, [num_directions] * hidden_size) + + return outputs, hidden diff --git a/nemo/collections/nlp/callbacks/state_tracking_trade_callback.py b/nemo/collections/nlp/callbacks/state_tracking_trade_callback.py new file mode 100644 index 000000000000..01cc43047d36 --- /dev/null +++ b/nemo/collections/nlp/callbacks/state_tracking_trade_callback.py @@ -0,0 +1,103 @@ +# ============================================================================= +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import numpy as np +import torch + +from nemo import logging + +__all__ = ['eval_iter_callback', 'eval_epochs_done_callback'] + + +def eval_iter_callback(tensors, global_vars, data_desc): + + if 'loss' not in global_vars: + global_vars['loss'] = [] + if 'comp_res' not in global_vars: + global_vars['comp_res'] = [] + if 'gating_labels' not in global_vars: + global_vars['gating_labels'] = [] + if 'gating_preds' not in global_vars: + global_vars['gating_preds'] = [] + + for kv, v in tensors.items(): + if kv.startswith('loss'): + loss_numpy = v[0].cpu().numpy() + global_vars['loss'].append(loss_numpy) + if kv.startswith('point_outputs'): + point_outputs = v[0] + if kv.startswith('gate_outputs'): + gate_outputs = v[0] + if kv.startswith('gating_labels'): + gating_labels = v[0].cpu().numpy() + global_vars['gating_labels'].extend(gating_labels) + if kv.startswith('tgt_ids'): + tgt_ids = v[0] + + point_outputs_max = torch.argmax(point_outputs, dim=-1) + mask_paddings = tgt_ids == data_desc.vocab.pad_id + comp_res = (point_outputs_max == tgt_ids) | mask_paddings + comp_res = torch.all(comp_res, axis=-1, keepdims=False) + + global_vars['comp_res'].extend(comp_res.cpu().numpy()) + global_vars['gating_preds'].extend(torch.argmax(gate_outputs, axis=-1).cpu().numpy()) + + +def eval_epochs_done_callback(global_vars, data_desc): + joint_acc, turn_acc = evaluate_metrics( + global_vars['comp_res'], + global_vars['gating_labels'], + global_vars['gating_preds'], + data_desc.gating_dict["ptr"], + ) + + gating_comp_flatten = (np.asarray(global_vars['gating_labels']) == np.asarray(global_vars['gating_preds'])).ravel() + gating_acc = np.sum(gating_comp_flatten) / len(gating_comp_flatten) + + evaluation_metrics = {"Joint_Goal_Acc": joint_acc, "Turn_Acc": turn_acc, "Gate_Acc": gating_acc} + logging.info(evaluation_metrics) + + return evaluation_metrics + + +def evaluate_metrics(comp_res, gating_labels, gating_preds, ptr_code): + # TODO: Calculate precision, recall, and F1 + total_slots = 0 + correct_slots = 0 + total_turns = 0 + correct_turns = 0 + for result_idx, result in enumerate(comp_res): + turn_wrong = False + total_turns += 1 + for slot_idx, slot_eq in enumerate(result): + total_slots += 1 + if gating_labels[result_idx][slot_idx] == ptr_code: + if slot_eq: + correct_slots += 1 + else: + turn_wrong = True + elif gating_labels[result_idx][slot_idx] == gating_preds[result_idx][slot_idx] or ( + slot_eq and gating_preds[result_idx][slot_idx] == ptr_code + ): + correct_slots += 1 + else: + turn_wrong = True + if not turn_wrong: + correct_turns += 1 + + turn_acc = correct_slots / float(total_slots) if total_slots != 0 else 0 + joint_acc = correct_turns / float(total_turns) if total_turns != 0 else 0 + return joint_acc, turn_acc diff --git a/nemo/collections/nlp/data/datasets/__init__.py b/nemo/collections/nlp/data/datasets/__init__.py index f0eafa0d62f1..c2decfb1c855 100644 --- a/nemo/collections/nlp/data/datasets/__init__.py +++ b/nemo/collections/nlp/data/datasets/__init__.py @@ -30,6 +30,7 @@ BertPunctuationCapitalizationInferDataset, ) from nemo.collections.nlp.data.datasets.qa_squad_dataset import SquadDataset +from nemo.collections.nlp.data.datasets.state_tracking_trade_dataset import * from nemo.collections.nlp.data.datasets.text_classification_dataset import BertTextClassificationDataset from nemo.collections.nlp.data.datasets.token_classification_dataset import ( BertTokenClassificationDataset, diff --git a/nemo/collections/nlp/data/datasets/state_tracking_trade_dataset.py b/nemo/collections/nlp/data/datasets/state_tracking_trade_dataset.py new file mode 100644 index 000000000000..0995d7c14249 --- /dev/null +++ b/nemo/collections/nlp/data/datasets/state_tracking_trade_dataset.py @@ -0,0 +1,428 @@ +# ============================================================================= +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# ============================================================================= +# Copyright 2019 Salesforce Research. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom +# the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# ============================================================================= + +import json +import os +import pickle +import random + +from torch.utils.data import Dataset + +from nemo import logging + +__all__ = ['MultiWOZDataset', 'MultiWOZDataDesc'] + + +class MultiWOZDataset(Dataset): + """ + By default, use only vocab from training data + Need to modify the code a little bit to create the vocab from all files + """ + + def __init__(self, data_dir, mode, domains, all_domains, vocab, gating_dict, slots, num_samples=-1, shuffle=False): + + logging.info(f'Processing {mode} data') + self.data_dir = data_dir + self.mode = mode + self.gating_dict = gating_dict + self.domains = domains + self.all_domains = all_domains + self.vocab = vocab + self.slots = slots + + self.features, self.max_len = self.get_features(num_samples, shuffle) + logging.info("Sample 0: " + str(self.features[0])) + + def get_features(self, num_samples, shuffle): + if num_samples == 0: + raise ValueError("num_samples has to be positive", num_samples) + + filename = f'{self.data_dir}/{self.mode}_dials.json' + logging.info(f'Reading from {filename}') + dialogs = json.load(open(filename, 'r')) + + domain_count = {} + data = [] + max_resp_len, max_value_len = 0, 0 + + for dialog_dict in dialogs: + if num_samples > 0 and len(data) >= num_samples: + break + + dialog_history = "" + for domain in dialog_dict['domains']: + if domain not in self.domains: + continue + if domain not in domain_count: + domain_count[domain] = 0 + domain_count[domain] += 1 + + for turn in dialog_dict['dialogue']: + if num_samples > 0 and len(data) >= num_samples: + break + + turn_uttr = turn['system_transcript'] + ' ; ' + turn['transcript'] + turn_uttr_strip = turn_uttr.strip() + dialog_history += turn["system_transcript"] + " ; " + turn["transcript"] + " ; " + source_text = dialog_history.strip() + + turn_beliefs = fix_general_label_error_multiwoz(turn['belief_state'], self.slots) + + turn_belief_list = [f'{k}-{v}' for k, v in turn_beliefs.items()] + + gating_label, responses = [], [] + for slot in self.slots: + if slot in turn_beliefs: + responses.append(str(turn_beliefs[slot])) + if turn_beliefs[slot] == "dontcare": + gating_label.append(self.gating_dict["dontcare"]) + elif turn_beliefs[slot] == "none": + gating_label.append(self.gating_dict["none"]) + else: + gating_label.append(self.gating_dict["ptr"]) + else: + responses.append("none") + gating_label.append(self.gating_dict["none"]) + + sample = { + 'ID': dialog_dict['dialogue_idx'], + 'domains': dialog_dict['domains'], + 'turn_domain': turn['domain'], + 'turn_id': turn['turn_idx'], + 'dialogue_history': source_text, + 'turn_belief': turn_belief_list, + 'gating_label': gating_label, + 'turn_uttr': turn_uttr_strip, + 'responses': responses, + } + + sample['context_ids'] = self.vocab.tokens2ids(sample['dialogue_history'].split()) + sample['responses_ids'] = [ + self.vocab.tokens2ids(y.split() + [self.vocab.eos]) for y in sample['responses'] + ] + sample['turn_domain'] = self.all_domains[sample['turn_domain']] + + data.append(sample) + + resp_len = len(sample['dialogue_history'].split()) + max_resp_len = max(max_resp_len, resp_len) + + logging.info(f'Domain count{domain_count}') + logging.info(f'Max response length{max_resp_len}') + logging.info(f'Processing {len(data)} samples') + + if shuffle: + logging.info(f'Shuffling samples.') + random.shuffle(data) + + return data, max_resp_len + + def __len__(self): + return len(self.features) + + def __getitem__(self, idx): + item = self.features[idx] + return { + 'dialog_id': item['ID'], + 'turn_id': item['turn_id'], + 'turn_belief': item['turn_belief'], + 'gating_label': item['gating_label'], + 'context_ids': item['context_ids'], + 'turn_domain': item['turn_domain'], + 'responses_ids': item['responses_ids'], + } + + +class Vocab: + """ + Vocab class for TRADE model + UNK_token = 0 + PAD_token = 1 + SOS_token = 3 + EOS_token = 2 + """ + + def __init__(self): + self.word2idx = {'UNK': 0, 'PAD': 1, 'EOS': 2, 'BOS': 3} + self.idx2word = ['UNK', 'PAD', 'EOS', 'BOS'] + self.unk_id = self.word2idx['UNK'] + self.pad_id = self.word2idx['PAD'] + self.eos_id = self.word2idx['EOS'] + self.bos_id = self.word2idx['BOS'] + self.unk, self.pad, self.eos, self.bos = 'UNK', 'PAD', 'EOS', 'BOS' + + def __len__(self): + return len(self.idx2word) + + def add_word(self, word): + if word not in self.word2idx: + self.word2idx[word] = len(self.idx2word) + self.idx2word.append(word) + + def add_words(self, sent, level): + """ + level == 'utterance': sent is a string + level == 'slot': sent is a list + level == 'belief': sent is a dictionary + """ + if level == 'utterance': + for word in sent.split(): + self.add_word(word) + elif level == 'slot': + for slot in sent: + domain, info = slot.split('-') + self.add_word(domain) + for subslot in info.split(' '): + self.add_word(subslot) + elif level == 'belief': + for slot, value in sent.items(): + domain, info = slot.split('-') + self.add_word(domain) + for subslot in info.split(' '): + self.add_word(subslot) + for val in value.split(' '): + self.add_word(val) + + def tokens2ids(self, tokens): + """Converts list of tokens to list of ids.""" + return [self.word2idx[w] if w in self.word2idx else self.unk_id for w in tokens] + + +class MultiWOZDataDesc: + """ + Processes MultiWOZ dataset, creates vocabulary file and list of slots. + """ + + def __init__(self, data_dir, domains={"attraction": 0, "restaurant": 1, "taxi": 2, "train": 3, "hotel": 4}): + logging.info(f'Processing MultiWOZ dataset') + + self.all_domains = { + 'attraction': 0, + 'restaurant': 1, + 'taxi': 2, + 'train': 3, + 'hotel': 4, + 'hospital': 5, + 'bus': 6, + 'police': 7, + } + self.gating_dict = {'ptr': 0, 'dontcare': 1, 'none': 2} + + self.data_dir = data_dir + self.domains = domains + self.vocab = Vocab() + + ontology_file = open(f'{self.data_dir}/ontology.json', 'r') + self.ontology = json.load(ontology_file) + + self.vocab_file = None + self.slots = None + + self.get_slots() + self.get_vocab() + + def get_vocab(self): + self.vocab_file = f'{self.data_dir}/vocab.pkl' + + if os.path.exists(self.vocab_file): + logging.info(f'Loading vocab from {self.data_dir}') + self.vocab = pickle.load(open(self.vocab_file, 'rb')) + else: + self.create_vocab() + + logging.info(f'Vocab size {len(self.vocab)}') + + def get_slots(self): + used_domains = [key for key in self.ontology if key.split('-')[0] in self.domains] + self.slots = [k.replace(' ', '').lower() if 'book' not in k else k.lower() for k in used_domains] + + def create_vocab(self): + self.vocab.add_words(self.slots, 'slot') + + filename = f'{self.data_dir}/train_dials.json' + logging.info(f'Building vocab from {filename}') + dialogs = json.load(open(filename, 'r')) + + max_value_len = 0 + + for dialog_dict in dialogs: + for turn in dialog_dict['dialogue']: + self.vocab.add_words(turn['system_transcript'], 'utterance') + self.vocab.add_words(turn['transcript'], 'utterance') + + turn_beliefs = fix_general_label_error_multiwoz(turn['belief_state'], self.slots) + lengths = [len(turn_beliefs[slot]) for slot in self.slots if slot in turn_beliefs] + lengths.append(max_value_len) + max_value_len = max(lengths) + + logging.info(f'Saving vocab to {self.data_dir}') + with open(self.vocab_file, 'wb') as handle: + pickle.dump(self.vocab, handle) + + +def fix_general_label_error_multiwoz(labels, slots): + label_dict = dict([label['slots'][0] for label in labels]) + GENERAL_TYPO = { + # type + "guesthouse": "guest house", + "guesthouses": "guest house", + "guest": "guest house", + "mutiple sports": "multiple sports", + "sports": "multiple sports", + "mutliple sports": "multiple sports", + "swimmingpool": "swimming pool", + "concerthall": "concert hall", + "concert": "concert hall", + "pool": "swimming pool", + "night club": "nightclub", + "mus": "museum", + "ol": "architecture", + "colleges": "college", + "coll": "college", + "architectural": "architecture", + "musuem": "museum", + "churches": "church", + # area + "center": "centre", + "center of town": "centre", + "near city center": "centre", + "in the north": "north", + "cen": "centre", + "east side": "east", + "east area": "east", + "west part of town": "west", + "ce": "centre", + "town center": "centre", + "centre of cambridge": "centre", + "city center": "centre", + "the south": "south", + "scentre": "centre", + "town centre": "centre", + "in town": "centre", + "north part of town": "north", + "centre of town": "centre", + "cb30aq": "none", + # price + "mode": "moderate", + "moderate -ly": "moderate", + "mo": "moderate", + # day + "next friday": "friday", + "monda": "monday", + # parking + "free parking": "free", + # internet + "free internet": "yes", + # star + "4 star": "4", + "4 stars": "4", + "0 star rarting": "none", + # others + "y": "yes", + "any": "dontcare", + "n": "no", + "does not care": "dontcare", + "not men": "none", + "not": "none", + "not mentioned": "none", + '': "none", + "not mendtioned": "none", + "3 .": "3", + "does not": "no", + "fun": "none", + "art": "none", + } + + hotel_ranges = [ + "nigh", + "moderate -ly priced", + "bed and breakfast", + "centre", + "venetian", + "intern", + "a cheap -er hotel", + ] + locations = ["gastropub", "la raza", "galleria", "gallery", "science", "m"] + detailed_hotels = ["hotel with free parking and free wifi", "4", "3 star hotel"] + areas = ["stansted airport", "cambridge", "silver street"] + attr_areas = ["norwich", "ely", "museum", "same area as hotel"] + + for slot in slots: + if slot in label_dict.keys(): + # general typos + if label_dict[slot] in GENERAL_TYPO.keys(): + label_dict[slot] = label_dict[slot].replace(label_dict[slot], GENERAL_TYPO[label_dict[slot]]) + + # miss match slot and value + if ( + (slot == "hotel-type" and label_dict[slot] in hotel_ranges) + or (slot == "hotel-internet" and label_dict[slot] == "4") + or (slot == "hotel-pricerange" and label_dict[slot] == "2") + or (slot == "attraction-type" and label_dict[slot] in locations) + or ("area" in slot and label_dict[slot] in ["moderate"]) + or ("day" in slot and label_dict[slot] == "t") + ): + label_dict[slot] = "none" + elif slot == "hotel-type" and label_dict[slot] in detailed_hotels: + label_dict[slot] = "hotel" + elif slot == "hotel-star" and label_dict[slot] == "3 star hotel": + label_dict[slot] = "3" + elif "area" in slot: + if label_dict[slot] == "no": + label_dict[slot] = "north" + elif label_dict[slot] == "we": + label_dict[slot] = "west" + elif label_dict[slot] == "cent": + label_dict[slot] = "centre" + elif "day" in slot: + if label_dict[slot] == "we": + label_dict[slot] = "wednesday" + elif label_dict[slot] == "no": + label_dict[slot] = "none" + elif "price" in slot and label_dict[slot] == "ch": + label_dict[slot] = "cheap" + elif "internet" in slot and label_dict[slot] == "free": + label_dict[slot] = "yes" + + # some out-of-define classification slot values + if (slot == "restaurant-area" and label_dict[slot] in areas) or ( + slot == "attraction-area" and label_dict[slot] in attr_areas + ): + label_dict[slot] = "none" + + return label_dict diff --git a/nemo/collections/nlp/nm/data_layers/__init__.py b/nemo/collections/nlp/nm/data_layers/__init__.py index 897974506fae..1b35d9adc25a 100644 --- a/nemo/collections/nlp/nm/data_layers/__init__.py +++ b/nemo/collections/nlp/nm/data_layers/__init__.py @@ -21,6 +21,7 @@ from nemo.collections.nlp.nm.data_layers.machine_translation_datalayer import * from nemo.collections.nlp.nm.data_layers.punctuation_capitalization_datalayer import * from nemo.collections.nlp.nm.data_layers.qa_squad_datalayer import * +from nemo.collections.nlp.nm.data_layers.state_tracking_trade_datalayer import * from nemo.collections.nlp.nm.data_layers.text_classification_datalayer import * from nemo.collections.nlp.nm.data_layers.text_datalayer import * from nemo.collections.nlp.nm.data_layers.token_classification_datalayer import * diff --git a/nemo/collections/nlp/nm/data_layers/state_tracking_trade_datalayer.py b/nemo/collections/nlp/nm/data_layers/state_tracking_trade_datalayer.py new file mode 100644 index 000000000000..decfc035c25b --- /dev/null +++ b/nemo/collections/nlp/nm/data_layers/state_tracking_trade_datalayer.py @@ -0,0 +1,210 @@ +# ============================================================================= +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# ============================================================================= +# Copyright 2019 Salesforce Research. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom +# the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# ============================================================================= + +import numpy as np +import torch +from torch.utils import data as pt_data + +import nemo +from nemo.collections.nlp.data.datasets import MultiWOZDataset +from nemo.collections.nlp.nm.data_layers.text_datalayer import TextDataLayer +from nemo.core.neural_types import * + +__all__ = ['MultiWOZDataLayer'] + + +class MultiWOZDataLayer(TextDataLayer): + @property + def output_ports(self): + """Returns definitions of module output ports. + + src_ids: ids of input sequences + 0: AxisType(BatchTag) + + 1: AxisType(ChannelTag) + + src_lens: lengths of input sequences + 0: AxisType(BatchTag) + + tgt_ids: labels for the generator output + 0: AxisType(BatchTag) + + 1: AxisType(ChannelTag) + + 2: AxisType(TimeTag) + + tgt_lens: lengths of the generator targets + 0: AxisType(BatchTag) + + 1: AxisType(ChannelTag) + + gating_labels: labels for the gating head + 0: AxisType(BatchTag) + + 1: AxisType(ChannelTag) + + turn_domain: list of the domains + NeuralType(None) + + """ + return { + "src_ids": NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + "src_lens": NeuralType({0: AxisType(BatchTag)}), + "tgt_ids": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag), 2: AxisType(TimeTag)}), + "tgt_lens": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), + "gating_labels": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), + "turn_domain": NeuralType(None), + } + + def __init__( + self, + data_dir, + domains, + all_domains, + vocab, + slots, + gating_dict, + num_samples=-1, + batch_size=16, + mode='train', + dataset_type=MultiWOZDataset, + shuffle=False, + num_workers=0, + input_dropout=0, + is_training=False, + ): + + dataset_params = { + 'data_dir': data_dir, + 'domains': domains, + 'num_samples': num_samples, + 'mode': mode, + 'shuffle': shuffle, + 'all_domains': all_domains, + 'vocab': vocab, + 'slots': slots, + 'gating_dict': gating_dict, + } + super().__init__(dataset_type, dataset_params, batch_size=batch_size) + + if self._placement == nemo.core.DeviceType.AllGpu: + sampler = pt_data.distributed.DistributedSampler(self._dataset) + else: + sampler = None + + self._dataloader = pt_data.DataLoader( + dataset=self._dataset, + batch_size=batch_size, + shuffle=sampler is None, + num_workers=num_workers, + collate_fn=self._collate_fn, + sampler=sampler, + ) + self.pad_id = self._dataset.vocab.pad_id + self.gating_dict = self._dataset.gating_dict + self.input_dropout = input_dropout + self.is_training = is_training + self.vocab = self._dataset.vocab + self.slots = self._dataset.slots + + def _collate_fn(self, data): + """ data is a list of batch_size sample + each sample is a dictionary of features + """ + + def pad_batch_context(sequences): + ''' + merge from batch * sent_len to batch * max_len + ''' + lengths = [len(seq) for seq in sequences] + max_len = 1 if max(lengths) == 0 else max(lengths) + for i, seq in enumerate(sequences): + sequences[i] = seq + [1] * (max_len - len(seq)) + return torch.tensor(sequences), torch.tensor(lengths) + + def pad_batch_response(sequences, pad_id): + ''' + merge from batch * nb_slot * slot_len to batch * nb_slot * max_slot_len + ''' + lengths = [] + for bsz_seq in sequences: + length = [len(v) for v in bsz_seq] + lengths.append(length) + max_len = max([max(l) for l in lengths]) + padded_seqs = [] + for bsz_seq in sequences: + pad_seq = [] + for v in bsz_seq: + v = v + [pad_id] * (max_len - len(v)) + pad_seq.append(v) + padded_seqs.append(pad_seq) + padded_seqs = torch.tensor(padded_seqs) + lengths = torch.tensor(lengths) + return padded_seqs, lengths + + data.sort(key=lambda x: len(x['context_ids']), reverse=True) + item_info = {} + for key in data[0]: + item_info[key] = [item[key] for item in data] + + src_ids, src_lens = pad_batch_context(item_info['context_ids']) + tgt_ids, tgt_lens = pad_batch_response(item_info['responses_ids'], self._dataset.vocab.pad_id) + gating_label = torch.tensor(item_info['gating_label']) + turn_domain = torch.tensor(item_info['turn_domain']) + + if self.input_dropout > 0 and self.is_training: + bi_mask = np.random.binomial([np.ones(src_ids.size())], 1.0 - self.input_dropout)[0] + rand_mask = torch.Tensor(bi_mask).long().to(src_ids.device) + src_ids = src_ids * rand_mask + + return ( + src_ids.to(self._device), + src_lens.to(self._device), + tgt_ids.to(self._device), + tgt_lens.to(self._device), + gating_label.to(self._device), + turn_domain.to(self._device), + ) + + @property + def dataset(self): + return None + + @property + def data_iterator(self): + return self._dataloader diff --git a/nemo/collections/nlp/nm/losses/__init__.py b/nemo/collections/nlp/nm/losses/__init__.py index 76e04131232f..20333eb42715 100644 --- a/nemo/collections/nlp/nm/losses/__init__.py +++ b/nemo/collections/nlp/nm/losses/__init__.py @@ -20,4 +20,5 @@ from nemo.collections.nlp.nm.losses.padded_smoothed_cross_entropy_loss import * from nemo.collections.nlp.nm.losses.qa_squad_loss import * from nemo.collections.nlp.nm.losses.smoothed_cross_entropy_loss import * +from nemo.collections.nlp.nm.losses.state_tracking_trade_loss import * from nemo.collections.nlp.nm.losses.token_classification_loss import * diff --git a/nemo/collections/nlp/nm/losses/joint_intent_slot_loss.py b/nemo/collections/nlp/nm/losses/joint_intent_slot_loss.py index 1eb8b7e5610f..3ba4d631f1da 100644 --- a/nemo/collections/nlp/nm/losses/joint_intent_slot_loss.py +++ b/nemo/collections/nlp/nm/losses/joint_intent_slot_loss.py @@ -73,6 +73,7 @@ def input_ports(self): 0: AxisType(BatchTag) 1: AxisType(TimeTag) + """ return { "intent_logits": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), diff --git a/nemo/collections/nlp/nm/losses/state_tracking_trade_loss.py b/nemo/collections/nlp/nm/losses/state_tracking_trade_loss.py new file mode 100644 index 000000000000..c591fc453afb --- /dev/null +++ b/nemo/collections/nlp/nm/losses/state_tracking_trade_loss.py @@ -0,0 +1,162 @@ +# ============================================================================= +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# ============================================================================= +# Copyright 2019 Salesforce Research. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom +# the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# ============================================================================= + +import torch + +from nemo.backends.pytorch.nm import LossNM +from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag + +__all__ = ['TRADEMaskedCrossEntropy', 'CrossEntropyLoss3D'] + + +class TRADEMaskedCrossEntropy(LossNM): + """ + Neural module which implements a cross entropy for trade model with masking feature. + + Args: + logits (float): output of the classifier + targets (long): ground truth targets + loss_mask (long): specifies the ones to get ignored in loss calculation + + + """ + + @property + def input_ports(self): + """Returns definitions of module input ports. + + logits: 4d tensor of logits + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + 2: AxisType(ChannelTag) + + 3: AxisType(ChannelTag) + + targets: 3d tensor of labels + 0: AxisType(BatchTag) + + 1: AxisType(ChannelTag) + + 2: AxisType(TimeTag) + + loss_mask: specifies the words to be considered in the loss calculation + 0: AxisType(BatchTag) + + 1: AxisType(ChannelTag) + + """ + return { + "logits": NeuralType( + {0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag), 3: AxisType(ChannelTag)} + ), + "targets": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag), 2: AxisType(TimeTag)}), + "loss_mask": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), + } + + @property + def output_ports(self): + """Returns definitions of module output ports. + + loss: loss value + NeuralType(None) + + """ + return {"loss": NeuralType(None)} + + def __init__(self): + LossNM.__init__(self) + + def _loss_function(self, logits, targets, loss_mask): + logits_flat = logits.view(-1, logits.size(-1)) + eps = 1e-10 + log_probs_flat = torch.log(torch.clamp(logits_flat, min=eps)) + target_flat = targets.view(-1, 1) + losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) + losses = losses_flat.view(*targets.size()) + loss = self.masking(losses, loss_mask) + return loss + + @staticmethod + def masking(losses, mask): + max_len = losses.size(2) + + mask_ = torch.arange(max_len, device=mask.device)[None, None, :] < mask[:, :, None] + mask_ = mask_.float() + losses = losses * mask_ + loss = losses.sum() / mask_.sum() + return loss + + +class CrossEntropyLoss3D(LossNM): + """ + Neural module which implements a cross entropy loss for 3d logits. + Args: + num_classes (int): number of classes in a classifier, e.g. size + of the vocabulary in language modeling objective + logits (float): output of the classifier + labels (long): ground truth labels + """ + + @property + def input_ports(self): + """Returns definitions of module input ports. + """ + return { + "logits": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag), 2: AxisType(ChannelTag)}), + "labels": NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag)}), + } + + @property + def output_ports(self): + """Returns definitions of module output ports. + """ + return {"loss": NeuralType(None)} + + def __init__(self, num_classes, **kwargs): + LossNM.__init__(self, **kwargs) + self._criterion = torch.nn.CrossEntropyLoss() + self.num_classes = num_classes + + def _loss_function(self, logits, labels): + logits_flatten = logits.view(-1, self.num_classes) + labels_flatten = labels.view(-1) + + loss = self._criterion(logits_flatten, labels_flatten) + return loss diff --git a/nemo/collections/nlp/nm/trainables/__init__.py b/nemo/collections/nlp/nm/trainables/__init__.py index 7114bdda312f..d466413a905e 100644 --- a/nemo/collections/nlp/nm/trainables/__init__.py +++ b/nemo/collections/nlp/nm/trainables/__init__.py @@ -15,4 +15,5 @@ # ============================================================================= from nemo.collections.nlp.nm.trainables.common import * +from nemo.collections.nlp.nm.trainables.dialogue_state_tracking import * from nemo.collections.nlp.nm.trainables.joint_intent_slot import * diff --git a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/__init__.py b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/__init__.py new file mode 100644 index 000000000000..7d8279b73c0d --- /dev/null +++ b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/__init__.py @@ -0,0 +1,17 @@ +# ============================================================================= +# Copyright 2019 AI Applications Design Team at NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.collections.nlp.nm.trainables.dialogue_state_tracking.state_tracking_trade_nm import * diff --git a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/state_tracking_trade_nm.py b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/state_tracking_trade_nm.py new file mode 100644 index 000000000000..5a2aa466afe1 --- /dev/null +++ b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/state_tracking_trade_nm.py @@ -0,0 +1,235 @@ +# ============================================================================= +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# ============================================================================= +# Copyright 2019 Salesforce Research. +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom +# the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# ============================================================================= + + +import random + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn as nn + +from nemo.backends.pytorch.nm import TrainableNM +from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag + +__all__ = ['TRADEGenerator'] + + +class TRADEGenerator(TrainableNM): + @property + def input_ports(self): + """Returns definitions of module input ports. + + encoder_hidden: hidden states of the encoder + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + 2: AxisType(ChannelTag) + + encoder_outputs: outputs of the encoder + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + 2: AxisType(ChannelTag) + + input_lens: lengths of the input sequences to encoder + 0: AxisType(BatchTag) + + src_ids: input sequences to encoder + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + targets: targets for the output of the generator + 0: AxisType(BatchTag) + + 1: AxisType(BatchTag) + + 2: AxisType(TimeTag) + + """ + return { + 'encoder_hidden': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag)}), + 'encoder_outputs': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag)}), + 'input_lens': NeuralType({0: AxisType(BatchTag)}), + 'src_ids': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + 'targets': NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag), 2: AxisType(TimeTag)}), + } + + @property + def output_ports(self): + """Returns definitions of module output ports. + + point_outputs: outputs of the generator + 0: AxisType(BatchTag) + + 1: AxisType(TimeTag) + + 2: AxisType(ChannelTag) + + 3: AxisType(ChannelTag) + + gate_outputs: outputs of gating heads + 0: AxisType(BatchTag) + + 1: AxisType(ChannelTag) + + 2: AxisType(ChannelTag) + + """ + return { + 'point_outputs': NeuralType( + {0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag), 3: AxisType(ChannelTag)} + ), + 'gate_outputs': NeuralType({0: AxisType(BatchTag), 1: AxisType(ChannelTag), 2: AxisType(ChannelTag)}), + } + + def __init__(self, vocab, embeddings, hid_size, dropout, slots, nb_gate, teacher_forcing=0.5): + super().__init__() + self.vocab_size = len(vocab) + self.vocab = vocab + self.embedding = embeddings + self.dropout = nn.Dropout(dropout) + self.rnn = nn.GRU(hid_size, hid_size, dropout=dropout, batch_first=True) + self.nb_gate = nb_gate + self.hidden_size = hid_size + self.w_ratio = nn.Linear(3 * hid_size, 1) + self.w_gate = nn.Linear(hid_size, nb_gate) + self.softmax = nn.Softmax(dim=1) + self.sigmoid = nn.Sigmoid() + self.slots = slots + self.teacher_forcing = teacher_forcing + + self._slots_split_to_index() + self.slot_emb = nn.Embedding(len(self.slot_w2i), hid_size) + self.slot_emb.weight.data.normal_(0, 0.1) + self.to(self._device) + + def _slots_split_to_index(self): + split_slots = [slot.split('-') for slot in self.slots] + domains = [split_slot[0] for split_slot in split_slots] + slots = [split_slot[1] for split_slot in split_slots] + split_slots = list({s: 0 for s in sum(split_slots, [])}) + self.slot_w2i = {split_slots[i]: i for i in range(len(split_slots))} + self.domain_idx = torch.tensor([self.slot_w2i[domain] for domain in domains], device=self._device) + self.subslot_idx = torch.tensor([self.slot_w2i[slot] for slot in slots], device=self._device) + + def forward(self, encoder_hidden, encoder_outputs, input_lens, src_ids, targets=None): + + if (not self.training) or (random.random() > self.teacher_forcing): + use_teacher_forcing = False + else: + use_teacher_forcing = True + + # TODO: set max_res_len to 10 in evaluation mode or + # when targets are not provided + max_res_len = targets.shape[2] + batch_size = encoder_hidden.shape[0] + + targets = targets.transpose(0, 1) + + all_point_outputs = torch.zeros(len(self.slots), batch_size, max_res_len, self.vocab_size, device=self._device) + all_gate_outputs = torch.zeros(len(self.slots), batch_size, self.nb_gate, device=self._device) + + domain_emb = self.slot_emb(self.domain_idx).to(self._device) + subslot_emb = self.slot_emb(self.subslot_idx).to(self._device) + slot_emb = domain_emb + subslot_emb + slot_emb = slot_emb.unsqueeze(1) + slot_emb = slot_emb.repeat(1, batch_size, 1) + decoder_input = self.dropout(slot_emb).view(-1, self.hidden_size) + hidden = encoder_hidden.transpose(0, 1).repeat(len(self.slots), 1, 1) + + hidden = hidden.view(-1, self.hidden_size).unsqueeze(0) + + enc_len = input_lens.repeat(len(self.slots)) + + maxlen = encoder_outputs.size(1) + padding_mask_bool = ~(torch.arange(maxlen, device=self._device)[None, :] <= enc_len[:, None]) + padding_mask = torch.zeros_like(padding_mask_bool, dtype=encoder_outputs.dtype, device=self._device) + padding_mask.masked_fill_(mask=padding_mask_bool, value=-np.inf) + + for wi in range(max_res_len): + dec_state, hidden = self.rnn(decoder_input.unsqueeze(1), hidden) + + enc_out = encoder_outputs.repeat(len(self.slots), 1, 1) + context_vec, logits, prob = TRADEGenerator.attend(enc_out, hidden.squeeze(0), padding_mask) + + if wi == 0: + all_gate_outputs = torch.reshape(self.w_gate(context_vec), all_gate_outputs.size()) + + p_vocab = TRADEGenerator.attend_vocab(self.embedding.weight, hidden.squeeze(0)) + p_gen_vec = torch.cat([dec_state.squeeze(1), context_vec, decoder_input], -1) + vocab_pointer_switches = self.sigmoid(self.w_ratio(p_gen_vec)) + p_context_ptr = torch.zeros(p_vocab.size(), device=self._device) + + p_context_ptr.scatter_add_(1, src_ids.repeat(len(self.slots), 1), prob) + + final_p_vocab = (1 - vocab_pointer_switches).expand_as( + p_context_ptr + ) * p_context_ptr + vocab_pointer_switches.expand_as(p_context_ptr) * p_vocab + pred_word = torch.argmax(final_p_vocab, dim=1) + + all_point_outputs[:, :, wi, :] = torch.reshape( + final_p_vocab, (len(self.slots), batch_size, self.vocab_size) + ) + + if use_teacher_forcing: + decoder_input = self.embedding(torch.flatten(targets[:, :, wi])) + else: + decoder_input = self.embedding(pred_word) + + decoder_input = decoder_input.to(self._device) + all_point_outputs = all_point_outputs.transpose(0, 1).contiguous() + all_gate_outputs = all_gate_outputs.transpose(0, 1).contiguous() + return all_point_outputs, all_gate_outputs + + @staticmethod + def attend(seq, cond, padding_mask): + scores_ = cond.unsqueeze(1).expand_as(seq).mul(seq).sum(2) + scores_ = scores_ + padding_mask + scores = F.softmax(scores_, dim=1) + context = scores.unsqueeze(2).expand_as(seq).mul(seq).sum(1) + return context, scores_, scores + + @staticmethod + def attend_vocab(seq, cond): + scores_ = cond.matmul(seq.transpose(1, 0)) + scores = F.softmax(scores_, dim=1) + return scores