Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
subwarp version gather op for small embedding size (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
chuangz0 authored May 23, 2024
1 parent a087085 commit 7352f1c
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 2 deletions.
126 changes: 125 additions & 1 deletion cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -309,6 +309,62 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,
return;
}

template <int N>
struct IsPowerOfTwo {
static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0);
};

template <typename EmbeddingT,
typename IndexT,
typename OutputT,
int SUB_WARP_SIZE = 1,
int ALIGNMENT = 1>
__global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
const IndexT* indices,
int64_t indice_count,
OutputT* output,
wholememory_matrix_description_t output_desc)
{
static_assert(IsPowerOfTwo<SUB_WARP_SIZE>::value && SUB_WARP_SIZE < 32,
"SUB_WARP_SIZE must be the power of 2,and smaller than 32.");

auto block = cooperative_groups::this_thread_block();

auto subwarp = cooperative_groups::tiled_partition<SUB_WARP_SIZE>(block);
int sub_warp_id = subwarp.meta_group_size() * blockIdx.x + subwarp.meta_group_rank();
int sub_warp_num = subwarp.meta_group_size() * gridDim.x;

int lane_id_in_sub_warp = subwarp.thread_rank();
wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);

int embedding_size = embedding_desc.sizes[1];
int64_t embedding_stride = embedding_desc.stride;
int64_t output_stride = output_desc.stride;

typed_data_vector<EmbeddingT, ALIGNMENT> embeddings;
typed_data_vector<OutputT, ALIGNMENT> outputs;
for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) {
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx;
IndexT embedding_table_idx = indices[output_idx];
if (embedding_table_idx < 0) continue;
int64_t embedding_offset =
embedding_desc.storage_offset + embedding_table_idx * embedding_stride;

for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size;
emb_idx += ALIGNMENT * SUB_WARP_SIZE) {
mov_data<sizeof(EmbeddingT) * ALIGNMENT>(&embeddings,
&embedding_dev_ref[embedding_offset + emb_idx]);
#pragma unroll
for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) {
typed_data_vector_at(outputs, sub_idx) =
convert_type<EmbeddingT, OutputT>(typed_data_vector_at(embeddings, sub_idx));
}
mov_data<sizeof(OutputT) * ALIGNMENT>(output_ptr + emb_idx, &outputs);
}
}
}

template <typename EmbeddingT, typename IndexT, typename OutputT>
void gather_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
Expand Down Expand Up @@ -338,6 +394,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
int64_t,
OutputT*,
wholememory_matrix_description_t) = nullptr;

switch (alignment) {
case 16: {
kernel_fn = gather_func_kernel<EmbeddingT, IndexT, OutputT, 16>;
Expand Down Expand Up @@ -367,6 +424,73 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
int block_size = 1024;
int block_count = indice_count > 1568 ? 1568 : indice_count;
if (gather_sms != -1) block_count = gather_sms;

// for small embedding size ,use subwarp to gather
int min_threads_per_embedding = embedding_desc.sizes[1] / alignment;
if (min_threads_per_embedding < 32) {
#define SWITCH_GATHER_FUNC_WITH_ALIGNMENT(KERNEL_NAME, SUB_WARP_SIZE) \
switch (alignment) { \
case 16: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 16>; \
break; \
} \
case 8: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 8>; \
break; \
} \
case 4: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 4>; \
break; \
} \
case 2: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 2>; \
break; \
} \
case 1: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 1>; \
break; \
} \
default: { \
WHOLEMEMORY_FAIL("gather func alignment=%d.", alignment); \
return; \
} \
}

int threads_per_embedding = 16;
if (min_threads_per_embedding >= 16) {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 16);
threads_per_embedding = 16;
} else if (min_threads_per_embedding < 16 && min_threads_per_embedding >= 8) {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 8);
threads_per_embedding = 8;
} else if (min_threads_per_embedding < 8 && min_threads_per_embedding >= 4) {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 4);
threads_per_embedding = 4;
} else if (min_threads_per_embedding < 4 && min_threads_per_embedding >= 2) {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 2);
threads_per_embedding = 2;
} else {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 1);
threads_per_embedding = 1;
}

#undef SWITCH_GATHER_FUNC_WITH_ALIGNMENT
block_size = 128;
int max_blocks_per_sm = 8;
WM_CUDA_CHECK(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, kernel_fn, block_size, 0));

int sm_count = 100;
int device_id = 0;
WM_CUDA_CHECK(cudaGetDevice(&device_id));
WM_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id));

// block_count = indice_count > 1568 ? 1568 : indice_count;
int min_embedding_per_block = block_size / threads_per_embedding;
block_count = min((int)(indice_count + min_embedding_per_block - 1) / min_embedding_per_block,
sm_count * max_blocks_per_sm * 4);
if (gather_sms != -1) block_count = gather_sms * max_blocks_per_sm;
}
kernel_fn<<<block_count, block_size, 0, stream>>>(embedding_gref,
embedding_desc,
static_cast<const IndexT*>(indices),
Expand Down
12 changes: 11 additions & 1 deletion cpp/tests/wholememory_ops/wholememory_gather_tests.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -311,6 +311,16 @@ INSTANTIATE_TEST_SUITE_P(
.set_embedding_dim(11)
.set_embedding_stride(12)
.set_indices_count(100005),
WholeMemoryGatherTestParam()
.set_memory_type(WHOLEMEMORY_MT_CHUNKED)
.set_embedding_dim(1)
.set_embedding_stride(1)
.set_indices_count(100005),
WholeMemoryGatherTestParam()
.set_memory_type(WHOLEMEMORY_MT_CHUNKED)
.set_embedding_dim(1)
.set_embedding_stride(2)
.set_indices_count(100005),
WholeMemoryGatherTestParam()
.set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED)
.set_embedding_dim(11)
Expand Down

0 comments on commit 7352f1c

Please sign in to comment.