Skip to content

Commit

Permalink
reduce size of beam_search topk
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Feb 4, 2023
1 parent fefd5d1 commit f9da238
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 82 deletions.
53 changes: 2 additions & 51 deletions onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ void LaunchBatchTopKKernel(const T* topk_scores,
int32_t num_beams,
int32_t k,
cudaStream_t stream) {
ORT_ENFORCE(k <= 256, "LaunchBatchTopKKernel doesn't support k >= 256");
ORT_ENFORCE(k <= 64, "LaunchBatchTopKKernel doesn't support k >= 64");

#define BatchTopKKernelLauncher(K) \
BatchTopKKernel<T, I, K, 32><<<batch_size, 32, 0, stream>>>(topk_scores, \
Expand All @@ -311,12 +311,8 @@ void LaunchBatchTopKKernel(const T* topk_scores,
BatchTopKKernelLauncher(16);
} else if (k <= 32) {
BatchTopKKernelLauncher(32);
} else if (k <= 64) {
BatchTopKKernelLauncher(64);
} else if (k <= 128) {
BatchTopKKernelLauncher(128);
} else {
BatchTopKKernelLauncher(256);
BatchTopKKernelLauncher(64);
}
}

Expand All @@ -330,36 +326,6 @@ template void LaunchBatchTopKKernel(const float* topk_scores,
int32_t k,
cudaStream_t stream);

template void LaunchBatchTopKKernel(const float* topk_scores,
const int64_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
float* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
cudaStream_t stream);

template void LaunchBatchTopKKernel(const half* topk_scores,
const int32_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
half* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
cudaStream_t stream);

template void LaunchBatchTopKKernel(const half* topk_scores,
const int64_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
half* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
cudaStream_t stream);

template <typename T>
void BeamSearchTopK(
const T* input,
Expand Down Expand Up @@ -426,21 +392,6 @@ template void BeamSearchTopK(
int32_t* output_indices,
cudaStream_t stream);

template void BeamSearchTopK(
const half* input,
int32_t batch_size,
int32_t num_beams,
int32_t vocab_size,
int32_t k,
half* tmp_values_1st_stage,
int32_t* tmp_indices_1st_stage,
half* tmp_values_2st_stage,
int32_t* tmp_indices_2st_stage,
half* output_values,
int32_t* output_tokens,
int32_t* output_indices,
cudaStream_t stream);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
12 changes: 0 additions & 12 deletions onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

template <typename T, typename I>
void LaunchBatchTopKKernel(
const T* topk_scores,
const I* topk_indices,
int32_t* next_indices,
int32_t* next_tokens,
T* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
cudaStream_t stream);

template <typename T>
void BeamSearchTopK(
const T* input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,16 @@ Status ProcessLogits(const OrtValue& logits, //
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
#endif
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
beam_state->next_scores.data(),
beam_state->next_scores.size_bytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
} else {
// Apply top-k selection like the following:
// next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
// next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
// int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size};
int64_t next_token_scores_dims[] = {batch_size * num_beams, vocab_size};
int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size};

TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
auto element_type = DataTypeImpl::GetType<float>();
Expand All @@ -450,31 +454,36 @@ Status ProcessLogits(const OrtValue& logits, //
constexpr bool sorted = true; // results returned in sorted order.

std::unique_ptr<Tensor> topk_scores = Tensor::CreateDefault();
std::unique_ptr<Tensor> topk_tokens = Tensor::CreateDefault();
std::unique_ptr<Tensor> topk_indices = Tensor::CreateDefault();
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, ort_stream, thread_pool,
*topk_scores, *topk_tokens));
*topk_scores, *topk_indices));

#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", *(topk_scores.get()));
dumper->Print("topk_tokens", *(topk_tokens.get()));
dumper->Print("topk_indices", *(topk_indices.get()));
#endif

// Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following:
// next_indices = (next_tokens / vocab_size).long()
// next_tokens = next_tokens % vocab_size
const int64_t* next_token_indices = topk_indices->Data<int64_t>();
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
batch_size, top_k, vocab_size, cuda_stream);

const float* data = topk_scores->Data<float>();
#ifdef DEBUG_GENERATION
dumper->Print("next_scores before scorer", data, batch_size, top_k);
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k);
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
#endif

cuda::LaunchBatchTopKKernel(topk_scores->Data<float>(),
topk_tokens->Data<int64_t>(),
beam_state->next_indices.data(),
beam_state->next_tokens.data(),
beam_state->next_scores.data(),
batch_size,
num_beams,
2 * num_beams,
cuda_stream);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
data,
topk_scores->Shape().Size() * sizeof(float),
cudaMemcpyDeviceToHost,
cuda_stream));
}

CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
beam_state->next_scores.data(),
beam_state->next_scores.size_bytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(),
beam_state->next_tokens.data(),
beam_state->next_tokens.size_bytes(),
Expand Down

0 comments on commit f9da238

Please sign in to comment.