Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Evaluation failed when training with multiple work gpus #266

Closed
skyw opened this issue Aug 30, 2017 · 38 comments
Closed

Evaluation failed when training with multiple work gpus #266

skyw opened this issue Aug 30, 2017 · 38 comments
Labels

Comments

@skyw
Copy link

skyw commented Aug 30, 2017

I set worker_gpu to 2 and use 2 gpus in the same node. training is completely fine. But evaluation fails with this error
"
tensorflow.python.framework.errors_impl.InvalidArgumentError: Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 47) and num_split 2
[[Node: split_2 = Split[T=DT_INT32, num_split=2, _device="/job:localhost/replica:0/task:0/cpu:0"](split_2/split_dim, input_reader/ExpandDims_3/_1823)]]
[[Node: split_2/_1825 = _HostRecvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:1", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1113_split_2", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/gpu:1"]]

Caused by op u'split_2', defined at:
File "/raid/skyw/venv/tensorflow-pip-py27/bin/t2t-trainer", line 5, in
pkg_resources.run_script('tensor2tensor==1.2.1', 't2t-trainer')
"

It wasn't very clear what the error is. But single GPU training/evaluation is fine, the problem comes with 2 worker GPUs.

@martinpopel
Copy link
Contributor

I see exactly the same problem, with T2T 1.2.1 and TensorFlow 1.3. With multiple GPUs I can train with --local_eval_frequency=0, which disables the evaluation completely.

With batch_size 2048, 1024, 512 and 256, it fails just at the beginning of the evaluation (after "Starting evaluation at..." and "Restoring parameters from...").
With batch_size=128, it fails after "Evaluation [18/100]", ie. 18 batches out of 100 of the dev set are translated OK, but the 19th fails.
With batch_size=100, it fails after "Evaluation [77/100]".
With batch_size=64, the whole evaluation goes OK and the model continues to train (but quite slowly due to the low batch size, so there is no advantage in multi-gpu).

@benjamincohen1
Copy link

Also having this issue. Have you made any progress on this?

@lukaszkaiser
Copy link
Contributor

Could you guys just try --eval_steps=5 or sth like this? I mean: is it only failing because it has evaluated the whole set, or sth else? Could you try --eval_steps=5 and report back if it works? Thanks!

@martinpopel
Copy link
Contributor

@lukaszkaiser My dev set has about 110k subwords (BTW: is there an easy way to compute it exactly?). When training on a single GPU with batch_size=2048, I need 55 eval steps to cover it whole. If I use more eval_steps, I see "Out of range: End of sequence" warnings, but the training continues and approx_bleu is computed correctly.
As I wrote in my last post, when training on two GPUs with batch_size=128, the evaluation fails after 18 batches, but the whole dev set has 110k/128=860 batches. So the failing is not because of evaluating the whole set (even if I multiply 128 by 2 because of the two GPUs, BTW: is the evaluation is parallelized?).
I don't have access to two GPUs on the same machine at the moment. Once I have, I will try --eval_steps=5 and report back. Based on my previous experiments, I guess that batch_size bigger than 256 will fail anyway (it had failed before finishing the first eval batch).

@skyw
Copy link
Author

skyw commented Sep 7, 2017

I tried eval_steps=1/5/8/16 with --worker_gpu=2, none of them worked.

@martinpopel
Copy link
Contributor

I tried eval_step=5 with worker_gpu=2 and it failed for batch_size 100 and more (128, 256, 512, 1024, 2048). Usually, it failed within the first evaluation (at different steps 1-4), just once it failed within the second evaluation.
With batch_size 80 or less, it worked OK (at least for 18 evaluations, when I stopped the experiment).
However, when I increased eval_step to cover the whole dev set, it failed anyway (at [289/1051]), even with tiny batch_size.

I guess this is the relevant portion of the long stacktrace:

  File "/home/popel/venv/tf1.3gpu/lib/python3.4/site-packages/tensor2tensor/utils/input_fn_builder.py", line 200, in cond_on_index
    return fn(cur_idx)
  File "/home/popel/venv/tf1.3gpu/lib/python3.4/site-packages/tensor2tensor/utils/model_builder.py", line 242, in nth_model
    features, skip=(skipping_is_on and skip_this_one))
  File "/home/popel/venv/tf1.3gpu/lib/python3.4/site-packages/tensor2tensor/utils/t2t_model.py", line 430, in model_fn
    sharded_features = self._shard_features(features)
  File "/home/popel/venv/tf1.3gpu/lib/python3.4/site-packages/tensor2tensor/utils/t2t_model.py", line 411, in _shard_features
    0))
  File "/home/popel/venv/tf1.3gpu/lib/python3.4/site-packages/tensorflow/python/ops/array_ops.py", line 1232, in split
    split_dim=axis, num_split=num_or_size_splits, value=value, name=name)

