diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index ef4c38e2a30..391d8e8ce1f 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -17,6 +17,7 @@ from __future__ import annotations import copy +import gc import glob import inspect import math @@ -709,6 +710,11 @@ def test_compute_mask_indices_overlap(self): @require_tf @slow class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + def _load_datasamples(self, num_samples): ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index cf41dd9a301..87206a4b9b8 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -14,6 +14,7 @@ # limitations under the License. """ Testing suite for the PyTorch Wav2Vec2 model. """ +import gc import math import multiprocessing import os @@ -1374,6 +1375,12 @@ def test_sample_negatives_with_mask(self): @require_soundfile @slow class Wav2Vec2ModelIntegrationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + torch.cuda.empty_cache() + def _load_datasamples(self, num_samples): ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech