Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

slow pmap allreduce #973

Open
christopherhesse opened this issue Jul 4, 2019 · 7 comments
Open

slow pmap allreduce #973

christopherhesse opened this issue Jul 4, 2019 · 7 comments
Assignees
Labels
enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) performance make things lean and fast

Comments

@christopherhesse
Copy link

Related to @joschu's question about direct access to device arrays, I was curious how fast a pmap allreduce would be as an alternative to trying to use nccl directly on GPU pointers.

This script (https://gist.github.com/christopherhesse/192d78f0f082d66dfb26cac112c5cf99) takes 10,000 ms per loop on 8 V100s, which is surprising to me because nccl-tests' all_reduce_perf takes about 5 ms to do what I think is the same operation. Is there an error in my script? I tried using .block_until_ready() instead of np.array() but that failed with an exception, so there's an additional copy to host memory, but even with that it seems like it should be faster.

@jekbradbury commented on a similar issue here: #606 (comment)

I'm using jaxlib 0.1.21 and (I think) jax 1508405.

@mattjj mattjj added the performance make things lean and fast label Jul 4, 2019
@mattjj
Copy link
Collaborator

mattjj commented Jul 4, 2019

That code looks reasonable to me! Maybe something fishy is going on...

@mattjj
Copy link
Collaborator

mattjj commented Jul 4, 2019

The np.array call is going to force a device-to-host transfer and include that in the timing, so we should find a way to get .block_until_ready() to work. What was the error you saw? (I'll try to get a repro going so I can find out for myself.)

EDIT: I forgot you already pointed this out :) but in any case that's the first thing to sort out.

@mattjj
Copy link
Collaborator

mattjj commented Jul 4, 2019

My GCP instance (also with 8xV100) running the original script is showing these numbers:

5.574406623840332
3.9834365844726562
3.4369263648986816
3.4715540409088135

so slightly better but still weird.

The problem with block_until_ready is one of inheritance: we need ShardedDeviceArray to have its own version of that method (and _check_if_deleted, which is raising the error first) rather than inheriting one from DeviceArray. This works as an implementation (notice self.device_buffers rather than self.device_buffer):

  def block_until_ready(self):
    for buf in self.device_buffers:
      buf.block_host_until_ready()

After adding that method to ShardedDeviceArray and adjusting the script to use that, I'm seeing these numbers:

1.9906294345855713
0.09681510925292969
0.0945134162902832

So it's much better, but still ~100ms (on my setup) instead of ~5ms (on your setup). I'll try to ping some actual experts to investigate more :)

@mattjj
Copy link
Collaborator

mattjj commented Jul 4, 2019

It's possible that we have ~100ms of dispatch overheads. To amortize those away I changed the computation to do 10 all-reduces (under one pmap), then divided the printed timing numbers by 10. Here's what I got:

0.20596561431884766
0.00972421169281006
0.009992289543151855
0.009689640998840333

Here's the script I'm running (check for bugs!):

import time
import numpy as np
from jax import pmap, lax, partial

NUM_GPU = 8 
NUM_PARAMS = 100_000_000

@partial(pmap, axis_name='i')
def do_allreduce(x):
    for _ in range(10):
        x = lax.psum(x, 'i')
    return x

def main():
    while True:
        data = pmap(lambda x: x)(np.random.rand(NUM_GPU, NUM_PARAMS).astype(np.float32))
        start = time.time()
        do_allreduce(data).block_until_ready()
        print((time.time() - start) / 10) 

if __name__ == '__main__':
    main()

That seems like compelling evidence that this is an overheads issue. We should look into that, but if you agree that this suggests it's an overheads issue, can you say a bit about how important these pmap dispatch overheads are to you (i.e. how representative this micro-benchmark is of a workload you care about)? That can help us prioritize whether to dig into these overheads. If your actual computation will do more work under the pmap, so that overheads get amortized away, then maybe these aren't very salient at the moment!

@christopherhesse
Copy link
Author

Thanks for investigating!

It's interesting that you ran the same script on a GCP instance with 8xV100s and got different numbers. I, too, am using a GCP instance with 8xV100s and the times I get are a little over double what you get:

12.342523574829102
9.377157926559448
9.431708574295044
9.199352502822876

Maybe we have different versions of some dependency.

The timings I get when using block_until_ready() are similar to yours:

2.8297135829925537
0.08717107772827148
0.08565568923950195
0.08546280860900879
0.08353567123413086

Does this mean that in my case it takes 9000 ms (and still 3400 ms in your case) just to copy the data to the host? That seems odd to me.

If each GPU copied to host memory at the full pcie bandwidth in parallel, I'd expect this to take a minimum of 25ms.

If a single core is doing all the copying manually, from each GPU in series, that seems like it should take a minimum of 1600 ms assuming a transfer rate of 2GB/s for that core.

As for the utility of looking more into this, the thing I want to do is transfer arrays from GPU to GPU in a fast way, allreducing gradients being one such thing that you would want to do in parallel training. 100M parameters is not an unreasonable number here.

Even if I get pmap to work for this case, I would still need to send the tensors over the network, so it might be easiest to try NCCL on pointers to the arrays directly in GPU memory. Assuming I try that next, this issue would be pretty low priority to me.

@mattjj
Copy link
Collaborator

mattjj commented Jul 5, 2019

Hrm not sure about the GCP thing. I can dig more into the config I'm using if that would be helpful, but the basics are:

n1-standard-64 (64 vCPUs, 240 GB memory) in us-west-1b
300GB SSD persistent disk
8 x NVIDIA Tesla V100
"Deep Learning VM" image (maybe it's called tf-1-13-cu100-20190524)
Miniconda (Anaconda) Python 3.7
jax from github master, jaxlib 0.1.21 from pypi

Does this mean that in my case it takes 9000 ms (and still 3400 ms in your case) just to copy the data to the host? That seems odd to me.

cc @hawkinsp for someone who knows how computers are supposed to work. Any thoughts?

@christopherhesse
Copy link
Author

Thanks! That should be enough information to investigate the difference in speed if it ends up impacting my application's performance.

If we end up using NCCL directly then I expect we will not have to copy much data to main memory so this particular issue may not matter as much to me (especially if it is somehow pmap specific).

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) labels Aug 10, 2022
@hawkinsp hawkinsp added the XLA label Aug 12, 2022
@sudhakarsingh27 sudhakarsingh27 added enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) and removed NVIDIA GPU Issues specific to NVIDIA GPUs XLA P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) labels Sep 21, 2022
@sudhakarsingh27 sudhakarsingh27 assigned hawkinsp and unassigned mattjj Sep 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) performance make things lean and fast
Projects
None yet
Development

No branches or pull requests

4 participants