Replies: 2 comments 1 reply
-
Could you speak more about what happens to this short-term solution in the long term? I.e. after we have sparse types, would the
Have you considered using a Some thoughts on
|
Beta Was this translation helpful? Give feedback.
-
Indeed, if the proposed operation is only available on the NVIDIA hardware, and nowhere else, custom_call seems like a better fit, and is normally used for backend-specific operations. |
Beta Was this translation helpful? Give feedback.
-
RFC: Add 2:4 structured sparsity support to XLA:GPU
This RFC proposes integrating support for NVidia’s 2:4 structured sparsity into the XLA:GPU compiler with the objective of speeding up matrix multiplications where one matrix is or can be pruned into a format where every four consecutive elements exactly contain two nonzeros and two zeros by means of the hardware acceleration provided for this format on A100 and H100 GPUs. The objective of this RFC is to make 2:4 structured sparsity very rapidly available to model writers by means of special operations that deal with the sparsity explicitly (in contrast with the longer term implicit sparsity proposal that introduces sparse tensor types, described in the StableHLO sparsity RFC). Note that, although this RFC focuses on 2:4 structured sparsity, with minor changes the RFC can also be generalized to include block sparsity.
2:4 Structured Sparsity Storage Format
The NVidia 2:4 structured sparsity format is described in the data type section of the cusparselt library documentation and consists of two buffers, i.e. an array with compressed 2-bit indices and an array with the nonzero values. This enforces a 1:2 relation between sparse matrices and the actual metadata buffers. This RFC proposes a simple design that allows model writers to quickly speedup 2:4 structured sparse matrix multiplication operations with explicit sparsity, i.e. the user is aware of all the metadata for sparse matrices and inserts a special operation for the matrix multiplication. The drawbacks of this approach are that users must be aware of all details of the sparsity format and keep metadata together, users must explicitly introduce the accelerated matrix multiplication through special operations, and the compiler provides almost no safety support on the metadata. However, the advantage of the explicit sparsity approach is that this approach will put accelerated sparse performance much quicker in the hands of model writers with minimal changes to the XLA infrastructure (IR and ops).
Jax Changes
This proposal completely avoids sparse tensor types at JAX level (as was the approach taken in jax.experimental.sparse), since this will complicate ABI requirements (i.e. a sparse tensors map to more than one buffer). Instead, a few lower-level primitives expose the metadata of the sparse storage formats together with a special operation for matrix multiplication for 2:4 structured sparsity, operating directly on the metadata, as shown below.
Initially, the former two operations can solely remain in JAX land as a library. Only the matmul primitive is recognized as a custom operation that maps to a special accelerated operation. Over time, accelerated versions of the prune and compress step that run on the GPU can be made available as well.
An element-wise operation on the original sparse matrix can simply be applied to the values buffer, as shown below. At every point, the model writer must be fully aware of what metadata represents what matrix and is fully responsible for preserving the integrity of metadata (e.g. scaling the indices array would corrupt the sparse storage scheme; the compiler provides no safety support against such corruption).
The operations manifest themselves as dense operations to the backend, which fully enables the typical optimization such as fusion (but the special accelerated matrix multiplication operation needs some attention).
Note that it is probably a good idea to follow up the initial implementation with some additional JAX work to “hide” the metadata in a single JAX construct (not a full sparse tensor type, but some wrapper that hides some of the details from the user, e.g. a struct with values, indices, m and n, together with some zero-overhead setter and getter methods).
HLO and StableHLO Changes
Since the user deals with operations on the metadata explicitly, the only addition to the IR involves providing a special sparse matrix multiplication operation. As mentioned earlier, an audit of the “dense” fusion optimization is required to ensure this special sparse operation fuses with dense operations to make sure no performance is lost (element-wise operations manifest themselves as dense operations on the values part of the metadata, so no special attention is needed for these).
For a very first exploration of adding 2:4 support (e.g. in private fork), StableHLO doesn't necessarily need to change because there is a process for exposing HLO/MHLO features to JAX and other frameworks without going through StableHLO. However, before submitting the 2:4 code into the XLA repository, the special operation needs to be introduced to StableHLO as well, which will need to be proposed as RFC, and, once approved by the governance body, added to the StableHLO repository with compliance tests.
Code Generation for GPU
The objective of introducing the special sparse matrix multiplication operation is ultimately to map this onto an efficient usage of the mma.sp instruction. The following three approaches to generation code for the special operation are possible, listed in increasing order of difficulty.
Implementation Status
This RFC is meant to solicit early feedback and also acts as a call for volunteers interested in contributing to this work. The JAX operations have been implemented as a small library already, but the proper connection with HLO, StableHLO, and all the backend work still has to start.
Beta Was this translation helpful? Give feedback.
All reactions