Skip to content

Commit

Permalink
[TTA Pipeline] Test MusicGen and VITS (huggingface#26146)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi authored and parambharat committed Sep 26, 2023
1 parent 1fedeac commit 322bda5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def prepare_config_and_inputs_for_common(self):
class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {}
pipeline_model_mapping = {"text-to-audio": MusicgenForConditionalGeneration} if is_torch_available() else {}
test_pruning = False # training is not supported yet for MusicGen
test_headmasking = False
test_resize_embeddings = False
Expand Down
4 changes: 3 additions & 1 deletion tests/models/vits/test_modeling_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ids_tensor,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin


if is_torch_available():
Expand Down Expand Up @@ -153,8 +154,9 @@ def create_and_check_model_forward(self, config, inputs_dict):


@require_torch
class VitsModelTest(ModelTesterMixin, unittest.TestCase):
class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (VitsModel,) if is_torch_available() else ()
pipeline_model_mapping = {"text-to-audio": VitsModel} if is_torch_available() else {}
is_encoder_decoder = False
test_pruning = False
test_headmasking = False
Expand Down

0 comments on commit 322bda5

Please sign in to comment.