Skip to content

How to make jax.numpy.linspace to produce sharded result #17447

Answered by yashk2810
jczaja asked this question in General
Discussion options

You must be logged in to vote

One way of doing this is like this: (the example is with arange but you can adopt it for linspace)

@jax.jit
def f():
  x = jnp.arange(16).reshape(8, 2)
  return jax.lax.with_sharding_constraint(x, sharding)

I assume you know what the sharding is?

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@Findus23
Comment options

Answer selected by jczaja
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants