feat: torch.compile and custom_op support #554
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Follow up of #552. This PR adds torch library annotation to all FlashInfer kernels so that torch.compile can recognize the kernels. Most changes are tedious.
I manually ran subsets of pytest test cases when I made these changes, but since there are too many of them and also some of them didn't pass even before I made the change, I cannot guarantee it's all working. To run tests with torch.compile, pass
FLASHINFER_TEST_TORCH_COMPILE=1
env.Notable changes:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
depending onreturn_lse
. This causes trouble fortorch.compile
. I changed the pybind interface to accept amaybe_lse: Optional[torch.Tensor]
and only return one tensor. The allocation of the lse tensor is moved to Python side. The Python API does not change.chain_speculative_sampling
pybind: Move the allocation ofaccepted
andemitted
from C++ to Python. This is becausetorch.compile
doesn't like returning input tensor as output tensor. The Python API does not change.Piggyback changes:
BatchPrefillWithRaggedKVCacheWrapper.plan
: Bugfix qo_indptr not on CPUmerge_state
: Fix typo in docsrun_return_lse(...)
torun(..., return_lse=True)
because torch.compile does not recognizefunctools.partial
.flashinfer.xxx()
toflashinfer.<module>.xxx()
so that the monkeypatch works.Unsupported for torch.compile:
flashinfer.quantization.segment_packbits
: Because it's data dependent.Untouched:
sparse.py
: Tests didn't pass beforehand, so I skiped this. Also, it doesn't seem like need custom_op annotations, as it does not have CUDA kernels.Failed test cases:
test_batch_decode_with_paged_kv_cache[False-kv_dtype0-q_dtype0-True-0.0-NONE-NHD-128-4-4-1-54-12]