Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention 2 support to Musicgen and Musicgen Melody #29939

Merged
merged 16 commits into from
Apr 2, 2024
Merged
Prev Previous commit
Next Next commit
add copied form in sdpa tests melody
  • Loading branch information
ylacombe committed Apr 2, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 1146c92882e09da01fef2ea377379d97240e3958
19 changes: 12 additions & 7 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
@@ -1862,11 +1862,8 @@ def test_flash_attn_2_generate_use_cache(self):
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
"""
Model requires `decoder_input_ids` but `is_encoder_decoder=False`. Additionnally, decoder's input batch size must be
`num_codebooks*batch_size`.
"""
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")

@@ -1983,7 +1980,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
# Ignore copy
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
# Ignore copy
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
@@ -2009,11 +2008,15 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size_input_ids
]
# Ignore copy
if decoder_input_ids.shape[0] != batch_size_input_ids:
# Ignore copy
extension = torch.ones(
batch_size_input_ids - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
@@ -2024,13 +2027,15 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
decoder_input_ids = decoder_input_ids.to(torch_device)

# TODO: never an `attention_mask` arg here?
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}

# TODO: test gradients as well (& for FA2 as well!)
# Ignore copy
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
@@ -2109,14 +2114,14 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))


@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_generate
def test_eager_matches_sdpa_generate(self):
"""
Overwrite generative model classes with greedy sample model classes.
"""
max_new_tokens = 30

# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
18 changes: 11 additions & 7 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
@@ -1818,11 +1818,8 @@ def test_flash_attn_2_generate_use_cache(self):
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
"""
Model requires `decoder_input_ids` but `is_encoder_decoder=False`. Additionnally, decoder's input batch size must be
`num_codebooks*batch_size`.
"""
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")

@@ -1939,7 +1936,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
# Ignore copy
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
# Ignore copy
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
@@ -1965,11 +1964,15 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size_input_ids
]
# Ignore copy
if decoder_input_ids.shape[0] != batch_size_input_ids:
# Ignore copy
extension = torch.ones(
batch_size_input_ids - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
@@ -1980,13 +1983,15 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
decoder_input_ids = decoder_input_ids.to(torch_device)

# TODO: never an `attention_mask` arg here?
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}

# TODO: test gradients as well (& for FA2 as well!)
# Ignore copy
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=enable_kernels,
@@ -2067,12 +2072,11 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

@require_torch_sdpa
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_generate
def test_eager_matches_sdpa_generate(self):
"""
Overwrite generative model classes with greedy sample model classes.
"""
max_new_tokens = 30

# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")