Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix chatglm2 multi-turn streamchat #8867

Merged
merged 1 commit into from
Sep 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/llm/src/bigdl/llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def chatglm2_attention_forward_8eb45c(
# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
past_length = cache_k.size(2)
past_length = cache_k.size(0)

if past_length + cur_length > self.max_cache_length:
self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
Expand All @@ -159,8 +159,8 @@ def chatglm2_attention_forward_8eb45c(
self.max_cache_length,
self.hidden_size_per_attention_head,
device=device))
self.kv_cache[0][:, :, :past_length, :] = cache_k
self.kv_cache[1][:, :, :past_length, :] = cache_v
self.kv_cache[0][:, :, :past_length, :] = cache_k.permute(1, 2, 0, 3)
self.kv_cache[1][:, :, :past_length, :] = cache_v.permute(1, 2, 0, 3)
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer

Expand Down Expand Up @@ -196,7 +196,7 @@ def chatglm2_attention_forward_8eb45c(

output = self.dense(context_layer)

return output, kv_cache
return output, (key_layer.permute(2, 0, 1, 3), value_layer.permute(2, 0, 1, 3))


def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
Expand Down Expand Up @@ -228,6 +228,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
else:
if attention_mask is not None:
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )

if torch.is_autocast_cpu_enabled():
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
Expand Down