From ddef3f3ae99cb8d05d9a1ae64a2fc5b99b2e2b14 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 15 Oct 2024 01:13:02 -0400 Subject: [PATCH] Prefetch device transfer for ptrs to CPU (#529) The FlashInfer kernel [here](https://github.com/flashinfer-ai/flashinfer/blob/main/python/flashinfer/jit/batch_prefill_templ.py#L52C15-L53) does: ``` qo_indptr = qo_indptr.to(torch::kCPU); kv_indptr = kv_indptr.to(torch::kCPU); ``` which is a blocking device synchronization for the CPU worker. We would like to avoid this for certain optimizations. Accordingly, this PR schedules the device transfer ahead of time in the python code before the kernel to avoid blocking the CPU worker. --- python/flashinfer/decode.py | 3 +++ python/flashinfer/prefill.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 31694030..131f1cf2 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -568,6 +568,9 @@ def plan( if self.use_tensor_cores: self._qo_indptr_buf = qo_indptr.to(self.device) + qo_indptr = qo_indptr.to('cpu', non_blocking=True) + indptr = indptr.to('cpu', non_blocking=True) + data_type = canonicalize_torch_dtype(data_type) if not q_data_type: q_data_type = data_type diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index e9d49cd2..474e5927 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -740,6 +740,9 @@ def plan( self._custom_mask_buf = packed_custom_mask.to(self.device) self._qk_indptr_buf = qk_indptr.to(self.device) + qo_indptr = qo_indptr.to('cpu', non_blocking=True) + paged_kv_indptr = paged_kv_indptr.to('cpu', non_blocking=True) + if packed_custom_mask is not None: mask_mode = MaskMode.CUSTOM.value else: