Skip to content

map reduce using jax.lax.scan? #17537

Sep 11, 2023 · 1 comments · 3 replies
Discussion options

You must be logged in to vote

If you want to find the shape of a function output without actually evaluating the function, you can do so with jax.eval_shape:

result_shape = jax.eval_shape(map_fn, initial_data, rng)
init = jax.tree_map(jnp.zeros_like, result_shape)

Note also that if you simply want a reduction of an arbitrary function over an array of values, we recently added jnp.frompyfunc that will allow this. For example:

result = jnp.frompyfunc(map_fn).reduce(data)

I don't think this is directly applicable here, though, because of the additional rng information carried in the scan.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@hrbigelow
Comment options

@jakevdp
Comment options

@hrbigelow
Comment options

Answer selected by hrbigelow
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants