Skip to content

Commit

Permalink
Avoid check expected exception when it is on CUDA (#34408)
Browse files Browse the repository at this point in the history
* update

* update

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Oct 25, 2024
1 parent e447185 commit f73f5e6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
5 changes: 3 additions & 2 deletions tests/pipelines/test_pipelines_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def run_pipeline_test(self, summarizer, _):
and len(summarizer.model.trainable_weights) > 0
and "GPU" in summarizer.model.trainable_weights[0].device
):
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
if str(summarizer.device) == "cpu":
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)

@require_torch
Expand Down
18 changes: 10 additions & 8 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,17 +493,19 @@ def run_pipeline_test(self, text_generator, _):
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
):
# Handling of large generations
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
if str(text_generator.device) == "cpu":
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)

outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
# Hole strategy cannot work
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)
if str(text_generator.device) == "cpu":
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)

@require_torch
@require_accelerate
Expand Down

0 comments on commit f73f5e6

Please sign in to comment.