From b1c1c93e9954d6d20056a4f82e13914d92f33673 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 2 Jul 2024 09:29:22 -0500 Subject: [PATCH] return correct device when ACCELERATE_TORCH_DEVICE is defined --- src/transformers/training_args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5eff032774e203..c39c893470acbb 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2170,7 +2170,9 @@ def _setup_devices(self) -> "torch.device": # trigger an error that a device index is missing. Index 0 takes into account the # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device( + "cuda:0" if torch.cuda.is_available() else os.environ.get("ACCELERATE_TORCH_DEVICE", "cpu") + ) # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # the default value. self._n_gpu = torch.cuda.device_count()