@martinpopel
Copy link
Contributor

martinpopel commented Sep 9, 2017

Update: I checked 1.2.2 and the problem is still there.

The problem is that here:
The feature value v dimension 0 size is not divisible by self._num_datashards.

Is there a way in TF to split a tensor, while possibly keeping the last returned value with less items if not divisible?
Or using something like:

v_size = v.shape.as_list()[0]
if v_size % self._num_datashards == 0:
  list_of_v = tf.split(v, self._num_datashards, 0)
else:
  shard_size = (v_size // self._num_datashards) + 1
  list_of_v = tf.split(v, [shard_size] * (self._num_datashards - 1) + [-1])
sharded_features[k] = self._data_parallelism(tf.identity, list_of_v)

(but v_size is not known at graph creation time, so this does not work).

@ehasler
Copy link

ehasler commented Sep 11, 2017

Until the problem is fixed, is it possible to switch off multi-GPU decoding for evaluation without switching it off for training?

@martinpopel
Copy link
Contributor

martinpopel commented Sep 11, 2017

@ehasler: With --local_eval_frequency=0 you disable the internal evaluation completely. You can run the evaluation separately (manually, from the stored checkpoints, in another job), but I don't know how to get this into Tensorboard.

@ryonakamura
Copy link

I made it possible to execute the above code, and evaluation was done halfway.

v_size = tf.shape(v)[0]
condition = tf.constant((v_size % self._num_datashards) == 0)
list_of_v = tf.cond(condition,
        lambda: tf.split(v, self._num_datashards, 0),
        lambda: tf.split(v, [(v_size // self._num_datashards) + 1] \
                * (self._num_datashards - 1) + [-1]))
sharded_features[k] = self._data_parallelism(tf.identity, list_of_v)

But I encountered another error.

INFO:tensorflow:Starting evaluation at 2017-09-13-01:18:11
2017-09-13 10:18:12.222136: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1080, pci bus id: 0000:01:$
0.0)
2017-09-13 10:18:12.222159: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:1) -> (device: 1, name: GeForce GTX 1080, pci bus id: 0000:02:$
0.0)
INFO:tensorflow:Restoring parameters from /home/ryo/t2t_train/translate_ende_wmt32k/transformer-transformer_base/model.ckpt-1105
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
2017-09-13 10:18:13.764736: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Reshape cannot infer the missing input size for an empty tensor unless all specified in$
ut sizes are non-zero

......

InvalidArgumentError (see above for traceback): Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
         [[Node: body/model/parallel_1/body/encoder/layer_0/self_attention/multihead_attention/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:
0/gpu:1"](body/model/parallel_1/body/encoder/layer_0/self_attention/multihead_attention/split, body/model/parallel_1/body/encoder/layer_0/self_attention/multihead_attention/concat)]
]
         [[Node: mean_3/broadcast_weights/assert_broadcastable/AssertGuard/switch_f/_2118 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_
device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_7461_mean_3/broadcast_weights/assert_broadcastable/AssertGuard/switch_f", tensor_type=DT
_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

@martinpopel
Copy link
Contributor

TL;DR: I still don't know how to workaround this problem (except for disabling the internal evaluation completely).

@ryonakamura Yes, I did something similar and also got the "Reshape cannot infer" error and I was not able to workaround this error. Apparently, the problem is that one of the list_of_v is zero-sized. I've also tried to use just one datashard (list_of_v = [v]) for the evaluation as @ehasler suggested, but self._data_parallelism is built with n fixed to the number of GPUs and there are assertions (and even if disabled, we would need to put also the logits into one datashard).

@martinpopel
Copy link
Contributor

martinpopel commented Sep 23, 2017

Update: I checked 1.2.3 and 1.2.4 and the problem is still there.

@lukaszkaiser, @rsepassi: Can you please comment on this issue? Can you train on multiple GPUs with evaluation? Is a fix planned (e.g. waiting for TF 1.4 as in other issues)? Or should we try to fix this ourselves and is any of the above-mentioned workarounds promising?

@rsepassi
Copy link
Contributor

rsepassi commented Oct 7, 2017

Until we address the underlying issue of getting eval to work on multiple gpus, the workaround would be to have your training job separate from your eval job. For training, set --schedule=train and for the eval job (with 1 GPU), set --schedule=continuous_eval. The eval job will look for new checkpoints from the training job in --output_dir and run eval on the single gpu.

@mehmedes
Copy link

Training a new model with T2T 1.2.5 and TF 1.4 allows me to evaluate using multiple gpus.

@vince62s
Copy link
Contributor

@mehmedes are you sure ?
I just started a training with TF 1.4.0-rc1 and t2t 1.2.6, it still throws an error with 2 worker gpu when it comes to evaluation.

@martinpopel
Copy link
Contributor

I confirm the error (Number of ways to split...) is still there even with TF 1.4.0rc1 and T2T 1.2.6.
In some setups (just two GPUs, just ten eval steps, high local_eval_frequency) it is less likely, but it may still occur after several hours of training.
When I increase the amount of test data with --eval_steps=60, the error occurs usually within the first evaluation.

Thanks for the workaround with two jobs (one with --schedule=train and one with --schedule=continuous_eval). It has the advantage that the continuous_eval job may run on a different machine (and thus not slowing down the training).
According to the documentation, train_and_evaluate takes double resources relative to train (and continuous_train_and_eval) because it needs extra memory for loading the checkpoint to be evaluated (while keeping in memory the current model being trained). So this could be another advantage of this workaround and of not using train_and_evaluate. However, in practice I was not able to use e.g. a bigger batch size with --schedule=train (the out-of-memory error starts appearing with the same batch size).

@mehmedes
Copy link

mehmedes commented Nov 1, 2017

@vince62s : I was testing on a setup with two 1080s, where it works. Just tried on four 1080 TIs and evaluation still fails.

@rsepassi rsepassi added the bug label Nov 14, 2017
@liesun1994
Copy link

@mehmedes I failed to with the same configuration. Four gpus failed and two gpus worked. Really puzzling.

@lukaszkaiser
Copy link
Contributor

Please keep reporting your experiences here, it might helps us figure this out too. And thanks for checking and pushing it!

@martinpopel
Copy link
Contributor

OK, keeping reporting. The problem is still here with tensorflow-gpu 1.4.0, tensor2tensor 1.2.9.
I encounter the problem even with two GPUs, after setting --eval_steps=68 (it actually crashes always in the 21st evaluation step during the very first evaluation, this is on translate_ende_wmt32k with batch_size=1500).
With the default --eval_steps=10 it does not crash (during the first few evaluations at least - I have not tried more).

@liesun1994
Copy link

@lukaszkaiser I tried it again , when using 2 GPUs and transformer_base_single_gpu , it works well . but it failed when using transformer_base . My configuration is t2t 1.3.2 and tf 1.4.0 .

@liesun1994
Copy link

@martinpopel have u solve the multi-devices problem recently ? I used 4 GPUs and transformer_base_single_gpu , it fails again with the same error . Really PUZZLING .

@liesun1994
Copy link

@rsepassi Good idea . But one GPU maybe wasted for evaluation . Anyway , that is a proper way to solve it .

@martinpopel
Copy link
Contributor

My solution is #436 (which was merged, but then the t2t-bleu script was deleted from the git master). I also don't want to waste one GPU for evaluation (I want to run multiple training experiments in parallel, so that would mean one extra GPU per experiment), so:

  • I wrote another script which translates (using t2t-decoder) the test set with all checkpoints in a given directory. All translations are done on CPU and one takes about 75 minutes (3000 sentence with 32 CPU cores), but we have a CPU cluster, so all the translations are submitted there (via qsub) and translated in parallel (if there are free machines).
  • I adapted t2t-bleu, so it doesn't try to actively translate anything, it just waits until the translation files are created in a given directory. So it can also run on CPU, doing only the BLEU computation, which is very fast.

If you are interested I may provide my scripts (after cleaning them a bit).

@liesun1994
Copy link

@martinpopel Wow, that would be nice ! If the script is easily to use , please send it to me ~ Much THANKS.

@liesun1994
Copy link

@martinpopel We don't have CPU cluster environment ...

@rsepassi
Copy link
Contributor

Ok, so I think the only difference between training and evaluation when it comes to the input pipeline is whether long sequences are skipped. Try setting --hparams=eval_drop_long_sequences=True. Does that fix the issue?

If so, then I believe what's happening is that the batch size for those long sequences are not divisible by the number of GPUs.

Not entirely sure what the fix would be yet, but I think we could check to see if the batch is small and pad it out if it is.

@vince62s
Copy link
Contributor

I am not following, why don't you just eval on a single GPU then ?
Generally speaking, I think martin's idea to just put the training in standby, do the full evaluation with actual BLEU on some sets (on single GPU if needed) and return to multi gpu training would be a better solution.

@martinpopel
Copy link
Contributor

@rsepassi: Even after adding --hparams=eval_drop_long_sequences=True the problem is still there. This is with version 1.2.9 and --eval_steps high enough to cover the whole dev set.

@mehmedes
Copy link

Just tried pull request #484 (v1.4.1), multi eval works perfectly with the preset --eval_steps=1000.
Thank you!

@rsepassi
Copy link
Contributor

Great. Yes, this is fixed in 1.4.1, which will be merged and pushed to PyPI shortly.

@martinpopel
Copy link
Contributor

I tried 1.4.1, but it crashed with:

INFO:tensorflow:global_step/sec: 9.23761
INFO:tensorflow:loss = 5.81556, step = 1901 (10.825 sec)
INFO:tensorflow:Reading data files from data/translate_encs_wmt32k-dev*
WARNING:tensorflow:Padding the batch to ensure that remainder eval batches have a batch size divisible by the number of data shards. This may lead to incorrect metrics for non-zero-padded features, e.g. images. U
se a single datashard (i.e. 1 GPU) in that case.
Traceback (most recent call last):
  File "/home/popel/work/tensor2tensor/new/tensor2tensor/bin/t2t-trainer", line 190, in <module>
    tf.app.run()
  ...
  File "/home/popel/work/tensor2tensor/new/tensor2tensor/data_generators/problem.py", line 588, in _pad_batch
    return pad_batch(features, config.data_parallelism.n)
  File "/home/popel/work/tensor2tensor/new/tensor2tensor/data_generators/problem.py", line 950, in pad_batch
    feature = features.items()[0][1]
TypeError: 'dict_items' object does not support indexing

@martinpopel martinpopel mentioned this issue Dec 22, 2017
@rsepassi
Copy link
Contributor

Thanks! Will fix.

@martinpopel
Copy link
Contributor

I confirm it is fixed (I tested with 2 GPUs so far). Thanks.

@ndvbd
Copy link
Contributor

ndvbd commented Feb 2, 2018

Are you sure this is fixed? So to use 8 GPUs instead of 1, we should only add --worker_gpu=8 and it should work out of the box? Because for me the evaluation step failed...

I get:

INFO:tensorflow:Evaluation [6/10000]
2018-02-02 21:49:02.451743: W tensorflow/core/common_runtime/bfc_allocator.cc:273] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.35GiB.  Current allocation summary follows.

@martinpopel
Copy link
Contributor

For me the internal evaluation works out of the box on 8 GPUs in T2T v1.4.2. There are 8 times more warnings "W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence", but these are just warnings (and OK in this case), not errors. I can see approx_bleu, rouge etc. in Tensorboard.
So technically the internal evaluation works, although there are some related issues:

@ndvbd
Copy link
Contributor

ndvbd commented Feb 5, 2018

@martinpopel I solved it by upgrading to TF 1.5.0 (cuda 9.0, cuDNN 7.0). Before that I had TF 1.4.1. Now the regular continuous_train_and_eval process with --worker_gpu=8 works out of the box.

@jazzminewang
Copy link

Seeing this again with TF-gpu 1.12, training is fine, but when I want to evaluate, I see:

_InvalidArgumentError (see above for traceback): Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 1) and num_split 2
[[node train/split (defined at /home/ml/jwang301/Development/RUBER/unreferenced_metric.py:146) = Split[T=DT_FLOAT, num_split=2, device="/job:localhost/replica:0/task:0/device:CPU:0"](train/split/split_dim, multi_layer_perceptron/Reshape_1)]]

I'm not using T2T but this was the closest error thread I could find.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests