Skip to content

Commit

Permalink
fix dsk diff
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Feb 13, 2025
1 parent 5a1c4ac commit d3e48c8
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 23 deletions.
26 changes: 16 additions & 10 deletions csrc/gpu/step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
Expand All @@ -43,6 +44,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
int *block_table_now = block_tables + tid * block_num_per_seq;
if (stop_flags[tid] && !is_block_step[tid]) {
// 回收block块
first_token_ids[tid] = -1;
const int encoder_block_len = encoder_block_lens[tid];
const int decoder_used_len = used_list_len[tid];
if (decoder_used_len > 0) {
Expand Down Expand Up @@ -166,11 +168,11 @@ __global__ void recover_block(int *recover_block_list, // [bsz]
int *encoder_block_lens,
int *used_list_len,
const int64_t *next_tokens,
const int64_t *first_token_ids,
const int bsz,
const int block_num_per_seq,
const int length,
const int pre_id_length,
const int first_token_id) {
const int pre_id_length) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
__shared__ int ori_free_list_len;
Expand All @@ -189,7 +191,8 @@ __global__ void recover_block(int *recover_block_list, // [bsz]
seq_lens_encoder[recover_id] = seq_len;
stop_flags[recover_id] = false;
input_ids_now[ori_seq_len_encoder + step_idx_now - 1] = next_tokens[recover_id]; // next tokens
input_ids_now[0] = first_token_id; // set first prompt token
input_ids_now[0] =
first_token_ids[recover_id]; // set first prompt token
const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len);
ori_free_list_len = ori_free_list_len_tid0;
#ifdef DEBUG_STEP
Expand Down Expand Up @@ -234,9 +237,9 @@ void StepPaddle(const paddle::Tensor& stop_flags,
const paddle::Tensor& pre_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& next_tokens,
const paddle::Tensor &first_token_ids,
const int block_size,
const int encoder_decoder_block_num,
const int64_t first_token_id,
const int speculate_step_token_num) {
auto cu_stream = seq_lens_this_time.stream();
const int bsz = seq_lens_this_time.shape()[0];
Expand Down Expand Up @@ -264,6 +267,7 @@ void StepPaddle(const paddle::Tensor& stop_flags,
const_cast<int*>(used_list_len.data<int>()),
const_cast<int*>(free_list.data<int>()),
const_cast<int*>(free_list_len.data<int>()),
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
bsz,
block_size,
block_num_per_seq,
Expand Down Expand Up @@ -300,11 +304,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
const_cast<int*>(encoder_block_lens.data<int>()),
const_cast<int*>(used_list_len.data<int>()),
next_tokens.data<int64_t>(),
first_token_ids.data<int64_t>(),
bsz,
block_num_per_seq,
length,
pre_id_length,
first_token_id
pre_id_length
);
#ifdef DEBUG_STEP
#ifdef PADDLE_WITH_HIP
Expand Down Expand Up @@ -337,10 +341,10 @@ PD_BUILD_OP(step_paddle)
"input_ids",
"pre_ids",
"step_idx",
"next_tokens"})
"next_tokens",
"first_token_ids",})
.Attrs({"block_size: int",
"encoder_decoder_block_num: int",
"first_token_id: int64_t",
"speculate_step_token_num: int"})
.Outputs({"stop_flags_out",
"seq_lens_this_time_out",
Expand All @@ -358,7 +362,8 @@ PD_BUILD_OP(step_paddle)
"used_list_len_out",
"free_list_out",
"free_list_len_out",
"input_ids_out"})
"input_ids_out",
"first_token_ids_out",})
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
Expand All @@ -375,5 +380,6 @@ PD_BUILD_OP(step_paddle)
{"used_list_len", "used_list_len_out"},
{"free_list", "free_list_out"},
{"free_list_len", "free_list_len_out"},
{"input_ids", "input_ids_out"}})
{"input_ids", "input_ids_out"},
{"first_token_ids", "first_token_ids_out"}})
.SetKernelFn(PD_KERNEL(StepPaddle));
40 changes: 28 additions & 12 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ def __init__(
* attn_factor
)

cache = self._compute_cos_sin_cache()
cos_cache, sin_cache = self._compute_cos_sin_cache()

self.cos_sin_cache: paddle.Tensor
self.register_buffer("cos_sin_cache", cache, persistable=True)
self.cos_cache: paddle.Tensor
self.register_buffer("cos_cache", cos_cache, persistable=True)
self.sin_cache: paddle.Tensor
self.register_buffer("sin_cache", sin_cache, persistable=True)

def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
pos_freqs = self.base ** (paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim)
Expand All @@ -114,23 +116,37 @@ def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
def _compute_cos_sin_cache(self) -> paddle.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32)
freqs = paddle.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = paddle.concat((cos, sin), axis=-1)
return cache.cast(self._dtype)

freqs = paddle.outer(t, inv_freq)
emb = paddle.concat((freqs, freqs), axis=-1)
cos = emb.cos() * self.mscale
sin = emb.sin() * self.mscale

return cos.cast(self._dtype), sin.cast(self._dtype)

def forward(
self,
position_ids: paddle.Tensor,
query: paddle.Tensor,
key: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
from paddlenlp_ops import fused_rotary_position_encoding
cos = self.cos_cache[position_ids].unsqueeze(1)
sin = self.sin_cache[position_ids].unsqueeze(1)

def rotate_half(x):
"""Rotates half the hidden axiss of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x

s, h, d = query.shape
query = query.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])

s, h, d = key.shape
key = key.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])

# In-place operations that update the query and key tensors.
os.environ["stride_in_no_check_dy2st_diff"] = "1"
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)
query = (query * cos) + (rotate_half(query) * sin)
key = (key * cos) + (rotate_half(key) * sin)

return query, key

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ def compute_fused_moe(self, tmp_out, i):
def get_moe_scores(
gating_output: paddle.Tensor,
config: MoeConfig,
) -> (paddle.Tensor, paddle.Tensor):
) -> tuple[paddle.Tensor, paddle.Tensor]:

num_token = gating_output.shape[0]
num_expert_group = config.num_expert_group
Expand Down

0 comments on commit d3e48c8

Please sign in to comment.