diff --git a/docs/source/en/testing.md b/docs/source/en/testing.md index 0179d5c6353..ec11e39be09 100644 --- a/docs/source/en/testing.md +++ b/docs/source/en/testing.md @@ -511,15 +511,20 @@ from transformers.testing_utils import get_gpu_count n_gpu = get_gpu_count() # works with torch and tf ``` -### Testing with a specific PyTorch backend +### Testing with a specific PyTorch backend or device -To run the test suite on a specific torch backend add `TRANSFORMERS_TEST_DEVICE="$device"` where `$device` is the target backend. For example, to test on CPU only: +To run the test suite on a specific torch device add `TRANSFORMERS_TEST_DEVICE="$device"` where `$device` is the target backend. For example, to test on CPU only: ```bash TRANSFORMERS_TEST_DEVICE="cpu" pytest tests/test_logging.py ``` This variable is useful for testing custom or less common PyTorch backends such as `mps`. It can also be used to achieve the same effect as `CUDA_VISIBLE_DEVICES` by targeting specific GPUs or testing in CPU-only mode. +Certain devices will require an additional import after importing `torch` for the first time. This can be specified using the environment variable `TRANSFORMERS_TEST_BACKEND`: +```bash +TRANSFORMERS_TEST_BACKEND="torch_npu" pytest tests/test_logging.py +``` + ### Distributed training diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 18d5880a172..2d8d7f64ed2 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -16,6 +16,7 @@ import contextlib import doctest import functools +import importlib import inspect import logging import multiprocessing @@ -629,6 +630,17 @@ def require_torch_multi_npu(test_case): torch_device = "npu" else: torch_device = "cpu" + + if "TRANSFORMERS_TEST_BACKEND" in os.environ: + backend = os.environ["TRANSFORMERS_TEST_BACKEND"] + try: + _ = importlib.import_module(backend) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its" + f" traceback):\n{e}" + ) from e + else: torch_device = None