Skip to content

Commit

Permalink
Fix test_compile_static_cache (#30991)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Jun 3, 2024
1 parent 70c8713 commit df848ac
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 2 additions & 5 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,11 +729,8 @@ def test_compile_static_cache(self):
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
],
7: [
"Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory "
"goes that nothing travels faster than light, but the faster you go, the slower everything else will "
"be.\nThe theory of relativity",
"My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, "
"and even on a good old fashioned cheeseburger. I love it on everything. I love it so",
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
],
9: [
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial"
Expand Down
5 changes: 5 additions & 0 deletions tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_flaky,
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
require_torch_sdpa,
Expand Down Expand Up @@ -658,12 +659,16 @@ def test_speculative_generation(self):
gc.collect()

@slow
@require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")

if self.cuda_compute_capability_major_version == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")

NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = {
8: [
Expand Down

0 comments on commit df848ac

Please sign in to comment.