Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba: add generative tests #31478

Merged
merged 4 commits into from
Jun 19, 2024
Merged

Mamba: add generative tests #31478

merged 4 commits into from
Jun 19, 2024

Conversation

gante
Copy link
Member

@gante gante commented Jun 18, 2024

Supercedes #31094

Fixes #30828 🤗

This PR:

  1. Enables the generative tests to mamba
  2. Removes redundant/needless manual skips through the has_attentions tester flag (TIL about this flag)
  3. To enable (1.), adds a new flag to PreTrainedModel -- _is_stateful (self explanatory) [note: AFAIK we only have 3 stateful models, mamba, jamba, and rwkv, lmk if there are more!]
  4. Uses _is_stateful to set a few appropriate exceptions and test skips

Note: with has_attentions and _is_stateful, I suspect I can remove a few more skips in a follow up PR :)

@gante gante marked this pull request as ready for review June 18, 2024 18:02
@@ -250,6 +250,8 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else ()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all tests passing for mamba :)

Comment on lines -295 to -297
@unittest.skip("No attention in mamba")
def test_retain_grad_hidden_states_attentions(self):
pass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test was actually passing, just needed the has_attentions flag

Comment on lines -367 to -373
@unittest.skip("Mamba does not use attention")
def test_attention_outputs(self):
r"""
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
it has a shape `batch_size, seq_len, hidden_size`.
"""
pass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_attentions skips this test

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful - thanks for adding. bonus points for skipping properly with skip messages ❤️

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding! Good to see some more usage of GenerationTesterMixin for non-standard models

@gante gante merged commit 83259e4 into huggingface:main Jun 19, 2024
23 checks passed
@gante gante deleted the mamba_gen_tests branch June 19, 2024 09:27
itazap pushed a commit that referenced this pull request Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unable to run generation tests for Mamba & Jamba models
4 participants