-
Hi all, I've been working with some custom code that represents a dense centered/standardized matrix as a sparse linear operation, and have been running into some trouble trying to compute a memory efficient sum of squares. Namely, what I'd like to compute (using an example with dense operations) would be If I represent term1 = jnp.einsum("np,np,p,p,p->", G, G, S, S, W)
term2 = jnp.einsum("nk,kp,p,nk,kp->", C, B, W, C, B)
term3 = -2 * jnp.einsum("np,p,p,nk,kp->", G, S, W, C, B)
wgt_ss = term1 + term2 + term3 Computing
I -could- replace this with So I have two questions:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
The Regarding performance: to be quite honest, the jax sparse implementations (and in particular sparse-sparse matmul) is quite slow, so I suspect writing We are working on ideas about how to make this all better, but nothing is ready for production yet. |
Beta Was this translation helpful? Give feedback.
Actually, I suspect that
G ** 2
will be much better performance-wise thanG, G
: the reason is that in youreinsum
statement,G, G
essentially ends up computingG * G
, and element-wise multiplication between two sparse matrices requires a set intersection operation between the specified indices of the matrices. By contrast,G ** 2
is basically a single vectorized operation over the defined data, which should be much faster.