From 0374411b534e51b949688c2fbffe18758864effe Mon Sep 17 00:00:00 2001 From: wangyems Date: Thu, 8 Dec 2022 04:15:23 +0000 Subject: [PATCH] refactor --- .../transformers/generation_device_helper.cc | 10 ++- .../cpu/transformers/generation_shared.h | 4 ++ .../transformers/greedy_search_impl_base.h | 6 +- .../cpu/transformers/sampling_cpu_helper.h | 70 ++++++++++--------- .../transformers/generation_device_helper.cc | 3 +- .../cuda/transformers/sampling_cuda_helper.h | 39 +++++++---- 6 files changed, 80 insertions(+), 52 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index b902ee57a5467..634bfb4e12637 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -418,9 +418,6 @@ Status GreedySearchProcessLogits( int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper -#ifndef DEBUG_GENERATION - ORT_UNUSED_PARAMETER(dumper); -#endif int batch_size = parameters->batch_size; int vocab_size = parameters->vocab_size; @@ -458,18 +455,18 @@ Status GreedySearchProcessLogits( dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size); #endif - constexpr unsigned top_k = 1; - if (do_sampling) { SamplingCpuHelper::TopPSamplingCpu top_p_sampler(allocator, thread_pool, sampling_state, greedy_state, - parameters); + parameters, + dumper); ORT_RETURN_IF_ERROR(top_p_sampler.Sample(next_token_scores)); return Status::OK(); } + // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); @@ -482,6 +479,7 @@ Status GreedySearchProcessLogits( next_token_scores_value); const Tensor& input = next_token_scores_value.Get(); + constexpr unsigned top_k = 1; constexpr int axis = 1; constexpr bool largest = true; constexpr bool sorted = false; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index d6d34f39beb3d..2aafbcbcda736 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -83,6 +83,10 @@ struct ISamplingState { BufferUniquePtr storage_buffer; size_t temp_storage_bytes; std::default_random_engine generator; + + std::vector sorted_scores; + std::vector sorted_indices; + std::vector cumulative_probs; }; class ISequences { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 16b3d1f6b6ecc..d0acf6fef0ebd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -38,7 +38,7 @@ struct SamplingState : public ISamplingState { this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); this->temp_storage_bytes = 0; - // TODO(wy): Do not allocate this buffer if there's no presence_mask + // TODO: Do not allocate this buffer if there's no presence_mask this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); std::uniform_real_distribution distribution(0.0, 1.0); @@ -46,6 +46,10 @@ struct SamplingState : public ISamplingState { for (size_t i = 0; i < this->h_sampled_all.size(); ++i) { this->h_sampled_all[i] = distribution(this->generator); } + } else { + this->sorted_scores.reserve(total_count); + this->sorted_indices.reserve(total_count); + this->cumulative_probs.reserve(total_count); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h index 5adcfe15a6922..f45f22561c359 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -13,12 +13,14 @@ class TopPSamplingCpu{ onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ISamplingState* sampling_state, transformers::IGreedySearchState* greedy_state, - const transformers::IGenerationParameters* parameters): + const transformers::IGenerationParameters* parameters, + const transformers::IConsoleDumper* dumper): allocator_(allocator), thread_pool_(thread_pool), sampling_state_(sampling_state), greedy_state_(greedy_state), - parameters_(parameters) {} + parameters_(parameters), + dumper_(dumper) {} Status Sample(gsl::span& next_token_scores); @@ -30,6 +32,7 @@ class TopPSamplingCpu{ transformers::ISamplingState* sampling_state_; transformers::IGreedySearchState* greedy_state_; const transformers::IGenerationParameters* parameters_; + const transformers::IConsoleDumper* dumper_; }; template @@ -46,40 +49,43 @@ Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { ORT_THROW("top_p shall be greater than 0"); } - for (int i = 0; i < parameters_->batch_size; i++) { - gsl::span next_token_score = next_token_scores.subspan(i * parameters_->vocab_size, - parameters_->vocab_size); - - // Copy the vector - std::vector sorted_score(next_token_score.begin(), next_token_score.end()); - - // Decending sort - std::vector sorted_indice(parameters_->vocab_size); - std::iota(sorted_indice.begin(), sorted_indice.end(), 0); - std::sort(sorted_indice.begin(), - sorted_indice.end(), - [&sorted_score](size_t i1, size_t i2) { - return sorted_score[i1] > sorted_score[i2]; + std::vector& sorted_scores = sampling_state_->sorted_scores; + sorted_scores.assign(next_token_scores.begin(), next_token_scores.end()); + // Decending sort + std::vector& sorted_indices = sampling_state_->sorted_indices; + + for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { + auto indices_begin = sorted_indices.begin() + i * parameters_->vocab_size; + auto indices_end = sorted_indices.begin() + (i + 1) * parameters_->vocab_size; + std::iota(indices_begin, indices_end, 0); + std::sort(indices_begin, indices_end, + [&next_token_scores](size_t i1, size_t i2) { + return next_token_scores[i1] > next_token_scores[i2]; }); - std::sort(sorted_score.begin(), sorted_score.end(), std::greater()); - std::vector cumulative_prob(parameters_->vocab_size); + std::sort(sorted_scores.begin() + i * parameters_->vocab_size, + sorted_scores.end() + (i + 1) * parameters_->vocab_size, + std::greater()); + } + + std::vector& cumulative_probs = sampling_state_->cumulative_probs; - // TODO: batch - ORT_RETURN_IF_ERROR(SoftmaxCPU(1, - parameters_->vocab_size, - sorted_score.data(), - cumulative_prob.data(), - false, - thread_pool_)); + ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters_->batch_size, + parameters_->vocab_size, + sorted_scores.data(), + cumulative_probs.data(), + false, + thread_pool_)); - if (cumulative_prob[0] > parameters_->top_p) { - filter_scores(sorted_indice, next_token_score, 1); + for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { + size_t offset = i * parameters_->vocab_size; + if (cumulative_probs[offset] > parameters_->top_p) { + filter_scores(sorted_indices, next_token_scores, 1 + offset); } - for (size_t i = 1; i < static_cast(parameters_->vocab_size) - 1; i++) { - cumulative_prob[i] += cumulative_prob[i - 1]; - if (cumulative_prob[i] > parameters_->top_p) { - filter_scores(sorted_indice, next_token_score, i + 1); + for (size_t j = 1; j < static_cast(parameters_->vocab_size) - 1; j++) { + cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; + if (cumulative_probs[j + offset] > parameters_->top_p) { + filter_scores(sorted_indices, next_token_scores, j + offset + 1); } } } @@ -132,7 +138,7 @@ Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { // TODO: update presense_mask() #ifdef DEBUG_GENERATION - dumper->Print("sampled_idx", *sampled_idx); + dumper_->Print("sampled_idx", *sampled_idx); #endif return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index dfca731f8148d..884f80ca3a1e4 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -569,7 +569,8 @@ Status GreedySearchProcessLogits( cuda_stream, sampling_state, greedy_state, - parameters); + parameters, + dumper); ORT_RETURN_IF_ERROR(top_p_sampler.Sample(step, next_token_scores)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index 18e3cb5c6da68..52939efc5d866 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -22,12 +22,14 @@ class TopPSamplingCuda{ cudaStream_t cuda_stream, transformers::ISamplingState* sampling_state, transformers::IGreedySearchState* greedy_state, - const transformers::IGenerationParameters* parameters): + const transformers::IGenerationParameters* parameters, + const transformers::IConsoleDumper* dumper): allocator_(allocator), cuda_stream_(cuda_stream), sampling_state_(sampling_state), greedy_state_(greedy_state), - parameters_(parameters) {} + parameters_(parameters), + dumper_(dumper) {} Status Sample(int step, gsl::span& next_token_scores); @@ -37,6 +39,7 @@ class TopPSamplingCuda{ transformers::ISamplingState* sampling_state_; transformers::IGreedySearchState* greedy_state_; const transformers::IGenerationParameters* parameters_; + const transformers::IConsoleDumper* dumper_; }; template @@ -63,7 +66,7 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_); #ifdef DEBUG_GENERATION - dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); + dumper_->Print("d_offset_buffer", d_offset.data(), parameters_->batch_size + 1, 1); #endif void* temp_storage = allocator_->Alloc(sampling_state_->temp_storage_bytes); @@ -75,7 +78,7 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { gsl::span& d_index_out = sampling_state_->d_index_out; #ifdef DEBUG_GENERATION - dumper->Print("temp_storage_bytes", temp_storage_bytes, true); + dumper_->Print("temp_storage_bytes", sampling_state_->temp_storage_bytes, true); #endif cuda::LaunchSortPairsDescending(storage_buffer.get(), @@ -90,9 +93,12 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_); #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score.data()), batch_size, vocab_size); - dumper->Print("d_index_buffer_in", d_index_in.data(), batch_size, vocab_size); - dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size); + dumper_->Print("d_sorted_score_buffer", + reinterpret_cast(d_sorted_score.data()), + parameters_->batch_size, + parameters_->vocab_size); + dumper_->Print("d_index_buffer_in", d_index_in.data(), parameters_->batch_size, parameters_->vocab_size); + dumper_->Print("d_index_buffer_out", d_index_out.data(), parameters_->batch_size, parameters_->vocab_size); #endif gsl::span& d_sorted_softmaxed_score = sampling_state_->d_sorted_softmaxed_score; @@ -105,7 +111,10 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { parameters_->batch_size); #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), batch_size, vocab_size); + dumper_->Print("d_sorted_softmaxed_score_buffer", + d_sorted_softmaxed_score.data(), + parameters_->batch_size, + parameters_->vocab_size); #endif cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), @@ -118,7 +127,10 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_); #ifdef DEBUG_GENERATION - dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); + dumper_->Print("next_token_scores after filtering logits", + reinterpret_cast(next_token_scores.data()), + parameters_->batch_size, + parameters_->vocab_size); #endif // TODO(wy): Can we only do softmax at the very beginning and sort the softmaxed scores. @@ -132,7 +144,10 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { parameters_->batch_size); #ifdef DEBUG_GENERATION - dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), batch_size, vocab_size); + dumper_->Print("d_softmaxed_score_buffer", + d_softmaxed_score.data(), + parameters_->batch_size, + parameters_->vocab_size); #endif // Multinomial sampling @@ -145,7 +160,7 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_)); #ifdef DEBUG_GENERATION - dumper->Print("d_sampled", d_sampled.data(), batch_size, 1); + dumper_->Print("d_sampled", d_sampled.data(), parameters_->batch_size, 1); #endif gsl::span& d_indices = sampling_state_->d_indices; @@ -159,7 +174,7 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_); #ifdef DEBUG_GENERATION - dumper->Print("d_indices", d_indices.data(), batch_size, 1); + dumper_->Print("d_indices", d_indices.data(), parameters_->batch_size, 1); #endif CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state_->next_tokens_cpu.data(),