Skip to content

Commit

Permalink
Fix falcon with SDPA, alibi but no passed mask (#30123)
Browse files Browse the repository at this point in the history
* fix falcon without attention_mask & alibi

* add test

* Update tests/models/falcon/test_modeling_falcon.py

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
fxmarty and amyeroberts authored Apr 8, 2024
1 parent 1773afc commit 1897874
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 15 deletions.
26 changes: 11 additions & 15 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,27 +1098,23 @@ def forward(
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])

attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
min_dtype,
)

# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
min_dtype,
)

# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
Expand Down
24 changes: 24 additions & 0 deletions tests/models/falcon/test_modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,27 @@ def test_batched_generation(self):
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
self.assertEqual(unpadded_gen_text[0], expected_output)
self.assertEqual(padded_gen_text[0], expected_output)

@slow
@require_torch_sdpa
def test_falcon_alibi_sdpa_matches_eager(self):
input_ids = torch.randint(0, 1000, (5, 20))

config = FalconConfig(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=3,
num_attention_heads=4,
new_decoder_architecture=True,
alibi=True,
)

falcon = FalconForCausalLM(config)
falcon = falcon.eval()

with torch.no_grad():
# output_attentions=True dispatches to eager path
falcon_output_eager = falcon(input_ids, output_attentions=True)[0]
falcon_output_sdpa = falcon(input_ids)[0]

self.assertTrue(torch.allclose(falcon_output_eager, falcon_output_sdpa, atol=1e-3))

0 comments on commit 1897874

Please sign in to comment.