Skip to content

Commit

Permalink
update (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Apr 15, 2024
1 parent c095c62 commit f4696b7
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion csrc/cuda/scatter_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
Expand Down
4 changes: 2 additions & 2 deletions csrc/cuda/segment_coo_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

CHECK_INPUT(src.dim() >= index.dim());

Expand Down Expand Up @@ -330,7 +330,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= index.dim());
Expand Down
4 changes: 2 additions & 2 deletions csrc/cuda/segment_csr_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

CHECK_INPUT(src.dim() >= indptr.dim());

Expand Down Expand Up @@ -222,7 +222,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

CHECK_INPUT(src.dim() >= indptr.dim());

Expand Down

0 comments on commit f4696b7

Please sign in to comment.