Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the performence of put_along_axis #60618

Merged
merged 3 commits into from
Jan 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 34 additions & 54 deletions paddle/phi/kernels/funcs/gather_scatter_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class ReduceMin {
};
static ReduceMin reduce_min;

__global__ void CudaMemsetAsync(int* dest, int value, size_t size) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid * sizeof(int) >= size) return;
dest[tid] = value;
}

template <typename tensor_t,
typename index_t,
typename func_t,
Expand All @@ -112,13 +118,6 @@ __global__ void ScatterAssignGPUKernel(tensor_t* self_data,
int* thread_ids) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;

if (tid == 0) {
for (int i = 0; i < numel_data; i++) {
thread_ids[i] = 0;
}
}
__syncthreads();
int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
// squeezed from the N layers loop.
/* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
Expand Down Expand Up @@ -267,16 +266,6 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data,
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;

if (tid == 0) {
for (int i = 0; i < numel_data; i++) {
shared_mem[i] = 0; // thread_id
if (include_self)
shared_mem[numel_data + i] = 1; // reduce size
else
shared_mem[numel_data + i] = 0;
}
}
__syncthreads();
int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
// squeezed from the N layers loop.
/* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
Expand Down Expand Up @@ -384,6 +373,7 @@ struct gpu_gather_scatter_functor {
int* shared_mem;
cudaMallocAsync(
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
ScatterAssignGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
<<<grid, block, 0, stream>>>(self_data,
dim,
Expand All @@ -405,6 +395,14 @@ struct gpu_gather_scatter_functor {
int* shared_mem;
cudaMallocAsync(
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
cudaMemsetAsync(shared_mem, 0, sizeof(int) * self_size, stream);
if (include_self) {
int64_t grid_memset = (self_size * 2 + block - 1) / block;
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
shared_mem, 1, shared_mem_size);
} else {
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
}
ScatterMeanGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
<<<grid, block, 0, stream>>>(self_data,
dim,
Expand All @@ -429,6 +427,9 @@ struct gpu_gather_scatter_functor {
shared_mem_size = sizeof(int) * self_size;
cudaMallocAsync(
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
int64_t grid_memset = (self_size + block - 1) / block;
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
shared_mem, index_size + 1, shared_mem_size);
}
GatherScatterGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
<<<grid, block, shared_mem_size, stream>>>(self_data,
Expand Down Expand Up @@ -640,12 +641,6 @@ __global__ void ScatterMulInputGradGPUKernel(tensor_t* grad_data,
int* thread_ids) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;
if (tid == 0) {
for (int i = 0; i < numel_grad; i++) {
thread_ids[i] = 0;
}
}
__syncthreads();
int64_t i, j, k;
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
Expand Down Expand Up @@ -682,13 +677,6 @@ __global__ void ScatterMinMaxInputGradGPUKernel(tensor_t* grad_data,
int* shared_mem) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;

if (tid == 0) {
for (int i = 0; i < numel_grad; i++) {
shared_mem[i] = 1; // number of elements
}
}
__syncthreads();
int64_t i, j, k;
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
Expand Down Expand Up @@ -762,6 +750,7 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self,
int* shared_mem;
cudaMallocAsync(
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
ScatterMulInputGradGPUKernel<tensor_t, index_t>
<<<grid, block, 0, stream>>>(grad_data,
dim,
Expand All @@ -781,6 +770,9 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self,
int* shared_mem;
cudaMallocAsync(
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
int64_t grid_memset = (grad_size + block - 1) / block;
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
shared_mem, 1, shared_mem_size);
ScatterMinMaxInputGradGPUKernel<tensor_t, index_t>
<<<grid, block, 0, stream>>>(grad_data,
dim,
Expand Down Expand Up @@ -816,13 +808,6 @@ __global__ void ScatterMeanInputGradGPUKernel(tensor_t* grad_data,
int* shared_mem) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;
if (tid == 0) {
for (int i = 0; i < numel_grad; i++) {
shared_mem[i] = 0; // thread_ids
shared_mem[numel_grad + i] = 1; // number of elements
}
}
__syncthreads();
int64_t i, j, k;
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
Expand Down Expand Up @@ -879,6 +864,10 @@ void gpu_scatter_mean_input_grad_kernel(phi::DenseTensor self,
int* shared_mem;
cudaMallocAsync(
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
cudaMemsetAsync(shared_mem, 0, sizeof(int) * grad_size, stream);
int64_t grid_memset = (grad_size + block - 1) / block;
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
shared_mem + grad_size, 1, sizeof(int) * grad_size);
ScatterMeanInputGradGPUKernel<tensor_t, index_t>
<<<grid, block, 0, stream>>>(grad_data,
dim,
Expand Down Expand Up @@ -910,12 +899,6 @@ __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data,
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;

if (tid == 0) {
for (int i = 0; i < numel_data; i++) {
thread_ids[i] = 0;
}
}
__syncthreads();
int64_t i, j, k;
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
Expand Down Expand Up @@ -975,6 +958,7 @@ void gpu_scatter_value_grad_kernel(phi::DenseTensor self,
int* shared_mem;
cudaMallocAsync(
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
ScatterValueGradGPUKernel<tensor_t, index_t>
<<<grid, block, 0, stream>>>(grad_data,
dim,
Expand Down Expand Up @@ -1005,20 +989,10 @@ __global__ void ScatterMeanValueGradGPUKernel(tensor_t* grad_data,
int64_t outer_dim_size_grad,
int64_t numel,
int64_t numel_self,
bool include_self,
int* shared_mem) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= numel) return;

if (tid == 0) {
for (int i = 0; i < numel_self; i++) {
if (include_self)
shared_mem[i] = 1; // number of elements
else
shared_mem[i] = 0;
}
}
__syncthreads();
int64_t i, j, k;
i = tid / (select_dim_size * outer_dim_size);
int64_t remind = tid % (select_dim_size * outer_dim_size);
Expand Down Expand Up @@ -1114,6 +1088,13 @@ void gpu_scatter_add_mean_value_grad_kernel(
int* shared_mem;
cudaMallocAsync(
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
if (include_self) {
int64_t grid_memset = (self_size + block - 1) / block;
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
shared_mem, 1, shared_mem_size);
} else {
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
}
ScatterMeanValueGradGPUKernel<tensor_t, index_t>
<<<grid, block, 0, stream>>>(grad_data,
dim,
Expand All @@ -1127,7 +1108,6 @@ void gpu_scatter_add_mean_value_grad_kernel(
outer_dim_size_grad,
index_size,
self_size,
include_self,
shared_mem);
cudaFreeAsync(shared_mem, stream);
} else if (reduce == "add") {
Expand Down