diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3c7523684a2504..43d5c6faef86ed 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -699,7 +699,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)