Skip to content

Commit

Permalink
[Gemma] Fix eager attention (#29187)
Browse files Browse the repository at this point in the history
* fix modelling code

* add tests

* fix tests

* add some logit tests

* style

* fix fix
  • Loading branch information
sanchit-gandhi authored Feb 22, 2024
1 parent fc37f38 commit 2a9b1f8
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def forward(

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)

if not output_attentions:
Expand Down
129 changes: 129 additions & 0 deletions tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
Expand Down Expand Up @@ -460,6 +461,71 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_padding_right(self):
self.skipTest("Gemma flash attention does not support right padding")

@require_torch_sdpa
@require_torch_gpu
@slow
def test_sdpa_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
return

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa"
)
model_sdpa.to(torch_device)

model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)

dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)

logits = outputs.hidden_states[-1]
logits_sdpa = outputs_sdpa.hidden_states[-1]

# gemma sdpa needs a high tolerance
assert torch.allclose(logits_sdpa, logits, atol=3e-3)

@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)

model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)

dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)

logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]

# gemma flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=3e-3)


@require_torch_gpu
@slow
Expand Down Expand Up @@ -542,6 +608,69 @@ def test_model_2b_bf16(self):

self.assertEqual(output_text, EXPECTED_TEXTS)

def test_model_2b_eager(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
)
model.to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)

self.assertEqual(output_text, EXPECTED_TEXTS)

@require_torch_sdpa
def test_model_2b_sdpa(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
]

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
)
model.to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)

self.assertEqual(output_text, EXPECTED_TEXTS)

@pytest.mark.flash_attn_test
@require_flash_attn
def test_model_2b_flash_attn(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)

self.assertEqual(output_text, EXPECTED_TEXTS)

@require_bitsandbytes
def test_model_2b_4bit(self):
model_id = "google/gemma-2b"
Expand Down

0 comments on commit 2a9b1f8

Please sign in to comment.