diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu index 5c54c03a05d1a..dcbc733f2acb2 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu @@ -291,7 +291,7 @@ void LaunchBatchTopKKernel(const T* topk_scores, int32_t num_beams, int32_t k, cudaStream_t stream) { - ORT_ENFORCE(k <= 256, "LaunchBatchTopKKernel doesn't support k >= 256"); + ORT_ENFORCE(k <= 64, "LaunchBatchTopKKernel doesn't support k >= 64"); #define BatchTopKKernelLauncher(K) \ BatchTopKKernel<<>>(topk_scores, \ @@ -311,12 +311,8 @@ void LaunchBatchTopKKernel(const T* topk_scores, BatchTopKKernelLauncher(16); } else if (k <= 32) { BatchTopKKernelLauncher(32); - } else if (k <= 64) { - BatchTopKKernelLauncher(64); - } else if (k <= 128) { - BatchTopKKernelLauncher(128); } else { - BatchTopKKernelLauncher(256); + BatchTopKKernelLauncher(64); } } @@ -330,36 +326,6 @@ template void LaunchBatchTopKKernel(const float* topk_scores, int32_t k, cudaStream_t stream); -template void LaunchBatchTopKKernel(const float* topk_scores, - const int64_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - float* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - -template void LaunchBatchTopKKernel(const half* topk_scores, - const int32_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - half* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - -template void LaunchBatchTopKKernel(const half* topk_scores, - const int64_t* topk_tokens, - int32_t* next_indices, - int32_t* next_tokens, - half* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - template void BeamSearchTopK( const T* input, @@ -426,21 +392,6 @@ template void BeamSearchTopK( int32_t* output_indices, cudaStream_t stream); -template void BeamSearchTopK( - const half* input, - int32_t batch_size, - int32_t num_beams, - int32_t vocab_size, - int32_t k, - half* tmp_values_1st_stage, - int32_t* tmp_indices_1st_stage, - half* tmp_values_2st_stage, - int32_t* tmp_indices_2st_stage, - half* output_values, - int32_t* output_tokens, - int32_t* output_indices, - cudaStream_t stream); - } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h index 5e338b417e8a5..096448c002e36 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h @@ -11,18 +11,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template -void LaunchBatchTopKKernel( - const T* topk_scores, - const I* topk_indices, - int32_t* next_indices, - int32_t* next_tokens, - T* next_scores, - int32_t batch_size, - int32_t num_beams, - int32_t k, - cudaStream_t stream); - template void BeamSearchTopK( const T* input, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 1a5a9ac5d97b2..1c9013b092c4e 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -430,12 +430,16 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams); dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams); #endif + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), + beam_state->next_scores.data(), + beam_state->next_scores.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); } else { // Apply top-k selection like the following: // next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) // next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True) - // int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size}; - int64_t next_token_scores_dims[] = {batch_size * num_beams, vocab_size}; + int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size}; TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); auto element_type = DataTypeImpl::GetType(); @@ -450,31 +454,36 @@ Status ProcessLogits(const OrtValue& logits, // constexpr bool sorted = true; // results returned in sorted order. std::unique_ptr topk_scores = Tensor::CreateDefault(); - std::unique_ptr topk_tokens = Tensor::CreateDefault(); + std::unique_ptr topk_indices = Tensor::CreateDefault(); ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, ort_stream, thread_pool, - *topk_scores, *topk_tokens)); + *topk_scores, *topk_indices)); #ifdef DEBUG_GENERATION dumper->Print("topk_scores", *(topk_scores.get())); - dumper->Print("topk_tokens", *(topk_tokens.get())); + dumper->Print("topk_indices", *(topk_indices.get())); +#endif + + // Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following: + // next_indices = (next_tokens / vocab_size).long() + // next_tokens = next_tokens % vocab_size + const int64_t* next_token_indices = topk_indices->Data(); + cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(), + batch_size, top_k, vocab_size, cuda_stream); + + const float* data = topk_scores->Data(); +#ifdef DEBUG_GENERATION + dumper->Print("next_scores before scorer", data, batch_size, top_k); + dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k); + dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k); #endif - cuda::LaunchBatchTopKKernel(topk_scores->Data(), - topk_tokens->Data(), - beam_state->next_indices.data(), - beam_state->next_tokens.data(), - beam_state->next_scores.data(), - batch_size, - num_beams, - 2 * num_beams, - cuda_stream); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), + data, + topk_scores->Shape().Size() * sizeof(float), + cudaMemcpyDeviceToHost, + cuda_stream)); } - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), - beam_state->next_scores.data(), - beam_state->next_scores.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream)); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(), beam_state->next_tokens.data(), beam_state->next_tokens.size_bytes(),