Skip to content

Commit

Permalink
feat: torch custom_op fix for rope (#569)
Browse files Browse the repository at this point in the history
Fix after changes made in #568 

torch.compile doesn't like returning input arguments. So, change the
return type of pybind fns to `void`, given that it's already an inplace
op.

PyTorch Library annotation is applied to a wrapper function for each
pybind op. Python API doesn't change. Both inplace and non-inplace
versions calls the annotated wrapper function.
  • Loading branch information
abcdabcd987 authored Oct 30, 2024
1 parent 4f40420 commit 3e104bc
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 147 deletions.
40 changes: 17 additions & 23 deletions flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,23 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);

std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);

void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);

Expand Down
34 changes: 14 additions & 20 deletions python/csrc/flashinfer_rope_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,23 @@

#include <vector>

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);
void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);
void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_rope", &apply_rope, "Apply RoPE");
Expand Down
42 changes: 14 additions & 28 deletions python/csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

using namespace flashinfer;

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
Expand Down Expand Up @@ -65,14 +64,11 @@ std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::T
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}

std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta) {
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(pos_ids);
Expand Down Expand Up @@ -109,16 +105,12 @@ std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length) {
void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
Expand Down Expand Up @@ -162,16 +154,12 @@ std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}

std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length) {
void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(pos_ids);
Expand Down Expand Up @@ -209,6 +197,4 @@ std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Te
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}
Loading

0 comments on commit 3e104bc

Please sign in to comment.