-
I have 2D array, for which I need to calculate sum of rolling cumulative product.
For which I tested it on array:
Because of loop unrolling in func2 the compilation time is very long (a few minutes), but XLA is more capable of optimizing the code. Func2 is faster by around 10-50x depending on array shape, but if I increase array shape to (100000, 400) func2 won't compile due to OOM. Is there another approach that is faster than func1? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I would probably rewrite this in terms of vectorized operations. This computes the same results as your two functions, without any explicit iteration: def func3(x, y):
assert x.ndim == y.ndim == 1
assert x.shape == y.shape
i = jnp.arange(x.shape[0])
mask = i[None, :] < i[:, None]
cumprod = jnp.where(mask, 1, x[None, :] * y[:, None]).cumprod(1)
return jnp.where(mask, 0, cumprod).sum(1)
res3 = jit(vmap(func3))(x, y) |
Beta Was this translation helpful? Give feedback.
-
Thank you, @jakevdp. This is a very clever solution! For larger arrays, though, this solution tends to be slower than func1. I think the reason is the numerous multiplications by 1's that don't contribute to the final result, but are computationally intensive. Would it be possible to implement such a function as a custom operation as described here: |
Beta Was this translation helpful? Give feedback.
It may be possible to do this more efficiently with a custom kernel, but keep in mind that the kinds of operations that are efficient on GPUs are the kinds of operations used in my solution (full-axis reductions over statically-sized arrays). You'd probably have to play some of the same tricks in your custom kernel that I did in
func3
, but I wouldn't be surprised if you could make it faster with some thought.