diff --git a/src/peft/tuners/adaption_prompt/utils.py b/src/peft/tuners/adaption_prompt/utils.py index aadde6680c..d7a7eec5a8 100644 --- a/src/peft/tuners/adaption_prompt/utils.py +++ b/src/peft/tuners/adaption_prompt/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import torch import torch.nn as nn @@ -78,6 +79,28 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor: else: # since transformers 4.36, this is a DynamicCache instance seq_len += past_key_value.get_seq_length(model.layer_idx) + + # For transformers > 4.37.2 `position_ids` became a required arguments in the + # rotary embedding's forward pass. and cos/sin are indexed through the + # `rotary_emb` forward pass. + if "position_ids" in list(inspect.signature(model.rotary_emb.forward).parameters): + if position_ids is None: + kv_seq_len = value_states.shape[-2] + past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, model.layer_idx) + kv_seq_len += past_seen_tokens + + new_cache_positions = torch.arange( + past_seen_tokens, past_seen_tokens + q_len, device=value_states.device + ) + position_ids = new_cache_positions.unsqueeze(0) + + cos, sin = model.rotary_emb(value_states, seq_len=kv_seq_len, position_ids=position_ids) + + # Here cos and sin are are already indexed correctly, therefore to avoid adding + # boilerplate in `llama_apply_rotary_pos_emb` we just return here the correct query states + # embeddings + return (query_states * cos) + (llama_rotate_half(query_states) * sin) + cos, sin = model.rotary_emb(value_states, seq_len=seq_len) return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)