Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Adaptation Prompt] Fix llama rotary embedding issue with transformers main #1459

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect

import torch
import torch.nn as nn
Expand Down Expand Up @@ -78,6 +79,28 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)

# For transformers > 4.37.2 `position_ids` became a required arguments in the
# rotary embedding's forward pass. and cos/sin are indexed through the
# `rotary_emb` forward pass.
if "position_ids" in list(inspect.signature(model.rotary_emb.forward).parameters):
if position_ids is None:
kv_seq_len = value_states.shape[-2]
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, model.layer_idx)
kv_seq_len += past_seen_tokens

new_cache_positions = torch.arange(
past_seen_tokens, past_seen_tokens + q_len, device=value_states.device
)
position_ids = new_cache_positions.unsqueeze(0)

cos, sin = model.rotary_emb(value_states, seq_len=kv_seq_len, position_ids=position_ids)

# Here cos and sin are are already indexed correctly, therefore to avoid adding
# boilerplate in `llama_apply_rotary_pos_emb` we just return here the correct query states
# embeddings
return (query_states * cos) + (llama_rotate_half(query_states) * sin)

cos, sin = model.rotary_emb(value_states, seq_len=seq_len)

return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
Expand Down
Loading