map reduce using jax.lax.scan? #17537
-
Hi, Supposing you have some function you would like to map over all slices of a tensor, and then sum the results. Using Below is my attempt to encapsulate this pattern. I'm interested to know if it will be efficient or there is a better way! Please let me know your thoughts. def map_sum(map_fn, data, rng):
"""
Maps map_fn across the items of data (split by the leading axis)
Returns: sum of mapped items, new_rng
Inputs:
data: a pytree of tensor leaves. Each tensor has the same sized axis 0
rng: random seed
map_fn: (data_slice, rng) -> result
where data_slice is a pytree of one slice along axis 0 of each
leaf of `data`
Returns:
sum of each result returned by map_fn
"""
initial_data = jax.tree_map(lambda x: x[0], data)
rest_of_data = jax.tree_map(lambda x: x[1:], data)
result = map_fn(initial_data, rng)
rng, = jax.random.split(rng, 1)
carry = result, rng
def scan_fn(carry, item):
accu, rng = carry
result = map_fn(item, rng)
rng, = jax.random.split(rng, 1)
accu = jax.tree_map(lambda x, y: x + y, accu, result)
carry = accu, rng
return carry, 0
carry, _ = jax.lax.scan(scan_fn, carry, data)
return carry |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
If you want to find the shape of a function output without actually evaluating the function, you can do so with result_shape = jax.eval_shape(map_fn, initial_data, rng)
init = jax.tree_map(jnp.zeros_like, result_shape) Note also that if you simply want a reduction of an arbitrary function over an array of values, we recently added result = jnp.frompyfunc(map_fn).reduce(data) I don't think this is directly applicable here, though, because of the additional |
Beta Was this translation helpful? Give feedback.
If you want to find the shape of a function output without actually evaluating the function, you can do so with
jax.eval_shape
:Note also that if you simply want a reduction of an arbitrary function over an array of values, we recently added
jnp.frompyfunc
that will allow this. For example:I don't think this is directly applicable here, though, because of the additional
rng
information carried in the scan.