-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Unable to reproduce WMT En2De results #317
Comments
|
Hi Martin, Thank you for quick reply.
Is there anything I need to enable to utilize all 8 GPUs properly?
|
E.g. diskutieren split into two lines, which shifts all lines by one, unless that is a copy-paste error when copying to gist
|
Sorry, this was a copy&paste issue, just to be sure I checked line counts in all files, and it is the same:
I'll try running with --worker_gpu=8 and report results here. |
Unfortunately, it fails to run if I add --worker_gpu=8. It fails with the following error: Seems to be the same issue as here: #266 |
@edunov Have you checked whether the tests have already been tokenized? Standford's tests sets for example are already tokenized. Maybe you've tokenized the reference text twice? Moreover, in the excerpt you posted the sentences seem to miss full stops. Maybe this impacts your BPE score? |
@mehmedes I'm using the test set that tensor2tensor provides, I believe it comes from here: It has both tokenized (newstest2014.tok.de) and untokenized version (newstest2014.de), I'm using untokenized one. Btw, I re-run all the experiments and here the numbers I've got (everything before model averaging):
Seems like word piece model works right. But for BPE model I'm still very far from the expected numbers. From the output of BPE model it seems that periods are completely missing, so maybe the sentences are truncated or something? Is there any option or parameter in the generator that limits the sentence length? |
For the BPE model, you need to use the tokenized test set and do not re-run MOSES tokenizer on the decoded output. Your results look like some bug of this kind, it's hard to believe the model would be so different. There is a decode parameter limiting the length, but I doubt it's that problem. |
@edunov what did you use to reproduce 27.76 on newstest2014 (word piece) ? @lukaszkaiser what is the rationale behind the fact that with one GPU only it seems to converge faster to a lower point (in BLEU terms). Even if we train longer it never reaches the same level. |
@vince62s I used 8 GPUs, base model and everything else is set to default. I trained the model until the trainer stopped (250k steps). I believe this is different from what they did in the paper (100k steps) and it takes longer. |
Hi All, And run on a single gpu the setup: --problem=translate_ende_wmt32k I monitor the eval-bleu score with --eval_run_autoregressive set. After 110861 steps I get a approx_bleu_score = 0.108027 |
@edunov Have you solved the bpe lower results problem ? I am puzzling now. The bpe results on newstest2014 just get 12.16 bleu . Really puzzling . I am using the configuration just the same as you mentioned . em .... |
@lukaszkaiser Hello , When using translate_ende_bpe_32k translation , I met the same question as @edunov , The bleu score is too low . When I print the test log , it always truncate the source sentence . The examples are follows : |
Can someone try t2t 1.3.0 and confirm if the error is fixed (by 7909c69)? |
@martinpopel I will try it . |
@martinpopel I think the problem has solved . I trained 180K steps with 2 GPUs , transformer_base_single_gpu and all others are set to default . It got 22.72 in newstest2014 . |
Hi!@edunov I have a question,when you compute BLEU on newtest2014 using word-piece model,the reference file is newstest2014.de ? Have you try another way to compute BLEU ? Firstly, token your result file. Secondly , compute BLEU using newstest2014.tok.de . How will the BLUE score change? I run the experiment(word-piece model) has the same setting except trainsteps=10w,BLEU on newtest2014(reference file is newstest2014.de) is 22.57.Will two different ways produce significant different BLEU score? thanks! |
@jiangbojian I actually tokenize and then apply compound splitting to the reference before computing BLEU score:
/tmp/t2t_datagen/newstest${YEAR}.de.atat - is what I use to compute BLEU It is important to have exactly the same processing steps for both reference and system output. So, whatever you do to reference, you'll have to do to the system output as well. |
@edunov Thank you for your detailed answer. |
@edunov I only token the result file and the reference file. |
Yes, what it does is it replaces hyphen with 3 words, e.g. "Science-Fiction" becomes "Science ##AT##-##AT## Fiction" so if your translation is correct you'll have 3 consecutive tokens match instead of 1. It seems like it gives extra 0.7 - 1 BLEP point To get newstest${YEAR}.de.atat I just applied the same compound splitting to the tokenized version of newstest${YEAR}.de |
@edunov Yes,BLEU score gives extra 0.7-1 BLEU point on word-piece model. By the way,for bpe model(setting same with the paper) ,through applying compound splitting on result file and reference file,BLEU score on newtest2014 reach 27.11(the BLEU score on paper is 27.30) So,is the BLEU score(ende,bpe32k,base model) reported by the paper computed in the way? @lukaszkaiser |
@jiangbojian Hi, I'm now try to reproduce transformer, too. However, I can only achieve ~22 BLEU score on wmt'14-ende dataset. My configs:
My model covergented after ~8k steps. Here are my translating steps:
could you please provide a full experiment detail to help me achieve ~27 BLEU score as you? Thank you! |
@martinpopel Hi, I noticed that you claim with the transformer_base_single_gpu configuration( but 500~k steps and batch_size 3072) with t2t version1.1. However, I encountered NAN loss when trying it in the current version(pip install). My command is
Can you please show your command when you reach 25.61 Bleu score on the same task? Thanks very much! |
@Shrshore: What I wrote was true for T2T version 1.1 and |
@martinpopel Thanks for your reply! I'm also wondering whether the evaluation pipeline should contain the 'compound splitting' operation sincerely. Have you tried that before? Since I have obtained an model output, whose Bleu metric is only ~22 without compound splitting. But after applying compound splitting to the reference file and output file, the metric can goes to ~27 with t2t-bleu. So I'm wondering whether the result reported in the paper is calculated in this way, too. @lukaszkaiser |
I think (based on some post from @lukaszkaiser) the BLEU=28.4 in the Attention Is All You Need was computed with https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/get_ende_bleu.sh, which splits hyphen-compounds for "historical reasons". This splitting improves the BLEU score nominally, but (of course) not the translation quality, so I do not do it (I tried just once when I wanted to compare my scores to the paper and measure the effect of different BLEU evaluation). |
@jiangbojian @edunov @martinpopel @Shrshore Hi, I used
|
Hi, @edunov @martinpopel And evaluate the model using: |
Thank you for your reply.
I tried "t2t-avg-all", but I still got ~15.23 BLEU. I checked the output
.de file. There're too many "#" in the translated German sentences. When I
deleted "#" in .de file, evaluated the .de again. I got ~21.58 BLEU which
is similar to the BLEU score without averaging checkpoints. Have you ever
met this situation?
2018-04-11 16:28 GMT+08:00 Vu Cong Duy Hoang <[email protected]>:
… Hi,
You can use either way as follows:
mkdir $TRAIN_DIR/averaged-model
rm -rf $TRAIN_DIR/averaged-model/*
CUDA_VISIBLE_DEVICES=2 python /nfs/team/nlp/users/vhoang/
ve3/lib/python3.6/site-packages/tensor2tensor/utils/avg_checkpoints.py
--worker_gpu=1 --prefix=$TRAIN_DIR/ --num_last_checkpoints=10
--output_path=$TRAIN_DIR/averaged-model/averaged.ckpt
or
CUDA_VISIBLE_DEVICES=2 t2t-avg-all --model_dir=$TRAIN_DIR/
--output_dir=$TRAIN_DIR/averaged-model/ --n=10 --worker_gpu=1
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#317 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AEz1rL9r4aLlEf0ma1vhSBUvwuTyY3Nfks5tnb6ngaJpZM4Pf8XL>
.
|
@DC-Swind: No, I haven't met this situation. I always average checkpoints stored in 1-hour intervals ( |
@martinpopel : I will try How to specify checkpoint by name if I want to decode using "model.ckpt-24500"? By the way, is there any convenient method of drawing training loss or test BLEU curve? |
@DC-Swind:
Yes, tensorboard (plus t2t-bleu and t2t-translate-all if you prefer to see the real BLEU instead of approx_bleu or if you use |
If anyone is still trying to reproduce the Attention Is All You Need paper en-de wmt14 BLEU scores: note that you must manually tweak the newstest2014.de reference file (in addition to all the hacks in
|
Hi @martinpopel, I wonder what is the difference between model.ckpt-3817425 and averaged.ckpt-0? The former one produces much better results. Thanks. |
@szhengac: These checkpoints were uploaded by the T2T authors, not by me, so I am not sure. I guess the former model was trained for 3817425 steps (that is 3.8M steps, while in the Attention Is All You Need only 0.1M steps were used for the base models, but it also depends on the number of GPUs and batch size) and with a newer T2T version. |
As some users still ask how to replicate the BLEU scores (after downloading the trained checkpoint from
This gives BLEU 26.50 and 29.02 for wmt13 and wmt14, respectively. And (after
this gives BLEU 26.59 and 28.33. |
Hi @martinpopel. Thanks for your reply. I wonder what beam_size and alpha were used to obtain these translation text? |
The default |
Hi @martinpopel I have confusion about the score 29.3 in Scaling Neural Machine Translation , do you know the way they computed the bleu score ? |
I have just read the paper, but it seems quite clear:
In most tables they report only the multi-bleu version, just in Table 3 they report both and we can see that multi-bleu 29.3 corresponds to sacreBLEU 28.6. |
@martinpopel Yeah , the multibleu score they got should be handled by the get_ende_bleu.sh ? |
They report their multibleu score 29.3 in Table 2 as if it is comparable to Vaswani et al. (2017)'s 28.4, but I doubt they followed all the tweaks (unless there was a personal communication between the authors of the two papers). Luckily, we can compare the sacreBLEU scores (case.mixed-tok.13a verison): 28.6 (Ott et al) vs. 27.52 (T2T as reported above, about half a point better than Vaswani et al). It should be noted that while it is good to use comparable and replicable BLEU (i.e. sacreBLEU), it is not everything as most MT researchers know. It's not only BLEU, but any automatic metric based on similarity to human reference I am aware of (especially to a single reference, as is the case in WMT) is potentially flawed. There are systems today (for some language pairs and the "WMT domain") surpassing the quality of human references (or at least they are near). This of course does not mean that the systems are better in all aspects than human references, just in some aspects. But it means that single-reference BLEU (or any other automatic metric) is not reliable for such high quality systems. I'm curious what correlation scores will we see in WMT18 metrics task results. |
Ok ,I just want to build a baseline that the BLEU score is comparable with other systems above . I think it's a good beginning for me to do further research . Thanks for your answer ! |
Hi @martinpopel , Can you please help to explain completely the steps to reproduce the correct results? I downloaded the dataset from google_drive_link for for WMT16 I trained the model with problem translate_ende_wmt_bpe32k and get a bunch of model checkpoints model.ckpt-xxxx Now what should I do specifically to get the averaged model and get the correct bleu score? Thank you. |
Hi @nxphi47, For future research, it is recommended to use sacreBLEU, which when used with option For averaging, use avg_checkpoints.py or t2t-avg-all. |
Hi @martinpopel Thank you |
When using t2t-bleu or sacrebleu, always use non-tokenized (de-tokenized) version of translation and reference files, i.e. the version which would be presented to the users. |
@martinpopel Thank you for you reply |
I tried to reproduce results from the paper on WMT En2De, base model. In my experiments I tried both BPE and word piece model. Here are the steps I made to train models:
I trained both models till the trainer finished (~1 day). The last update for BPE model was:
For word piece model the last update was:
Then I tried to evaluate both models on newstest2013, newstest2014, newstest2015. Here are the commands that I used (I'm mostly following steps from here https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/get_ende_bleu.sh)
For BPE model:
For word piece model:
Here are the BLEU scores I've got:
There is a big mismatch with the results reported in the paper, so there must be something wrong with the way I ran these experiments. Could you please provide me some guidance on how to run this properly to reproduce the results from the paper?
The text was updated successfully, but these errors were encountered: