PyTorch implementation of the models described in the paper Minimizing the Bag-of-Ngrams Difference for Non-Autoregressive Neural Machine Translation .
- Python 3.6
- PyTorch >= 0.4
- Numpy
- NLTK
- torchtext 0.2.1
- torchvision
- revtok
- multiset
- ipdb
- This code is based on dl4mt-nonauto and RSI-NAT. We mainly modified the
model.py
(line 1107-1292).
The original translation corpora can be downloaded from (IWLST'16 En-De, WMT'16 En-Ro, WMT'14 En-De). We recommend you to download the preprocessed corpora released in dl4mt-nonauto.
Set correct path to data in data_path()
function located in data.py
before you run the code.
Combine the BoN objective and the cross-entropy loss to train NAT from scratch. This process usually takes about 5 days.
$ sh joint_wmt.sh
Take a checkpoint and train the length prediction model. This process usually takes about 1 day.
$ sh tune_wmt.sh
Decode the test set. This process usually takes about 20 seconds.
$ sh decode_wmt.sh
First, train a NAT model using the cross-entropy loss. This process usually takes about 5 days.
$ sh mle_wmt.sh
Then, take a pre-trained checkpoint and finetune the NAT model using the BoN objective. This process usually takes about 3 hours.
$ sh bontune_wmt.sh
Take a finetuned checkpoint and train the length prediction model. This process usually takes about 1 day.
$ sh tune_wmt.sh
Decode the test set. This process usually takes about 20 seconds.
$ sh decode_wmt.sh
We also implement Reinforce-NAT (line 1294-1390) described in the paper Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation. See RSI-NAT for the usage.
If you find the resources in this repository useful, please consider citing:
@article{Shao:19,
author = {Chenze Shao, Jinchao Zhang, Yang Feng, Fandong Meng and Jie Zhou},
title = {Minimizing the Bag-of-Ngrams Difference for Non-Autoregressive Neural Machine Translation},
year = {2019},
journal = {arXiv preprint arXiv:1911.09320},
}