Skip to content

Commit

Permalink
fix llama rotary embedding issue (huggingface#1459)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored and BenjaminBossan committed Mar 14, 2024
1 parent 47b63fb commit b68c303
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect

import torch
import torch.nn as nn
Expand Down Expand Up @@ -78,6 +79,28 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)

# For transformers > 4.37.2 `position_ids` became a required arguments in the
# rotary embedding's forward pass. and cos/sin are indexed through the
# `rotary_emb` forward pass.
if "position_ids" in list(inspect.signature(model.rotary_emb.forward).parameters):
if position_ids is None:
kv_seq_len = value_states.shape[-2]
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, model.layer_idx)
kv_seq_len += past_seen_tokens

new_cache_positions = torch.arange(
past_seen_tokens, past_seen_tokens + q_len, device=value_states.device
)
position_ids = new_cache_positions.unsqueeze(0)

cos, sin = model.rotary_emb(value_states, seq_len=kv_seq_len, position_ids=position_ids)

# Here cos and sin are are already indexed correctly, therefore to avoid adding
# boilerplate in `llama_apply_rotary_pos_emb` we just return here the correct query states
# embeddings
return (query_states * cos) + (llama_rotate_half(query_states) * sin)

cos, sin = model.rotary_emb(value_states, seq_len=seq_len)

return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
Expand Down

0 comments on commit b68c303

Please sign in to comment.