From 2aa2a14481dda0243522e6dff018aadab9829efa Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 5 Jul 2024 15:09:04 +0900 Subject: [PATCH] Make tensor device correct when ACCELERATE_TORCH_DEVICE is defined (#31751) return correct device when ACCELERATE_TORCH_DEVICE is defined --- src/transformers/training_args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a782f4bf7f92d4..9f305f6ce2ee42 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2194,7 +2194,9 @@ def _setup_devices(self) -> "torch.device": # trigger an error that a device index is missing. Index 0 takes into account the # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device( + "cuda:0" if torch.cuda.is_available() else os.environ.get("ACCELERATE_TORCH_DEVICE", "cpu") + ) # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # the default value. self._n_gpu = torch.cuda.device_count()