-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jamba: update integration tests #32250
Conversation
cc @ydshieh |
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4) | ||
# Depending on the hardware we get different logits / generations | ||
cuda_compute_capability_major_version = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cuda_compute_capability_major_version
pattern is copied from other models like e.g. gemma
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
(thank you for trigger the tests on the runner 🙏 ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for digging into this and fixing, and for writing a detailed PR description ❤️
Agreed it's not worth digging into given jamba usage, and as the generated texts appear similar despite the logic differences
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3) | ||
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3) | ||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist | ||
if self.cuda_compute_capability_major_version == 8: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe better to use
self.skipTest(reason="Skipping for T4 runners because ...")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, merged before seeing this comment!
You have a good point, in fact we should split the test in two to test (/skip) the logits separately
* try test updates * a few more changes * a few more changes * a few more changes * [run slow] jamba * skip logits checks on older gpus * [run slow] jamba * oops * [run slow] jamba * Update tests/models/jamba/test_modeling_jamba.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/jamba/test_modeling_jamba.py Co-authored-by: amyeroberts <[email protected]> --------- Co-authored-by: amyeroberts <[email protected]>
* try test updates * a few more changes * a few more changes * a few more changes * [run slow] jamba * skip logits checks on older gpus * [run slow] jamba * oops * [run slow] jamba * Update tests/models/jamba/test_modeling_jamba.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/jamba/test_modeling_jamba.py Co-authored-by: amyeroberts <[email protected]> --------- Co-authored-by: amyeroberts <[email protected]>
What does this PR do?
🟢 Fixes
generate
-related integration tests for jamba 🟢I've checked them against:
Detective work 🕵️
ai21labs/Jamba-tiny-random
, the generation text quality doesn't matter.