-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
slow pmap allreduce #973
Comments
That code looks reasonable to me! Maybe something fishy is going on... |
The EDIT: I forgot you already pointed this out :) but in any case that's the first thing to sort out. |
My GCP instance (also with 8xV100) running the original script is showing these numbers:
so slightly better but still weird. The problem with def block_until_ready(self):
for buf in self.device_buffers:
buf.block_host_until_ready() After adding that method to
So it's much better, but still ~100ms (on my setup) instead of ~5ms (on your setup). I'll try to ping some actual experts to investigate more :) |
It's possible that we have ~100ms of dispatch overheads. To amortize those away I changed the computation to do 10 all-reduces (under one pmap), then divided the printed timing numbers by 10. Here's what I got:
Here's the script I'm running (check for bugs!): import time
import numpy as np
from jax import pmap, lax, partial
NUM_GPU = 8
NUM_PARAMS = 100_000_000
@partial(pmap, axis_name='i')
def do_allreduce(x):
for _ in range(10):
x = lax.psum(x, 'i')
return x
def main():
while True:
data = pmap(lambda x: x)(np.random.rand(NUM_GPU, NUM_PARAMS).astype(np.float32))
start = time.time()
do_allreduce(data).block_until_ready()
print((time.time() - start) / 10)
if __name__ == '__main__':
main() That seems like compelling evidence that this is an overheads issue. We should look into that, but if you agree that this suggests it's an overheads issue, can you say a bit about how important these pmap dispatch overheads are to you (i.e. how representative this micro-benchmark is of a workload you care about)? That can help us prioritize whether to dig into these overheads. If your actual computation will do more work under the pmap, so that overheads get amortized away, then maybe these aren't very salient at the moment! |
Thanks for investigating! It's interesting that you ran the same script on a GCP instance with 8xV100s and got different numbers. I, too, am using a GCP instance with 8xV100s and the times I get are a little over double what you get:
Maybe we have different versions of some dependency. The timings I get when using
Does this mean that in my case it takes 9000 ms (and still 3400 ms in your case) just to copy the data to the host? That seems odd to me. If each GPU copied to host memory at the full pcie bandwidth in parallel, I'd expect this to take a minimum of 25ms. If a single core is doing all the copying manually, from each GPU in series, that seems like it should take a minimum of 1600 ms assuming a transfer rate of 2GB/s for that core. As for the utility of looking more into this, the thing I want to do is transfer arrays from GPU to GPU in a fast way, allreducing gradients being one such thing that you would want to do in parallel training. 100M parameters is not an unreasonable number here. Even if I get pmap to work for this case, I would still need to send the tensors over the network, so it might be easiest to try NCCL on pointers to the arrays directly in GPU memory. Assuming I try that next, this issue would be pretty low priority to me. |
Hrm not sure about the GCP thing. I can dig more into the config I'm using if that would be helpful, but the basics are: n1-standard-64 (64 vCPUs, 240 GB memory) in us-west-1b
cc @hawkinsp for someone who knows how computers are supposed to work. Any thoughts? |
Thanks! That should be enough information to investigate the difference in speed if it ends up impacting my application's performance. If we end up using NCCL directly then I expect we will not have to copy much data to main memory so this particular issue may not matter as much to me (especially if it is somehow pmap specific). |
Related to @joschu's question about direct access to device arrays, I was curious how fast a pmap allreduce would be as an alternative to trying to use nccl directly on GPU pointers.
This script (https://gist.github.com/christopherhesse/192d78f0f082d66dfb26cac112c5cf99) takes 10,000 ms per loop on 8 V100s, which is surprising to me because nccl-tests'
all_reduce_perf
takes about 5 ms to do what I think is the same operation. Is there an error in my script? I tried using.block_until_ready()
instead ofnp.array()
but that failed with an exception, so there's an additional copy to host memory, but even with that it seems like it should be faster.@jekbradbury commented on a similar issue here: #606 (comment)
I'm using jaxlib 0.1.21 and (I think) jax 1508405.
The text was updated successfully, but these errors were encountered: