-
Hi, How can I make jax.numpy.linspace to produce result already sharded across my two jax devices? Context:I have a host with two JAX devices. Problem is that each of this devices are limited in memory. So I can call jax.numpy.linspace(start, end, num=50000000) but bigger values of num would call Out of memory error. So if it is possible |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
One way of doing this is like this: (the example is with arange but you can adopt it for linspace)
I assume you know what the sharding is? |
Beta Was this translation helpful? Give feedback.
-
Thanks very much for answers |
Beta Was this translation helpful? Give feedback.
One way of doing this is like this: (the example is with arange but you can adopt it for linspace)
I assume you know what the sharding is?