-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Bug introduced in v1.3.0 causing training divergence #529
Comments
This to confirm @martinpopel's observance. The orange curve was trained on T2T 1.4.1. The other curves belong to to 1.4.2. I tried increasing warmup to 16k, 32k or decreasing learning rate to 0.05, the model seems to always diverge: |
@mehmedes: so you see a difference between T2T 1.4.1. and 1.4.2? So perhaps that's another issue. I haven't noticed any difference between these versions (for my translation problem) and also the changes related to transformer (i.e. not transformer_vae) in #506 seem just cosmetic (but I may be wrong).
In my case, increasing warmup steps from 16k to 32k helped - after one day of training (170k steps), the curves are basically the same (even now after 3 days of training, not shown on the picture below): However, all my experiments with the new version are worse than v1.2.9 in the first 12 hours of training. E.g. increasing warmup to 48k does not help (it starts similarly to the blue curve and joins the gray curve only after two days of training). |
Should we still try to pinpoint which exact change from 1.2.9 to 1.3.0 caused this effect? Or should we just increase warmup_steps in single_gpu configs? There doesn't seem to be that much that changed from 1.2.9 to 1.3.0, I'm surprised it causes this! Do you have any suggestions what could have caused it? We added target_weights_fn, but it doesn't look like we forgot padding. Otherwise, did anything change? Can you decode the 1.2.9-trained model with 1.3.0? Did the weights change in any way? |
No because the checkpoint format is not compatible (because of 01b8c31 and c10e016). |
like @mehmedes I really think there is an issue with 1.4.2 vs 1.4.1 |
After training more versions for longer time, it seems there is not a single commit culprit, but the issue is more complicated. See training loss curves: I have problems plotting the BLEU curves because most of the versions have broken/missing t2t-bleu and t2t-translate-all scripts, so I would have to patch these versions first (some commits between v1.2.9 and v.1.3.0 cannot be decoded with neither of these versions, which I have already patched). Another problem is that until recently t2t had not fixed properly rand seed, but even so, I think the difference between the bottom three "converging" curves and the three "diverging curves" is too high to be caused by the random initialization. |
T2T before v1.4.2 had not fixed rand seed, so I re-ran some experiments several times. |
@vince62s suggested the worsening may be caused by (probably unintentional) removal of bias variables in 0ffe0e6. |
I intentionally removed the bias, since it didn't seem to have any effect
on quality, and having a smaller number of variables was faster on some
systems. If it turns out to be important, we can put it back.
…On Thu, Feb 1, 2018 at 8:32 AM, Martin Popel ***@***.***> wrote:
@vince62s <https://github.com/vince62s> suggested the worsening may be
caused by (probably unintentional) removal of bias variables in 0ffe0e6
<0ffe0e6>
.
common_attention.py originally used function common_layers.conv1d with
the default use_bias=True.
0ffe0e6
<0ffe0e6>
changed this to tf.layers.dense which also adds the bias variables by
default, but this default has been overridden to use_bias=False in all
calls.
I've change it to use_bias=True, but it had no effect, the training loss
still starts growing after 16k steps, exactly as with use_bias=False.
So I still don't know which code change causes the worsening.
Nevertheless, it is possible that the change of use_bias was not
intentional
(at least it is not mentioned in the commit log, which says just "Make
Transformer fast on TPU.").
@nshazeer <https://github.com/nshazeer> @lukaszkaiser
<https://github.com/lukaszkaiser>: Can you check whether use_bias=False
is what you want and explain why the bias was dropped?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#529 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AcZ97hczleHC9lL-zIS5t07k3NGJa59Nks5tQec5gaJpZM4Rk3iD>
.
|
@nshazeer: Thank you for the info.
No, the bias has no effect according to my experiments. |
Does this problem also occur with |
All the experiments reported here are with If there are no more ideas what caused the error (in 0ffe0e6), we can close this issue.
Great. Thanks. |
I confirm that on 1.4.3 (with patch for the correct learning rate schedule) the ENDE transformer_base runs fine on 4GPU (batch size 5430 x4) => 26.62 |
Thanks @vince62s. That's good news. It may be that we need to retune some of the hparams sets and update them. In general it seems to be very difficult to maintain reproducibility without freezing a codebase/codepath entirely. What we may want to consider doing is a hard fork of the entire Open to suggestions and discussion here. This is a hard problem. |
We could also just say that to reproduce an experiment you have to be on a certain commit hash. |
If I may suggest something: |
Yeah, I do think it's reasonable to tag certain commits as the golden commits to reproduce a certain baseline (model-problem pair). Development can continue on master. So maybe for each paper publication or model-problem pair of interest, we tag a commit, and in the tag description we include a command line that reproduces it perfectly. That would be quite nice. |
Hi, I'm not sure if this is the right place (if not I can open a new issue for that), but I trained a The resulting BLEU-score was 8- 9! When I initially submitted the Here are some images from TensorBoard: Should I try an older version like 1.2.9 to check if this low BLEU-score is reproducable? Here's the BLEU-score graph from an older |
@stefan-it: Thanks for sharing your results.
|
@martinpopel thanks for that hint. I trained another model with 1.5.1 using Here are the results:
But compared to version 1.2.9 there's a difference of 29,63 points for BLEU. |
IMO it does not make sense to train a big model with 200K sentences. |
Yes, if the data is translate_enmk_setimes32k with just 205,777 training sentences, then I guess using However, no matter what training data and hyperparams set @stefan-it used, it would be interesting to confirm whether the drop in BLEU between 1.2.9 and 1.5.1 is caused by the same commit as I found. Just to be sure, I would suggest to re-evaluate the BLEU results (at least the final ones) with sacreBLEU or t2t-bleu to have trustworthy results (with approx_bleu there is a risk that also the non/autoregressive slow/fast implementation changes between versions influence the results). And finally a bit of terminological nitpicking: what @stefan-it calls epochs is usually called steps in T2T (or iterations or number of updates). Epoch is usually understood as one pass over the whole training data, see #415. |
Here are the BLEU scores using
So I think this problem here can be replicated by at least two datasets with different sizes. |
I train with exactly the same setup (1GPU, transformer_big_single_gpu, batch_size=1500, en-cs dataset with 58M training sentences, checkpoints saved each hour) with T2T v1.2.9 and v1.3.0 (which is the version immediately following 1.2.9). While in v1.2.9 everything is OK, in v1.3.0 the training diverges after about one hour of training and BLEU drops to almost zero after about 5 hours of training. See the learning curves below (orange=v1.2.9, magenta=1.3.0, y-axis is dev-set BLEU as measured by
t2t-bleu
, x-axis is training time in hours):I attach also training loss curves for these two experiments (but now orange=v1.3.0, gray=v1.2.9, so the colors don't match the previous graph, x-axis is now in steps):
The text was updated successfully, but these errors were encountered: