Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests: move generate tests to the right mixin and delete redundant tests #34464

Merged
merged 17 commits into from
Oct 30, 2024

Conversation

gante
Copy link
Member

@gante gante commented Oct 28, 2024

Moves generate tests incorrectly placed in the general mixin to GenerationTesterMixin. In the process, removes redundant tests and streamlines repeated logic 👀

⚠️ reviewer: start by reviewing test_modeling_utils.py, the last file in the diff. In it, I explain what happened to each test and why -- most were actually deleted, as they were redundant.

After this PR:
✅ fewer redundant tests and less code duplication
✅ fewer flaky tests
✅ faster tests
✅ more test coverage

In a follow-up PR:
👉 Fix failing upgraded generate+FA2 test
👉 Fix failing added generate+torch.compile test

Closes #32913

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -378,10 +378,14 @@ def prepare_inputs_for_generation(
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(changes in this file were done for end-to-end compilation)

Comment on lines 140 to 146
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
"""
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
the following siturations:
1. The sequences are the same
2. The sequences are different, but the scores up until the first mismatch are nearly identical
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many generate tests wanted to check equivalence of something, and were incorrectly tagged as flaky. This function contains the correct equivalence check -- as a result, several @is_flaky were removed :D

(I honestly don't know why I haven't implemented this before, instead of adding @is_flaky 🙃 )

@@ -453,13 +481,12 @@ def test_greedy_generate_dict_outputs(self):
# Retrocompatibility check
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)

self._check_outputs(output_generate, main_input, model.config)
self._check_outputs(output_generate, model.config)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually don't need main_input, we can obtain the batch size directly from the many outputs of generate

Comment on lines 1570 to 1571
def test_generate_from_inputs_embeds(self):
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the deleted inputs_embeds tests -- clarifies that this test also runs for VLMs

Comment on lines +1850 to +1851
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
has the expected shapes
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merges this existing test, which checked the static cache returned by generate, with the deleted static cache equivalence test.

One generate call is enough to run both checks :)

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 28, 2024

For tests where is_flaky is removed, you ran it with --flake-finder --flake-runs 500 right? (at least for some models ?)

("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
]
)
def test_generate_compile(self, name, end_to_end):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a test for torch.compile against large checkpoints, let's do it against dummy checkpoints -- the importance of the test is to confirm a) we can compile AND b) compiled is equivalent to non-compiled.

Testing model.forward compilation is very similar to testing model.generate compilation, so parameterized is used. A few modifications were added to make the test more efficient (e.g. compile only once)

Note: this test doesn't have to be slow, but we currently have a few failures. Let's fix them in a follow-up PR

@@ -2048,14 +2063,21 @@ def test_inherits_generation_mixin(self):
for model_class in self.all_generative_model_classes:
self.assertTrue("GenerationMixin" in str(model_class.__bases__))

@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test to check that SDPA = eager is the same as the test to check that FA2 = eager. However, parameterized doesn't work here: we want different pytest.mark decorators for each test. As such, the original test was moved to a helper function.

@@ -3000,71 +2995,6 @@ def test_inputs_embeds_matches_input_ids(self):
)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))

def test_inputs_embeds_matches_input_ids_with_generate(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already had a test that checks that generating from inputs_embeds is equivalent to generating from input_ids -- test_generate_from_inputs_embeds

This test was deleted.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one q: afaik the other test is skipped for encoder-decoder models? Can we enable it just to be sure we cover all models with generation abilities?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thank you for catching that! I'll update test_generate_from_inputs_embeds so as to support VLMs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_generate_from_inputs_embeds updated for VLMs, overwrites deleted. @zucchini-nlp have a look at the updated test -- you might need to add VLMs to the list of added special cases in the test

@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_generate_left_padding(self):
Copy link
Member Author

@gante gante Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have:

  • a test that checks that we can generate with left-padding (test_left_padding_compatibility)
  • a test that checks that we can generate with FA2 (test_eager_matches_fa2_generate, added in this PR)
  • a test that checks that we can run the forward pass in FA2 with left-padding (test_flash_attn_2_inference_equivalence)

As such, this test was redundant. From generate's point of view, FA2 is a modification to model.forward, so if model.forward is equivalent then so is model.generate

Deleted.

@mark.flash_attn_test
@is_flaky()
@slow
def test_flash_attn_2_generate_padding_right(self):
Copy link
Member Author

@gante gante Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar as above (we have individual checks in place). Moreover, generating with right-padding doesn't make sense on most models :D

Deleted.

@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was pretty toothless -- it didn't even check the actual output of generate.

I've added test_eager_matches_fa2_generate in its place, which checks the outputs. It's a clone of the SDPA generate equivalence test, except that it uses FA2.

@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_reuse_cache(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the FA2 + padding tests above.

In this case, we check that we can reuse a cache in test_generate_continue_from_past_key_values, no need for a FA2-specific test.

Deleted.

@@ -4999,82 +4718,6 @@ def test_custom_4d_attention_mask(self):
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)

def test_static_cache_matches_dynamic(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged the checks in this test [check that the outputs are the same] with the checks in test_generate_with_static_cache [check that the shapes are as expected]

@slow
@require_torch_accelerator
@require_read_token
def test_torch_compile(self):
Copy link
Member Author

@gante gante Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was running against very large checkpoints, and we had to manually specify one in _torch_compile_test_ckpt 👀

We were indeed missing a mixin test for generate with model.forward compiled, see the new test_generate_compile. The most important part of this test is to confirm that compiled == not compiled, we have checks for other parts of the problem (including checks to see whether large checkpoints are working as expected)

This heavy test was, in practice, deleted

@slow
@require_torch_gpu # Testing cuda graphs.
@require_read_token
def test_compile_cuda_graph_time(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above (uses heavy checkpoints)

We now have a benchmarks page, much better to track performance than a test.

@gante gante requested review from ydshieh and zucchini-nlp October 28, 2024 15:22
@gante gante changed the title Move generation tests Tests: move generate tests to the right mixin and delete redundant tests Oct 28, 2024
@gante
Copy link
Member Author

gante commented Oct 28, 2024

For tests where is_flaky is removed, you ran it with --flake-finder --flake-runs 500 right? (at least for some models ?)

@ydshieh yes :D The pattern was actually added in a recent PR: #34386

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great clean up, good to see redundant code killed! Left a few questions and also I think we had test_embeds_match_ids_generate overwritten in all VLMs so prob those can be removed now

Overall LGTM, thanks

EDIT: forgot to mention that a few models might not be running the new tests because they don't have GenerationTesterMixin. Working on adding that on a few VLMs currently and noticed that tests from here might fail for them. Guess it's okay, i'll skip/fix the failing ones after enabling generation tests

@@ -3000,71 +2995,6 @@ def test_inputs_embeds_matches_input_ids(self):
)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))

def test_inputs_embeds_matches_input_ids_with_generate(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one q: afaik the other test is skipped for encoder-decoder models? Can we enable it just to be sure we cover all models with generation abilities?

Comment on lines -4575 to -4585
# Generate with one batch only to test generation when attention mask will be None
# when real inputs are used, because there is no padding. See issue #32237 for more
dummy_input = dummy_input[:1, ...]
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was testing that FA2 with packing works in generate, and unfortunately we don't have equivalent forward test. We could extend the test_fa2_position_ids test with this specific case

Copy link
Member Author

@gante gante Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I'll make sure there is one packing FA2 test (testing the forward pass, and not generate -- generate does not support packing)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp actually test_flash_attn_2_fp32_ln covers the case without attention mask, no need to add more test cases

@pytest.mark.generate
@require_torch_gpu
@slow
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
def test_generate_compile_fullgraph(self):
def test_generate_compile(self, _, end_to_end):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RUN_SLOW=1 python3 -m pytest -v tests/models/llama/test_modeling_llama.py -k "test_generate_compile"

gives me

ling_llama.py::LlamaModelTest::test_generate_compile_1_end_to_end - RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 38, ...

Copy link
Member Author

@gante gante Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh have a look at the TODO in the @parameterized of this test -- both parameterizations have broken cases, whose cause comes from outside this PR.

This PR is already quite big, I'd rather fix the root causes of both parameterizations in a follow-up PR 🤗 This PR moves the tests to the right place, and doesn't touch modeling code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok!

@ydshieh ydshieh mentioned this pull request Oct 29, 2024
@gante gante requested a review from ArthurZucker October 29, 2024 14:32
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long due! Thanks everyone for reviewing and @gante for the cleanup!

@gante gante merged commit 8a734ea into huggingface:main Oct 30, 2024
26 checks passed
@gante gante deleted the move_generation_tests branch October 30, 2024 10:59
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…tests (huggingface#34464)

* tmp commit

* tmp commit

* cull overwrites of deleted tests

* typo

* more specific docstring

* make fixup

* parameterize at the top?

* correction

* more deletions :D

* tmp commit

* for VLMs too

* fix _check_outputs

* test nit

* make fixup

* fix another flaky

* test_generate_from_inputs_embeds -- handle missing attention mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

More precise inputs_embeds input logic and tests
5 participants