Skip to content

Commit

Permalink
FIX Prefix tuning Grouped-Query Attention (#1901)
Browse files Browse the repository at this point in the history
Fix prefix tuning when GQA is being used.
  • Loading branch information
ttw1018 authored Jul 22, 2024
1 parent e02b938 commit 6472061
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ def _prepare_prompt_learning_config(peft_config, model_config):
raise ValueError("Please specify `num_attention_heads` in `peft_config`")
peft_config.num_attention_heads = num_attention_heads

# For grouped-query attention, see #1901.
if peft_config.peft_type == "PREFIX_TUNING" and "num_key_value_heads" in model_config:
num_key_value_heads = model_config["num_key_value_heads"]
peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads
peft_config.num_attention_heads = num_key_value_heads

if getattr(peft_config, "encoder_hidden_size", None) is None:
setattr(peft_config, "encoder_hidden_size", peft_config.token_dim)

Expand Down
21 changes: 20 additions & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import AdaLoraConfig, BOFTConfig, HRAConfig, LoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model
from peft import (
AdaLoraConfig,
BOFTConfig,
HRAConfig,
LoraConfig,
PrefixTuningConfig,
PromptTuningConfig,
PromptTuningInit,
get_peft_model,
)

from .testing_common import PeftCommonTester, PeftTestConfigManager

Expand Down Expand Up @@ -428,3 +437,13 @@ def test_lora_layer_replication(self):
), "Expected 8 LoRA adapters since we are adding one each for up and down."
self._test_prepare_for_training(model_id, LoraConfig, config_kwargs)
self._test_generate(model_id, LoraConfig, config_kwargs)

def test_prompt_learning_with_grouped_query_attention(self):
# See 1901, fixes a bug with handling GQA
model_id = "peft-internal-testing/tiny-dummy-qwen2"
base_model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
model = get_peft_model(base_model, peft_config)
x = torch.tensor([[1, 2, 3]])
# does not raise
model(x)

0 comments on commit 6472061

Please sign in to comment.