diff --git a/csrc/cuda/scatter_cuda.cu b/csrc/cuda/scatter_cuda.cu index 4c827116..6dabad59 100644 --- a/csrc/cuda/scatter_cuda.cu +++ b/csrc/cuda/scatter_cuda.cu @@ -63,7 +63,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, CHECK_CUDA(index); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() == index.dim()); for (auto i = 0; i < index.dim() - 1; i++) diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index 50678da7..ea73e418 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -157,7 +157,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, CHECK_CUDA(index); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= index.dim()); @@ -330,7 +330,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, CHECK_CUDA(index); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= index.dim()); diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index e42e6655..9e6426f5 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -102,7 +102,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, CHECK_CUDA(indptr); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= indptr.dim()); @@ -222,7 +222,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, CHECK_CUDA(indptr); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= indptr.dim());