Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[maca] support deepseekv2 for maca backend. #133

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dlinfer/ops/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def fused_moe(
topk_weights: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
renormalize: bool = False,
) -> Tensor:
"""
Implement the Fused Mixture of Experts (MoE) model.
Expand All @@ -556,7 +557,13 @@ def fused_moe(

"""
return vendor_ops_registry["fused_moe"](
hidden_states, top_k, topk_ids, topk_weights, gate_up_weights, down_weights
hidden_states,
top_k,
topk_ids,
topk_weights,
gate_up_weights,
down_weights,
renormalize,
)


Expand Down
2 changes: 2 additions & 0 deletions dlinfer/vendor/maca/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ find_library(torch_python_LIBRARY torch_python PATHS

set(DLINFER_VLLM_SRC
"pybind.cpp"
"cache_kernels.cu"
"attention/attention_kernels.cu"
"pos_encoding_kernels.cu"
"moe_align_block_size_kernels.cu"
"moe/topk_softmax_kernels.cu"
Expand Down
3 changes: 3 additions & 0 deletions dlinfer/vendor/maca/csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,9 @@ void paged_attention_v1_launcher(
case 256:
LAUNCH_PAGED_ATTENTION_V1_32N(256);
break;
case 576:
LAUNCH_PAGED_ATTENTION_V1_32N(576);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
Expand Down
24 changes: 16 additions & 8 deletions dlinfer/vendor/maca/csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ __global__ void reshape_and_cache_kernel_layout(
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x,
const float kv_scale) {
const float kv_scale, bool is_deepseek) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
Expand Down Expand Up @@ -250,12 +250,16 @@ __global__ void reshape_and_cache_kernel_layout(
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
if (!is_deepseek) {
value_cache[tgt_value_idx] = tgt_value;
}
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
if (!is_deepseek) {
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
}
}
}
}
Expand All @@ -270,7 +274,7 @@ __global__ void reshape_and_cache_kernel_layout_opt(
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x,
const float kv_scale) {
const float kv_scale, bool is_deepseek) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
Expand Down Expand Up @@ -303,7 +307,9 @@ __global__ void reshape_and_cache_kernel_layout_opt(
head_offset;

*(float4*)(key_cache + tgt_key_idx) = *(float4*)(key + src_key_idx);
*(float4*)(value_cache + tgt_value_idx) = *(float4*)(value + src_value_idx);
if (!is_deepseek) {
*(float4*)(value_cache + tgt_value_idx) = *(float4*)(value + src_value_idx);
}
/*
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
Expand Down Expand Up @@ -415,7 +421,7 @@ void reshape_and_cache(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, kv_scale);
num_heads, head_size, block_size, x, kv_scale, is_deepseek);

#define CALL_RESHAPE_AND_CACHE_LAYOUT_OPT(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel_layout_opt<KV_T, CACHE_T, KV_DTYPE> \
Expand All @@ -425,7 +431,7 @@ void reshape_and_cache(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, kv_scale);
num_heads, head_size, block_size, x, kv_scale, is_deepseek);
void reshape_and_cache_new(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
Expand All @@ -439,6 +445,8 @@ void reshape_and_cache_new(
int num_heads = key.size(1);
int head_size = key.size(2);

bool is_deepseek = key.size(-1) != value.size(-1);

int x;
int block_size;
x = key_cache.size(4);
Expand Down
34 changes: 25 additions & 9 deletions dlinfer/vendor/maca/csrc/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
#include <torch/extension.h>

#include "cache.h"
#include "moe/moe_ops.h"
#include "ops.h"

Expand All @@ -18,6 +19,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");

// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def("paged_attention_v1",
&paged_attention_v1,
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");

// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def("rotary_embedding",
Expand All @@ -27,15 +43,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
" Tensor cos, Tensor sin,"
" bool is_neox) -> ()");

// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops.def("batched_rotary_embedding",
&batched_rotary_embedding,
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()");
// Cache ops
ops.def("reshape_and_cache_new",
&reshape_and_cache_new,
"reshape_and_cache_new(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float kv_scale,"
" float v_scale) -> ()");

// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
Expand Down
52 changes: 26 additions & 26 deletions dlinfer/vendor/maca/maca_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,11 @@ def prefill_attention(

# for deepseek v2 lite.
if query.shape[-1] == 576:
batch_size = kv_seq_len.dim()
batch_size = kv_seq_len.size(0)
head_dim = query.shape[-1]
nope_size = value.shape[-1]
groups = num_q_heads // num_q_heads
value = torch.nn.functional.pad(value, [0, head_dim - nope_size], value=0)

input_type = query.dtype
query = query.to(torch.float32)
Expand All @@ -129,7 +130,7 @@ def prefill_attention(
# (bs, seq_len, num_head, head_dim)
query = query.view(batch_size, -1, num_q_heads, head_dim)
key = key.view(batch_size, -1, num_kv_heads, head_dim)
value = value.view(batch_size, -1, num_kv_heads, nope_size)
value = value.view(batch_size, -1, num_kv_heads, head_dim)
key = key.repeat(1, 1, groups, 1)
value = value.repeat(1, 1, groups, 1)

Expand All @@ -147,7 +148,7 @@ def prefill_attention(
attn_output = attn_output.transpose(1, 2).flatten(0, 1)
attn_output = attn_output[..., :nope_size].contiguous()
attn_output = attn_output.to(input_type)
return attn_output[..., :512].contiguous()
return attn_output

# for cogvlm vl part.
if query.size(-2) != num_q_heads:
Expand Down Expand Up @@ -188,7 +189,7 @@ def fill_kv_cache(
quant_bits: int,
) -> Tuple[Tensor, Tensor]:
kv_indices = kv_indices.squeeze(-1)
vllm._custom_ops.reshape_and_cache_new(
maca_ext_ops.reshape_and_cache_new(
key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0
)
return key_cache, value_cache
Expand Down Expand Up @@ -218,9 +219,15 @@ def paged_decode_attention(
softmax_scale = float(1 / math.sqrt(query.size(-1)))

num_kv_heads = value_cache.size(1)
block_size = value_cache.size(2)
block_size = value_cache.size(-2)
output = torch.empty_like(query)
vllm._custom_ops.paged_attention_v1(

# for deepseek v2 lite.
if query.size(-1) == 576:
value_cache = key_cache.transpose(2, 3).reshape(
-1, num_kv_heads, block_size, 576
)
maca_ext_ops.paged_attention_v1(
output,
query,
key_cache,
Expand All @@ -241,8 +248,7 @@ def paged_decode_attention(
1, # blocksparse_block_size
1, # blocksparse_head_sliding_step
)

return output
return output[..., :512]


@register_ops(vendor_ops_registry)
Expand Down Expand Up @@ -301,9 +307,12 @@ def rms_norm(
weight: Tensor,
epsilon: float,
) -> Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
weight = weight.to(torch.float32)
output = torch.empty_like(hidden_states)
vllm._custom_ops.rms_norm(output, hidden_states, weight, epsilon)
return output
return output.to(input_dtype)


@register_ops(vendor_ops_registry)
Expand Down Expand Up @@ -354,25 +363,16 @@ def fused_moe(
topk_weights: torch.Tensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
renormalize: bool = False,
):
N, D = hidden_states.shape
hidden_states = hidden_states.view(N, -1, D).repeat(1, top_k, 1).reshape(-1, D)
out = torch.zeros(
N * top_k,
down_weights.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
N = hidden_states.size(0)
topk_weights = topk_weights.reshape(N, top_k)
topk_ids = topk_ids.reshape(N, top_k)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return vllm.model_executor.layers.fused_moe.fused_experts(
hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids
)
for i in range(gate_up_weights.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = silu_and_mul(
hidden_states[mask] @ gate_up_weights[i].transpose(0, 1)
) @ down_weights[i].transpose(0, 1)
return (
out.view(N, -1, down_weights.shape[1])
* topk_weights.view(N, -1, 1).to(out.dtype)
).sum(dim=1)


@register_ops(vendor_ops_registry)
Expand Down