Skip to content

Commit

Permalink
Automatically detect CUDA if it isn't set and default to upstream tor…
Browse files Browse the repository at this point in the history
…ch.cuda number of devices (#6605)
  • Loading branch information
changm authored Feb 27, 2024
1 parent cd47390 commit cb4983e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ function run_xla_op_tests3 {
run_torchrun "$CDIR/pjrt/test_torchrun.py"
run_test "$CDIR/test_persistent_cache.py"
run_test "$CDIR/test_devices.py"
run_test "$CDIR/test_gpu_device_detection.py"
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
}
Expand Down
34 changes: 34 additions & 0 deletions test/test_gpu_device_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import unittest

import torch
import torch.cuda
import torch_xla
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr


@unittest.skipIf(
os.getenv(xenv.PJRT_DEVICE) != None or not torch.cuda.is_available(),
f"SKipping test since PJRT_DEVICE was explicitly set or CUDA is not available.",
)
class GpuDeviceDetectionTest(unittest.TestCase):

def setUpClass():
os.unsetenv(xenv.PJRT_DEVICE)
os.unsetenv(xenv.GPU_NUM_DEVICES)

def test_automatically_detects_cuda(self):
device_type = xr.device_type()
self.assertEqual(device_type, "CUDA")
self.assertEqual(os.environ[xenv.GPU_NUM_DEVICES],
str(torch.cuda.device_count()))

supported_devices = xm.get_xla_supported_devices("CUDA")
self.assertTrue(len(supported_devices) > 0)


if __name__ == "__main__":
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
9 changes: 8 additions & 1 deletion torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List, Optional, TypeVar

import torch
import torch.cuda
import torch_xla
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -44,10 +45,16 @@ def _maybe_select_default_device():
if torch_xla._found_libtpu and tpu.num_available_chips() > 0:
logging.warning('libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.')
os.environ[xenv.PJRT_DEVICE] = 'TPU'
# TODO(wcromar): Detect GPU device
elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0:
logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=CUDA')
os.environ[xenv.PJRT_DEVICE] = 'CUDA'
elif torch.cuda.is_available() and torch.cuda.device_count() > 0:
num_devices_str = str(torch.cuda.device_count())
logging.warning(
'Found CUDA without GPU_NUM_DEVICES. Defaulting to PJRT_DEVICE=CUDA with GPU_NUM_DEVICES='
+ num_devices_str)
os.environ[xenv.PJRT_DEVICE] = 'CUDA'
os.environ[xenv.GPU_NUM_DEVICES] = num_devices_str
else:
logging.warning('Defaulting to PJRT_DEVICE=CPU')
os.environ[xenv.PJRT_DEVICE] = 'CPU'
Expand Down

0 comments on commit cb4983e

Please sign in to comment.