Skip to content

Commit

Permalink
misc: add device guard for kernels (#611)
Browse files Browse the repository at this point in the history
## plan
- [x] Check all kernels and add device guard
- [x] Complete the  tests

FIX: #452
  • Loading branch information
jeejeelee authored Nov 15, 2024
1 parent a3360ff commit b53a46f
Show file tree
Hide file tree
Showing 19 changed files with 54 additions and 5 deletions.
1 change: 1 addition & 0 deletions python/csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
auto workspace_buffer = torch::empty(
{32 * 1024 * 1024}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device()));
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
const at::cuda::OptionalCUDAGuard device_guard(A.device());
auto stream = at::cuda::getCurrentCUDAStream();

// PyTorch is row major by default. cuBLASLt is column major by default.
Expand Down
5 changes: 5 additions & 0 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
unsigned int seq_len = v_a.size(0);
unsigned int num_heads = v_a.size(1);
unsigned int head_dim = v_a.size(2);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto v_merged = torch::empty_like(v_a, v_a.options());
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());
Expand Down Expand Up @@ -91,6 +93,8 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
unsigned int seq_len = v.size(0);
unsigned int num_heads = v.size(1);
unsigned int head_dim = v.size(2);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] {
Expand Down Expand Up @@ -121,6 +125,7 @@ std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
unsigned int num_heads = v.size(2);
unsigned int head_dim = v.size(3);
s = s.to(torch::kFloat32);
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options());
auto s_merged = torch::empty({seq_len, num_heads}, s.options());
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_proble
torch::Tensor empty_x_data, bool weight_column_major) {
unsigned int batch_size = x_ptr.size(0);
auto device = workspace_buffer.device();

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] {
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/group_gemm_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer,
bool weight_column_major) {
unsigned int batch_size = x_ptr.size(0);
auto device = float_workspace_buffer.device();

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_x_data.scalar_type(), c_type, [&] {
Expand Down
4 changes: 4 additions & 0 deletions python/csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight,
CHECK_EQ(output.size(0), batch_size);
CHECK_EQ(output.size(1), hidden_size);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()),
Expand Down Expand Up @@ -61,6 +62,7 @@ void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Ten
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()),
Expand All @@ -86,6 +88,7 @@ void gemma_rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& w
CHECK_EQ(output.size(0), batch_size);
CHECK_EQ(output.size(1), hidden_size);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaRMSNorm(static_cast<c_type*>(input.data_ptr()),
Expand Down Expand Up @@ -115,6 +118,7 @@ void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torc
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaFusedAddRMSNorm(
Expand Down
3 changes: 2 additions & 1 deletion python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
CHECK_EQ(append_key.size(2), head_dim);
CHECK_EQ(append_value.size(1), num_heads);
CHECK_EQ(append_value.size(2), head_dim);


const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

auto kv_scalar_dtype = paged_k_cache.scalar_type();
Expand Down
1 change: 1 addition & 0 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
Expand Down
3 changes: 3 additions & 0 deletions python/csrc/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder) {
auto device = x.device();
TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'");
x = x.to(torch::kBool);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

int64_t num_elements = x.numel();
Expand Down Expand Up @@ -57,6 +59,7 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
int64_t output_nnz = output_indptr[batch_size].item<int64_t>();
auto y = torch::empty({output_nnz}, x.options().dtype(torch::kUInt8));

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaError_t status = quantization::SegmentPackBits(
static_cast<bool*>(x.data_ptr()), static_cast<uint8_t*>(y.data_ptr()),
static_cast<int32_t*>(input_indptr.data_ptr()),
Expand Down
7 changes: 6 additions & 1 deletion python/csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T
size_t k_rope_stride_h = k_rope.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);


const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotary(
Expand Down Expand Up @@ -93,6 +94,7 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
size_t k_rope_stride_h = k_rope.stride(1);
pos_ids = pos_ids.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryPosIds(
Expand Down Expand Up @@ -145,6 +147,7 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T
size_t k_rope_stride_h = k_rope.stride(1);
pos_ids = pos_ids.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
Expand Down Expand Up @@ -195,6 +198,7 @@ void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyLlama31Rotary(
Expand Down Expand Up @@ -240,6 +244,7 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor
size_t k_rope_stride_h = k_rope.stride(1);
pos_ids = pos_ids.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyLlama31RotaryPosIds(
Expand Down
11 changes: 10 additions & 1 deletion python/csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));

Expand Down Expand Up @@ -71,6 +72,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
uniform_samples = uniform_samples.to(torch::kFloat32);
top_p_arr = top_p_arr.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
Expand Down Expand Up @@ -112,6 +114,7 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
uniform_samples = uniform_samples.to(torch::kFloat32);
top_k_arr = top_k_arr.to(torch::kInt32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
Expand Down Expand Up @@ -153,6 +156,7 @@ std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
Expand Down Expand Up @@ -203,6 +207,7 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
Expand Down Expand Up @@ -236,7 +241,8 @@ torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
}
top_p_arr = top_p_arr.to(torch::kFloat32);
probs = probs.to(torch::kFloat32);


const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto renorm_probs =
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
Expand Down Expand Up @@ -268,6 +274,7 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
top_k_arr = top_k_arr.to(torch::kInt32);
probs = probs.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto renorm_probs =
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
Expand Down Expand Up @@ -300,6 +307,7 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
top_k_arr = top_k_arr.to(torch::kInt32);
logits = logits.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto mask_logits =
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
Expand Down Expand Up @@ -348,6 +356,7 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
uniform_samples = uniform_samples.to(torch::kFloat32);
target_probs = target_probs.to(torch::kFloat32);

const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1},
torch::dtype(torch::kInt32).device(device));
Expand Down
2 changes: 2 additions & 0 deletions python/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");

Expand Down Expand Up @@ -112,6 +113,7 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun(
}
uint32_t head_dim = q.size(2);

const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
if (maybe_lse) {
Expand Down
1 change: 1 addition & 0 deletions python/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
Expand Down
1 change: 1 addition & 0 deletions python/csrc_aot/single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
kv_len = k.size(1);
}
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q);

Expand Down
1 change: 1 addition & 0 deletions python/csrc_aot/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ torch::Tensor single_prefill_with_kv_cache(
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down
4 changes: 3 additions & 1 deletion python/flashinfer/jit/batch_decode_mla_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
indptr = indptr.to(torch::kCPU);
Expand Down Expand Up @@ -83,8 +84,9 @@
auto device = q_nope.device();
int64_t batch_size = q_nope.size(0);
int64_t num_qo_heads = q_nope.size(1);
int64_t page_size = paged_ckv_cache.size(1);;
int64_t page_size = paged_ckv_cache.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q_nope);
torch::Tensor lse;
Expand Down
4 changes: 3 additions & 1 deletion python/flashinfer/jit/batch_decode_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");
Expand Down Expand Up @@ -93,7 +94,8 @@
page_size = paged_k_cache.size(1);
num_kv_heads = paged_k_cache.size(2);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
if (maybe_lse) {
Expand Down
3 changes: 3 additions & 0 deletions python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
Expand Down Expand Up @@ -92,6 +93,7 @@
}
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down Expand Up @@ -187,6 +189,7 @@
num_kv_heads = paged_k_cache.size(2);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down
2 changes: 2 additions & 0 deletions python/flashinfer/jit/single_decode_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
num_kv_heads = k.size(0);
kv_len = k.size(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q);
Expand Down Expand Up @@ -157,6 +158,7 @@
num_kv_heads = k.size(0);
kv_len = k.size(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q);
Expand Down
2 changes: 2 additions & 0 deletions python/flashinfer/jit/single_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down Expand Up @@ -177,6 +178,7 @@
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down

0 comments on commit b53a46f

Please sign in to comment.