From f4696b75534cac73c559d43e79dc25d71be32c25 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 15 Apr 2024 11:35:50 +0200 Subject: [PATCH] update (#436) --- csrc/cuda/scatter_cuda.cu | 2 +- csrc/cuda/segment_coo_cuda.cu | 4 ++-- csrc/cuda/segment_csr_cuda.cu | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/cuda/scatter_cuda.cu b/csrc/cuda/scatter_cuda.cu index 4c827116..6dabad59 100644 --- a/csrc/cuda/scatter_cuda.cu +++ b/csrc/cuda/scatter_cuda.cu @@ -63,7 +63,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, CHECK_CUDA(index); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() == index.dim()); for (auto i = 0; i < index.dim() - 1; i++) diff --git a/csrc/cuda/segment_coo_cuda.cu b/csrc/cuda/segment_coo_cuda.cu index 50678da7..ea73e418 100644 --- a/csrc/cuda/segment_coo_cuda.cu +++ b/csrc/cuda/segment_coo_cuda.cu @@ -157,7 +157,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, CHECK_CUDA(index); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= index.dim()); @@ -330,7 +330,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, CHECK_CUDA(index); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= index.dim()); diff --git a/csrc/cuda/segment_csr_cuda.cu b/csrc/cuda/segment_csr_cuda.cu index e42e6655..9e6426f5 100644 --- a/csrc/cuda/segment_csr_cuda.cu +++ b/csrc/cuda/segment_csr_cuda.cu @@ -102,7 +102,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, CHECK_CUDA(indptr); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= indptr.dim()); @@ -222,7 +222,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, CHECK_CUDA(indptr); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= indptr.dim());