Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 15, 2024
1 parent c095c62 commit 0e816c6
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 133 deletions.
49 changes: 26 additions & 23 deletions csrc/cuda/scatter_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
Expand Down Expand Up @@ -111,28 +111,31 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,

auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());

scatter_kernel<scalar_t, REDUCE>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, K, N, src.numel());

if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);

if (REDUCE == MIN || REDUCE == MAX)
scatter_arg_kernel<scalar_t>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N,
src.numel());
});
});
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_",
[&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());

scatter_kernel<scalar_t, REDUCE>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, K, N, src.numel());

if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(),
(scalar_t)0);

if (REDUCE == MIN || REDUCE == MAX)
scatter_arg_kernel<scalar_t>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N,
src.numel());
});
});

return std::make_tuple(out, arg_out);
}
161 changes: 83 additions & 78 deletions csrc/cuda/segment_coo_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

CHECK_INPUT(src.dim() >= index.dim());

Expand Down Expand Up @@ -214,70 +214,73 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,

auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());

if (K == 1)
segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, N);
else if (avg_len <= 8)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
<<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
else if (avg_len <= 16)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
<<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
else if (avg_len <= 32)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
<<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
else
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
<<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX) {
if (K == 1)
segment_coo_arg_kernel<scalar_t>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, N);
else
segment_coo_arg_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N);
}
if (REDUCE == MEAN) {
auto count_data = arg_out.value().data_ptr<scalar_t>();
segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N);
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
(scalar_t)1);
auto count = arg_out.value();
for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1);
if (out.is_floating_point())
out.true_divide_(count);
else
out.div_(count, "floor");
}
});
});
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_",
[&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());

if (K == 1)
segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, N);
else if (avg_len <= 8)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
<<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E,
K, N);
else if (avg_len <= 16)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
<<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E,
K, N);
else if (avg_len <= 32)
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
<<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E,
K, N);
else
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
<<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E,
K, N);
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(),
(scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX) {
if (K == 1)
segment_coo_arg_kernel<scalar_t>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, N);
else
segment_coo_arg_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N);
}
if (REDUCE == MEAN) {
auto count_data = arg_out.value().data_ptr<scalar_t>();
segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N);
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
(scalar_t)1);
auto count = arg_out.value();
for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1);
if (out.is_floating_point())
out.true_divide_(count);
else
out.div_(count, "floor");
}
});
});
return std::make_tuple(out, arg_out);
}
Expand Down Expand Up @@ -330,7 +333,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());
CHECK_INPUT(src.dim() >= index.dim());
Expand Down Expand Up @@ -365,18 +368,20 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
if (K == 1)
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, N);
else
gather_coo_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, K, N);
});
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_",
[&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
if (K == 1)
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, N);
else
gather_coo_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, K, N);
});
return out;
}
69 changes: 37 additions & 32 deletions csrc/cuda/segment_csr_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ segment_csr_kernel(const scalar_t *src_data,
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = SHFL_DOWN_SYNC(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, SHFL_DOWN_SYNC(FULL_MASK, val, i), &arg, arg_tmp);
Reducer<scalar_t, REDUCE>::update(&val, SHFL_DOWN_SYNC(FULL_MASK, val, i),
&arg, arg_tmp);
}

if (lane_idx == 0) {
Expand Down Expand Up @@ -102,7 +102,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

CHECK_INPUT(src.dim() >= indptr.dim());

Expand Down Expand Up @@ -147,22 +147,24 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,

auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_csr_kernel<scalar_t, REDUCE, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, E);
} else {
segment_csr_broadcast_kernel<scalar_t, REDUCE>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
});
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_",
[&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_csr_kernel<scalar_t, REDUCE, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, E);
} else {
segment_csr_broadcast_kernel<scalar_t, REDUCE>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
});

return std::make_tuple(out, arg_out);
}
Expand Down Expand Up @@ -222,7 +224,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
CHECK_CUDA(indptr);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
cudaSetDevice(src.get_device());
c10::cuda::MaybeSetDevice(src.get_device());

CHECK_INPUT(src.dim() >= indptr.dim());

Expand Down Expand Up @@ -264,18 +266,21 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,

auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

if (K == 1)
gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
else
gather_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
});
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_",
[&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

if (K == 1)
gather_csr_kernel<scalar_t, 4>
<<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E);
else
gather_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
});

return out;
}

0 comments on commit 0e816c6

Please sign in to comment.