From d4664e0b1a5449bda7abadddd38e9b539c506609 Mon Sep 17 00:00:00 2001 From: wangyems Date: Mon, 24 Oct 2022 20:27:48 +0000 Subject: [PATCH 01/51] logits wrapper cpu --- .../cpu/transformers/logits_processor.cc | 93 +++++++++++++++++++ .../cpu/transformers/logits_processor.h | 57 ++++++++++++ 2 files changed, 150 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 2f1e657c8e64a..2abc084ffa9b4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -6,8 +6,12 @@ #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/common/span_utils.h" +#include "core/providers/cpu/math/softmax_shared.h" #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/dump_tensor.h" +#include +#include +#include namespace onnxruntime { namespace contrib { @@ -187,6 +191,91 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, #endif } +template +TemperatureLogitsProcessor::TemperatureLogitsProcessor(float temperature) : temperature_(temperature) { +} + +template +void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, + NextTokenScores& next_token_scores) { + if (temperature_ == 1.0f) { + return; + } + + T* p = next_token_scores.scores.data(); + for (int i = 0; i < next_token_scores.scores.size(); i++) { + *p /= temperature_; + } +} + +template +TopPLogitsProcessor::TopPLogitsProcessor(float top_p, float filter_value, int min_tokens_to_keep) + : top_p_(top_p), filter_value_(filter_value), min_tokens_to_keep_(min_tokens_to_keep) { +} + +template +void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, + NextTokenScores& next_token_scores) { + const int batch_beam_size = next_token_scores.batch_beam_size; + const int vocab_size = next_token_scores.vocab_size; + + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.GetScores(i); + + std::vector sorted_scores(beam_token_scores.begin(), beam_token_scores.end()); + + std::vector sorted_indices(beam_token_scores.size()); + std::iota(sorted_indices.begin(), sorted_indices.end(), 0); + std::sort(sorted_indices.begin(), + sorted_indices.end(), + [&sorted_scores](size_t i1, size_t i2) { + return sorted_scores[i1] < sorted_scores[i2]; + }); + + std::sort(sorted_scores.begin(), sorted_scores.end()); + std::vector cumulative_probs(vocab_size); + ORT_RETURN_IF_ERROR(SoftmaxCPU(1, // rows + vocab_size, // elements per row + sorted_scores.data(), + cumulative_probs.data(), + false, + thread_pool)); + + std::unordered_set sorted_indices_to_remove; + if (cumulative_probs[0] > top_p_) { + sorted_indices_to_remove.insert(1); + } + for (size_t j = 1; j < vocab_size - 1; j++) { + cumulative_probs[j] += cumulative_probs[j - 1]; + if (cumulative_probs[j] > top_p_) { + sorted_indices_to_remove.insert(j + 1); + } + } + + for (auto it = sorted_indices_to_remove.begin(); it != sorted_indices_to_remove.end(); ++it) { + size_t index_to_remove = sorted_indices[*it]; + beam_token_scores[index_to_remove] = filter_value_; + } + } +} + +template +PresencePenaltyLogitsProcessor::PresencePenaltyLogitsProcessor(const gsl::span& presence_mask, + float presence_penalty) + : presence_mask_(presence_mask), presence_penalty_(presence_penalty) { +} + +template +void PresencePenaltyLogitsProcessor::Process(const ISequences* sequences, + NextTokenScores& next_token_scores) { + assert(!presence_mask_.empty()); + + T* p = next_token_scores.scores.data(); + for (int i = 0; i < next_token_scores.scores.size(); i++) { + *p -= presence_mask_[i] * presence_penalty_; + } +} + void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } @@ -195,6 +284,10 @@ void LogitsProcessorList::Init(const GreedySearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } +void LogitsProcessorList::Init(const BeamSamplingParameters& parameters) { + LogitsProcessorInitImpl(parameters); +} + void LogitsProcessorList::Process(const ISequences* sequences, gsl::span& next_token_scores, int step) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 1a2fba19bfd3d..2255d455d51d0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -96,11 +96,52 @@ class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor { const int batch_size_; }; +template +class TemperatureLogitsProcessor : public ILogitsProcessor { + public: + TemperatureLogitsProcessor(int temperature); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + int temperature_; +}; + +template +class TopPLogitsProcessor : public ILogitsProcessor { + public: + TopPLogitsProcessor(float top_p, float filter_value, int min_tokens_to_keep); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + float top_p_; + float filter_value_; + int min_tokens_to_keep_; +}; + +template +class PresencePenaltyLogitsProcessor : public ILogitsProcessor { + public: + PresencePenaltyLogitsProcessor(const gsl::span& presence_mask, + float presence_penalty); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + gsl::span presence_mask_; + float presence_penalty_; +}; + class LogitsProcessorList : public ILogitsProcessorList { public: LogitsProcessorList() = default; void Init(const BeamSearchParameters& parameters); void Init(const GreedySearchParameters& parameters); + void Init(const BeamSamplingParameters& parameters); void Process(const ISequences* sequences, gsl::span& next_token_scores, int step); private: @@ -140,6 +181,18 @@ class LogitsProcessorList : public ILogitsProcessorList { processor_list_.push_back(min_length_processor_.get()); } + if (parameters.temperature > 0) { + temperature_processor_ = std::make_unique>(parameters.temperature); + processor_list_.push_back(temperature_processor_.get()); + } + + if (parameters.top_p > 0) { + top_p_processor_ = std::make_unique>(parameters.top_p, + parameters.filter_value, + parameters.min_tokens_to_keep); + processor_list_.push_back(top_p_processor_.get()); + } + batch_beam_size_ = parameters.BatchBeamSize(); vocab_size_ = parameters.vocab_size; } @@ -148,11 +201,15 @@ class LogitsProcessorList : public ILogitsProcessorList { int vocab_size_; InlinedVector*> processor_list_; + onnxruntime::concurrency::ThreadPool* thread_pool_; + std::unique_ptr> repetition_penalty_processor_; std::unique_ptr> no_repeat_ngram_processor_; std::unique_ptr> vocab_mask_processor_; std::unique_ptr> prefix_vocab_mask_processor_; std::unique_ptr> min_length_processor_; + std::unique_ptr> temperature_processor_; + std::unique_ptr> top_p_processor_; }; } // namespace transformers From c5fb195060edf1fcd8f113645cad93ede754fc95 Mon Sep 17 00:00:00 2001 From: wangyems Date: Mon, 24 Oct 2022 20:57:50 +0000 Subject: [PATCH 02/51] add sampling parameters --- .../cpu/transformers/logits_processor.cc | 12 ++++++++-- .../cpu/transformers/sampling_parameters.cc | 24 +++++++++++++++++++ .../cpu/transformers/sampling_parameters.h | 24 +++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 2abc084ffa9b4..6defafc086a75 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -216,6 +216,10 @@ TopPLogitsProcessor::TopPLogitsProcessor(float top_p, float filter_value, int template void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, NextTokenScores& next_token_scores) { + if (top_p_ == 0.0f) { + return; + } + const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; @@ -234,8 +238,8 @@ void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, std::sort(sorted_scores.begin(), sorted_scores.end()); std::vector cumulative_probs(vocab_size); - ORT_RETURN_IF_ERROR(SoftmaxCPU(1, // rows - vocab_size, // elements per row + ORT_RETURN_IF_ERROR(SoftmaxCPU(1, + vocab_size, sorted_scores.data(), cumulative_probs.data(), false, @@ -268,6 +272,10 @@ PresencePenaltyLogitsProcessor::PresencePenaltyLogitsProcessor(const gsl::spa template void PresencePenaltyLogitsProcessor::Process(const ISequences* sequences, NextTokenScores& next_token_scores) { + if (presence_penalty_ == 0.0f) { + return; + } + assert(!presence_mask_.empty()); T* p = next_token_scores.scores.data(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc new file mode 100644 index 0000000000000..2ec80917014ec --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "contrib_ops/cpu/transformers/sampling_parameters.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +constexpr int kMaxSequenceLength = 4096; + +void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { + model_type = static_cast(info.GetAttrOrDefault("model_type", 0)); + eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); + pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); + decoder_start_token_id = static_cast(info.GetAttrOrDefault("decoder_start_token_id", -1)); + no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0)); + temperature = info.GetAttrOrDefault("temperature", 1.0f); + top_p = info.GetAttrOrDefault("top_p", 0.0f); + presence_penalty = info.GetAttrOrDefault("presence_penalty", 0.0f); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h new file mode 100644 index 0000000000000..729e54933c82f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/transformers/greedy_search_parameters.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +struct SamplingParameters : public GreadySearchParameters { + void ParseFromAttributes(const OpKernelInfo& info); + + float presence_penalty; + float temperature; + float top_p; + gsl::span presence_mask; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime From d664ac6880c4e50036edd18ac595349fc276be14 Mon Sep 17 00:00:00 2001 From: wangyems Date: Mon, 24 Oct 2022 22:58:58 +0000 Subject: [PATCH 03/51] register sampling cpu --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../cpu/transformers/beam_search.cc | 12 +- .../cpu/transformers/beam_search_impl_base.h | 2 +- .../transformers/beam_search_parameters.cc | 4 +- .../cpu/transformers/beam_search_parameters.h | 2 +- .../transformers/generation_device_helper.cc | 8 +- .../transformers/generation_device_helper.h | 8 +- .../cpu/transformers/generation_shared.h | 7 +- .../cpu/transformers/greedy_search.cc | 8 +- .../cpu/transformers/greedy_search.h | 3 +- .../transformers/greedy_search_impl_base.h | 24 ++-- .../cpu/transformers/greedy_search_impl_gpt.h | 38 +++--- .../cpu/transformers/logits_processor.cc | 41 ++++--- .../cpu/transformers/logits_processor.h | 29 +++-- .../contrib_ops/cpu/transformers/sampling.cc | 116 ++++++++++++++++++ .../contrib_ops/cpu/transformers/sampling.h | 30 +++++ .../cpu/transformers/sampling_parameters.cc | 1 + .../cpu/transformers/sampling_parameters.h | 7 +- .../transformers/generation_device_helper.cc | 12 +- .../transformers/generation_device_helper.h | 4 +- .../core/graph/contrib_ops/contrib_defs.cc | 29 +++++ 21 files changed, 290 insertions(+), 97 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/transformers/sampling.cc create mode 100644 onnxruntime/contrib_ops/cpu/transformers/sampling.h diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 3f459dff99a12..a04ef0d71b113 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -18,6 +18,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); @@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index d4132675edcd3..225a667b73d32 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -60,11 +60,11 @@ void BeamSearch::Init(const OpKernelInfo& info) { parameters_.ParseFromAttributes(info); // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5) - ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt || - parameters_.model_type == IBeamSearchParameters::kModelTypeT5); + ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt || + parameters_.model_type == IGenerationParameters::kModelTypeT5); ONNX_NAMESPACE::GraphProto proto; - if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); } @@ -77,7 +77,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { const auto& node = Node(); - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { if (attribute_name == "decoder") { ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer()); @@ -88,7 +88,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, gpt_subgraph_->head_size, gpt_subgraph_->num_layers); } - } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { + } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { if (attribute_name == "encoder") { ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); @@ -135,7 +135,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { // Make a copy of parameters since we will update it based on inputs later BeamSearchParameters parameters = parameters_; - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32 BeamSearchGpt impl{ *ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index f3764fcbfa1e3..863982972923f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -217,7 +217,7 @@ Status BeamSearchBase::Initialize() { if (!IsCuda()) { // Logits processor is used in CPU only. In CUDA, cuda kernels are used instead. // Initialize processors after CheckInputs so that parameters_->vocab_mask is ready. - logits_processors_.Init(*parameters_); + logits_processors_.Init(*parameters_, thread_pool_); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 0269efd319ee1..260cfc7ebf654 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -18,7 +18,7 @@ Status BeamSearchParameters::Validate() const { } void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { - model_type = static_cast(info.GetAttrOrDefault("model_type", IBeamSearchParameters::kModelTypeGpt)); + model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeGpt)); early_stopping = info.GetAttrOrDefault("early_stopping", 0) == 1; eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); @@ -35,7 +35,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { batch_size = static_cast(dims[0]); // For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1 - sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast(dims[1]) : 1; + sequence_length = (this->model_type == IGenerationParameters::kModelTypeGpt) ? static_cast(dims[1]) : 1; auto* max_length_tensor = context->Input(1); max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : kMaxSequenceLength; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 1a3a87bd3f6aa..0cb2b39976cc3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace transformers { -struct BeamSearchParameters : public IBeamSearchParameters { +struct BeamSearchParameters : public IGenerationParameters { Status Validate() const; int BatchBeamSize() const { return batch_size * num_beams; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 6ce41ae86d86c..f853bd74a59be 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -242,7 +242,7 @@ Status ProcessLogits(const OrtValue& logits, // onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors transformers::IBeamScorer* beam_scorer, // beam scorer - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper @@ -403,7 +403,7 @@ Status GreedySearchProcessLogits( AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper @@ -821,7 +821,7 @@ template Status ProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, transformers::IBeamScorer* beam_scorer, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, void* stream, const transformers::IConsoleDumper* dumper); @@ -833,7 +833,7 @@ template Status GreedySearchProcessLogits( AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, void* stream, const transformers::IConsoleDumper* dumper); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index ac64dc9cf598f..e4c571dfd2b51 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -82,7 +82,7 @@ using ProcessLogitsFunc = std::function; // tensor dumper @@ -95,7 +95,7 @@ using GreedySearchProcessLogitsFunc = std::function; // tensor dumper @@ -199,7 +199,7 @@ Status ProcessLogits(const OrtValue& logits, // onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors transformers::IBeamScorer* beam_scorer, // beam scorer - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper @@ -211,7 +211,7 @@ Status GreedySearchProcessLogits(const OrtValue& logits, AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index edbebcc81ad9b..5e5accdc5b962 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -96,7 +96,7 @@ class IBeamScorer { Tensor* output_sequence_scores) = 0; }; -struct IBeamSearchParameters { +struct IGenerationParameters { static constexpr int kModelTypeGpt = 0; static constexpr int kModelTypeT5 = 1; @@ -107,6 +107,10 @@ struct IBeamSearchParameters { int decoder_start_token_id; int no_repeat_ngram_size; bool early_stopping; + float presence_penalty; + float temperature; + float top_p; + float filter_value; // Parameters from inputs int min_length; @@ -120,6 +124,7 @@ struct IBeamSearchParameters { gsl::span vocab_mask; gsl::span prefix_vocab_mask; + gsl::span presence_mask; // Parameters from outputs. bool output_scores; // whether scores existed in output diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc index 0299912cab8c2..2fc97a8f3347e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc @@ -73,7 +73,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat const std::string& attribute_name, const SessionState& subgraph_session_state) { const auto& node = Node(); - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { // GPT-2 + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2 if (attribute_name == "decoder") { ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); @@ -85,7 +85,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat gpt_subgraph_->head_size, gpt_subgraph_->num_layers); } - } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { // encoder-decoder like T5 + } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5 ORT_THROW("Not Implemented"); // if (attribute_name == "encoder") { // ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, @@ -130,7 +130,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const { if (parameters_.model_type == 0) { // GPT-2 // Subgraph has constraint that the output is either float or float16 if (!gpt_subgraph_->IsOutputFloat16()) { - GreedySearchGpt impl{ + GreedySearchGpt impl{ *ctx_internal, *decoder_session_state, *gpt_subgraph_, @@ -149,7 +149,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const { return impl.Execute(*decoder_feeds_fetches_manager_); } else { - GreedySearchGpt impl{ + GreedySearchGpt impl{ *ctx_internal, *decoder_session_state, *gpt_subgraph_, diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h index 281834d9a9f29..a545c68b1ef48 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h @@ -21,7 +21,6 @@ namespace transformers { using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel -// bugbug: refactor class GreedySearch : public IControlFlowKernel { public: explicit GreedySearch(const OpKernelInfo& info) @@ -70,7 +69,7 @@ class GreedySearch : public IControlFlowKernel { update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; } - private: + protected: // Device specific functions GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::TopkFunc topk_func_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 6c54fc7198cf1..8978258961db6 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -68,7 +68,7 @@ struct GreedySearchState : public IGreedySearchState { }; // Base class of gready search implementation that is common for both GPT-2 and Bart/T5. -template +template class GreedySearchBase : public GenerateBase { public: GreedySearchBase(OpKernelContextInternal& context, @@ -76,7 +76,7 @@ class GreedySearchBase : public GenerateBase { concurrency::ThreadPool* thread_pool, void* cuda_stream, IConsoleDumper* cuda_dumper, - GreedySearchParameters& params, + ParametersT& params, const GenerationDeviceHelper::TopkFunc& topk_func, const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func) @@ -114,14 +114,14 @@ class GreedySearchBase : public GenerateBase { AllocatorPtr& allocator, int counter); - GreedySearchParameters* parameters_; + ParametersT* parameters_; // Device specific functions GenerationDeviceHelper::GreedySearchProcessLogitsFunc process_logits_func_; }; -template -Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context) { +template +Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context) { // Input shapes: // input_ids : (batch_size, sequence_length) // vocab_mask : (vocab_size) or nullptr @@ -134,8 +134,8 @@ Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context) return Status::OK(); } -template -Status GreedySearchBase::Initialize() { +template +Status GreedySearchBase::Initialize() { ORT_RETURN_IF_ERROR(this->context_.GetTempSpaceAllocator(&this->temp_space_allocator_)); ORT_RETURN_IF_ERROR(CheckScalarInput("max_length", 1, true)); @@ -149,14 +149,14 @@ Status GreedySearchBase::Initialize() { if (!this->IsCuda()) { // Logits processor is used in CPU only. In CUDA, cuda kernels are used instead. // Initialize processors after CheckInputs so that parameters_->vocab_mask is ready. - this->logits_processors_.Init(*parameters_); + this->logits_processors_.Init(*parameters_, thread_pool_); } return Status::OK(); } -template -Status GreedySearchBase::ProcessLogits( +template +Status GreedySearchBase::ProcessLogits( const OrtValue& logits, GreedySearchState& greedy_state, AllocatorPtr& allocator, @@ -166,8 +166,8 @@ Status GreedySearchBase::ProcessLogits( parameters_, counter, this->cuda_stream_, this->GetConsoleDumper()); } -template -Status GreedySearchBase::GenerateNextToken( +template +Status GreedySearchBase::GenerateNextToken( const OrtValue& logits, gsl::span& next_tokens, GreedySearchState& greedy_state, diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index cbbcf29672add..26d19a4fbdc84 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -14,8 +14,8 @@ namespace contrib { namespace transformers { // Beam search implementation for GPT-2 model. -template -class GreedySearchGpt : public GreedySearchBase { +template +class GreedySearchGpt : public GreedySearchBase { public: GreedySearchGpt(OpKernelContextInternal& context, const SessionState& decoder_session_state, @@ -23,7 +23,7 @@ class GreedySearchGpt : public GreedySearchBase { concurrency::ThreadPool* thread_pool, void* cuda_stream, IConsoleDumper* cuda_dumper, - GreedySearchParameters& params, + ParametersT& params, const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func, const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, const GenerationDeviceHelper::TopkFunc& topk_func, @@ -31,15 +31,15 @@ class GreedySearchGpt : public GreedySearchBase { const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func) - : GreedySearchBase(context, - decoder_session_state, - thread_pool, - cuda_stream, - cuda_dumper, - params, - topk_func, - process_logits_func, - device_copy_func), + : GreedySearchBase(context, + decoder_session_state, + thread_pool, + cuda_stream, + cuda_dumper, + params, + topk_func, + process_logits_func, + device_copy_func), gpt_subgraph_(gpt_subgraph), create_inputs_func_(create_inputs_func), add_to_feeds_func_(add_to_feeds_func), @@ -76,8 +76,8 @@ class GreedySearchGpt : public GreedySearchBase { GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_; }; -template -Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths, +template +Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths, OrtValue& expanded_input_ids, std::vector& feeds, IAllocatorUniquePtr& buffer) { @@ -97,8 +97,8 @@ Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengt buffer); } -template -Status GreedySearchGpt::UpdateFeeds( +template +Status GreedySearchGpt::UpdateFeeds( const std::vector& last_outputs, std::vector& next_inputs, int current_length, @@ -120,10 +120,10 @@ Status GreedySearchGpt::UpdateFeeds( gpt_subgraph_.GetFirstPresentOutputIndex()); } -template -Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds_fetches_manager) { +template +Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds_fetches_manager) { auto status = Status::OK(); - const GreedySearchParameters* parameters = this->parameters_; + const ParametersT* parameters = this->parameters_; // Allocate output tensors. int64_t sequences_dims[] = {parameters->batch_size, parameters->max_length}; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 6defafc086a75..26b597b09c6d1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -203,14 +203,15 @@ void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, } T* p = next_token_scores.scores.data(); - for (int i = 0; i < next_token_scores.scores.size(); i++) { + for (size_t i = 0; i < next_token_scores.scores.size(); i++) { *p /= temperature_; } } template -TopPLogitsProcessor::TopPLogitsProcessor(float top_p, float filter_value, int min_tokens_to_keep) - : top_p_(top_p), filter_value_(filter_value), min_tokens_to_keep_(min_tokens_to_keep) { +TopPLogitsProcessor::TopPLogitsProcessor(float top_p, float filter_value, + onnxruntime::concurrency::ThreadPool* thread_pool) + : top_p_(top_p), filter_value_(filter_value), thread_pool_(thread_pool) { } template @@ -238,18 +239,19 @@ void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, std::sort(sorted_scores.begin(), sorted_scores.end()); std::vector cumulative_probs(vocab_size); - ORT_RETURN_IF_ERROR(SoftmaxCPU(1, - vocab_size, - sorted_scores.data(), - cumulative_probs.data(), - false, - thread_pool)); + // bugbug + ORT_UNUSED_PARAMETER(SoftmaxCPU(1, + vocab_size, + sorted_scores.data(), + cumulative_probs.data(), + false, + thread_pool_)); std::unordered_set sorted_indices_to_remove; if (cumulative_probs[0] > top_p_) { sorted_indices_to_remove.insert(1); } - for (size_t j = 1; j < vocab_size - 1; j++) { + for (size_t j = 1; j < static_cast(vocab_size) - 1; j++) { cumulative_probs[j] += cumulative_probs[j - 1]; if (cumulative_probs[j] > top_p_) { sorted_indices_to_remove.insert(j + 1); @@ -270,7 +272,7 @@ PresencePenaltyLogitsProcessor::PresencePenaltyLogitsProcessor(const gsl::spa } template -void PresencePenaltyLogitsProcessor::Process(const ISequences* sequences, +void PresencePenaltyLogitsProcessor::Process(const ISequences*, NextTokenScores& next_token_scores) { if (presence_penalty_ == 0.0f) { return; @@ -279,21 +281,24 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences* sequences, assert(!presence_mask_.empty()); T* p = next_token_scores.scores.data(); - for (int i = 0; i < next_token_scores.scores.size(); i++) { + for (size_t i = 0; i < next_token_scores.scores.size(); i++) { *p -= presence_mask_[i] * presence_penalty_; } } -void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { - LogitsProcessorInitImpl(parameters); +void LogitsProcessorList::Init(const BeamSearchParameters& parameters, + onnxruntime::concurrency::ThreadPool* thread_pool) { + LogitsProcessorInitImpl(parameters, thread_pool); } -void LogitsProcessorList::Init(const GreedySearchParameters& parameters) { - LogitsProcessorInitImpl(parameters); +void LogitsProcessorList::Init(const GreedySearchParameters& parameters, + onnxruntime::concurrency::ThreadPool* thread_pool) { + LogitsProcessorInitImpl(parameters, thread_pool); } -void LogitsProcessorList::Init(const BeamSamplingParameters& parameters) { - LogitsProcessorInitImpl(parameters); +void LogitsProcessorList::Init(const SamplingParameters& parameters, + onnxruntime::concurrency::ThreadPool* thread_pool) { + LogitsProcessorInitImpl(parameters, thread_pool); } void LogitsProcessorList::Process(const ISequences* sequences, diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 2255d455d51d0..0005ed0016857 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -7,6 +7,7 @@ #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_parameters.h" #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" +#include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" namespace onnxruntime { @@ -99,7 +100,7 @@ class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor { template class TemperatureLogitsProcessor : public ILogitsProcessor { public: - TemperatureLogitsProcessor(int temperature); + TemperatureLogitsProcessor(float temperature); void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override; @@ -111,7 +112,8 @@ class TemperatureLogitsProcessor : public ILogitsProcessor { template class TopPLogitsProcessor : public ILogitsProcessor { public: - TopPLogitsProcessor(float top_p, float filter_value, int min_tokens_to_keep); + TopPLogitsProcessor(float top_p, float filter_value, + onnxruntime::concurrency::ThreadPool* thread_pool); void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override; @@ -120,6 +122,7 @@ class TopPLogitsProcessor : public ILogitsProcessor { float top_p_; float filter_value_; int min_tokens_to_keep_; + onnxruntime::concurrency::ThreadPool* thread_pool_; }; template @@ -139,14 +142,15 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { class LogitsProcessorList : public ILogitsProcessorList { public: LogitsProcessorList() = default; - void Init(const BeamSearchParameters& parameters); - void Init(const GreedySearchParameters& parameters); - void Init(const BeamSamplingParameters& parameters); + void Init(const BeamSearchParameters& parameters, onnxruntime::concurrency::ThreadPool* thread_pool); + void Init(const GreedySearchParameters& parameters, onnxruntime::concurrency::ThreadPool* thread_pool); + void Init(const SamplingParameters& parameters, onnxruntime::concurrency::ThreadPool* thread_pool); void Process(const ISequences* sequences, gsl::span& next_token_scores, int step); private: template - void LogitsProcessorInitImpl(const GenerationParametersT& parameters) { + void LogitsProcessorInitImpl(const GenerationParametersT& parameters, + onnxruntime::concurrency::ThreadPool* thread_pool) { processor_list_.clear(); if (parameters.repetition_penalty != 1.0f) { // 1.0 means no penalty @@ -189,10 +193,18 @@ class LogitsProcessorList : public ILogitsProcessorList { if (parameters.top_p > 0) { top_p_processor_ = std::make_unique>(parameters.top_p, parameters.filter_value, - parameters.min_tokens_to_keep); + thread_pool); processor_list_.push_back(top_p_processor_.get()); } + if (!parameters.presence_mask.empty()) { + presence_penalty_processor_ = std::make_unique< + PresencePenaltyLogitsProcessor + >(parameters.presence_mask, + parameters.presence_penalty); + processor_list_.push_back(presence_penalty_processor_.get()); + } + batch_beam_size_ = parameters.BatchBeamSize(); vocab_size_ = parameters.vocab_size; } @@ -201,8 +213,6 @@ class LogitsProcessorList : public ILogitsProcessorList { int vocab_size_; InlinedVector*> processor_list_; - onnxruntime::concurrency::ThreadPool* thread_pool_; - std::unique_ptr> repetition_penalty_processor_; std::unique_ptr> no_repeat_ngram_processor_; std::unique_ptr> vocab_mask_processor_; @@ -210,6 +220,7 @@ class LogitsProcessorList : public ILogitsProcessorList { std::unique_ptr> min_length_processor_; std::unique_ptr> temperature_processor_; std::unique_ptr> top_p_processor_; + std::unique_ptr> presence_penalty_processor_; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc new file mode 100644 index 0000000000000..09d831ba38ff4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// there's no way to use a raw pointer as the copy destination with std::copy_n +// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset +// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include +#include +#include +#include "core/common/safeint.h" +#include "core/providers/cpu/math/top_k.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/framework/allocator.h" +#include "core/framework/framework_common.h" +#include "core/framework/feeds_fetches_manager.h" +#include "core/framework/op_kernel_context_internal.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/framework/session_options.h" +#include "core/framework/TensorSeq.h" +#include "core/framework/ort_value.h" +#include "gsl/gsl" +#include "contrib_ops/cpu/transformers/sampling.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/sequences.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" +#include "contrib_ops/cpu/transformers/greedy_search_impl_gpt.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Sampling, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + transformers::Sampling); + +REGISTER_KERNEL_TYPED(float) + +namespace transformers { + +Status Sampling::Compute(OpKernelContext* ctx) const { + auto* ctx_internal = static_cast(ctx); + + auto* decoder_session_state = ctx_internal->SubgraphSessionState("decoder"); + ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute."); + ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); + + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + // make a copy since we will update the parameters based on inputs later + SamplingParameters parameters = parameters_; + + if (parameters_.model_type == 0) { // GPT-2 + // Subgraph has constraint that the output is either float or float16 + if (!gpt_subgraph_->IsOutputFloat16()) { + GreedySearchGpt impl{ + *ctx_internal, + *decoder_session_state, + *gpt_subgraph_, + thread_pool, + cuda_stream_, + dumper_, + parameters, + GenerationCpuDeviceHelper::CreateGptInputs, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::GreedySearchProcessLogits, + init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState, + device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, + update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(*decoder_feeds_fetches_manager_); + } else { + GreedySearchGpt impl{ + *ctx_internal, + *decoder_session_state, + *gpt_subgraph_, + thread_pool, + cuda_stream_, + dumper_, + parameters, + GenerationCpuDeviceHelper::CreateGptInputs, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_fp16_func_, + init_greedy_state_fp16_func_, + device_copy_func_, + update_gpt_feeds_fp16_func_}; + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(*decoder_feeds_fetches_manager_); + } + } + + return Status::OK(); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.h b/onnxruntime/contrib_ops/cpu/transformers/sampling.h new file mode 100644 index 0000000000000..fbc1cc5d7adc3 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "core/common/common.h" +#include "contrib_ops/cpu/transformers/greedy_search.h" +#include "contrib_ops/cpu/transformers/sampling_parameters.h" +#include "contrib_ops/cpu/transformers/subgraph_gpt.h" +#include "contrib_ops/cpu/transformers/generation_device_helper.h" + + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +class Sampling : public GreedySearch { + public: + explicit Sampling(const OpKernelInfo& info) : GreedySearch(info) {} + + Status Compute(OpKernelContext* ctx) const override; + + private: + SamplingParameters parameters_; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc index 2ec80917014ec..51153c33912da 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -16,6 +16,7 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0)); temperature = info.GetAttrOrDefault("temperature", 1.0f); top_p = info.GetAttrOrDefault("top_p", 0.0f); + filter_value = info.GetAttrOrDefault("filter_value", -std::numeric_limits::infinity()); presence_penalty = info.GetAttrOrDefault("presence_penalty", 0.0f); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h index 729e54933c82f..6c0f866f09fe2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h @@ -10,13 +10,8 @@ namespace onnxruntime { namespace contrib { namespace transformers { -struct SamplingParameters : public GreadySearchParameters { +struct SamplingParameters : public GreedySearchParameters { void ParseFromAttributes(const OpKernelInfo& info); - - float presence_penalty; - float temperature; - float top_p; - gsl::span presence_mask; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 46e4066d7e286..257efaa9c0eb6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -212,7 +212,7 @@ Status ProcessLogits(const OrtValue& logits, // onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors transformers::IBeamScorer* beam_scorer, // beam scorer - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper @@ -456,7 +456,7 @@ Status GreedySearchProcessLogits( AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper @@ -885,7 +885,7 @@ template Status ProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, transformers::IBeamScorer* beam_scorer, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, void* stream, const transformers::IConsoleDumper* dumper); @@ -897,7 +897,7 @@ template Status GreedySearchProcessLogits( AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, void* stream, const transformers::IConsoleDumper* dumper); @@ -950,7 +950,7 @@ template Status ProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, transformers::IBeamScorer* beam_scorer, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, void* stream, const transformers::IConsoleDumper* dumper); @@ -962,7 +962,7 @@ template Status GreedySearchProcessLogits( AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, void* stream, const transformers::IConsoleDumper* dumper); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index e0e5890d0b36c..9c9b5fd90ec51 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -54,7 +54,7 @@ Status ProcessLogits(const OrtValue& logits, // onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors transformers::IBeamScorer* beam_scorer, // beam scorer - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper @@ -66,7 +66,7 @@ Status GreedySearchProcessLogits(const OrtValue& logits, AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5d577ebb3be1e..cb6ff56a32ba1 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1106,6 +1106,35 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, GreedySearchShapeInference(ctx); })); +ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, + OpSchema() + .SetDoc("Greedy Sampling for text generation.") + .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) + .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) + .Attr("temperature", "temperature for sampling", AttributeProto::FLOAT, 1.0f) + .Attr("top_p", "top_p for sampling", AttributeProto::FLOAT, 0.0f) + .Attr("filter_value", "filter value for top_p", AttributeProto::FLOAT, -1e20f) + .Attr("presence_penalty", "presence penalty for sampling", AttributeProto::FLOAT, 0.0f) + .Attr("model_type", "model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast(0)) + .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) + .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") + .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") + .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) + .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) + .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) + .Input(7, "presence_penalty_mask", "presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") + .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + GreedySearchShapeInference(ctx); + })); + ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") From 3331c8f12d7c46e96c2a2b7f7cf6700e4b61395a Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 25 Oct 2022 00:36:48 +0000 Subject: [PATCH 04/51] add multinomial cpu --- .../transformers/generation_device_helper.cc | 108 ++++++++++++------ .../transformers/generation_device_helper.h | 2 + .../cpu/transformers/generation_shared.h | 1 + .../transformers/greedy_search_impl_base.h | 7 +- .../contrib_ops/cpu/transformers/sampling.cc | 15 --- .../contrib_ops/cpu/transformers/sampling.h | 6 - .../transformers/generation_device_helper.cc | 5 +- .../transformers/generation_device_helper.h | 1 + .../core/providers/cpu/generator/random.cc | 38 ++++-- .../core/providers/cpu/generator/random.h | 9 ++ 10 files changed, 125 insertions(+), 67 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index f853bd74a59be..d3852f89bc447 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -6,6 +6,7 @@ #include #include "core/providers/cpu/math/top_k.h" #include "core/providers/cpu/math/softmax_shared.h" +#include "core/providers/cpu/generator/random.h" #include "core/common/safeint.h" #include "core/common/gsl.h" #include "contrib_ops/cpu/transformers/sequences.h" @@ -404,6 +405,7 @@ Status GreedySearchProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper @@ -447,44 +449,81 @@ Status GreedySearchProcessLogits( dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size); #endif - // next_tokens = torch.argmax(scores, dim=-1) - int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; - TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); - auto element_type = DataTypeImpl::GetType(); - OrtValue next_token_scores_value; - Tensor::InitOrtValue(element_type, - next_token_scores_shape, - next_token_scores.data(), - allocator->Info(), - next_token_scores_value); - const Tensor& input = next_token_scores_value.Get(); - - constexpr int axis = 1; - constexpr unsigned top_k = 1; - constexpr bool largest = true; - constexpr bool sorted = false; - - Tensor topk_scores; - Tensor topk_indices; - ORT_RETURN_IF_ERROR( - TopK(&input, - axis, - top_k, - largest, - sorted, - allocator, - stream, - thread_pool, - topk_scores, - topk_indices)); + if (do_sampling) { + // bugbug: is this Softmax really needed? + gsl::span& next_token_probs = greedy_state->next_token_probs; + ORT_RETURN_IF_ERROR(SoftmaxCPU(batch_size, + vocab_size, + next_token_scores.data(), + next_token_probs.data(), + false, + thread_pool)); + + // torch.multinomial() + int64_t next_token_probs_dims[] = {static_cast(batch_size), vocab_size}; + TensorShape next_token_probs_shape(&next_token_probs_dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue next_token_probs_value; + Tensor::InitOrtValue(element_type, + next_token_probs_shape, + next_token_probs.data(), + allocator->Info(), + next_token_probs_value); + const Tensor& input = next_token_probs_value.Get(); + + float seed = 0.f; + std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(seed)}; + Tensor sampled_idx; + ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocator, + input, + batch_size, + vocab_size, + 1, + generator, + sampled_idx)); + + gsl::span next_token_indices = sampled_idx.DataAsSpan(); + gsl::copy(next_token_indices, greedy_state->next_tokens_cpu); + } else { + // next_tokens = torch.argmax(scores, dim=-1) + int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; + TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue next_token_scores_value; + Tensor::InitOrtValue(element_type, + next_token_scores_shape, + next_token_scores.data(), + allocator->Info(), + next_token_scores_value); + const Tensor& input = next_token_scores_value.Get(); + + constexpr int axis = 1; + constexpr unsigned top_k = 1; + constexpr bool largest = true; + constexpr bool sorted = false; + + Tensor topk_scores; + Tensor topk_indices; + ORT_RETURN_IF_ERROR(TopK(&input, + axis, + top_k, + largest, + sorted, + allocator, + stream, + thread_pool, + topk_scores, + topk_indices)); #ifdef DEBUG_GENERATION - dumper->Print("topk_scores", topk_scores); - dumper->Print("topk_indices", topk_indices); + dumper->Print("topk_scores", topk_scores); + dumper->Print("topk_indices", topk_indices); #endif - gsl::span next_token_indices = topk_indices.DataAsSpan(); - gsl::copy(next_token_indices, greedy_state->next_tokens_cpu); + gsl::span next_token_indices = topk_indices.DataAsSpan(); + gsl::copy(next_token_indices, greedy_state->next_tokens_cpu); + } + #ifdef DEBUG_GENERATION gsl::span next_tokens(greedy_state->next_tokens_cpu.data(), @@ -834,6 +873,7 @@ template Status GreedySearchProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, const transformers::IGenerationParameters* parameters, + bool do_sampling, int step, void* stream, const transformers::IConsoleDumper* dumper); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index e4c571dfd2b51..e2d46ce244eea 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -96,6 +96,7 @@ using GreedySearchProcessLogitsFunc = std::function; // tensor dumper @@ -212,6 +213,7 @@ Status GreedySearchProcessLogits(const OrtValue& logits, onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 5e5accdc5b962..ebf18152ab584 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -62,6 +62,7 @@ struct IGreedySearchState { gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. gsl::span eos_meet; // shape (batch_size) gsl::span next_token_scores; // shape (batch_size, vocab_size) + gsl::span next_token_probs; // shape (batch_size, vocab_size) gsl::span next_tokens; // shape (batch_size) }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 8978258961db6..02b38bbb675b2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -41,6 +41,7 @@ struct GreedySearchState : public IGreedySearchState { // below buffers are on cpu or cuda size_t next_token_size = SafeInt(batch_size) * vocab_size; this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); + this->next_token_probs = AllocateBuffer(allocator, next_token_probs_buffer_, next_token_size); this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size); } @@ -61,6 +62,7 @@ struct GreedySearchState : public IGreedySearchState { BufferUniquePtr sequences_space_buffer_; BufferUniquePtr sequence_lengths_buffer_; BufferUniquePtr next_token_scores_buffer_; + BufferUniquePtr next_token_probs_buffer_; BufferUniquePtr next_tokens_buffer_; BufferUniquePtr next_tokens_cpu_buffer_; BufferUniquePtr next_positions_buffer_; @@ -161,9 +163,10 @@ Status GreedySearchBase::ProcessLogits( GreedySearchState& greedy_state, AllocatorPtr& allocator, int counter) { + bool use_sampling = std::is_same::value; return process_logits_func_(logits, &greedy_state, &(greedy_state.sequences), allocator, - this->thread_pool_, &this->logits_processors_, - parameters_, counter, this->cuda_stream_, this->GetConsoleDumper()); + this->thread_pool_, &this->logits_processors_, parameters_, + use_sampling, counter, this->cuda_stream_, this->GetConsoleDumper()); } template diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index 09d831ba38ff4..42e3da40fc4e0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -9,23 +9,8 @@ #pragma warning(disable : 4996) #endif -#include -#include -#include -#include "core/common/safeint.h" -#include "core/providers/cpu/math/top_k.h" -#include "core/providers/cpu/tensor/utils.h" -#include "core/framework/allocator.h" -#include "core/framework/framework_common.h" -#include "core/framework/feeds_fetches_manager.h" #include "core/framework/op_kernel_context_internal.h" -#include "core/framework/session_state.h" -#include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" -#include "core/framework/session_options.h" -#include "core/framework/TensorSeq.h" -#include "core/framework/ort_value.h" -#include "gsl/gsl" #include "contrib_ops/cpu/transformers/sampling.h" #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/sequences.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.h b/onnxruntime/contrib_ops/cpu/transformers/sampling.h index fbc1cc5d7adc3..890559f54cc76 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.h @@ -2,14 +2,8 @@ // Licensed under the MIT License. #pragma once -#include -#include -#include "core/common/common.h" #include "contrib_ops/cpu/transformers/greedy_search.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" -#include "contrib_ops/cpu/transformers/subgraph_gpt.h" -#include "contrib_ops/cpu/transformers/generation_device_helper.h" - namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 257efaa9c0eb6..e1c82b0891b25 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -457,11 +457,12 @@ Status GreedySearchProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper ORT_UNUSED_PARAMETER(logits_processors); - + ORT_UNUSED_PARAMETER(do_sampling); #ifndef DEBUG_GENERATION ORT_UNUSED_PARAMETER(dumper); #endif @@ -898,6 +899,7 @@ template Status GreedySearchProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, const transformers::IGenerationParameters* parameters, + bool do_sampling, int step, void* stream, const transformers::IConsoleDumper* dumper); @@ -963,6 +965,7 @@ template Status GreedySearchProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, const transformers::IGenerationParameters* parameters, + bool do_sampling, int step, void* stream, const transformers::IConsoleDumper* dumper); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index 9c9b5fd90ec51..2fe57e8d35c19 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -67,6 +67,7 @@ Status GreedySearchProcessLogits(const OrtValue& logits, onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index 3add38239e957..84a283c3c863e 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -207,13 +207,13 @@ template using EigenVector = Eigen::TensorMap>; template -static Status MultinomialCompute(OpKernelContext* ctx, - const Tensor& X, - const int64_t batch_size, - const int64_t num_classes, - const int64_t num_samples, - std::default_random_engine& generator, - Tensor& Y) { +Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y) { if (!utils::HasType()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build."); } @@ -227,8 +227,6 @@ static Status MultinomialCompute(OpKernelContext* ctx, Matrix output = Matrix(Y.MutableData(), Y_dims); // BEGIN create temporary tensor - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); auto cdf_data = static_cast(alloc->Alloc(SafeInt(sizeof(double)) * num_classes)); BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(std::move(alloc))); Eigen::array cdf_dims = {{num_classes}}; @@ -271,6 +269,20 @@ static Status MultinomialCompute(OpKernelContext* ctx, return Status::OK(); } +template +static Status MultinomialCompute(OpKernelContext* ctx, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y) { + // BEGIN create temporary tensor + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); + return MultinomialComputeShared(alloc, X, batch_size, num_classes, num_samples, generator, Y); +} + Status Multinomial::Compute(OpKernelContext* ctx) const { const auto* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); @@ -408,4 +420,12 @@ void GenerateData(std::default_random_engine& generator, TDistribution distribut } } +template Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y); + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/random.h b/onnxruntime/core/providers/cpu/generator/random.h index 7f5a761853d0a..dc99328b6c5d6 100644 --- a/onnxruntime/core/providers/cpu/generator/random.h +++ b/onnxruntime/core/providers/cpu/generator/random.h @@ -13,6 +13,15 @@ namespace onnxruntime { +template +Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y); + class RandomNormal final : public OpKernel { public: RandomNormal(const OpKernelInfo& info) : OpKernel(info) { From e7e4d97c2f783d204a21ebd3f5300626690728c7 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 25 Oct 2022 05:47:08 +0000 Subject: [PATCH 05/51] fix build --- onnxruntime/contrib_ops/cpu/transformers/logits_processor.h | 3 +-- .../contrib_ops/cpu/transformers/sampling_parameters.cc | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 0005ed0016857..8b2bc4a2811be 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -106,7 +106,7 @@ class TemperatureLogitsProcessor : public ILogitsProcessor { NextTokenScores& next_token_scores) override; private: - int temperature_; + float temperature_; }; template @@ -121,7 +121,6 @@ class TopPLogitsProcessor : public ILogitsProcessor { private: float top_p_; float filter_value_; - int min_tokens_to_keep_; onnxruntime::concurrency::ThreadPool* thread_pool_; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc index 51153c33912da..4da912b1bce24 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -6,8 +6,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -constexpr int kMaxSequenceLength = 4096; - void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { model_type = static_cast(info.GetAttrOrDefault("model_type", 0)); eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); From bc42d2975f74c163309fb7aad669ff0131734d26 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 25 Oct 2022 22:02:54 +0000 Subject: [PATCH 06/51] logits wrapper cuda --- .../cuda/transformers/beam_search_impl.cu | 27 +++++++ .../cuda/transformers/beam_search_impl.h | 3 + .../transformers/generation_device_helper.cc | 72 +++++++++++-------- 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu index 6bc52758c7cc3..c4eb9c968c616 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu @@ -64,6 +64,9 @@ __global__ void LogitsProcessKernel( T* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, + const int* presence_mask, + float presence_penalty, + float temperature, int num_beams, int vocab_size, int total_elements, @@ -136,6 +139,18 @@ __global__ void LogitsProcessKernel( if (word_id == demote_token_id) { next_token_scores[index] = cub::FpLimits::Lowest(); } + + // PresencePenaltyLogitsProcessor + if (presence_mask != nullptr && presence_mask[index] == 1) { + float score = (float)next_token_scores[index] - presence_penalty; + next_token_scores[index] = (T)score; + } + + // TemperatureLogitsProcessor + if (temperature != 1.0f) { + float score = (float)(next_token_scores[index]); + next_token_scores[index] = (T)(score / temperature); + } } } @@ -144,6 +159,9 @@ void LaunchLogitsProcessKernel( T* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, + const int* presence_mask, + float presence_penalty, + float temperature, int batch_size, int num_beams, int vocab_size, @@ -161,6 +179,9 @@ void LaunchLogitsProcessKernel( next_token_scores, vocab_mask, prefix_vocab_mask, + presence_mask, + presence_penalty, + temperature, num_beams, vocab_size, total_elements, @@ -177,6 +198,9 @@ template void LaunchLogitsProcessKernel( float* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, + const int* presence_mask, + float presence_penalty, + float temperature, int batch_size, int num_beams, int vocab_size, @@ -192,6 +216,9 @@ template void LaunchLogitsProcessKernel( half* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, + const int* presence_mask, + float presence_penalty, + float temperature, int batch_size, int num_beams, int vocab_size, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h index b1685326a1279..57b3ffde0a9bd 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h @@ -29,6 +29,9 @@ void LaunchLogitsProcessKernel( T* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, + const int* presence_mask, + float presence_penalty, + float temperature, int batch_size, int num_beams, int vocab_size, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e1c82b0891b25..9d2a216e82924 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -317,6 +317,9 @@ Status ProcessLogits(const OrtValue& logits, // next_token_scores.data(), parameters->vocab_mask.data(), step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. + parameters->presence_mask.data(), + parameters->presence_penalty, + parameters->temperature, parameters->batch_size, parameters->num_beams, parameters->vocab_size, @@ -462,7 +465,6 @@ Status GreedySearchProcessLogits( void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper ORT_UNUSED_PARAMETER(logits_processors); - ORT_UNUSED_PARAMETER(do_sampling); #ifndef DEBUG_GENERATION ORT_UNUSED_PARAMETER(dumper); #endif @@ -531,6 +533,9 @@ Status GreedySearchProcessLogits( reinterpret_cast(next_token_scores.data()), parameters->vocab_mask.data(), step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. + parameters->presence_mask.data(), + parameters->presence_penalty, + parameters->temperature, parameters->batch_size, parameters->num_beams, parameters->vocab_size, @@ -549,41 +554,46 @@ Status GreedySearchProcessLogits( // TODO(wy): support output_scores in greedy search ORT_UNUSED_PARAMETER(output_scores); - // next_tokens = torch.argmax(scores, dim=-1) - int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; - TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); - auto element_type = DataTypeImpl::GetType(); - OrtValue next_token_scores_value; - Tensor::InitOrtValue(element_type, - next_token_scores_shape, - next_token_scores.data(), - allocator->Info(), - next_token_scores_value); - const Tensor& input = next_token_scores_value.Get(); - - constexpr int axis = 1; - constexpr unsigned top_k = static_cast(1); - constexpr bool largest = true; - constexpr bool sorted = false; - - auto topk_scores = Tensor::CreateDefault(); - auto topk_indices = Tensor::CreateDefault(); - ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, - *topk_scores, *topk_indices)); + if (do_sampling) { + ORT_UNUSED_PARAMETER(do_sampling); + + } else { + // next_tokens = torch.argmax(scores, dim=-1) + int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; + TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue next_token_scores_value; + Tensor::InitOrtValue(element_type, + next_token_scores_shape, + next_token_scores.data(), + allocator->Info(), + next_token_scores_value); + const Tensor& input = next_token_scores_value.Get(); + + constexpr int axis = 1; + constexpr unsigned top_k = static_cast(1); + constexpr bool largest = true; + constexpr bool sorted = false; + + auto topk_scores = Tensor::CreateDefault(); + auto topk_indices = Tensor::CreateDefault(); + ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, + *topk_scores, *topk_indices)); #ifdef DEBUG_GENERATION - dumper->Print("topk_scores", *(topk_scores.get())); - dumper->Print("topk_indices", *(topk_indices.get())); + dumper->Print("topk_scores", *(topk_scores.get())); + dumper->Print("topk_indices", *(topk_indices.get())); #endif - const int64_t* next_token_indices = topk_indices->Data(); + const int64_t* next_token_indices = topk_indices->Data(); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), - next_token_indices, - greedy_state->next_tokens_cpu.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), + next_token_indices, + greedy_state->next_tokens_cpu.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } #ifdef DEBUG_GENERATION dumper->Print("greedy_state->next_tokens", greedy_state->next_tokens.data(), batch_size, 1); From 713d932651a4283feaf8cbdf21f5fd2a7b931e18 Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 26 Oct 2022 21:27:17 +0000 Subject: [PATCH 07/51] fix a bug --- .../cpu/transformers/generation_device_helper.cc | 3 ++- .../contrib_ops/cpu/transformers/generation_shared.h | 10 +++++----- .../contrib_ops/cpu/transformers/logits_processor.cc | 1 + 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index d3852f89bc447..75b8f9d1d48ae 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -449,6 +449,8 @@ Status GreedySearchProcessLogits( dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size); #endif + constexpr unsigned top_k = 1; + if (do_sampling) { // bugbug: is this Softmax really needed? gsl::span& next_token_probs = greedy_state->next_token_probs; @@ -498,7 +500,6 @@ Status GreedySearchProcessLogits( const Tensor& input = next_token_scores_value.Get(); constexpr int axis = 1; - constexpr unsigned top_k = 1; constexpr bool largest = true; constexpr bool sorted = false; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index ebf18152ab584..615323e5c5da8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -8,9 +8,9 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -#ifndef NDEBUG -//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -#endif +//#ifndef NDEBUG +#define DEBUG_GENERATION 1 // uncomment it for debugging beam search +//#endif namespace onnxruntime { @@ -109,8 +109,8 @@ struct IGenerationParameters { int no_repeat_ngram_size; bool early_stopping; float presence_penalty; - float temperature; - float top_p; + float temperature = 1.0f; + float top_p = 0.0f; float filter_value; // Parameters from inputs diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 26b597b09c6d1..5368a907755d3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -229,6 +229,7 @@ void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, std::vector sorted_scores(beam_token_scores.begin(), beam_token_scores.end()); + // bugbug: decending sort std::vector sorted_indices(beam_token_scores.size()); std::iota(sorted_indices.begin(), sorted_indices.end(), 0); std::sort(sorted_indices.begin(), From 400614e1c526cd6729e2c9f78c70d4df541be2c7 Mon Sep 17 00:00:00 2001 From: wangyems Date: Thu, 27 Oct 2022 05:41:12 +0000 Subject: [PATCH 08/51] add cub::radixsort --- .../cpu/transformers/generation_shared.h | 6 +- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + ...search_impl.cu => generation_cuda_impl.cu} | 122 +++++++++++++++++- ...m_search_impl.h => generation_cuda_impl.h} | 26 ++++ .../transformers/generation_device_helper.cc | 51 +++++++- .../contrib_ops/cuda/transformers/sampling.cc | 69 ++++++++++ .../contrib_ops/cuda/transformers/sampling.h | 26 ++++ 7 files changed, 296 insertions(+), 6 deletions(-) rename onnxruntime/contrib_ops/cuda/transformers/{beam_search_impl.cu => generation_cuda_impl.cu} (64%) rename onnxruntime/contrib_ops/cuda/transformers/{beam_search_impl.h => generation_cuda_impl.h} (63%) create mode 100644 onnxruntime/contrib_ops/cuda/transformers/sampling.cc create mode 100644 onnxruntime/contrib_ops/cuda/transformers/sampling.h diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 615323e5c5da8..802024c41667b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -8,9 +8,9 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -//#ifndef NDEBUG -#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -//#endif +#ifndef NDEBUG +//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search +#endif namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 84e8b0bfc9763..385f3bc9d735b 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -75,6 +75,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); @@ -189,6 +190,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu similarity index 64% rename from onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu rename to onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index c4eb9c968c616..436db4f82ee94 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -4,7 +4,9 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "cub/util_type.cuh" -#include "contrib_ops/cuda/transformers/beam_search_impl.h" +#include +#include "contrib_ops/cuda/transformers/generation_cuda_impl.h" + namespace onnxruntime { namespace contrib { @@ -298,6 +300,124 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data, old_mask_data, mask_data, next_positions, batch_beam_size, current_length); } +// bugbug: merge those kernels into one +template +size_t GetTempStorageSize(const T *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream) { + size_t temp_storage_bytes = 0; + cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, + temp_storage_bytes, + d_keys_in, + (T*)nullptr, + d_values_in, + (int*)nullptr, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); + return temp_storage_bytes; +} + +template size_t GetTempStorageSize( + const float *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream); + +template size_t GetTempStorageSize( + const half *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream); + +// bugbug: merge to one kernel +__global__ void SetupParamsKernel(int* d_values_in, + int* d_offsets, + int batch_size, + int vocab_size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int total_elements = batch_size * vocab_size; + if (index < total_elements) { + d_values_in[index] = index % vocab_size; + } + if (index < batch_size + 1) { + d_offsets[index] = index * vocab_size; + } +} + +void LaunchSetupParamsKernel(int* d_values_in, + int* d_offsets, + int batch_size, + int vocab_size, + cudaStream_t stream) { +int total_elements = batch_size * vocab_size; +constexpr int blockSize = 256; +const int gridSize = (total_elements + blockSize - 1) / blockSize; +SetupParamsKernel<<>>(d_values_in, + d_offsets, + batch_size, + vocab_size); +} + +template +void LaunchSortPairsDescending(void *d_temp_storage, + size_t temp_storage_bytes, + const T *d_keys_in, + T *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream) { + cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); +} + +template void LaunchSortPairsDescending(void *d_temp_storage, + size_t temp_storage_bytes, + const float *d_keys_in, + float *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream); + +template void LaunchSortPairsDescending(void *d_temp_storage, + size_t temp_storage_bytes, + const half *d_keys_in, + half *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h similarity index 63% rename from onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h rename to onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 57b3ffde0a9bd..4e592a80daa83 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -58,6 +58,32 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data, int current_length, cudaStream_t stream); +template +size_t GetTempStorageSize(const T *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream); + +void LaunchSetupParamsKernel(int* d_values_in, + int* d_offsets, + int batch_size, + int vocab_size, + cudaStream_t stream); + +template +void LaunchSortPairsDescending(void *d_temp_storage, + size_t temp_storage_bytes, + const T *d_keys_in, + T *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 9d2a216e82924..801a548d2152b 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -4,13 +4,14 @@ #include #include #include "core/providers/shared_library/provider_api.h" +#include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/math/topk_impl.h" #include "core/providers/cuda/math/softmax.h" #include "core/providers/cuda/shared_inc/accumulation_type.h" #include "core/framework/ort_value.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" #include -#include "contrib_ops/cuda/transformers/beam_search_impl.h" +#include "contrib_ops/cuda/transformers/generation_cuda_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" @@ -555,7 +556,53 @@ Status GreedySearchProcessLogits( ORT_UNUSED_PARAMETER(output_scores); if (do_sampling) { - ORT_UNUSED_PARAMETER(do_sampling); + // bugbug: move this outside probably in execute()? + size_t bytes = SafeInt(sizeof(int)) * 2 * parameters->batch_size * parameters->vocab_size + + SafeInt(sizeof(int)) * (parameters->batch_size + 1) + + SafeInt(sizeof(CudaT)) * parameters->batch_size * parameters->vocab_size + + SafeInt(sizeof(float) * parameters->batch_size * parameters->vocab_size); + void* data = allocator->Alloc(bytes); + BufferUniquePtr workspace_buffer(data, BufferDeleter(allocator)); + int* d_index_buffer_in = reinterpret_cast(workspace_buffer.get()); + int* d_index_buffer_out = d_index_buffer_in + parameters->batch_size * parameters->vocab_size; + int* d_offset_buffer = d_index_buffer_out + parameters->batch_size * parameters->vocab_size; + CudaT* d_sorted_score_buffer = reinterpret_cast(d_offset_buffer + parameters->batch_size + 1); + float* d_softmaxed_score_buffer = reinterpret_cast(d_sorted_score_buffer + parameters->batch_size * parameters->vocab_size); + + size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), + d_index_buffer_in, + d_offset_buffer, + parameters->vocab_size, + parameters->batch_size, + cuda_stream); + + cuda::LaunchSetupParamsKernel(d_index_buffer_in, + d_offset_buffer, + parameters->batch_size, + parameters->vocab_size, + cuda_stream); + + void* temp_storage = allocator->Alloc(temp_storage_bytes); + BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); + + cuda::LaunchSortPairsDescending(temp_storage_buffer.get(), + temp_storage_bytes, + reinterpret_cast(next_token_scores.data()), + d_sorted_score_buffer, + d_index_buffer_in, + d_index_buffer_out, + parameters->vocab_size, + parameters->batch_size, + d_offset_buffer, + cuda_stream); + + dispatch_blockwise_softmax_forward(cuda_stream, + d_softmaxed_score_buffer, + d_sorted_score_buffer, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size); + } else { // next_tokens = torch.argmax(scores, dim=-1) diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling.cc b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc new file mode 100644 index 0000000000000..9707a3d03daf2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_execution_provider.h" +#include "contrib_ops/cuda/transformers/sampling.h" +#include "contrib_ops/cuda/transformers/generation_device_helper.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + Sampling, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'input_ids' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 1) // 'max_length' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 2) // 'min_length' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 3) // 'repetition_penalty' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 6) // 'custom_attention_mask' needs to be on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Sampling); + +transformers::CudaTensorConsoleDumper g_cuda_dumper_sampling; + +Sampling::Sampling(const OpKernelInfo& info) + : onnxruntime::contrib::transformers::Sampling(info) { + SetComputeStream(static_cast(info.GetExecutionProvider()->GetComputeStream())); + + SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds, + GenerationCudaDeviceHelper::TopK, + GenerationCudaDeviceHelper::DeviceCopy, + GenerationCudaDeviceHelper::GreedySearchProcessLogits, + GenerationCudaDeviceHelper::GreedySearchProcessLogits, + GenerationCudaDeviceHelper::InitGreedyState, + GenerationCudaDeviceHelper::InitGreedyState); + + SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds, + GenerationCudaDeviceHelper::UpdateGptFeeds); + + SetConsoleDumper(&g_cuda_dumper_sampling); +} + +Status Sampling::ComputeInternal(OpKernelContext* context) const { + return onnxruntime::contrib::transformers::Sampling::Compute(context); +} + +Status Sampling::Compute(OpKernelContext* context) const { + auto s = ComputeInternal(context); + + if (s.IsOK()) { + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err)); + } + } + + return s; +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling.h b/onnxruntime/contrib_ops/cuda/transformers/sampling.h new file mode 100644 index 0000000000000..65bee53573ec4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/transformers/sampling.h" + +namespace onnxruntime { +class SessionState; + +namespace contrib { +namespace cuda { + +class Sampling final : public onnxruntime::contrib::transformers::Sampling { + public: + Sampling(const OpKernelInfo& info); + + Status Compute(OpKernelContext* context) const override; + + private: + Status ComputeInternal(OpKernelContext* context) const; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime From 3dfde5681e36515a87aa2473e3927245537d3ddc Mon Sep 17 00:00:00 2001 From: wangyems Date: Thu, 27 Oct 2022 21:17:40 +0000 Subject: [PATCH 09/51] add filterlogits cuda --- .../cuda/transformers/generation_cuda_impl.cu | 87 +++++++++++++++++-- .../cuda/transformers/generation_cuda_impl.h | 10 +++ .../transformers/generation_device_helper.cc | 27 +++++- 3 files changed, 114 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 436db4f82ee94..80ef7eb999307 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -361,13 +361,13 @@ void LaunchSetupParamsKernel(int* d_values_in, int batch_size, int vocab_size, cudaStream_t stream) { -int total_elements = batch_size * vocab_size; -constexpr int blockSize = 256; -const int gridSize = (total_elements + blockSize - 1) / blockSize; -SetupParamsKernel<<>>(d_values_in, - d_offsets, - batch_size, - vocab_size); + int total_elements = batch_size * vocab_size; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + SetupParamsKernel<<>>(d_values_in, + d_offsets, + batch_size, + vocab_size); } template @@ -418,6 +418,79 @@ template void LaunchSortPairsDescending(void *d_temp_storage, int *d_offsets, cudaStream_t stream); +// A trick here: cumuliative sum of the sorted logits is a temporarily variable in the kernel. +template +__global__ void FilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + T* d_logits_in_out, + float top_p, + float filter_value, + int batch_size, + int vocab_size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + + int vocab_idx = index % vocab_size; + int batch_id = index / vocab_size; + int start_index = batch_id * vocab_size; + + int count = vocab_idx; + float sum = 0.0f; + while (count != 0) { + sum += d_sorted_logits_in[start_index]; + ++start_index; + --count; + } + + if (sum > top_p) { + // Shift the indices to the right by one according to the Turing implementation. + int shifted_index = index + 1; + if (shifted_index % vocab_size != 0) { + int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index]; + d_logits_in_out[original_index] = (T)filter_value; + } + } +} + +template +void LaunchFilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + T* d_logits_in_out, + float top_p, + float filter_value, + int batch_size, + int vocab_size, + cudaStream_t stream) { + int total_elements = batch_size * vocab_size; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + FilterLogitsKernel<<>>(d_sorted_logits_in, + d_sorted_indices, + d_logits_in_out, + top_p, + filter_value, + batch_size, + vocab_size); +} + +template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + float* d_logits_in_out, + float top_p, + float filter_value, + int batch_size, + int vocab_size, + cudaStream_t stream); + +template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + half* d_logits_in_out, + float top_p, + float filter_value, + int batch_size, + int vocab_size, + cudaStream_t stream); + + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 4e592a80daa83..20b160f4b452e 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -84,6 +84,16 @@ void LaunchSortPairsDescending(void *d_temp_storage, int *d_offsets, cudaStream_t stream); +template +void LaunchFilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + T* d_logits_in_out, + float top_p, + float filter_value, + int batch_size, + int vocab_size, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 801a548d2152b..172d55f85ee27 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -560,14 +560,15 @@ Status GreedySearchProcessLogits( size_t bytes = SafeInt(sizeof(int)) * 2 * parameters->batch_size * parameters->vocab_size + SafeInt(sizeof(int)) * (parameters->batch_size + 1) + SafeInt(sizeof(CudaT)) * parameters->batch_size * parameters->vocab_size + - SafeInt(sizeof(float) * parameters->batch_size * parameters->vocab_size); + SafeInt(2 * sizeof(float) * parameters->batch_size * parameters->vocab_size); void* data = allocator->Alloc(bytes); BufferUniquePtr workspace_buffer(data, BufferDeleter(allocator)); int* d_index_buffer_in = reinterpret_cast(workspace_buffer.get()); int* d_index_buffer_out = d_index_buffer_in + parameters->batch_size * parameters->vocab_size; int* d_offset_buffer = d_index_buffer_out + parameters->batch_size * parameters->vocab_size; CudaT* d_sorted_score_buffer = reinterpret_cast(d_offset_buffer + parameters->batch_size + 1); - float* d_softmaxed_score_buffer = reinterpret_cast(d_sorted_score_buffer + parameters->batch_size * parameters->vocab_size); + float* d_sorted_softmaxed_score_buffer = reinterpret_cast(d_sorted_score_buffer + parameters->batch_size * parameters->vocab_size); + float* d_softmaxed_score_buffer = d_sorted_softmaxed_score_buffer + parameters->batch_size * parameters->vocab_size; size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), d_index_buffer_in, @@ -597,12 +598,32 @@ Status GreedySearchProcessLogits( cuda_stream); dispatch_blockwise_softmax_forward(cuda_stream, - d_softmaxed_score_buffer, + d_sorted_softmaxed_score_buffer, d_sorted_score_buffer, parameters->vocab_size, parameters->vocab_size, parameters->batch_size); + cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score_buffer, + d_index_buffer_out, + reinterpret_cast(next_token_scores.data()), + parameters->top_p, + parameters->filter_value, + parameters->batch_size, + parameters->vocab_size, + cuda_stream); + + // bugbug: actually we can only do softmax at the very beginning and sort the softmaxed scores. + // Not sure if the order change will affect the result. + dispatch_blockwise_softmax_forward(cuda_stream, + d_softmaxed_score_buffer, + reinterpret_cast(next_token_scores.data()), + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size); + + // multinomial sampling + } else { // next_tokens = torch.argmax(scores, dim=-1) From e1c0046d605a735a4c46d8e04cf068af401dbdc3 Mon Sep 17 00:00:00 2001 From: wangyems Date: Fri, 28 Oct 2022 05:11:54 +0000 Subject: [PATCH 10/51] provider --- .../contrib_ops/cpu/transformers/sampling.cc | 10 ++++ .../contrib_ops/cpu/transformers/sampling.h | 6 +++ .../transformers/generation_device_helper.cc | 48 +++++++++---------- onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../core/providers/cpu/cpu_provider_shared.cc | 6 +++ .../core/providers/cpu/cpu_provider_shared.h | 5 ++ .../provider_bridge_provider.cc | 10 ++++ 7 files changed, 63 insertions(+), 24 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index 42e3da40fc4e0..2dc313e3f7316 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -38,6 +38,16 @@ REGISTER_KERNEL_TYPED(float) namespace transformers { +void Sampling::Init(const OpKernelInfo& info) { + this->Init(info); +} + +Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, + const std::string& attribute_name, + const SessionState& subgraph_session_state) { + return this->SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); +} + Status Sampling::Compute(OpKernelContext* ctx) const { auto* ctx_internal = static_cast(ctx); diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.h b/onnxruntime/contrib_ops/cpu/transformers/sampling.h index 890559f54cc76..ae6ccc9389f8a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.h @@ -13,8 +13,14 @@ class Sampling : public GreedySearch { public: explicit Sampling(const OpKernelInfo& info) : GreedySearch(info) {} + void Init(const OpKernelInfo& info); + Status Compute(OpKernelContext* ctx) const override; + Status SetupSubgraphExecutionInfo(const SessionState& session_state, + const std::string& attribute_name, + const SessionState& subgraph_session_state) override; + private: SamplingParameters parameters_; }; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 172d55f85ee27..bc5c2d2be3349 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -570,12 +570,12 @@ Status GreedySearchProcessLogits( float* d_sorted_softmaxed_score_buffer = reinterpret_cast(d_sorted_score_buffer + parameters->batch_size * parameters->vocab_size); float* d_softmaxed_score_buffer = d_sorted_softmaxed_score_buffer + parameters->batch_size * parameters->vocab_size; - size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), - d_index_buffer_in, - d_offset_buffer, - parameters->vocab_size, - parameters->batch_size, - cuda_stream); + size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), + d_index_buffer_in, + d_offset_buffer, + parameters->vocab_size, + parameters->batch_size, + cuda_stream); cuda::LaunchSetupParamsKernel(d_index_buffer_in, d_offset_buffer, @@ -586,16 +586,16 @@ Status GreedySearchProcessLogits( void* temp_storage = allocator->Alloc(temp_storage_bytes); BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); - cuda::LaunchSortPairsDescending(temp_storage_buffer.get(), - temp_storage_bytes, - reinterpret_cast(next_token_scores.data()), - d_sorted_score_buffer, - d_index_buffer_in, - d_index_buffer_out, - parameters->vocab_size, - parameters->batch_size, - d_offset_buffer, - cuda_stream); + cuda::LaunchSortPairsDescending(temp_storage_buffer.get(), + temp_storage_bytes, + reinterpret_cast(next_token_scores.data()), + d_sorted_score_buffer, + d_index_buffer_in, + d_index_buffer_out, + parameters->vocab_size, + parameters->batch_size, + d_offset_buffer, + cuda_stream); dispatch_blockwise_softmax_forward(cuda_stream, d_sorted_softmaxed_score_buffer, @@ -604,14 +604,14 @@ Status GreedySearchProcessLogits( parameters->vocab_size, parameters->batch_size); - cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score_buffer, - d_index_buffer_out, - reinterpret_cast(next_token_scores.data()), - parameters->top_p, - parameters->filter_value, - parameters->batch_size, - parameters->vocab_size, - cuda_stream); + cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score_buffer, + d_index_buffer_out, + reinterpret_cast(next_token_scores.data()), + parameters->top_p, + parameters->filter_value, + parameters->batch_size, + parameters->vocab_size, + cuda_stream); // bugbug: actually we can only do softmax at the very beginning and sort the softmaxed scores. // Not sure if the order change will affect the result. diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index ff32e568e0eaf..338baf82d3532 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -81,6 +81,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer); @@ -165,6 +166,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 2122269f29a4d..225661b7a3538 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -33,6 +33,7 @@ #include "contrib_ops/cpu/bert/longformer_attention_base.h" #include "contrib_ops/cpu/transformers/beam_search.h" #include "contrib_ops/cpu/transformers/greedy_search.h" +#include "contrib_ops/cpu/transformers/sampling.h" #ifdef ENABLE_ATEN #include "contrib_ops/cpu/aten_ops/aten_op.h" #endif @@ -243,6 +244,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU { subgraph_session_state); } + void Sampling__Init(contrib::transformers::Sampling* p, const OpKernelInfo& info) override { p->contrib::transformers::Sampling::Init(info); } + Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } + Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } + + #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 9b77a0c774d68..e640e1ac389d1 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -9,6 +9,7 @@ class AttentionBase; namespace transformers { class BeamSearch; class GreedySearch; +class Sampling; } } // namespace contrib @@ -173,6 +174,10 @@ struct ProviderHostCPU { const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + virtual void Sampling__Init(contrib::transformers::Sampling* p, const OpKernelInfo& info) = 0; + virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; + virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 418a4a1ce8e7a..5f14d34a091a5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -32,6 +32,7 @@ #include "contrib_ops/cpu/bert/longformer_attention_base.h" #include "contrib_ops/cpu/transformers/beam_search.h" #include "contrib_ops/cpu/transformers/greedy_search.h" +#include "contrib_ops/cpu/transformers/sampling.h" #ifdef ENABLE_ATEN #include "contrib_ops/cpu/aten_ops/aten_op.h" #endif @@ -592,6 +593,15 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat return g_host_cpu.GreedySearch__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } + +void Sampling::Init(const OpKernelInfo& info) { g_host_cpu.Sampling__Init(this, info); } + +Status Sampling::Compute(OpKernelContext* ctx) const { return g_host_cpu.Sampling__Compute(this, ctx); } + +Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, + const SessionState& subgraph_session_state) { + return g_host_cpu.Sampling__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } + } // namespace transformers #ifdef ENABLE_ATEN From 7c64df9e72c35ed7948b94347c7886aad6929736 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 1 Nov 2022 00:31:13 +0000 Subject: [PATCH 11/51] multinomial cuda --- .../cuda/transformers/generation_cuda_impl.cu | 169 +++++++++++++++++- .../cuda/transformers/generation_cuda_impl.h | 7 + .../transformers/generation_device_helper.cc | 19 +- 3 files changed, 191 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 80ef7eb999307..c70122e9b355a 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -4,6 +4,7 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "cub/util_type.cuh" +#include #include #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" @@ -300,7 +301,7 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data, old_mask_data, mask_data, next_positions, batch_beam_size, current_length); } -// bugbug: merge those kernels into one +// TODO: merge those kernels into one template size_t GetTempStorageSize(const T *d_keys_in, const int* d_values_in, @@ -341,7 +342,7 @@ template size_t GetTempStorageSize( int num_segments, cudaStream_t stream); -// bugbug: merge to one kernel +// TODO: merge to one kernel __global__ void SetupParamsKernel(int* d_values_in, int* d_offsets, int batch_size, @@ -491,6 +492,170 @@ template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, cudaStream_t stream); +// Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu +template +__global__ void sampleMultinomialOnce( + int64_t* dest, + int distributions, + int categories, + scalar_t* sampled, + scalar_t* dist, + int stride_dist, // dist->stride(0) + int stride_categories // dist->stride(1) +) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + extern __shared__ unsigned char my_smem[]; + __shared__ bool found; + __shared__ unsigned foundPos; + accscalar_t *smem = reinterpret_cast(my_smem); + accscalar_t accZero = static_cast(0); + scalar_t zero = static_cast(0); + for (int curDist = blockIdx.x; + curDist < distributions; curDist += gridDim.x) { + // Each block handles one distribution + // First pass, find the total sum of the distribution + accscalar_t sum = accZero; + scalar_t val; + for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) { + val = dist[curDist * stride_dist + cat * stride_categories]; + // CUDA_KERNEL_ASSERT(!at::_isnan(val)); + // CUDA_KERNEL_ASSERT(!_isinf(val)); + // CUDA_KERNEL_ASSERT(!(val < zero)); + sum = sum + static_cast(val); + } + // threadIdx.x == 0 has the sum value from this + // sum = cuda_utils::BlockReduceSum(sum, smem); + sum = BlockReduce(tmp_storage).Reduce(sum, cub::Sum()); + // Broadcast sum and sample value + if (threadIdx.x == 0) { + // Make sure the sum of our distribution didn't overflow + // CUDA_KERNEL_ASSERT(!_isinf(val)); + // CUDA_KERNEL_ASSERT(sum > accZero); + foundPos = 0; + smem[0] = sum; + smem[1] = sampled[curDist]; + } + __syncthreads(); + sum = smem[0]; + scalar_t sample = static_cast(smem[1]); + __syncthreads(); + if (sum == accZero) { + // Choose the first element + if (threadIdx.x == 0) { + dest[curDist] = 0; + } + continue; + } + int chunks = (categories + (int)blockDim.x - 1) / blockDim.x; + accscalar_t prevHighProb = accZero; + found = false; + for (int chunk = 0; chunk < chunks && !found; ++chunk) { + // All threads in bounds load a value + int cat = chunk * blockDim.x + threadIdx.x; + accscalar_t dist_val = cat < categories ? + static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : + accZero; + smem[threadIdx.x] = dist_val; + __syncthreads(); + // Perform an inclusive prefix sum of the shared memory contents + for (int offset = 1; offset < blockDim.x; offset *= 2) { + accscalar_t val = accZero; + if (threadIdx.x >= offset) { + val = smem[threadIdx.x - offset] + smem[threadIdx.x]; + } + __syncthreads(); + if (threadIdx.x >= offset) { + smem[threadIdx.x] = val; + } + __syncthreads(); + } + // Each thread will check to see if the sample falls in its + // bucket + scalar_t curBucket = + static_cast(smem[threadIdx.x] + prevHighProb); + scalar_t prevBucket = static_cast( + threadIdx.x == 0 ? prevHighProb + : smem[threadIdx.x - 1] + prevHighProb); + bool inBucket = + (cat < categories) && + (!(sample >= curBucket) && + (sample >= prevBucket) && + (dist_val > zero)); + if (inBucket) { + // We're done; we have the sample + // Torch indices are 1-based + atomicMax(&foundPos, cat); + found = true; + } + // Store the previous scan's high value for future use + prevHighProb = prevHighProb + smem[blockDim.x - 1]; + __syncthreads(); + } + if (threadIdx.x == 0) { + if (found) { + dest[curDist] = foundPos; + } else { + // This should address a rare bug where we don't select a valid index. This likely occurs when + // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but + // and our uniform sample is greater than this value. In this case we likely have unitialized memory + // in dest[curDist]. So basically we will loop through the distribution and pick the largest index + // where the distribution is non-zero. This is obviously terribly inefficient, but due to the + // rarity in which this occurs, this should not be an issue. + for (int cat = categories - 1; cat >= 0; --cat) { + if (dist[curDist * stride_dist + cat * stride_categories] > zero) { + dest[curDist] = cat; + break; + } + } + } + } + } +} + +// Only support n_sample = 1 +void TorchMultinomialKernelLauncher(float* d_input, + float* d_sampled, + int64_t* d_output, + int batch_size, + int vocab_size, + cudaStream_t stream) +{ + int device; + cudaGetDevice(&device); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device); + + int numSM = props.multiProcessorCount; + int maxThreads = props.maxThreadsPerBlock; + int warp_size = 32; //at::cuda::warp_size(); + int requiredWarps = (vocab_size + warp_size - 1) / warp_size; + int requiredThreads = std::min(maxThreads, requiredWarps * warp_size); + int requiredShared = requiredThreads * sizeof(float); + + // bugbug: randomize d_sampled + + dim3 block(requiredThreads); + dim3 grid(std::min(batch_size, numSM * 4)); + + if (block.x == 1024) { + const int block_size = 1024; + sampleMultinomialOnce + <<>>(d_output, + batch_size, + vocab_size, + d_sampled, + d_input, + vocab_size, + batch_size); + } else { + printf("Please add more cases for block size"); + } +} + + + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 20b160f4b452e..e13bb038b24f2 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -94,6 +94,13 @@ void LaunchFilterLogitsKernel(float* d_sorted_logits_in, int vocab_size, cudaStream_t stream); +void TorchMultinomialKernelLauncher(float* d_input, + float* d_sampled, + int64_t* d_output, + int batch_size, + int vocab_size, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index bc5c2d2be3349..69e09092df849 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -560,7 +560,9 @@ Status GreedySearchProcessLogits( size_t bytes = SafeInt(sizeof(int)) * 2 * parameters->batch_size * parameters->vocab_size + SafeInt(sizeof(int)) * (parameters->batch_size + 1) + SafeInt(sizeof(CudaT)) * parameters->batch_size * parameters->vocab_size + - SafeInt(2 * sizeof(float) * parameters->batch_size * parameters->vocab_size); + SafeInt(2 * sizeof(float) * parameters->batch_size * parameters->vocab_size) + + SafeInt(sizeof(float)) * parameters->batch_size + + SafeInt(sizeof(int64_t)) * parameters->batch_size; void* data = allocator->Alloc(bytes); BufferUniquePtr workspace_buffer(data, BufferDeleter(allocator)); int* d_index_buffer_in = reinterpret_cast(workspace_buffer.get()); @@ -569,6 +571,8 @@ Status GreedySearchProcessLogits( CudaT* d_sorted_score_buffer = reinterpret_cast(d_offset_buffer + parameters->batch_size + 1); float* d_sorted_softmaxed_score_buffer = reinterpret_cast(d_sorted_score_buffer + parameters->batch_size * parameters->vocab_size); float* d_softmaxed_score_buffer = d_sorted_softmaxed_score_buffer + parameters->batch_size * parameters->vocab_size; + float* d_sampled = d_softmaxed_score_buffer + 2 * parameters->batch_size * parameters->vocab_size; + int64_t* d_indices = reinterpret_cast(d_sampled + parameters->batch_size); size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), d_index_buffer_in, @@ -623,8 +627,19 @@ Status GreedySearchProcessLogits( parameters->batch_size); // multinomial sampling + cuda::TorchMultinomialKernelLauncher(d_softmaxed_score_buffer, + d_sampled, + d_indices, + parameters->batch_size, + parameters->vocab_size, + cuda_stream); - + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), + d_indices, + greedy_state->next_tokens_cpu.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); } else { // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; From 25f58fe3c2fa9af91fbcb6133f0104242efb0b2d Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 1 Nov 2022 01:15:43 +0000 Subject: [PATCH 12/51] use curand and add seed --- .../cpu/transformers/generation_device_helper.cc | 3 +-- .../cpu/transformers/generation_shared.h | 1 + .../cuda/transformers/generation_cuda_impl.cu | 13 +++++++++---- .../cuda/transformers/generation_cuda_impl.h | 4 +++- .../cuda/transformers/generation_device_helper.cc | 6 +++++- 5 files changed, 19 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 75b8f9d1d48ae..f358430707dbf 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -473,8 +473,7 @@ Status GreedySearchProcessLogits( next_token_probs_value); const Tensor& input = next_token_probs_value.Get(); - float seed = 0.f; - std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(seed)}; + std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(parameters->seed)}; Tensor sampled_idx; ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocator, input, diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 802024c41667b..72381be677b3a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -112,6 +112,7 @@ struct IGenerationParameters { float temperature = 1.0f; float top_p = 0.0f; float filter_value; + int seed = 0; // Parameters from inputs int min_length; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index c70122e9b355a..f90a81c98f728 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -501,7 +501,8 @@ __global__ void sampleMultinomialOnce( scalar_t* sampled, scalar_t* dist, int stride_dist, // dist->stride(0) - int stride_categories // dist->stride(1) + int stride_categories, // dist->stride(1) + curandState_t* curandstate ) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -535,7 +536,8 @@ __global__ void sampleMultinomialOnce( // CUDA_KERNEL_ASSERT(sum > accZero); foundPos = 0; smem[0] = sum; - smem[1] = sampled[curDist]; + //smem[1] = sampled[curDist]; + smem[1] = static_cast(curand_uniform(curandstate + blockIdx.x)); } __syncthreads(); sum = smem[0]; @@ -620,8 +622,10 @@ void TorchMultinomialKernelLauncher(float* d_input, int64_t* d_output, int batch_size, int vocab_size, - cudaStream_t stream) + cudaStream_t stream, + curandState_t* curandstate) { + // Store the props in class variables int device; cudaGetDevice(&device); cudaDeviceProp props; @@ -648,7 +652,8 @@ void TorchMultinomialKernelLauncher(float* d_input, d_sampled, d_input, vocab_size, - batch_size); + batch_size, + curandstate); } else { printf("Please add more cases for block size"); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index e13bb038b24f2..60fa6aaa8b71a 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -5,6 +5,7 @@ #include #include +#include namespace onnxruntime { namespace contrib { @@ -99,7 +100,8 @@ void TorchMultinomialKernelLauncher(float* d_input, int64_t* d_output, int batch_size, int vocab_size, - cudaStream_t stream); + cudaStream_t stream, + curandState_t* curandstate); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 69e09092df849..a9ede240d49ab 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -627,12 +627,16 @@ Status GreedySearchProcessLogits( parameters->batch_size); // multinomial sampling + // bugbug: move curandState initialization out of the loop + curandState state; + curand_init(static_cast(parameters->seed), 0, 0, &state); cuda::TorchMultinomialKernelLauncher(d_softmaxed_score_buffer, d_sampled, d_indices, parameters->batch_size, parameters->vocab_size, - cuda_stream); + cuda_stream, + &state); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), d_indices, From 6a603289836077a3d5896759a6c627dc0d1028a1 Mon Sep 17 00:00:00 2001 From: wangye Date: Wed, 2 Nov 2022 01:21:15 +0000 Subject: [PATCH 13/51] fix a bug --- .../cpu/transformers/generation_shared.h | 6 +- .../cpu/transformers/greedy_search.h | 2 +- .../contrib_ops/cpu/transformers/sampling.cc | 29 ++++++- .../contrib_ops/cpu/transformers/sampling.h | 79 ++++++++++++++++++- .../cpu/transformers/sampling_parameters.cc | 1 + 5 files changed, 108 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 72381be677b3a..411bad3a5856e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -8,9 +8,9 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -#ifndef NDEBUG -//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -#endif +//#ifndef NDEBUG +#define DEBUG_GENERATION 1 // uncomment it for debugging beam search +//#endif namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h index a545c68b1ef48..1b3ca1e36bcd7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h @@ -69,7 +69,7 @@ class GreedySearch : public IControlFlowKernel { update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; } - protected: + private: // Device specific functions GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::TopkFunc topk_func_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index 2dc313e3f7316..e6d1b5e656750 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -39,13 +39,38 @@ REGISTER_KERNEL_TYPED(float) namespace transformers { void Sampling::Init(const OpKernelInfo& info) { - this->Init(info); + parameters_.ParseFromAttributes(info); + + // Check model_type 0 (GPT-2) + ORT_ENFORCE(parameters_.model_type == 0); + + ONNX_NAMESPACE::GraphProto proto; + + ORT_ENFORCE(info.GetAttr("decoder", &proto).IsOK()); + ORT_IGNORE_RETURN_VALUE(proto); } Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { - return this->SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); + const auto& node = Node(); + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2 + if (attribute_name == "decoder") { + ORT_ENFORCE(gpt_subgraph_ == nullptr, + "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer()); + ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state)); + decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); + parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size, + gpt_subgraph_->num_heads, + gpt_subgraph_->head_size, + gpt_subgraph_->num_layers); + } + } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5 + ORT_THROW("Not Implemented"); + } + + return Status::OK(); } Status Sampling::Compute(OpKernelContext* ctx) const { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.h b/onnxruntime/contrib_ops/cpu/transformers/sampling.h index ae6ccc9389f8a..fb12f1dbc3db2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.h @@ -2,16 +2,32 @@ // Licensed under the MIT License. #pragma once -#include "contrib_ops/cpu/transformers/greedy_search.h" +#include +#include +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/controlflow/utils.h" +#include "contrib_ops/cpu/transformers/subgraph_gpt.h" +#include "contrib_ops/cpu/transformers/generation_device_helper.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" namespace onnxruntime { +class FeedsFetchesManager; + namespace contrib { namespace transformers { -class Sampling : public GreedySearch { +using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel + +class Sampling : public IControlFlowKernel { public: - explicit Sampling(const OpKernelInfo& info) : GreedySearch(info) {} + explicit Sampling(const OpKernelInfo& info) + : IControlFlowKernel(info), + decoder_feeds_fetches_manager_(nullptr), + cuda_stream_(nullptr), + dumper_(nullptr) { + Init(info); + } void Init(const OpKernelInfo& info); @@ -21,7 +37,64 @@ class Sampling : public GreedySearch { const std::string& attribute_name, const SessionState& subgraph_session_state) override; + protected: + void SetComputeStream(void* stream) { cuda_stream_ = stream; } + void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; } + + // device helpers that is same for both GPT and encoder-decoder models. + void SetDeviceHelpers( + const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + const GenerationDeviceHelper::TopkFunc& topk_func, + const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, + const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_func, + const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_fp16_func, + const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_func, + const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_fp16_func) { + add_to_feeds_func_ = add_to_feeds_func; + topk_func_ = topk_func; + device_copy_func_ = device_copy_func; + process_logits_func_ = process_logits_func; + process_logits_fp16_func_ = process_logits_fp16_func; + init_greedy_state_func_ = init_greedy_state_func; + init_greedy_state_fp16_func_ = init_greedy_state_fp16_func; + } + + void SetDeviceHelpers_Gpt( + const GenerationDeviceHelper::UpdateGptFeedsFunc& update_gpt_feeds_func, + const GenerationDeviceHelper::UpdateGptFeedsFunc& update_gpt_feeds_fp16_func) { + update_gpt_feeds_func_ = update_gpt_feeds_func; + update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; + } + private: + // Device specific functions + GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; + GenerationDeviceHelper::TopkFunc topk_func_; + GenerationDeviceHelper::DeviceCopyFunc device_copy_func_; + + GenerationDeviceHelper::GreedySearchProcessLogitsFunc process_logits_func_; + GenerationDeviceHelper::GreedySearchProcessLogitsFunc process_logits_fp16_func_; + + GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_func_; + GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_fp16_func_; + + //------------------------------------------------------------ + // Device specific functions for GPT + //------------------------------------------------------------ + GenerationDeviceHelper::UpdateGptFeedsFunc update_gpt_feeds_func_; + GenerationDeviceHelper::UpdateGptFeedsFunc update_gpt_feeds_fp16_func_; + + //------------------------------------------------------------ + // Subgraph and FeedsFetchesManager re-used for each subgraph execution. + //------------------------------------------------------------ + std::unique_ptr gpt_subgraph_; + + FeedsFetchesManager* decoder_feeds_fetches_manager_; + + void* cuda_stream_; + + IConsoleDumper* dumper_; + SamplingParameters parameters_; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc index 4da912b1bce24..2f9649b3021c5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -7,6 +7,7 @@ namespace contrib { namespace transformers { void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { + std::cout << "parse sampling params" << std::endl; model_type = static_cast(info.GetAttrOrDefault("model_type", 0)); eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); From 0363d1b63626ca206c05f4223fc1777c280710ff Mon Sep 17 00:00:00 2001 From: wangye Date: Thu, 3 Nov 2022 01:04:47 +0000 Subject: [PATCH 14/51] fix a few crash --- .../transformers/generation_device_helper.cc | 29 ++++++---- .../cpu/transformers/greedy_search_impl_gpt.h | 1 - .../cpu/transformers/logits_processor.cc | 21 ++++++-- .../cpu/transformers/sampling_parameters.cc | 1 - .../cuda/transformers/generation_cuda_impl.cu | 17 +++--- .../cuda/transformers/generation_cuda_impl.h | 3 +- .../transformers/generation_device_helper.cc | 53 ++++++++++++++++--- 7 files changed, 91 insertions(+), 34 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index f358430707dbf..e50bdf40080d4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -474,17 +474,31 @@ Status GreedySearchProcessLogits( const Tensor& input = next_token_probs_value.Get(); std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(parameters->seed)}; - Tensor sampled_idx; - ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocator, + + int64_t sampled_idx_dims[] = {static_cast(batch_size), 1}; + TensorShape sampled_idx_shape(&sampled_idx_dims[0], 2); + + gsl::span& next_token_idx = greedy_state->next_tokens_cpu; + + OrtValue sampled_idx_ov; + Tensor::InitOrtValue(DataTypeImpl::GetType(), + sampled_idx_shape, + next_token_idx.data(), + allocator->Info(), + sampled_idx_ov); + Tensor* sampled_idx = sampled_idx_ov.GetMutable(); + + AllocatorPtr allocator_temp = allocator; + ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocator_temp, input, batch_size, vocab_size, 1, generator, - sampled_idx)); - - gsl::span next_token_indices = sampled_idx.DataAsSpan(); - gsl::copy(next_token_indices, greedy_state->next_tokens_cpu); + *sampled_idx)); +#ifdef DEBUG_GENERATION + dumper->Print("sampled_idx", *sampled_idx); +#endif } else { // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; @@ -598,7 +612,6 @@ Status UpdateGptFeeds( // last_outputs: logits, present_0, present_1, ... // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 ORT_UNUSED_PARAMETER(stream); - // The following updates inputs for subgraph // Update input_ids with next tokens. @@ -614,7 +627,6 @@ Status UpdateGptFeeds( input_ids_data[i] = beam_next_tokens[i]; } next_inputs[0] = input_ids; - if (increase_position) { // Update position IDs int32_t* position_data = position_ids.GetMutable()->MutableData(); @@ -623,7 +635,6 @@ Status UpdateGptFeeds( } } next_inputs[1] = position_ids; - // Update attention mask const OrtValue& old_mask = next_inputs[2]; const int32_t* old_mask_data = old_mask.Get().Data(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 26d19a4fbdc84..1efc8e3d6ed7d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -225,7 +225,6 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds } fetches.clear(); } - // Copy the sequences to output gsl::span output = output_sequences->MutableDataAsSpan(); for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 5368a907755d3..1c4547b7ab873 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -205,7 +205,12 @@ void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, T* p = next_token_scores.scores.data(); for (size_t i = 0; i < next_token_scores.scores.size(); i++) { *p /= temperature_; + ++p; } + +#ifdef DEBUG_GENERATION + DumpScores("TemperatureLogitsProcessor", next_token_scores); +#endif } template @@ -229,18 +234,18 @@ void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, std::vector sorted_scores(beam_token_scores.begin(), beam_token_scores.end()); - // bugbug: decending sort + // decending sort std::vector sorted_indices(beam_token_scores.size()); std::iota(sorted_indices.begin(), sorted_indices.end(), 0); std::sort(sorted_indices.begin(), sorted_indices.end(), [&sorted_scores](size_t i1, size_t i2) { - return sorted_scores[i1] < sorted_scores[i2]; + return sorted_scores[i1] > sorted_scores[i2]; }); - std::sort(sorted_scores.begin(), sorted_scores.end()); + std::sort(sorted_scores.begin(), sorted_scores.end(), std::greater()); std::vector cumulative_probs(vocab_size); - // bugbug + // todo: batch ORT_UNUSED_PARAMETER(SoftmaxCPU(1, vocab_size, sorted_scores.data(), @@ -264,6 +269,10 @@ void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, beam_token_scores[index_to_remove] = filter_value_; } } + +#ifdef DEBUG_GENERATION + DumpScores("TopPLogitsProcessor", next_token_scores); +#endif } template @@ -285,6 +294,10 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, for (size_t i = 0; i < next_token_scores.scores.size(); i++) { *p -= presence_mask_[i] * presence_penalty_; } + +#ifdef DEBUG_GENERATION + DumpScores("PresencePenaltyLogitsProcessor", next_token_scores); +#endif } void LogitsProcessorList::Init(const BeamSearchParameters& parameters, diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc index 2f9649b3021c5..4da912b1bce24 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -7,7 +7,6 @@ namespace contrib { namespace transformers { void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { - std::cout << "parse sampling params" << std::endl; model_type = static_cast(info.GetAttrOrDefault("model_type", 0)); eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index f90a81c98f728..eb6296b1592eb 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -501,8 +501,7 @@ __global__ void sampleMultinomialOnce( scalar_t* sampled, scalar_t* dist, int stride_dist, // dist->stride(0) - int stride_categories, // dist->stride(1) - curandState_t* curandstate + int stride_categories // dist->stride(1) ) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -515,6 +514,7 @@ __global__ void sampleMultinomialOnce( scalar_t zero = static_cast(0); for (int curDist = blockIdx.x; curDist < distributions; curDist += gridDim.x) { + // Each block handles one distribution // First pass, find the total sum of the distribution accscalar_t sum = accZero; @@ -527,8 +527,7 @@ __global__ void sampleMultinomialOnce( sum = sum + static_cast(val); } // threadIdx.x == 0 has the sum value from this - // sum = cuda_utils::BlockReduceSum(sum, smem); - sum = BlockReduce(tmp_storage).Reduce(sum, cub::Sum()); + sum = BlockReduce(tmp_storage).Reduce(sum, cub::Sum()); // sum = cuda_utils::BlockReduceSum(sum, smem); // Broadcast sum and sample value if (threadIdx.x == 0) { // Make sure the sum of our distribution didn't overflow @@ -536,8 +535,7 @@ __global__ void sampleMultinomialOnce( // CUDA_KERNEL_ASSERT(sum > accZero); foundPos = 0; smem[0] = sum; - //smem[1] = sampled[curDist]; - smem[1] = static_cast(curand_uniform(curandstate + blockIdx.x)); + smem[1] = sampled[curDist]; } __syncthreads(); sum = smem[0]; @@ -622,8 +620,7 @@ void TorchMultinomialKernelLauncher(float* d_input, int64_t* d_output, int batch_size, int vocab_size, - cudaStream_t stream, - curandState_t* curandstate) + cudaStream_t stream) { // Store the props in class variables int device; @@ -639,7 +636,6 @@ void TorchMultinomialKernelLauncher(float* d_input, int requiredShared = requiredThreads * sizeof(float); // bugbug: randomize d_sampled - dim3 block(requiredThreads); dim3 grid(std::min(batch_size, numSM * 4)); @@ -652,8 +648,7 @@ void TorchMultinomialKernelLauncher(float* d_input, d_sampled, d_input, vocab_size, - batch_size, - curandstate); + batch_size); } else { printf("Please add more cases for block size"); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 60fa6aaa8b71a..02b803b8855b9 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -100,8 +100,7 @@ void TorchMultinomialKernelLauncher(float* d_input, int64_t* d_output, int batch_size, int vocab_size, - cudaStream_t stream, - curandState_t* curandstate); + cudaStream_t stream); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index a9ede240d49ab..44a6a6fed13c4 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -3,6 +3,7 @@ #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/math/topk_impl.h" @@ -571,7 +572,7 @@ Status GreedySearchProcessLogits( CudaT* d_sorted_score_buffer = reinterpret_cast(d_offset_buffer + parameters->batch_size + 1); float* d_sorted_softmaxed_score_buffer = reinterpret_cast(d_sorted_score_buffer + parameters->batch_size * parameters->vocab_size); float* d_softmaxed_score_buffer = d_sorted_softmaxed_score_buffer + parameters->batch_size * parameters->vocab_size; - float* d_sampled = d_softmaxed_score_buffer + 2 * parameters->batch_size * parameters->vocab_size; + float* d_sampled = d_softmaxed_score_buffer + parameters->batch_size * parameters->vocab_size; int64_t* d_indices = reinterpret_cast(d_sampled + parameters->batch_size); size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), @@ -581,6 +582,10 @@ Status GreedySearchProcessLogits( parameters->batch_size, cuda_stream); +#ifdef DEBUG_GENERATION + dumper->Print("temp_storage_bytes", temp_storage_bytes, true); +#endif + cuda::LaunchSetupParamsKernel(d_index_buffer_in, d_offset_buffer, parameters->batch_size, @@ -601,6 +606,12 @@ Status GreedySearchProcessLogits( d_offset_buffer, cuda_stream); +#ifdef DEBUG_GENERATION + dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score_buffer), batch_size, vocab_size); + dumper->Print("d_index_buffer_in", d_index_buffer_in, batch_size, vocab_size); + dumper->Print("d_index_buffer_out", d_index_buffer_out, batch_size, vocab_size); +#endif + dispatch_blockwise_softmax_forward(cuda_stream, d_sorted_softmaxed_score_buffer, d_sorted_score_buffer, @@ -608,6 +619,10 @@ Status GreedySearchProcessLogits( parameters->vocab_size, parameters->batch_size); +#ifdef DEBUG_GENERATION + dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score_buffer, batch_size, vocab_size); +#endif + cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score_buffer, d_index_buffer_out, reinterpret_cast(next_token_scores.data()), @@ -617,6 +632,10 @@ Status GreedySearchProcessLogits( parameters->vocab_size, cuda_stream); +#ifdef DEBUG_GENERATION + dumper->Print("next_token_scores after filtering logits", next_token_scores.data(), batch_size, vocab_size); +#endif + // bugbug: actually we can only do softmax at the very beginning and sort the softmaxed scores. // Not sure if the order change will affect the result. dispatch_blockwise_softmax_forward(cuda_stream, @@ -626,17 +645,37 @@ Status GreedySearchProcessLogits( parameters->vocab_size, parameters->batch_size); +#ifdef DEBUG_GENERATION + dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score_buffer, batch_size, vocab_size); +#endif + // multinomial sampling - // bugbug: move curandState initialization out of the loop - curandState state; - curand_init(static_cast(parameters->seed), 0, 0, &state); + std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(parameters->seed)}; + std::uniform_real_distribution distribution(0.0, 1.0); + std::vector sampled(parameters->batch_size); + for (int i = 0; i < parameters->batch_size; ++i) { + sampled[i] = distribution(generator); + } + +#ifdef DEBUG_GENERATION + dumper->Print("sampled value on cpu", sampled.data(), batch_size, 1); +#endif + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled, + sampled.data(), + sizeof(float) * parameters->batch_size, + cudaMemcpyHostToDevice, + cuda_stream)); cuda::TorchMultinomialKernelLauncher(d_softmaxed_score_buffer, d_sampled, d_indices, parameters->batch_size, parameters->vocab_size, - cuda_stream, - &state); + cuda_stream); + +#ifdef DEBUG_GENERATION + dumper->Print("d_sampled", d_sampled, batch_size, 1); + dumper->Print("d_indices", d_indices, batch_size, 1); +#endif CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), d_indices, @@ -644,6 +683,8 @@ Status GreedySearchProcessLogits( cudaMemcpyDeviceToHost, cuda_stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + + cudaDeviceSynchronize(); } else { // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; From 3b5392fbbe8a33181805969dc24970a531c9c1fe Mon Sep 17 00:00:00 2001 From: wangye Date: Thu, 3 Nov 2022 02:26:23 +0000 Subject: [PATCH 15/51] try to fix win build --- .../cpu/transformers/logits_processor.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 1c4547b7ab873..4efa12e6d5f0f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -246,12 +246,15 @@ void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, std::sort(sorted_scores.begin(), sorted_scores.end(), std::greater()); std::vector cumulative_probs(vocab_size); // todo: batch - ORT_UNUSED_PARAMETER(SoftmaxCPU(1, - vocab_size, - sorted_scores.data(), - cumulative_probs.data(), - false, - thread_pool_)); + Status status = SoftmaxCPU(1, + vocab_size, + sorted_scores.data(), + cumulative_probs.data(), + false, + thread_pool_); + if (!status.IsOK()) { + ORT_THROW(status.ErrorMessage()); + } std::unordered_set sorted_indices_to_remove; if (cumulative_probs[0] > top_p_) { From 7b8cc2ac8b450547cb52bb25f3f732f230acfe12 Mon Sep 17 00:00:00 2001 From: wangye Date: Thu, 3 Nov 2022 03:35:05 +0000 Subject: [PATCH 16/51] no debug --- .../contrib_ops/cpu/transformers/generation_shared.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 411bad3a5856e..72381be677b3a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -8,9 +8,9 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -//#ifndef NDEBUG -#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -//#endif +#ifndef NDEBUG +//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search +#endif namespace onnxruntime { From b985964791126103c3ea484bba561359a616ea5c Mon Sep 17 00:00:00 2001 From: wangye Date: Thu, 3 Nov 2022 06:28:06 +0000 Subject: [PATCH 17/51] update --- .../transformers/generation_device_helper.cc | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 44a6a6fed13c4..21f91a5c6fd8c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -18,6 +18,10 @@ #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" +#ifdef DEBUG_GENERATION +#include +#endif + namespace onnxruntime { namespace concurrency { class ThreadPool; @@ -556,6 +560,8 @@ Status GreedySearchProcessLogits( // TODO(wy): support output_scores in greedy search ORT_UNUSED_PARAMETER(output_scores); + BufferUniquePtr workspace_buffer; + BufferUniquePtr storage_buffer; if (do_sampling) { // bugbug: move this outside probably in execute()? size_t bytes = SafeInt(sizeof(int)) * 2 * parameters->batch_size * parameters->vocab_size + @@ -565,7 +571,8 @@ Status GreedySearchProcessLogits( SafeInt(sizeof(float)) * parameters->batch_size + SafeInt(sizeof(int64_t)) * parameters->batch_size; void* data = allocator->Alloc(bytes); - BufferUniquePtr workspace_buffer(data, BufferDeleter(allocator)); + BufferUniquePtr workspace_buffer_temp(data, BufferDeleter(allocator)); + workspace_buffer = std::move(workspace_buffer_temp); int* d_index_buffer_in = reinterpret_cast(workspace_buffer.get()); int* d_index_buffer_out = d_index_buffer_in + parameters->batch_size * parameters->vocab_size; int* d_offset_buffer = d_index_buffer_out + parameters->batch_size * parameters->vocab_size; @@ -594,8 +601,8 @@ Status GreedySearchProcessLogits( void* temp_storage = allocator->Alloc(temp_storage_bytes); BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); - - cuda::LaunchSortPairsDescending(temp_storage_buffer.get(), + storage_buffer = std::move(temp_storage_buffer); + cuda::LaunchSortPairsDescending(storage_buffer.get(), temp_storage_bytes, reinterpret_cast(next_token_scores.data()), d_sorted_score_buffer, @@ -607,7 +614,8 @@ Status GreedySearchProcessLogits( cuda_stream); #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score_buffer), batch_size, vocab_size); + //cudaDeviceSynchronize(); + //dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score_buffer), batch_size, vocab_size); dumper->Print("d_index_buffer_in", d_index_buffer_in, batch_size, vocab_size); dumper->Print("d_index_buffer_out", d_index_buffer_out, batch_size, vocab_size); #endif @@ -620,6 +628,7 @@ Status GreedySearchProcessLogits( parameters->batch_size); #ifdef DEBUG_GENERATION + //cudaDeviceSynchronize(); dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score_buffer, batch_size, vocab_size); #endif @@ -633,7 +642,8 @@ Status GreedySearchProcessLogits( cuda_stream); #ifdef DEBUG_GENERATION - dumper->Print("next_token_scores after filtering logits", next_token_scores.data(), batch_size, vocab_size); + //cudaDeviceSynchronize(); + //dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); #endif // bugbug: actually we can only do softmax at the very beginning and sort the softmaxed scores. @@ -646,6 +656,7 @@ Status GreedySearchProcessLogits( parameters->batch_size); #ifdef DEBUG_GENERATION + //cudaDeviceSynchronize(); dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score_buffer, batch_size, vocab_size); #endif @@ -655,11 +666,11 @@ Status GreedySearchProcessLogits( std::vector sampled(parameters->batch_size); for (int i = 0; i < parameters->batch_size; ++i) { sampled[i] = distribution(generator); - } - #ifdef DEBUG_GENERATION - dumper->Print("sampled value on cpu", sampled.data(), batch_size, 1); + std::cout << "sampled value on cpu: " << sampled[i] << std::endl; #endif + } + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled, sampled.data(), sizeof(float) * parameters->batch_size, @@ -673,6 +684,7 @@ Status GreedySearchProcessLogits( cuda_stream); #ifdef DEBUG_GENERATION + //cudaDeviceSynchronize(); dumper->Print("d_sampled", d_sampled, batch_size, 1); dumper->Print("d_indices", d_indices, batch_size, 1); #endif @@ -683,8 +695,6 @@ Status GreedySearchProcessLogits( cudaMemcpyDeviceToHost, cuda_stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); - - cudaDeviceSynchronize(); } else { // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; From 7830bec8eea93b3ab8869e59b30a9580bb7d9a08 Mon Sep 17 00:00:00 2001 From: wangye Date: Thu, 3 Nov 2022 19:33:48 +0000 Subject: [PATCH 18/51] update --- .../transformers/generation_device_helper.cc | 15 +- .../transformers/convert_generation_temp.py | 1590 +++++++++++++++++ 2 files changed, 1594 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/convert_generation_temp.py diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 21f91a5c6fd8c..e2478aa3b9b91 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -614,8 +614,7 @@ Status GreedySearchProcessLogits( cuda_stream); #ifdef DEBUG_GENERATION - //cudaDeviceSynchronize(); - //dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score_buffer), batch_size, vocab_size); + dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score_buffer), batch_size, vocab_size); dumper->Print("d_index_buffer_in", d_index_buffer_in, batch_size, vocab_size); dumper->Print("d_index_buffer_out", d_index_buffer_out, batch_size, vocab_size); #endif @@ -628,7 +627,6 @@ Status GreedySearchProcessLogits( parameters->batch_size); #ifdef DEBUG_GENERATION - //cudaDeviceSynchronize(); dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score_buffer, batch_size, vocab_size); #endif @@ -642,8 +640,7 @@ Status GreedySearchProcessLogits( cuda_stream); #ifdef DEBUG_GENERATION - //cudaDeviceSynchronize(); - //dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); + dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); #endif // bugbug: actually we can only do softmax at the very beginning and sort the softmaxed scores. @@ -656,14 +653,14 @@ Status GreedySearchProcessLogits( parameters->batch_size); #ifdef DEBUG_GENERATION - //cudaDeviceSynchronize(); dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score_buffer, batch_size, vocab_size); #endif // multinomial sampling - std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(parameters->seed)}; + std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(parameters->seed + step)}; std::uniform_real_distribution distribution(0.0, 1.0); std::vector sampled(parameters->batch_size); + distribution(generator); // the first one is subnormal numbers for (int i = 0; i < parameters->batch_size; ++i) { sampled[i] = distribution(generator); #ifdef DEBUG_GENERATION @@ -684,7 +681,6 @@ Status GreedySearchProcessLogits( cuda_stream); #ifdef DEBUG_GENERATION - //cudaDeviceSynchronize(); dumper->Print("d_sampled", d_sampled, batch_size, 1); dumper->Print("d_indices", d_indices, batch_size, 1); #endif @@ -733,9 +729,6 @@ Status GreedySearchProcessLogits( CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); } -#ifdef DEBUG_GENERATION - dumper->Print("greedy_state->next_tokens", greedy_state->next_tokens.data(), batch_size, 1); -#endif return Status::OK(); } diff --git a/onnxruntime/python/tools/transformers/convert_generation_temp.py b/onnxruntime/python/tools/transformers/convert_generation_temp.py new file mode 100644 index 0000000000000..a22e3f67e2e93 --- /dev/null +++ b/onnxruntime/python/tools/transformers/convert_generation_temp.py @@ -0,0 +1,1590 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- +""" +This converts GPT2 or T5 model to onnx with beam search operator. + +Example 1: convert gpt2 model with beam search: + python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx + +Example 2: convert T5 model with beam search in two steps: + cd ./models/t5 + python convert_to_onnx.py -m t5-small + cd ../.. + python convert_generation.py -m t5-small --model_type t5 \ + --decoder_onnx ./models/t5/onnx_models/t5-small_decoder.onnx \ + --encoder_decoder_init_onnx ./models/t5/onnx_models/t5-small_encoder_decoder_init.onnx \ + --output ./models/t5/onnx_models/t5_small_beam_search.onnx + +Example 3: convert T5 model with beam search. All in one step: + python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx + +Example 4: convert MT5 model with external data file like mt5-base-beamsearch.onnx.data in below example. + python convert_generation.py -m google/mt5-base --model_type mt5 --output mt5-base-beamsearch.onnx -e + +Example 5: convert gpt2 model with sampling: + python convert_generation_temp.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 +""" + +import argparse +import logging +import os +import sys +import time +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import onnx +import torch +from benchmark_helper import Precision +from onnx import GraphProto, ModelProto, TensorProto +from transformers import ( + GPT2Config, + GPT2LMHeadModel, + GPT2Tokenizer, + MT5Config, + MT5ForConditionalGeneration, + T5Config, + T5ForConditionalGeneration, + T5Tokenizer, +) + +from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers + +sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2")) +from gpt2_helper import PRETRAINED_GPT2_MODELS # noqa: E402 +from models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx # noqa: E402 + +sys.path.append(os.path.join(os.path.dirname(__file__), "models", "t5")) +from benchmark_helper import setup_logger +from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models # noqa: E402 +from models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS # noqa: E402 +from onnx_model import OnnxModel + +logger = logging.getLogger("") + + +class GenerationType(Enum): + BEAMSEARCH = "beam_search" + GREEDYSEARCH = "greedy_search" + + def __str__(self): + return self.value + + +def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: + """Parse arguments + + Args: + argv (Optional[List[str]], optional): _description_. Defaults to None. + + Returns: + argparse.Namespace: Parsed arguments. + """ + parser = argparse.ArgumentParser() + + input_group = parser.add_argument_group("Input options") + + input_group.add_argument( + "-m", + "--model_name_or_path", + required=True, + type=str, + help="Pytorch model checkpoint path, or pretrained model name in the list: " + + ", ".join(PRETRAINED_GPT2_MODELS + PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS), + ) + + input_group.add_argument( + "--model_type", + required=False, + type=str, + default="gpt2", + choices=["gpt2", "t5", "mt5"], + help="Model type (default is gpt2) in the list: " + ", ".join(["gpt2", "t5", "mt5"]), + ) + + input_group.add_argument( + "--cache_dir", + required=False, + type=str, + default=os.path.join(".", "cache_models"), + help="Directory to cache pre-trained models", + ) + + input_group.add_argument( + "--decoder_onnx", + required=False, + type=str, + default="", + help="Path of onnx model for decoder. Specify it when you have exported the model.", + ) + + input_group.add_argument( + "--encoder_decoder_init_onnx", + required=False, + type=str, + default="", + help="Path of ONNX model for encoder and decoder initialization. Specify it when you have exported the model.", + ) + + parser.add_argument( + "--verbose", + required=False, + action="store_true", + help="Print more information", + ) + parser.set_defaults(verbose=False) + + output_group = parser.add_argument_group("Output options") + + output_group.add_argument( + "--output", + required=True, + type=str, + help="Output path for onnx model with beam search.", + ) + + output_group.add_argument( + "-p", + "--precision", + required=False, + type=Precision, + default=Precision.FLOAT32, + choices=[Precision.FLOAT32, Precision.FLOAT16], + help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision", + ) + + output_group.add_argument( + "-e", + "--use_external_data_format", + required=False, + action="store_true", + help="save external data for model > 2G", + ) + output_group.set_defaults(use_external_data_format=False) + + output_group.add_argument( + "-s", "--run_shape_inference", required=False, action="store_true", help="run shape inference" + ) + output_group.set_defaults(run_shape_inference=False) + + output_group.add_argument( + "-i", + "--disable_shared_initializers", + required=False, + action="store_true", + help="do not share initializers in encoder and decoder. It will increase memory usage of t5/mt5 models.", + ) + output_group.set_defaults(disable_shared_initializers=False) + + model_group = parser.add_argument_group("Beam search parameters that stored in the output model") + + model_group.add_argument( + "--output_sequences_scores", + required=False, + action="store_true", + help="output sequences scores", + ) + model_group.set_defaults(output_sequences_scores=False) + + model_group.add_argument( + "--output_token_scores", + required=False, + action="store_true", + help="output token scores", + ) + model_group.set_defaults(output_token_scores=False) + + model_group.add_argument("--early_stopping", required=False, action="store_true") + model_group.set_defaults(early_stopping=False) + + model_group.add_argument( + "--no_repeat_ngram_size", + type=int, + required=False, + default=0, + help="No repeat ngram size", + ) + + model_group.add_argument( + "--vocab_mask", + required=False, + action="store_true", + help="Enable vocab_mask. This mask applies only to every generated token to filter some bad words.", + ) + model_group.set_defaults(vocab_mask=False) + + model_group.add_argument( + "--prefix_vocab_mask", + required=False, + action="store_true", + help="Enable prefix_vocab_mask. This mask can be used to filter bad words in the first generated token only", + ) + model_group.set_defaults(prefix_vocab_mask=False) + + model_group.add_argument( + "--custom_attention_mask", + required=False, + action="store_true", + help="Enable custom_attention_mask. This mask can be used to replace default encoder attention mask", + ) + model_group.set_defaults(custom_attention_mask=False) + + beam_parameters_group = parser.add_argument_group( + "Beam search parameters not stored in the output model, for testing parity and performance" + ) + + beam_parameters_group.add_argument("--min_length", type=int, required=False, default=1, help="Min sequence length") + + beam_parameters_group.add_argument("--max_length", type=int, required=False, default=50, help="Max sequence length") + + beam_parameters_group.add_argument("--num_beams", type=int, required=False, default=4, help="Beam size") + + beam_parameters_group.add_argument( + "--num_return_sequences", + type=int, + required=False, + default=1, + help="Number of return sequence <= num_beams", + ) + + beam_parameters_group.add_argument( + "--length_penalty", + type=float, + required=False, + default=1, + help="Positive. >1 to penalize and <1 to encourage short sentence.", + ) + + beam_parameters_group.add_argument( + "--repetition_penalty", + type=float, + required=False, + default=1, + help="Positive. >1 to penalize and <1 to encourage.", + ) + + beam_parameters_group.add_argument( + "--vocab_size", + type=int, + required=False, + default=-1, + help="Vocab_size of the underlying model used to decide the shape of vocab mask", + ) + + test_group = parser.add_argument_group("Other options for testing parity and performance") + + test_group.add_argument( + "--use_gpu", required=False, action="store_true", help="use GPU for inference. Required for fp16." + ) + test_group.set_defaults(use_gpu=False) + + test_group.add_argument( + "--disable_parity", + required=False, + action="store_true", + help="do not run parity test", + ) + test_group.set_defaults(disable_parity=False) + + test_group.add_argument( + "--torch_performance", + required=False, + action="store_true", + help="test PyTorch performance", + ) + test_group.set_defaults(torch_performance=False) + + test_group.add_argument( + "--total_runs", + required=False, + type=int, + default=1, + help="Number of times of inference for latency measurement", + ) + + test_group.add_argument( + "--save_test_data", + required=False, + action="store_true", + help="save test data for onnxruntimer_perf_test tool", + ) + test_group.set_defaults(save_test_data=False) + + args = parser.parse_args(argv) + + return args + + +def gpt2_to_onnx(args: argparse.Namespace): + """Convert GPT-2 model to onnx + + Args: + args (argparse.Namespace): arguments parsed from command line + """ + model_name = args.model_name_or_path + + arguments = [ + "--model_name_or_path", + model_name, + "--output", + args.decoder_onnx, + "--optimize_onnx", + "--precision", + "fp32" if args.precision == Precision.FLOAT32 else "fp16", + "--test_runs", + "1", + "--test_cases", + "10", + "--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, position_ids and attention_mask + "--overwrite", # Overwrite onnx file if existed + ] + if args.use_gpu: + arguments.append("--use_gpu") + if args.use_external_data_format: + arguments.append("--use_external_data_format") + + if args.precision == Precision.FLOAT16: + assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu" + # TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision') + # Need change cuda kernel to support a combination of fp32 logits and fp16 past state. + # Currently logits and past state shall be same data type. + arguments.extend(["--op_block_list", "Add", "LayerNormalization", "FastGelu"]) + + if args.verbose: + logger.info(f"arguments for convert_to_onnx:{arguments}") + + convert_gpt2_to_onnx(argv=arguments) + + +def t5_to_onnx(args: argparse.Namespace): + """Convert T5 model to onnx + + Args: + args (argparse.Namespace): arguments parsed from command line + """ + paths = export_t5_onnx_models( + args.model_name_or_path, + args.cache_dir, + Path(args.output).parent, + use_gpu=args.use_gpu, + use_external_data_format=args.use_external_data_format, + optimize_onnx=False, + precision=args.precision, + verbose=False, + use_decoder_start_token=False, + merge_encoder_and_decoder_init=True, + overwrite=True, + disable_auto_mixed_precision=False, + use_int32_inputs=True, + model_type=args.model_type, + ) + + logger.debug(f"onnx model for encoder: {paths[0]}") + logger.debug(f"onnx model for decoder: {paths[1]}") + args.encoder_decoder_init_onnx = paths[0] + args.decoder_onnx = paths[1] + + +def shape_inference(onnx_path: str, use_external_data_format: bool = True): + """Shape inference on an onnx file, which will be overwritten. + + Args: + onnx_path (str): Path of onnx model + use_external_data_format(bool): output tensors to external data or not. + """ + # Run symbolic shape inference to walk around ORT shape inference issue for subgraph. + from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + + model = onnx.load_model(onnx_path, load_external_data=True) + out = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=False) + if out: + OnnxModel.save(out, onnx_path, save_as_external_data=use_external_data_format) + else: + logger.warning("Failed to run symbolic shape inference on the model.") + + +def create_ort_session(model_path: str, use_gpu: bool) -> InferenceSession: + """Create OnnxRuntime session. + + Args: + model_path (str): onnx model path + use_gpu (bool): use GPU or not + + Raises: + RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified. + + Returns: + onnxruntime.InferenceSession: The created session. + """ + sess_options = SessionOptions() + sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL + execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"] + if use_gpu: + if "CUDAExecutionProvider" not in get_available_providers(): + raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!") + else: + logger.info("use CUDAExecutionProvider") + + ort_session = InferenceSession(model_path, sess_options, providers=execution_providers) + return ort_session + + +def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision): + """Verify GPT-2 subgraph + + Args: + graph (onnx.GraphProto): onnx graph of GPT-2 + precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. + + Raises: + ValueError: Number of inputs not expected. + ValueError: Input name is not expected. + ValueError: Input data type is not expected. + ValueError: Number of outputs not expected. + ValueError: Output name is not expected. + ValueError: Output data type is not expected. + """ + is_float16 = Precision.FLOAT16 == precision + + input_count = len(graph.input) + layer_count = input_count - 3 + assert layer_count >= 1 + + expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)] + if len(graph.input) != len(expected_inputs): + raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") + + for i, expected_input in enumerate(expected_inputs): + if graph.input[i].name != expected_input: + raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") + + expected_type = TensorProto.INT32 + if i >= 3: + expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT + + input_type = graph.input[i].type.tensor_type.elem_type + if input_type != expected_type: + raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") + logger.info("Verifying GPT-2 graph inputs: name and data type are good.") + + expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)] + if len(graph.output) != len(expected_outputs): + raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") + + for i, expected_output in enumerate(expected_outputs): + if graph.output[i].name != expected_output: + raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") + + expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT + output_type = graph.output[i].type.tensor_type.elem_type + if output_type != expected_type: + raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}") + logger.info("Verifying GPT-2 graph outputs: name and data type are good.") + + # TODO(tianleiwu): verify shapes of inputs and outputs. + return + + +def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision): + """Verify T5 decoder subgraph + + Args: + graph (onnx.GraphProto): onnx graph of T5 decoder + precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. + + Raises: + ValueError: Number of inputs not expected. + ValueError: Input name is not expected. + ValueError: Input data type is not expected. + ValueError: Number of outputs not expected. + ValueError: Output name is not expected. + ValueError: Output data type is not expected. + """ + is_float16 = Precision.FLOAT16 == precision + float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT + + input_count = len(graph.input) + layer_count = (input_count - 3) // 4 + assert layer_count >= 1 + + # Expect inputs: + # input_ids: int32 (B, 1) + # encoder_attention_mask: int32 (B, encode_sequence_length) + # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + + # past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) + # past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) + # ... (for each self attention layer) + + # past_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + # past_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + # ... (for each cross attention layer) + + # TODO: encoder_hidden_states is optional + expected_inputs = ["input_ids", "encoder_attention_mask", "encoder_hidden_states"] + for i in range(layer_count): + expected_inputs.append(f"past_key_self_{i}") + expected_inputs.append(f"past_value_self_{i}") + for i in range(layer_count): + expected_inputs.append(f"past_key_cross_{i}") + expected_inputs.append(f"past_value_cross_{i}") + + if len(graph.input) != len(expected_inputs): + raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") + + for i, expected_input in enumerate(expected_inputs): + if graph.input[i].name != expected_input: + raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") + + expected_type = TensorProto.INT32 if i < 2 else float_type + input_type = graph.input[i].type.tensor_type.elem_type + if input_type != expected_type: + raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") + + # Expect outputs: + # logits: (B, 1, vocab_size) + # present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) + # present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) + # ... (for each self attention layer) + expected_outputs = ["logits"] + for i in range(layer_count): + expected_outputs.append(f"present_key_self_{i}") + expected_outputs.append(f"present_value_self_{i}") + + if len(graph.output) != len(expected_outputs): + raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") + + for i, expected_output in enumerate(expected_outputs): + if graph.output[i].name != expected_output: + raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") + output_type = graph.output[i].type.tensor_type.elem_type + if output_type != float_type: + raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}") + + +def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision): + """Verify T5 decoder subgraph + + Args: + graph (onnx.GraphProto): onnx graph of T5 decoder + precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. + + Raises: + ValueError: Number of inputs not expected. + ValueError: Input name is not expected. + ValueError: Input data type is not expected. + ValueError: Number of outputs not expected. + ValueError: Output name is not expected. + ValueError: Output data type is not expected. + """ + is_float16 = Precision.FLOAT16 == precision + layer_count = (len(graph.output) - 2) // 4 + assert layer_count >= 1 + + # Expect 3 inputs: + # encoder_input_ids: int32 (B, encode_sequence_length) + # encoder_attention_mask: int32 (B, encode_sequence_length) + # decoder_input_ids: int32 (B, 1) + expected_inputs = ["encoder_input_ids", "encoder_attention_mask", "decoder_input_ids"] + if len(graph.input) != len(expected_inputs): + raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") + + for i, expected_input in enumerate(expected_inputs): + if graph.input[i].name != expected_input: + raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") + + expected_type = TensorProto.INT32 + input_type = graph.input[i].type.tensor_type.elem_type + if input_type != expected_type: + raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") + + # Expected outputs: + # logits: (B, 1, vocab_size) + # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + # present_key_self_0: (B, num_heads, 1, head_size) + # present_value_self_0: (B, num_heads, 1, head_size) + # ... (for each self attention layer) + # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + # ... (for each cross attention layer) + expected_outputs = ["logits", "encoder_hidden_states"] + for i in range(layer_count): + expected_outputs.append(f"present_key_self_{i}") + expected_outputs.append(f"present_value_self_{i}") + for i in range(layer_count): + expected_outputs.append(f"present_key_cross_{i}") + expected_outputs.append(f"present_value_cross_{i}") + + if len(graph.output) != len(expected_outputs): + raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") + + for i, expected_output in enumerate(expected_outputs): + if graph.output[i].name != expected_output: + raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") + + expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT + output_type = graph.output[i].type.tensor_type.elem_type + if output_type != expected_type: + raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}") + + logger.info("T5 encoder graph verified: name and data type of inputs and outputs are good.") + + +def remove_shared_initializers( + graph1: GraphProto, + graph2: GraphProto, + shared_prefix: str = "shared_", + min_elements: int = 1024, +): + """Remove initializers with same value from two graphs. + + Args: + graph1 (GraphProto): the first graph to process + graph2 (GraphProto): the second graph to process + shared_prefix (str): add prefix to the shared initializers among two graphs + min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024. + """ + + mapping_initializers_1 = {} + mapping_initializers_2 = {} + shared_initializers_1 = [] + shared_initializers_2 = [] + shared_initializers_names = [] + + for initializer1 in graph1.initializer: + if not (initializer1.dims and sum(initializer1.dims) >= min_elements): + continue + + for initializer2 in graph2.initializer: + if not (initializer2.dims and sum(initializer2.dims) >= min_elements): + continue + + if OnnxModel.has_same_value(initializer1, initializer2): + mapping_initializers_1[initializer1.name] = shared_prefix + initializer2.name + shared_initializers_1.append(initializer1) + + if initializer2.name not in mapping_initializers_2: + shared_name = shared_prefix + initializer2.name + mapping_initializers_2[initializer2.name] = shared_name + shared_initializers_2.append(initializer2) + shared_initializers_names.append(shared_name) + break + + logger.debug(f"shared initializers:{shared_initializers_names}") + + # Make sure new name does not exist in graph 1 + for node in graph1.node: + for j in range(len(node.input)): + if node.input[j] in shared_initializers_names: + raise RuntimeError(f"name is found in graph 1: {node.input[j]}") + + # Make sure new name does not exist in graph 2 + for node in graph2.node: + for j in range(len(node.input)): + if node.input[j] in shared_initializers_names: + raise RuntimeError(f"name is found in graph 2: {node.input[j]}") + + # Remove shared initializers from graph 2 + for initializer in shared_initializers_2: + graph2.initializer.remove(initializer) + + # Rename value info for old names in graph 2 + for value_info in graph2.value_info: + if value_info.name in mapping_initializers_2: + value_info.name = mapping_initializers_2[value_info.name] + + # Rename nodes inputs in graph 2: + for node in graph2.node: + for j in range(len(node.input)): + if node.input[j] in mapping_initializers_2: + new_name = mapping_initializers_2[node.input[j]] + logger.debug(f"graph 2 rename node {node.name} input {j} from {node.input[j]} to {new_name}") + node.input[j] = new_name + + # Remove shared initializers from graph 1 + for initializer in shared_initializers_1: + graph1.initializer.remove(initializer) + + # Rename value info for old names in graph 1 + for value_info in graph1.value_info: + if value_info.name in mapping_initializers_1: + value_info.name = mapping_initializers_1[value_info.name] + + # Rename nodes inputs in graph 1: + for node in graph1.node: + for j in range(len(node.input)): + if node.input[j] in mapping_initializers_1: + new_name = mapping_initializers_1[node.input[j]] + logger.debug(f"graph 1 rename node {node.name} input {j} from {node.input[j]} to {new_name}") + node.input[j] = new_name + + # Rename shared initializers in graph 2 + for initializer in shared_initializers_2: + initializer.name = mapping_initializers_2[initializer.name] + + for initializer in shared_initializers_2: + shape = onnx.numpy_helper.to_array(initializer).shape + value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape) + # Need add value_info for initializers moved to parent graph. Otherwise, ORT will fail. + graph1.value_info.append(value_info) + graph2.value_info.append(value_info) + + return shared_initializers_2 + + +def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto): + encoder = OnnxModel(encoder_model) + decoder = OnnxModel(decoder_model) + encoder.add_prefix_to_names("e_") + decoder.add_prefix_to_names("d_") + encoder.remove_duplicated_initializer() + decoder.remove_duplicated_initializer() + initializers = remove_shared_initializers(encoder.model.graph, decoder.model.graph, "s_") + return initializers + + +def move_initializers( + graph: GraphProto, + min_elements: int = 1024, +) -> List[TensorProto]: + """Remove initializers of a graph, when they have number of elements larger than a threshold. + + Args: + graph (GraphProto): the graph. + min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024. + + Returns: + List[TensorProto]: initializers that are removed from the graph. + """ + moved_initializers = [] + for tensor in graph.initializer: + if not (tensor.dims and sum(tensor.dims) >= min_elements): + continue + moved_initializers.append(tensor) + + for initializer in moved_initializers: + graph.initializer.remove(initializer) + + # Add type info, otherwise ORT will raise error: "input arg (*) does not have type information set by parent node." + for initializer in moved_initializers: + shape = onnx.numpy_helper.to_array(initializer).shape + value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape) + graph.value_info.append(value_info) + + return moved_initializers + + +def convert_generation_model(args: argparse.Namespace, generation_type: GenerationType = GenerationType.BEAMSEARCH): + """Convert model according to command line arguments. + + Args: + args (argparse.Namespace): arguments parsed from command line + """ + is_gpt2: bool = args.model_type == "gpt2" + is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH + + if is_greedysearch: + if not is_gpt2: + raise NotImplementedError("Currently only gpt2 with greedy search is supported") + if args.output_sequences_scores: + raise NotImplementedError("output_sequences_scores currently is not supported in greedy search") + if args.output_token_scores: + raise NotImplementedError("output_token_scores currently is not supported in greedy search") + + if is_gpt2: + if args.decoder_onnx and os.path.exists(args.decoder_onnx): + logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}") + else: + if not args.decoder_onnx: + onnx_filename = "gpt2_past_{}.onnx".format("fp16" if args.precision == Precision.FLOAT16 else "fp32") + args.decoder_onnx = Path(Path(args.output).parent, onnx_filename).as_posix() + + logger.info(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...") + gpt2_to_onnx(args) + else: # t5 or mt5 + if args.decoder_onnx and args.encoder_decoder_init_onnx: + logger.info( + f"skip convert_to_onnx since paths specified: {args.decoder_onnx} and {args.encoder_decoder_init_onnx}" + ) + else: + logger.info(f"Convert model {args.model_name_or_path} to onnx ...") + t5_to_onnx(args) + + if args.run_shape_inference: + logger.info(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.") + shape_inference(args.decoder_onnx, args.use_external_data_format) + + if is_gpt2: + config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + elif args.model_type == "t5": + config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + else: + config = MT5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + + if args.verbose: + logger.info(f"Config={config}") + + eos_token_id = config.eos_token_id + pad_token_id = config.eos_token_id if is_gpt2 else config.pad_token_id + vocab_size = config.vocab_size + + # if vocab_size is given in parameters use that. + if args.vocab_size != -1: + vocab_size = args.vocab_size + + decoder_model = onnx.load_model(args.decoder_onnx, load_external_data=True) + decoder_model.graph.name = f"{args.model_type} decoder" + + if args.model_type == "gpt2": + verify_gpt2_subgraph(decoder_model.graph, args.precision) + else: + verify_t5_decoder_subgraph(decoder_model.graph, args.precision) + + inputs = ( + [ + "input_ids", + "max_length", + "min_length", + "num_beams", + "num_return_sequences", + "length_penalty", + "repetition_penalty", + ] + if not is_greedysearch + else [ + "input_ids", + "max_length", + "min_length", + "repetition_penalty", + ] + ) + + outputs = ["sequences"] + if args.output_sequences_scores: + outputs.append("sequences_scores") + + if args.output_token_scores: + assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" + outputs.append("scores") + + node = ( + onnx.helper.make_node( + "BeamSearch", + inputs=inputs, + outputs=outputs, + name=f"BeamSearch_{args.model_type}", + ) + if not is_greedysearch + else onnx.helper.make_node( + "Sampling", + inputs=inputs, + outputs=outputs, + name=f"GreedySearch_{args.model_type}", + ) + ) + + node.domain = "com.microsoft" + + attr_to_extend = ( + [ + onnx.helper.make_attribute("eos_token_id", eos_token_id), + onnx.helper.make_attribute("pad_token_id", pad_token_id), + onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), + onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), + ] + if not is_greedysearch + else [ + onnx.helper.make_attribute("eos_token_id", eos_token_id), + onnx.helper.make_attribute("pad_token_id", pad_token_id), + onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), + onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + onnx.helper.make_attribute("temperature", 1.1), + onnx.helper.make_attribute("top_p", 0.6), + onnx.helper.make_attribute("filter_value", 0.0), + onnx.helper.make_attribute("presence_penalty", 0.0), + ] + ) + node.attribute.extend(attr_to_extend) + + initializers = [] + if args.model_type in ["t5", "mt5"]: + if args.run_shape_inference: + logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.") + shape_inference(args.encoder_decoder_init_onnx, args.use_external_data_format) + encoder_model = onnx.load_model(args.encoder_decoder_init_onnx, load_external_data=True) + encoder_model.graph.name = f"{args.model_type} encoder and decoder init" + verify_t5_encoder_decoder_init_subgraph(encoder_model.graph, args.precision) + + if not args.disable_shared_initializers: + # Unique shared initializers from the decoder and decoder_init could reduce memory usage in inference. + initializers = get_shared_initializers(encoder_model, decoder_model) + logger.info( + f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in subgraphs are moved to the main graph" + ) + + # TODO(tianleiwu): investigate the following which causes error in inference + # Move initializer from subgraph to main graph could reduce memory usage in inference. + # moved_initializers = move_initializers(encoder_model.graph) + # logger.info( + # f"{len(moved_initializers)} initializers ({[i.name for i in moved_initializers]}) from the encoder are moved to the main graph" + # ) + # initializers.extend(moved_initializers) + + node.attribute.extend( + [ + onnx.helper.make_attribute("encoder", encoder_model.graph), + onnx.helper.make_attribute("decoder", decoder_model.graph), + onnx.helper.make_attribute( + "decoder_start_token_id", + config.decoder_start_token_id if len(encoder_model.graph.input) == 3 else -1, + ), + ] + ) + else: + # Move initializer from subgraph to main graph could reduce memory usage in inference. + initializers = move_initializers(decoder_model.graph) + logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph") + + node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph)) + + # graph inputs + input_ids = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"]) + max_length = onnx.helper.make_tensor_value_info("max_length", TensorProto.INT32, [1]) + min_length = onnx.helper.make_tensor_value_info("min_length", TensorProto.INT32, [1]) + num_beams = onnx.helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1]) + num_return_sequences = onnx.helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) + length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) + repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) + + graph_inputs = ( + [ + input_ids, + max_length, + min_length, + num_beams, + num_return_sequences, + length_penalty, + repetition_penalty, + ] + if not is_greedysearch + else [ + input_ids, + max_length, + min_length, + repetition_penalty, + ] + ) + + if args.vocab_mask: + vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size]) + graph_inputs.append(vocab_mask) + + if args.prefix_vocab_mask: + prefix_vocab_mask = onnx.helper.make_tensor_value_info( + "prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size] + ) + graph_inputs.append(prefix_vocab_mask) + + if args.custom_attention_mask: + attention_mask = onnx.helper.make_tensor_value_info( + "attention_mask", TensorProto.INT32, ["batch_size", "sequence_length"] + ) + graph_inputs.append(attention_mask) + + # graph outputs + sequences = ( + onnx.helper.make_tensor_value_info( + "sequences", + TensorProto.INT32, + ["batch_size", "num_return_sequences", "max_length"], + ) + if not is_greedysearch + else onnx.helper.make_tensor_value_info( + "sequences", + TensorProto.INT32, + ["batch_size", "max_length"], + ) + ) + + sequences_scores = onnx.helper.make_tensor_value_info( + "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] + ) + + scores = onnx.helper.make_tensor_value_info( + "scores", + TensorProto.FLOAT, + ["max_length - sequence_length", "batch_size", "num_beams", vocab_size], + ) + + graph_outputs = [sequences] + + if args.output_sequences_scores: + graph_outputs.append(sequences_scores) + + if args.output_token_scores: + graph_outputs.append(scores) + + new_graph = onnx.helper.make_graph( + [node], + f"{args.model_type} beam search" if not is_greedysearch else f"{args.model_type} greedy search", + graph_inputs, + graph_outputs, + initializers, + ) + + # Create the model + new_model = onnx.helper.make_model( + new_graph, + producer_name="onnxruntime.transformers", + opset_imports=decoder_model.opset_import, + ) + + # TODO(tianleiwu): move shared initializers from T5 encoder and decoder subgraphs to parent graph to save memory. + if args.use_external_data_format: + from packaging import version + + if version.parse(onnx.__version__) < version.parse("1.12.0"): + logger.warning("Require onnx >= 1.12 to save large (>2GB) model!") + + OnnxModel.save( + new_model, + args.output, + save_as_external_data=True, + all_tensors_to_one_file=True, + ) + else: + onnx.save(new_model, args.output) + logger.info(f"model save to {args.output}") + + +def test_torch_performance( + args: argparse.Namespace, + model: Union[GPT2LMHeadModel, T5ForConditionalGeneration], + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + eos_token_id: int, + pad_token_id: int, + bad_words_ids: List[List[int]], +) -> Dict[str, Any]: + """Test PyTorch performance of text generation. + + Args: + args (argparse.Namespace): arguments parsed from command line + model (Union[GPT2LMHeadModel, T5ForConditionalGeneration]): PyTorch model + input_ids (torch.Tensor): input_ids + attention_mask (torch.Tensor): Attention mask + eos_token_id (int): EOS token ID + pad_token_id (int): Padding token ID + bad_words_ids (List[List[int]]): Words shall not be generated. + + Raises: + RuntimeError: PyTorch with CUDA is not available for --use_gpu + + Returns: + Dict[str, Any]: A dictionary with string with metric name, and value can be integer or string. + """ + if args.use_gpu and not torch.cuda.is_available(): + raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.") + + if args.precision == Precision.FLOAT16: + model.half() + + device = torch.device("cuda:0" if args.use_gpu else "cpu") + model.to(device) + + torch.set_grad_enabled(False) + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + torch_latency = [] + for _ in range(args.total_runs): + start = time.time() + _ = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=args.max_length, + min_length=args.min_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty, + bad_words_ids=bad_words_ids, + return_dict_in_generate=True, + output_scores=args.output_sequences_scores or args.output_token_scores, + ) + torch_latency.append(time.time() - start) + batch_size = input_ids.shape[0] + from benchmark_helper import get_latency_result + + return get_latency_result(torch_latency, batch_size) + + +def create_attention_mask(input_ids, pad_token_id): + attention_mask = np.ones(input_ids.shape, dtype=np.int32) + for i in range(input_ids.shape[0]): + abs_pos = 0 + for j in range(input_ids.shape[1]): + if input_ids[i][j] == pad_token_id and abs_pos == 0: + attention_mask[i][j] = 0 + else: + abs_pos += 1 + return attention_mask + + +def test_gpt_model(args: argparse.Namespace, sentences: Optional[List[str]] = None, is_greedy: bool = False): + """Test GPT-2 model + + Args: + args (argparse.Namespace): arguments parsed from command line + sentences (Optional[List[str]], optional): input text. Defaults to None. + + Returns: + Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. + """ + assert args.model_type == "gpt2" + + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token + + model = GPT2LMHeadModel.from_pretrained( + args.model_name_or_path, + cache_dir=args.cache_dir, + pad_token_id=tokenizer.eos_token_id, + ) + + # Use different length sentences to test batching + if sentences is None: + sentences = [ + "The product is released", + "I enjoy walking in the park", + "Test best way to invest", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + bad_words = "walk in park" + bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True) + bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list + if args.vocab_mask: + logger.debug("bad_words_ids", bad_words_ids) + else: + bad_words_ids = [] + + config = model.config + eos_token_id = config.eos_token_id + pad_token_id = config.eos_token_id + vocab_size = config.vocab_size + + torch_decoded_sequences = [] + beam_outputs = None + if not args.disable_parity: + print("-" * 50) + print("Test PyTorch model and beam search with huggingface transformers...") + beam_outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=args.max_length, + min_length=args.min_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty, + bad_words_ids=bad_words_ids if bad_words_ids else None, + return_dict_in_generate=True, + output_scores=args.output_sequences_scores or args.output_token_scores, + ) + print("input_ids", input_ids) + print("huggingface transformers outputs:") + print("sequences", beam_outputs.sequences) + if args.output_sequences_scores: + print("sequences_scores", beam_outputs.sequences_scores) + if args.output_token_scores: + print("scores", beam_outputs.scores) + for i, sequence in enumerate(beam_outputs.sequences): + decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True) + torch_decoded_sequences.append(decoded_sequence) + print(f"{i}: {decoded_sequence}") + + print("-" * 50) + print("Testing beam search with onnxruntime...") + + ort_session = create_ort_session(args.output, args.use_gpu) + print("ort session created") + if is_greedy: + print("is_greedy") + inputs = { + "input_ids": input_ids.cpu().numpy().astype(np.int32), + "max_length": np.array([args.max_length], dtype=np.int32), + "min_length": np.array([args.min_length], dtype=np.int32), + "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), + } + print(inputs) + else: + inputs = { + "input_ids": input_ids.cpu().numpy().astype(np.int32), + "max_length": np.array([args.max_length], dtype=np.int32), + "min_length": np.array([args.min_length], dtype=np.int32), + "num_beams": np.array([args.num_beams], dtype=np.int32), + "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32), + "length_penalty": np.array([args.length_penalty], dtype=np.float32), + "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), + } + + if args.vocab_mask: + vocab_mask = np.ones((vocab_size), dtype=np.int32) + if args.vocab_mask: + for bad_word_id in bad_words_ids: + vocab_mask[bad_word_id] = 0 + inputs["vocab_mask"] = vocab_mask + + if args.custom_attention_mask: + inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id) + + batch_size = input_ids.shape[0] + if args.prefix_vocab_mask: + logger.info("Use prefix vocab mask with all ones in ORT, but no corresponding setting for Torch model.") + prefix_vocab_mask = np.ones((batch_size, vocab_size), dtype=np.int32) + inputs["prefix_vocab_mask"] = prefix_vocab_mask + + logger.debug("ORT inputs", inputs) + result = ort_session.run(None, inputs) + + if args.save_test_data: + test_data_dir = Path(args.output).parent.as_posix() + logger.debug("test_data_dir", test_data_dir) + from bert_test_data import output_test_data + + all_inputs = [inputs] + for i, inputs in enumerate(all_inputs): + dir = os.path.join(test_data_dir, "test_data_set_" + str(i)) + output_test_data(dir, inputs) + + # Test performance + latency = [] + for _ in range(args.total_runs): + start = time.time() + _ = ort_session.run(None, inputs) + latency.append(time.time() - start) + + from benchmark_helper import get_latency_result + + batch_size = input_ids.shape[0] + output = get_latency_result(latency, batch_size) + + print("ORT outputs:") + sequences = result[0] + print("sequences", sequences) + if args.output_sequences_scores: + print("sequences_scores", result[1]) + if args.output_token_scores: + print("scores", result[2]) + + if is_greedy: + (batch_size, max_length) = sequences.shape + ort_decoded_sequences = [] + for i in range(batch_size): + decoded_sequence = tokenizer.decode(sequences[i], skip_special_tokens=True) + ort_decoded_sequences.append(decoded_sequence) + print(f"batch {i} sequence: {decoded_sequence}") + else: + (batch_size, num_sequences, max_length) = sequences.shape + ort_decoded_sequences = [] + for i in range(batch_size): + for j in range(num_sequences): + decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) + ort_decoded_sequences.append(decoded_sequence) + print(f"batch {i} sequence {j}: {decoded_sequence}") + + if beam_outputs: + torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1) + ort_sequences = torch.LongTensor(sequences) + print("-" * 50) + print("Torch Sequences:") + print(torch_sequences) + print(torch_decoded_sequences) + print("-" * 50) + print("ORT Sequences:") + print(ort_sequences) + print(ort_decoded_sequences) + print("-" * 50) + # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. + is_same = torch_decoded_sequences == ort_decoded_sequences + print("Torch and ORT result is ", "same" if is_same else "different") + output["parity"] = is_same + + if args.torch_performance: + torch_latency_output = test_torch_performance( + args, + model, + input_ids, + attention_mask, + eos_token_id, + pad_token_id, + bad_words_ids, + ) + print("Torch Latency", torch_latency_output) + + print("ORT", output) + + return output + + +def test_t5_model(args: argparse.Namespace, sentences: Optional[List[str]] = None): + """Test T5 or MT5 model + + Args: + args (argparse.Namespace): arguments parsed from command line + sentences (Optional[List[str]], optional): input text. Defaults to None. + + Returns: + Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. + """ + assert args.model_type in ["t5", "mt5"] + + if args.prefix_vocab_mask: + logger.debug("Skipping parity test as prefix vocab mask is not implemented by Hugging Face") + return None + + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + tokenizer.padding_side = "left" + + if args.model_type == "t5": + model = T5ForConditionalGeneration.from_pretrained( + args.model_name_or_path, + cache_dir=args.cache_dir, + ) + else: + model = MT5ForConditionalGeneration.from_pretrained( + args.model_name_or_path, + cache_dir=args.cache_dir, + ) + + # Use different length sentences to test batching + if sentences is None: + sentences = [ + "translate English to French: The product is released", + "summarize: research continues to show that pets bring real health benefits to their owners." + + "Having a dog around can lead to lower levels of stress for both adults and kids.", + # "summarize: I enjoy walking in the park. It makes my mind feel calm and refreshed. " + # + "I enjoy looking at the trees, flowers, and wildlife around me, and listening to sound from natural.", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + bad_words = "walk in park" + bad_words_ids = tokenizer.encode(bad_words)[:-1] # exclude the last token (EOS) + bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list + if args.vocab_mask: + logger.debug("bad_words_ids", bad_words_ids) + else: + bad_words_ids = [] + + config = model.config + eos_token_id = config.eos_token_id + pad_token_id = config.pad_token_id + vocab_size = config.vocab_size + logger.debug(f"eos_token_id:{eos_token_id}, pad_token_id:{pad_token_id}, vocab_size:{vocab_size}") + + torch_decoded_sequences = [] + if not args.disable_parity: + print("-" * 50) + print("Test PyTorch model and beam search with huggingface transformers...") + beam_outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=args.max_length, + min_length=args.min_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty, + bad_words_ids=bad_words_ids if bad_words_ids else None, + return_dict_in_generate=True, + output_scores=args.output_sequences_scores or args.output_token_scores, + ) + + print("input_ids", input_ids) + print("huggingface transformers outputs:") + print("sequences", beam_outputs.sequences) + if args.output_sequences_scores: + print("sequences_scores", beam_outputs.sequences_scores) + if args.output_token_scores: + print("scores", beam_outputs.scores) + for i, sequence in enumerate(beam_outputs.sequences): + decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True) + torch_decoded_sequences.append(decoded_sequence) + print("{}: {}".format(i, decoded_sequence)) + + print("-" * 50) + print("Testing beam search with onnxruntime...") + + ort_session = create_ort_session(args.output, args.use_gpu) + + vocab_mask = np.ones((vocab_size), dtype=np.int32) + if args.vocab_mask: + for bad_word_id in bad_words_ids: + vocab_mask[bad_word_id] = 0 + + inputs = { + "input_ids": input_ids.cpu().numpy().astype(np.int32), + "max_length": np.array([args.max_length], dtype=np.int32), + "min_length": np.array([args.min_length], dtype=np.int32), + "num_beams": np.array([args.num_beams], dtype=np.int32), + "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32), + "length_penalty": np.array([args.length_penalty], dtype=np.float32), + "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), + } + + if args.vocab_mask: + inputs["vocab_mask"] = vocab_mask + + if args.custom_attention_mask: + inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id) + + if args.save_test_data: + test_data_dir = Path(args.output).parent.as_posix() + logger.debug("test_data_dir", test_data_dir) + from bert_test_data import output_test_data + + all_inputs = [inputs] + for i, inputs in enumerate(all_inputs): + dir = os.path.join(test_data_dir, "test_data_set_" + str(i)) + output_test_data(dir, inputs) + + logger.debug("ORT inputs", inputs) + + # Test performance + latency = [] + for _ in range(args.total_runs): + start = time.time() + result = ort_session.run(None, inputs) + latency.append(time.time() - start) + batch_size = input_ids.shape[0] + from benchmark_helper import get_latency_result + + output = get_latency_result(latency, batch_size) + + print("ORT outputs:") + sequences = result[0] + print("sequences", sequences) + if args.output_sequences_scores: + print("sequences_scores", result[1]) + if args.output_token_scores: + print("scores", result[2]) + + (batch_size, num_sequences, max_length) = sequences.shape + ort_decoded_sequences = [] + for i in range(batch_size): + for j in range(num_sequences): + decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) + ort_decoded_sequences.append(decoded_sequence) + print(f"batch {i} sequence {j}: {decoded_sequence}") + + if not args.disable_parity: + torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1) + ort_sequences = torch.LongTensor(sequences) + print("-" * 50) + print("Torch Sequences:") + print(torch_sequences) + print(torch_decoded_sequences) + print("-" * 50) + print("ORT Sequences:") + print(ort_sequences) + print(ort_decoded_sequences) + print("-" * 50) + # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. + is_same = torch_decoded_sequences == ort_decoded_sequences + print("Torch and ORT result is ", "same" if is_same else "different") + output["parity"] = is_same + + if args.torch_performance: + torch_latency_output = test_torch_performance( + args, + model, + input_ids, + attention_mask, + eos_token_id, + pad_token_id, + bad_words_ids, + ) + print("Torch Latency", torch_latency_output) + + print("ORT", output) + return output + + +def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None): + """Main entry function + + Args: + argv (Optional[List[str]], optional): _description_. Defaults to None. + sentences (Optional[List[str]], optional): input text. Defaults to None. + + Raises: + ValueError: Path does not exist: --encoder_decoder_init_onnx + ValueError: Path does not exist: --decoder_onnx + ValueError: --decoder_onnx and --encoder_decoder_init_onnx are not used together for T5 + + Returns: + Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. + """ + + args = parse_arguments(argv) + setup_logger(args.verbose) + + if args.model_type in ["t5", "mt5"]: + if args.encoder_decoder_init_onnx and not os.path.exists(args.encoder_decoder_init_onnx): + raise ValueError(f"Path does not exist: --encoder_decoder_init_onnx {args.encoder_decoder_init_onnx}") + if args.decoder_onnx and not os.path.exists(args.decoder_onnx): + raise ValueError(f"Path does not exist: --decoder_onnx {args.decoder_onnx}") + if (args.encoder_decoder_init_onnx and not args.decoder_onnx) or ( + args.decoder_onnx and not args.encoder_decoder_init_onnx + ): + raise ValueError("--decoder_onnx shall use together with --encoder_decoder_init_onnx") + + is_greedy = args.num_beams == 1 and args.num_return_sequences == 1 + + if args.model_type == "gpt2" and is_greedy: + convert_generation_model(args, GenerationType.GREEDYSEARCH) + else: + convert_generation_model(args) + + logger.info("start testing model...") + if args.model_type in ["t5", "mt5"]: + result = test_t5_model(args, sentences=sentences) + else: + result = test_gpt_model(args, sentences=sentences, is_greedy=is_greedy) + + if result: + if args.use_external_data_format: + logger.info(f"Output files: {args.output}, {args.output}.data") + else: + logger.info(f"Output file: {args.output}") + + return result + + +if __name__ == "__main__": + main() From 33d232ee80b08df8ff41da9c214d41676cc7b613 Mon Sep 17 00:00:00 2001 From: wangyems Date: Fri, 4 Nov 2022 02:36:25 +0000 Subject: [PATCH 19/51] support 4d mask --- .../cpu/transformers/generate_impl_base.h | 18 +++++--- .../transformers/generation_device_helper.cc | 43 ++++++++++++------- .../transformers/generation_device_helper.cc | 29 +++++++------ 3 files changed, 55 insertions(+), 35 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index 751d5cce18f33..cbdd18410b3c3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -139,13 +139,19 @@ class GenerateBase { if (attention_mask != nullptr) { const auto& dims_attn = attention_mask->Shape().GetDims(); - if (dims_attn.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size()); - } - if (!SpanEq(dims_attn, dims)) { + if (dims_attn.size() == 2) { + if (!SpanEq(dims_attn, dims)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'attention_mask' is expected to have same shape as input_ids"); + } + } else if (dims_attn.size() == 4) { + if (dims_attn[0] != dims[0] || dims_attn[1] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'attention_mask' is expected to shape [batch_size, 1, max_sequence_length, max_sequence_length]"); + } + } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'attention_mask' is expected to have same shape as input_ids"); + "Input 'attention_mask' is expected to have 2 or 4 dimensions, got ", dims_attn.size()); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index e50bdf40080d4..5a1d4ac725d47 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -138,7 +138,8 @@ Status CreateGptInputs( OrtValue attention_mask; if (attn_mask_value != nullptr) { const Tensor& attn_mask = attn_mask_value->Get(); - Tensor::InitOrtValue(element_type, input_ids_shape, const_cast(&attn_mask)->MutableData(), + const TensorShape& attn_mask_shape = attn_mask.Shape(); // 2d or 4d + Tensor::InitOrtValue(element_type, attn_mask_shape, const_cast(&attn_mask)->MutableData(), allocator->Info(), attention_mask); } else { auto mask_type = DataTypeImpl::GetType(); @@ -176,9 +177,16 @@ Status CreateGptInputs( // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) // TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance. - ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids); - ExpandInputs(position_ids, num_beams, allocator, expanded_position_ids); - ExpandInputs(attention_mask, num_beams, allocator, expanded_attention_mask); + if (num_beams == 1) { + expanded_input_ids = input_ids; + expanded_position_ids = position_ids; + expanded_attention_mask = attention_mask; + } else { + // bugbug: 4d not supported here + ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids); + ExpandInputs(position_ids, num_beams, allocator, expanded_position_ids); + ExpandInputs(attention_mask, num_beams, allocator, expanded_attention_mask); + } return Status::OK(); } @@ -637,19 +645,22 @@ Status UpdateGptFeeds( next_inputs[1] = position_ids; // Update attention mask const OrtValue& old_mask = next_inputs[2]; - const int32_t* old_mask_data = old_mask.Get().Data(); - int64_t mask_dims[] = {batch_beam_size, current_length}; - TensorShape mask_shape(&mask_dims[0], 2); - OrtValue attention_mask; - Tensor::InitOrtValue(int32_type, mask_shape, allocator, attention_mask); - int32_t* mask_data = attention_mask.GetMutable()->MutableData(); - for (int i = 0; i < batch_beam_size; i++) { - for (int j = 0; j < current_length - 1; j++) { - mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j]; + const auto& mask_dims = old_mask.Get().Shape().GetDims(); + if (mask_dims.size() == 2) { + const int32_t* old_mask_data = old_mask.Get().Data(); + int64_t mask_dims[] = {batch_beam_size, current_length}; + TensorShape mask_shape(&mask_dims[0], 2); + OrtValue attention_mask; + Tensor::InitOrtValue(int32_type, mask_shape, allocator, attention_mask); + int32_t* mask_data = attention_mask.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < current_length - 1; j++) { + mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j]; + } + mask_data[i * current_length + current_length - 1] = 1; } - mask_data[i * current_length + current_length - 1] = 1; - } - next_inputs[2] = attention_mask; + next_inputs[2] = attention_mask; + } // if mask_dims.size() == 4 do nothing // Update past state if (num_beams == 1) { diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e2478aa3b9b91..2cc2542775c16 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -862,19 +862,22 @@ Status UpdateGptFeeds( // Update attention mask const OrtValue& old_mask = next_inputs[2]; - const int32_t* old_mask_data = old_mask.Get().Data(); - int64_t mask_dims[] = {batch_beam_size, current_length}; - TensorShape mask_shape(&mask_dims[0], 2); - OrtValue attention_mask; - auto mask_type = DataTypeImpl::GetType(); - Tensor::InitOrtValue(mask_type, mask_shape, allocator, attention_mask); - int32_t* mask_data = attention_mask.GetMutable()->MutableData(); - - // Launch kernel to update position_ids and attention_mask for next iteration - cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length, - reinterpret_cast(stream)); - - next_inputs[2] = attention_mask; + const auto& mask_dims = old_mask.Get().Shape().GetDims(); + if (mask_dims.size() == 2) { + const int32_t* old_mask_data = old_mask.Get().Data(); + int64_t mask_dims[] = {batch_beam_size, current_length}; + TensorShape mask_shape(&mask_dims[0], 2); + OrtValue attention_mask; + auto mask_type = DataTypeImpl::GetType(); + Tensor::InitOrtValue(mask_type, mask_shape, allocator, attention_mask); + int32_t* mask_data = attention_mask.GetMutable()->MutableData(); + + // Launch kernel to update position_ids and attention_mask for next iteration + cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length, + reinterpret_cast(stream)); + + next_inputs[2] = attention_mask; + } // do nothing for mask_dims.size() == 4 // Update past state if (num_beams == 1) { From e9d2215a1f4426bf1c9a419c80b58e07e599c44a Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 8 Nov 2022 01:39:22 +0000 Subject: [PATCH 20/51] update --- .../cuda/transformers/generation_device_helper.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 2cc2542775c16..50c3fcf8b0732 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -585,7 +585,7 @@ Status GreedySearchProcessLogits( size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), d_index_buffer_in, d_offset_buffer, - parameters->vocab_size, + parameters->batch_size * parameters->vocab_size, parameters->batch_size, cuda_stream); @@ -599,6 +599,10 @@ Status GreedySearchProcessLogits( parameters->vocab_size, cuda_stream); +#ifdef DEBUG_GENERATION + dumper->Print("d_offset_buffer", d_offset_buffer, batch_size + 1, 1); +#endif + void* temp_storage = allocator->Alloc(temp_storage_bytes); BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); storage_buffer = std::move(temp_storage_buffer); @@ -608,7 +612,7 @@ Status GreedySearchProcessLogits( d_sorted_score_buffer, d_index_buffer_in, d_index_buffer_out, - parameters->vocab_size, + parameters->batch_size * parameters->vocab_size, parameters->batch_size, d_offset_buffer, cuda_stream); From e48867c34a779633c4f0c27d011b67869a44a4c2 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 8 Nov 2022 18:51:19 +0000 Subject: [PATCH 21/51] check presence mask --- .../cpu/transformers/beam_search_impl_base.h | 3 ++- .../cpu/transformers/generate_impl_base.h | 25 ++++++++++++++++++- .../cpu/transformers/generation_shared.h | 6 ++--- .../transformers/greedy_search_impl_base.h | 3 ++- .../core/graph/contrib_ops/contrib_defs.cc | 2 +- 5 files changed, 32 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 863982972923f..0fb22b1578176 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -191,7 +191,8 @@ Status BeamSearchBase::CheckInputs(const OpKernelContextInternal& context) { context.Input(0), // input_ids context.Input(7), // vocab_mask context.Input(8), // prefix_vocab_mask - context.Input(9))); // attention_mask + context.Input(9), // attention_mask + nullptr)); // presence_mask return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index cbdd18410b3c3..b44ba6c75f87a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -87,7 +87,8 @@ class GenerateBase { const Tensor* input_ids, const Tensor* vocab_mask, const Tensor* prefix_vocab_mask, - const Tensor* attention_mask) const { + const Tensor* attention_mask, + const Tensor* presence_mask) const { const auto& dims = input_ids->Shape().GetDims(); if (dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -155,6 +156,28 @@ class GenerateBase { } } + if (presence_mask != nullptr) { + const auto& dims_presence = presence_mask->Shape().GetDims(); + if (dims_presence.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'presence_mask' is expected to have 2 dimensions, got ", dims_presence.size()); + } + + // presence_mask first dimension should be same as the first dimension of input_ids + if (static_cast(dims_presence[0]) != static_cast(dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input_ids and presence_mask must have the same batch_size"); + } + + if (static_cast(dims_presence[1]) != parameters->vocab_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'presence_mask' shape[1] shall be vocab_size, got ", dims_presence[1]); + } + + // store prefix vocab mask in parameters. + parameters->presence_mask = presence_mask->DataAsSpan(); + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 72381be677b3a..411bad3a5856e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -8,9 +8,9 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -#ifndef NDEBUG -//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -#endif +//#ifndef NDEBUG +#define DEBUG_GENERATION 1 // uncomment it for debugging beam search +//#endif namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 02b38bbb675b2..cbee8913a29d6 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -131,7 +131,8 @@ Status GreedySearchBase::CheckInputs(const OpKernelContextIntern context.Input(0), // input_ids context.Input(4), // vocab_mask context.Input(5), // prefix_vocab_mask - nullptr)); // attention_mask + context.Input(6), // attention_mask + context.Input(7))); // presence_mask return Status::OK(); } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index cb6ff56a32ba1..6300e749b2acb 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1127,7 +1127,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) - .Input(7, "presence_penalty_mask", "presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) + .Input(7, "presence_mask", "presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") From daebfaabb0b4ca392abb18708e4cacd6feeb7a57 Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 9 Nov 2022 19:27:16 +0000 Subject: [PATCH 22/51] fix crash --- .../transformers/generation_device_helper.cc | 9 +- .../transformers/generation_device_helper.h | 2 + .../cpu/transformers/generation_shared.h | 14 +++ .../transformers/greedy_search_impl_base.h | 39 +++++++- .../cpu/transformers/greedy_search_impl_gpt.h | 10 +- .../cuda/transformers/generation_cuda_impl.cu | 6 ++ .../transformers/generation_device_helper.cc | 95 +++++++++---------- .../transformers/generation_device_helper.h | 1 + 8 files changed, 119 insertions(+), 57 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 5a1d4ac725d47..5e13b3acd4d8f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -178,9 +178,9 @@ Status CreateGptInputs( // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) // TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance. if (num_beams == 1) { - expanded_input_ids = input_ids; - expanded_position_ids = position_ids; - expanded_attention_mask = attention_mask; + expanded_input_ids = std::move(input_ids); + expanded_position_ids = std::move(position_ids); + expanded_attention_mask = std::move(attention_mask); } else { // bugbug: 4d not supported here ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids); @@ -408,6 +408,7 @@ template Status GreedySearchProcessLogits( const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingCudaState* sampling_state, // sampling_state transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) @@ -417,6 +418,7 @@ Status GreedySearchProcessLogits( int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper + ORT_UNUSED_PARAMETER(sampling_state); #ifndef DEBUG_GENERATION ORT_UNUSED_PARAMETER(dumper); #endif @@ -890,6 +892,7 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, + transformers::ISamplingCudaState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index e2d46ce244eea..9b1d850623256 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -91,6 +91,7 @@ template using GreedySearchProcessLogitsFunc = std::function* greedy_state, // state + transformers::ISamplingCudaState* sampling_state, // sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) @@ -208,6 +209,7 @@ Status ProcessLogits(const OrtValue& logits, // template Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingCudaState* sampling_state, // sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 411bad3a5856e..c5a3a8cf649b2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -66,6 +66,20 @@ struct IGreedySearchState { gsl::span next_tokens; // shape (batch_size) }; +template +struct ISamplingCudaState { + gsl::span d_index_in; + gsl::span d_index_out; + gsl::span d_offset; + gsl::span d_sorted_score; + gsl::span d_sorted_softmaxed_score; + gsl::span d_softmaxed_score; + gsl::span d_sampled; + gsl::span d_indices; + + BufferUniquePtr storage_buffer; +}; + class ISequences { public: virtual ~ISequences() {} diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index cbee8913a29d6..493729e437d5c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -11,6 +11,37 @@ namespace contrib { namespace transformers { +template +struct SamplingState : public ISamplingCudaState { + void Init(AllocatorPtr allocator, + int batch_size, + int vocab_size, + bool is_cuda) { + if (!is_cuda) { + return; + } + int total_count = batch_size * vocab_size; + this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count)); + this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count)); + this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1)); + this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); + this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); + this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); + this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); + this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); + } + + private: + BufferUniquePtr d_index_in_buffer_; + BufferUniquePtr d_index_out_buffer_; + BufferUniquePtr d_offset_buffer_; + BufferUniquePtr d_sorted_score_buffer_; + BufferUniquePtr d_sorted_softmaxed_score_buffer_; + BufferUniquePtr d_softmaxed_score_buffer_; + BufferUniquePtr d_sampled_buffer_; + BufferUniquePtr d_indices_buffer_; +}; + template struct GreedySearchState : public IGreedySearchState { Sequences sequences; @@ -107,12 +138,14 @@ class GreedySearchBase : public GenerateBase { Status GenerateNextToken(const OrtValue& logits, gsl::span& next_tokens, GreedySearchState& greedy_state, + ISamplingCudaState& sampling_state, int counter, int eos_token_id); // Calculate scores from logits, then apply filtering and select next token for each beam. Status ProcessLogits(const OrtValue& logits, // logits output of subgraph GreedySearchState& greedy_state, + ISamplingCudaState& sampling_state, AllocatorPtr& allocator, int counter); @@ -162,10 +195,11 @@ template Status GreedySearchBase::ProcessLogits( const OrtValue& logits, GreedySearchState& greedy_state, + ISamplingCudaState& sampling_state, AllocatorPtr& allocator, int counter) { bool use_sampling = std::is_same::value; - return process_logits_func_(logits, &greedy_state, &(greedy_state.sequences), allocator, + return process_logits_func_(logits, &greedy_state, &sampling_state, &(greedy_state.sequences), allocator, this->thread_pool_, &this->logits_processors_, parameters_, use_sampling, counter, this->cuda_stream_, this->GetConsoleDumper()); } @@ -175,10 +209,11 @@ Status GreedySearchBase::GenerateNextToken( const OrtValue& logits, gsl::span& next_tokens, GreedySearchState& greedy_state, + ISamplingCudaState& sampling_state, int counter, int eos_token_id) { // Process logits to get next token scores - ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, this->temp_space_allocator_, counter)); + ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, sampling_state, this->temp_space_allocator_, counter)); next_tokens = greedy_state.next_tokens; for (size_t i = 0; i < next_tokens.size(); i++) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 1efc8e3d6ed7d..f04f44ee4390d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -142,10 +142,17 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds parameters->max_length, this->IsCuda()); + SamplingState sampling_state; + if (std::is_same::value) { + sampling_state.Init(this->temp_space_allocator_, + static_cast(parameters->BatchBeamSize()), + static_cast(parameters->vocab_size), + this->IsCuda()); + } + IAllocatorUniquePtr buffer; OrtValue expanded_input_ids_in_cpu; ORT_RETURN_IF_ERROR(CreateInitialFeeds(greedy_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer)); - init_greedy_state_func_(&greedy_state, greedy_state.sequence_lengths, this->cuda_stream_); @@ -198,6 +205,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits, next_tokens, greedy_state, + sampling_state, iteration_counter, parameters->eos_token_id)); // When all batches are finished, stop earlier to avoid wasting computation. diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index eb6296b1592eb..769d37b69fcb5 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -430,6 +430,10 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in, int vocab_size) { int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= batch_size * vocab_size) { + return; + } + int vocab_idx = index % vocab_size; int batch_id = index / vocab_size; int start_index = batch_id * vocab_size; @@ -449,6 +453,8 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in, int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index]; d_logits_in_out[original_index] = (T)filter_value; } + } else { + int original_index = batch_id * vocab_size + d_sorted_indices[index]; } } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 50c3fcf8b0732..5bff03219aef6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -115,7 +115,6 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, cudaStream_t stream = static_cast(provider->GetComputeStream()); auto pinned_buffer = IAllocator::MakeUniquePtr(pinned_allocator, total_bytes); char* pinned_data = static_cast(pinned_buffer.get()); - // Copy tensors to one pinned memory buffer (so that we only need copy to GPU once) char* destination = pinned_data; for (auto& input : inputs) { @@ -132,16 +131,13 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, "AddToFeeds: An implementation for the input type ", dataType, " is not supported yet"); } - // Do not need alignment because GPT has int32 inputs (past is empty) and T5 encoder has int64 inputs. destination += bytes; } } - if (!buffer) { buffer = provider->GetScratchBuffer(total_bytes); } - char* gpu_data = buffer.get(); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream)); @@ -151,7 +147,6 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, CUDA_RETURN_IF_ERROR(cudaEventCreate(&isCopyDone)); CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, stream)); CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone)); - // TODO(tianleiwu): allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call. const OrtMemoryInfo& location = provider->GetAllocator(0, OrtMemTypeDefault)->Info(); for (auto& input : inputs) { @@ -461,6 +456,7 @@ template Status GreedySearchProcessLogits( const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingCudaState* sampling_state, // buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) @@ -560,31 +556,15 @@ Status GreedySearchProcessLogits( // TODO(wy): support output_scores in greedy search ORT_UNUSED_PARAMETER(output_scores); - BufferUniquePtr workspace_buffer; - BufferUniquePtr storage_buffer; + BufferUniquePtr& storage_buffer = sampling_state->storage_buffer; if (do_sampling) { - // bugbug: move this outside probably in execute()? - size_t bytes = SafeInt(sizeof(int)) * 2 * parameters->batch_size * parameters->vocab_size + - SafeInt(sizeof(int)) * (parameters->batch_size + 1) + - SafeInt(sizeof(CudaT)) * parameters->batch_size * parameters->vocab_size + - SafeInt(2 * sizeof(float) * parameters->batch_size * parameters->vocab_size) + - SafeInt(sizeof(float)) * parameters->batch_size + - SafeInt(sizeof(int64_t)) * parameters->batch_size; - void* data = allocator->Alloc(bytes); - BufferUniquePtr workspace_buffer_temp(data, BufferDeleter(allocator)); - workspace_buffer = std::move(workspace_buffer_temp); - int* d_index_buffer_in = reinterpret_cast(workspace_buffer.get()); - int* d_index_buffer_out = d_index_buffer_in + parameters->batch_size * parameters->vocab_size; - int* d_offset_buffer = d_index_buffer_out + parameters->batch_size * parameters->vocab_size; - CudaT* d_sorted_score_buffer = reinterpret_cast(d_offset_buffer + parameters->batch_size + 1); - float* d_sorted_softmaxed_score_buffer = reinterpret_cast(d_sorted_score_buffer + parameters->batch_size * parameters->vocab_size); - float* d_softmaxed_score_buffer = d_sorted_softmaxed_score_buffer + parameters->batch_size * parameters->vocab_size; - float* d_sampled = d_softmaxed_score_buffer + parameters->batch_size * parameters->vocab_size; - int64_t* d_indices = reinterpret_cast(d_sampled + parameters->batch_size); + + gsl::span& d_index_in = sampling_state->d_index_in; + gsl::span& d_offset = sampling_state->d_offset; size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), - d_index_buffer_in, - d_offset_buffer, + d_index_in.data(), + d_offset.data(), parameters->batch_size * parameters->vocab_size, parameters->batch_size, cuda_stream); @@ -593,49 +573,52 @@ Status GreedySearchProcessLogits( dumper->Print("temp_storage_bytes", temp_storage_bytes, true); #endif - cuda::LaunchSetupParamsKernel(d_index_buffer_in, - d_offset_buffer, + cuda::LaunchSetupParamsKernel(d_index_in.data(), + d_offset.data(), parameters->batch_size, parameters->vocab_size, cuda_stream); #ifdef DEBUG_GENERATION - dumper->Print("d_offset_buffer", d_offset_buffer, batch_size + 1, 1); + dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); #endif void* temp_storage = allocator->Alloc(temp_storage_bytes); BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); storage_buffer = std::move(temp_storage_buffer); + gsl::span d_sorted_score = sampling_state->d_sorted_score; + gsl::span d_index_out = sampling_state->d_index_out; cuda::LaunchSortPairsDescending(storage_buffer.get(), temp_storage_bytes, reinterpret_cast(next_token_scores.data()), - d_sorted_score_buffer, - d_index_buffer_in, - d_index_buffer_out, + reinterpret_cast(d_sorted_score.data()), + d_index_in.data(), + d_index_out.data(), parameters->batch_size * parameters->vocab_size, parameters->batch_size, - d_offset_buffer, + d_offset.data(), cuda_stream); #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score_buffer), batch_size, vocab_size); - dumper->Print("d_index_buffer_in", d_index_buffer_in, batch_size, vocab_size); - dumper->Print("d_index_buffer_out", d_index_buffer_out, batch_size, vocab_size); + dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score.data()), batch_size, vocab_size); + dumper->Print("d_index_buffer_in", d_index_in.data(), batch_size, vocab_size); + dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size); #endif + gsl::span d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; dispatch_blockwise_softmax_forward(cuda_stream, - d_sorted_softmaxed_score_buffer, - d_sorted_score_buffer, + d_sorted_softmaxed_score.data(), + reinterpret_cast(d_sorted_score.data()), parameters->vocab_size, parameters->vocab_size, parameters->batch_size); #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score_buffer, batch_size, vocab_size); + dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), batch_size, vocab_size); #endif - cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score_buffer, - d_index_buffer_out, + cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), + d_index_out.data(), reinterpret_cast(next_token_scores.data()), parameters->top_p, parameters->filter_value, @@ -649,15 +632,16 @@ Status GreedySearchProcessLogits( // bugbug: actually we can only do softmax at the very beginning and sort the softmaxed scores. // Not sure if the order change will affect the result. + gsl::span d_softmaxed_score = sampling_state->d_softmaxed_score; dispatch_blockwise_softmax_forward(cuda_stream, - d_softmaxed_score_buffer, + d_softmaxed_score.data(), reinterpret_cast(next_token_scores.data()), parameters->vocab_size, parameters->vocab_size, parameters->batch_size); #ifdef DEBUG_GENERATION - dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score_buffer, batch_size, vocab_size); + dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), batch_size, vocab_size); #endif // multinomial sampling @@ -672,28 +656,35 @@ Status GreedySearchProcessLogits( #endif } - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled, + gsl::span d_sampled = sampling_state->d_sampled; + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(), sampled.data(), sizeof(float) * parameters->batch_size, cudaMemcpyHostToDevice, cuda_stream)); - cuda::TorchMultinomialKernelLauncher(d_softmaxed_score_buffer, - d_sampled, - d_indices, + +#ifdef DEBUG_GENERATION + dumper->Print("d_sampled", d_sampled.data(), batch_size, 1); +#endif + + gsl::span d_indices = sampling_state->d_indices; + cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(), + d_sampled.data(), + d_indices.data(), parameters->batch_size, parameters->vocab_size, cuda_stream); #ifdef DEBUG_GENERATION - dumper->Print("d_sampled", d_sampled, batch_size, 1); - dumper->Print("d_indices", d_indices, batch_size, 1); + dumper->Print("d_indices", d_indices.data(), batch_size, 1); #endif CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), - d_indices, + d_indices.data(), greedy_state->next_tokens_cpu.size_bytes(), cudaMemcpyDeviceToHost, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); } else { // next_tokens = torch.argmax(scores, dim=-1) @@ -1042,6 +1033,7 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, + transformers::ISamplingCudaState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, @@ -1108,6 +1100,7 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, + transformers::ISamplingCudaState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index 2fe57e8d35c19..73385450af428 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -62,6 +62,7 @@ Status ProcessLogits(const OrtValue& logits, // template Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingCudaState* sampling_state,// sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) From 0fd09a8456f0752e15799c780c61c38347954af9 Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 9 Nov 2022 21:59:03 +0000 Subject: [PATCH 23/51] reuse buffer --- .../cuda/transformers/generation_cuda_impl.cu | 2 -- .../transformers/generation_device_helper.cc | 36 +++++++++++-------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 769d37b69fcb5..0dee9bc4df16b 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -453,8 +453,6 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in, int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index]; d_logits_in_out[original_index] = (T)filter_value; } - } else { - int original_index = batch_id * vocab_size + d_sorted_indices[index]; } } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 5bff03219aef6..e555711736b6a 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -556,36 +556,42 @@ Status GreedySearchProcessLogits( // TODO(wy): support output_scores in greedy search ORT_UNUSED_PARAMETER(output_scores); - BufferUniquePtr& storage_buffer = sampling_state->storage_buffer; + if (do_sampling) { gsl::span& d_index_in = sampling_state->d_index_in; gsl::span& d_offset = sampling_state->d_offset; - size_t temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), - d_index_in.data(), - d_offset.data(), - parameters->batch_size * parameters->vocab_size, - parameters->batch_size, - cuda_stream); + BufferUniquePtr& storage_buffer = sampling_state->storage_buffer; + std::cout << "step" << step << std::endl; + size_t temp_storage_bytes = 0; + if (step == 1) { + temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), + d_index_in.data(), + d_offset.data(), + parameters->batch_size * parameters->vocab_size, + parameters->batch_size, + cuda_stream); #ifdef DEBUG_GENERATION dumper->Print("temp_storage_bytes", temp_storage_bytes, true); #endif - cuda::LaunchSetupParamsKernel(d_index_in.data(), - d_offset.data(), - parameters->batch_size, - parameters->vocab_size, - cuda_stream); + cuda::LaunchSetupParamsKernel(d_index_in.data(), + d_offset.data(), + parameters->batch_size, + parameters->vocab_size, + cuda_stream); #ifdef DEBUG_GENERATION dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); #endif - void* temp_storage = allocator->Alloc(temp_storage_bytes); - BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); - storage_buffer = std::move(temp_storage_buffer); + void* temp_storage = allocator->Alloc(temp_storage_bytes); + BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); + storage_buffer = std::move(temp_storage_buffer); + } + gsl::span d_sorted_score = sampling_state->d_sorted_score; gsl::span d_index_out = sampling_state->d_index_out; cuda::LaunchSortPairsDescending(storage_buffer.get(), From 61ee848e266bcf518185baf0eafacb4b9aa40cc2 Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 9 Nov 2022 23:46:11 +0000 Subject: [PATCH 24/51] update presence_mask --- .../transformers/generation_device_helper.cc | 2 ++ .../cpu/transformers/generation_shared.h | 1 + .../transformers/greedy_search_impl_base.h | 3 +++ .../cuda/transformers/generation_cuda_impl.cu | 24 +++++++++++++++---- .../cuda/transformers/generation_cuda_impl.h | 3 ++- .../transformers/generation_device_helper.cc | 14 ++++++++--- 6 files changed, 38 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 5e13b3acd4d8f..0788a84443d27 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -506,6 +506,8 @@ Status GreedySearchProcessLogits( 1, generator, *sampled_idx)); + // TODO: update presense_mask + #ifdef DEBUG_GENERATION dumper->Print("sampled_idx", *sampled_idx); #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index c5a3a8cf649b2..4578e8c652a87 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -76,6 +76,7 @@ struct ISamplingCudaState { gsl::span d_softmaxed_score; gsl::span d_sampled; gsl::span d_indices; + gsl::span d_presence_mask; BufferUniquePtr storage_buffer; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 493729e437d5c..4c404617040f9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -29,6 +29,8 @@ struct SamplingState : public ISamplingCudaState { this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); + // TODO: do not allocate this buffer there's no presence_mask + this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); } private: @@ -40,6 +42,7 @@ struct SamplingState : public ISamplingCudaState { BufferUniquePtr d_softmaxed_score_buffer_; BufferUniquePtr d_sampled_buffer_; BufferUniquePtr d_indices_buffer_; + BufferUniquePtr d_presence_mask_buffer_; }; template diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 0dee9bc4df16b..1cae0106a176c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -162,7 +162,7 @@ void LaunchLogitsProcessKernel( T* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, - const int* presence_mask, + int* presence_mask, float presence_penalty, float temperature, int batch_size, @@ -201,7 +201,7 @@ template void LaunchLogitsProcessKernel( float* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, - const int* presence_mask, + int* presence_mask, float presence_penalty, float temperature, int batch_size, @@ -219,7 +219,7 @@ template void LaunchLogitsProcessKernel( half* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, - const int* presence_mask, + int* presence_mask, float presence_penalty, float temperature, int batch_size, @@ -505,7 +505,8 @@ __global__ void sampleMultinomialOnce( scalar_t* sampled, scalar_t* dist, int stride_dist, // dist->stride(0) - int stride_categories // dist->stride(1) + int stride_categories, // dist->stride(1) + int* d_presence_mask ) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -616,6 +617,17 @@ __global__ void sampleMultinomialOnce( } } } + + // update presence mask + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= distributions * categories) { + return; + } + int dist_idx = index / categories; + int cat_idx = index % categories; + if (dest[dist_idx] == cat_idx) { + d_presence_mask[index] = 1; + } } // Only support n_sample = 1 @@ -624,6 +636,7 @@ void TorchMultinomialKernelLauncher(float* d_input, int64_t* d_output, int batch_size, int vocab_size, + int* d_presence_mask, cudaStream_t stream) { // Store the props in class variables @@ -652,7 +665,8 @@ void TorchMultinomialKernelLauncher(float* d_input, d_sampled, d_input, vocab_size, - batch_size); + batch_size, + d_presence_mask); } else { printf("Please add more cases for block size"); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 02b803b8855b9..193f25b7908c7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -30,7 +30,7 @@ void LaunchLogitsProcessKernel( T* next_token_scores, const int* vocab_mask, const int* prefix_vocab_mask, - const int* presence_mask, + int* presence_mask, float presence_penalty, float temperature, int batch_size, @@ -100,6 +100,7 @@ void TorchMultinomialKernelLauncher(float* d_input, int64_t* d_output, int batch_size, int vocab_size, + int* d_presence_mask, cudaStream_t stream); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e555711736b6a..f02b35fcd3f36 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -318,7 +318,7 @@ Status ProcessLogits(const OrtValue& logits, // next_token_scores.data(), parameters->vocab_mask.data(), step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. - parameters->presence_mask.data(), + nullptr, // parameters->presence_mask.data(), parameters->presence_penalty, parameters->temperature, parameters->batch_size, @@ -531,11 +531,18 @@ Status GreedySearchProcessLogits( cudaMemcpyHostToDevice, cuda_stream)); } + // Copy parameters->presence_mask to sampling_state->presence_mask + gsl::span& presence_mask = sampling_state->d_presence_mask; + if (step == 1 && parameters->presence_mask.data() != nullptr) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(presence_mask.data(), parameters->presence_mask.data(), + sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream)); + } + cuda::LaunchLogitsProcessKernel( reinterpret_cast(next_token_scores.data()), parameters->vocab_mask.data(), step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. - parameters->presence_mask.data(), + parameters->presence_mask.data() ? presence_mask.data() : nullptr, parameters->presence_penalty, parameters->temperature, parameters->batch_size, @@ -558,7 +565,6 @@ Status GreedySearchProcessLogits( if (do_sampling) { - gsl::span& d_index_in = sampling_state->d_index_in; gsl::span& d_offset = sampling_state->d_offset; @@ -655,6 +661,7 @@ Status GreedySearchProcessLogits( std::uniform_real_distribution distribution(0.0, 1.0); std::vector sampled(parameters->batch_size); distribution(generator); // the first one is subnormal numbers + // need to optimize the random generation. current version is for debug. for (int i = 0; i < parameters->batch_size; ++i) { sampled[i] = distribution(generator); #ifdef DEBUG_GENERATION @@ -679,6 +686,7 @@ Status GreedySearchProcessLogits( d_indices.data(), parameters->batch_size, parameters->vocab_size, + presence_mask.data(), cuda_stream); #ifdef DEBUG_GENERATION From f7cc0a9ae853c15aeee973f2a436554d49401af3 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 15 Nov 2022 08:19:01 +0000 Subject: [PATCH 25/51] generate random numbers at once --- .../cpu/transformers/generation_shared.h | 4 +- .../transformers/greedy_search_impl_base.h | 17 +++++++- .../cpu/transformers/greedy_search_impl_gpt.h | 3 ++ .../cuda/transformers/generation_cuda_impl.cu | 4 +- .../transformers/generation_device_helper.cc | 42 +++++++------------ 5 files changed, 39 insertions(+), 31 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 4578e8c652a87..89fffb4727f1a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -75,10 +75,12 @@ struct ISamplingCudaState { gsl::span d_sorted_softmaxed_score; gsl::span d_softmaxed_score; gsl::span d_sampled; + gsl::span h_sampled_all; gsl::span d_indices; gsl::span d_presence_mask; BufferUniquePtr storage_buffer; + size_t temp_storage_bytes; }; class ISequences { @@ -127,7 +129,7 @@ struct IGenerationParameters { float temperature = 1.0f; float top_p = 0.0f; float filter_value; - int seed = 0; + int seed = 1234; // Parameters from inputs int min_length; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 4c404617040f9..3808e5a5c2bb7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include #include #include "contrib_ops/cpu/transformers/generation_shared.h" #include "contrib_ops/cpu/transformers/generate_impl_base.h" @@ -14,8 +15,11 @@ namespace transformers { template struct SamplingState : public ISamplingCudaState { void Init(AllocatorPtr allocator, + AllocatorPtr cpu_allocator, int batch_size, int vocab_size, + int max_iter, + int seed, bool is_cuda) { if (!is_cuda) { return; @@ -28,9 +32,19 @@ struct SamplingState : public ISamplingCudaState { this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); + this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); - // TODO: do not allocate this buffer there's no presence_mask + // TODO: do not allocate this buffer if there's no presence_mask this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); + + this->temp_storage_bytes = 0; + + std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(seed)}; + std::uniform_real_distribution distribution(0.0, 1.0); + distribution(generator); + for (size_t i = 0; i < this->h_sampled_all.size(); ++i) { + this->h_sampled_all[i] = distribution(generator); + } } private: @@ -41,6 +55,7 @@ struct SamplingState : public ISamplingCudaState { BufferUniquePtr d_sorted_softmaxed_score_buffer_; BufferUniquePtr d_softmaxed_score_buffer_; BufferUniquePtr d_sampled_buffer_; + BufferUniquePtr h_sampled_all_buffer_; BufferUniquePtr d_indices_buffer_; BufferUniquePtr d_presence_mask_buffer_; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index f04f44ee4390d..a7d86a5dee9d9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -145,8 +145,11 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds SamplingState sampling_state; if (std::is_same::value) { sampling_state.Init(this->temp_space_allocator_, + this->cpu_allocator_, static_cast(parameters->BatchBeamSize()), static_cast(parameters->vocab_size), + static_cast(parameters->max_length - parameters->sequence_length), + parameters->seed, this->IsCuda()); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 1cae0106a176c..134b6118b1590 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -127,7 +127,7 @@ __global__ void LogitsProcessKernel( // VocabMaskLogitsProcessor if (vocab_mask != nullptr && vocab_mask[word_id] == 0) { - next_token_scores[index] = cub::FpLimits::Lowest(); + next_token_scores[index] = (T)cub::FpLimits::Lowest(); return; } @@ -140,7 +140,7 @@ __global__ void LogitsProcessKernel( // MinLengthLogitsProcessor if (word_id == demote_token_id) { - next_token_scores[index] = cub::FpLimits::Lowest(); + next_token_scores[index] = (T)cub::FpLimits::Lowest(); } // PresencePenaltyLogitsProcessor diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index f02b35fcd3f36..d3f10c8febc53 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -3,7 +3,6 @@ #include #include -#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/math/topk_impl.h" @@ -569,8 +568,7 @@ Status GreedySearchProcessLogits( gsl::span& d_offset = sampling_state->d_offset; BufferUniquePtr& storage_buffer = sampling_state->storage_buffer; - std::cout << "step" << step << std::endl; - size_t temp_storage_bytes = 0; + size_t& temp_storage_bytes = sampling_state->temp_storage_bytes; if (step == 1) { temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), d_index_in.data(), @@ -579,10 +577,6 @@ Status GreedySearchProcessLogits( parameters->batch_size, cuda_stream); -#ifdef DEBUG_GENERATION - dumper->Print("temp_storage_bytes", temp_storage_bytes, true); -#endif - cuda::LaunchSetupParamsKernel(d_index_in.data(), d_offset.data(), parameters->batch_size, @@ -593,13 +587,18 @@ Status GreedySearchProcessLogits( dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); #endif - void* temp_storage = allocator->Alloc(temp_storage_bytes); + void* temp_storage = allocator->Alloc(sampling_state->temp_storage_bytes); BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); storage_buffer = std::move(temp_storage_buffer); } - gsl::span d_sorted_score = sampling_state->d_sorted_score; - gsl::span d_index_out = sampling_state->d_index_out; + gsl::span& d_sorted_score = sampling_state->d_sorted_score; + gsl::span& d_index_out = sampling_state->d_index_out; + +#ifdef DEBUG_GENERATION + dumper->Print("temp_storage_bytes", temp_storage_bytes, true); +#endif + cuda::LaunchSortPairsDescending(storage_buffer.get(), temp_storage_bytes, reinterpret_cast(next_token_scores.data()), @@ -617,7 +616,7 @@ Status GreedySearchProcessLogits( dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size); #endif - gsl::span d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; + gsl::span& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; dispatch_blockwise_softmax_forward(cuda_stream, d_sorted_softmaxed_score.data(), reinterpret_cast(d_sorted_score.data()), @@ -644,7 +643,7 @@ Status GreedySearchProcessLogits( // bugbug: actually we can only do softmax at the very beginning and sort the softmaxed scores. // Not sure if the order change will affect the result. - gsl::span d_softmaxed_score = sampling_state->d_softmaxed_score; + gsl::span& d_softmaxed_score = sampling_state->d_softmaxed_score; dispatch_blockwise_softmax_forward(cuda_stream, d_softmaxed_score.data(), reinterpret_cast(next_token_scores.data()), @@ -657,21 +656,10 @@ Status GreedySearchProcessLogits( #endif // multinomial sampling - std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(parameters->seed + step)}; - std::uniform_real_distribution distribution(0.0, 1.0); - std::vector sampled(parameters->batch_size); - distribution(generator); // the first one is subnormal numbers - // need to optimize the random generation. current version is for debug. - for (int i = 0; i < parameters->batch_size; ++i) { - sampled[i] = distribution(generator); -#ifdef DEBUG_GENERATION - std::cout << "sampled value on cpu: " << sampled[i] << std::endl; -#endif - } - - gsl::span d_sampled = sampling_state->d_sampled; + gsl::span& d_sampled = sampling_state->d_sampled; + gsl::span& h_sampled_all = sampling_state->h_sampled_all; CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(), - sampled.data(), + h_sampled_all.data() + (step - 1) * parameters->batch_size, sizeof(float) * parameters->batch_size, cudaMemcpyHostToDevice, cuda_stream)); @@ -680,7 +668,7 @@ Status GreedySearchProcessLogits( dumper->Print("d_sampled", d_sampled.data(), batch_size, 1); #endif - gsl::span d_indices = sampling_state->d_indices; + gsl::span& d_indices = sampling_state->d_indices; cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(), d_sampled.data(), d_indices.data(), From 09825ce5c3f44dcbf6990b6197c7e0b669e1929d Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 16 Nov 2022 01:35:18 +0000 Subject: [PATCH 26/51] fix bugs --- .../contrib_ops/cpu/transformers/generation_shared.h | 8 ++++---- .../cuda/transformers/generation_device_helper.cc | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 89fffb4727f1a..5a8b0f46a72bc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -8,9 +8,9 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -//#ifndef NDEBUG -#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -//#endif +#ifndef NDEBUG +//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search +#endif namespace onnxruntime { @@ -129,7 +129,7 @@ struct IGenerationParameters { float temperature = 1.0f; float top_p = 0.0f; float filter_value; - int seed = 1234; + int seed = 0; // Parameters from inputs int min_length; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index d3f10c8febc53..60f08aaddf87c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -547,7 +547,7 @@ Status GreedySearchProcessLogits( parameters->batch_size, parameters->num_beams, parameters->vocab_size, - (parameters->min_length > 0 && current_sequence_length < parameters->min_length) ? parameters->eos_token_id : -1, + (parameters->min_length > 0 && current_sequence_length < parameters->sequence_length + parameters->min_length) ? parameters->eos_token_id : -1, reinterpret_cast(sequences_buffer.get()), parameters->max_length, current_sequence_length, @@ -562,7 +562,6 @@ Status GreedySearchProcessLogits( // TODO(wy): support output_scores in greedy search ORT_UNUSED_PARAMETER(output_scores); - if (do_sampling) { gsl::span& d_index_in = sampling_state->d_index_in; gsl::span& d_offset = sampling_state->d_offset; From cd12a2e53c516c5e09eb2081766e1f73966f3e87 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 29 Nov 2022 21:33:30 +0000 Subject: [PATCH 27/51] for debug purpose --- .../cpu/transformers/generation_shared.h | 1 + .../transformers/greedy_search_impl_base.h | 2 ++ .../cpu/transformers/greedy_search_impl_gpt.h | 20 +++++++++++++++++++ .../transformers/generation_device_helper.cc | 6 ++++++ .../contrib_ops/cuda/transformers/sampling.cc | 1 + .../core/graph/contrib_ops/contrib_defs.cc | 8 ++++++++ 6 files changed, 38 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 5a8b0f46a72bc..48aee8e30936d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -74,6 +74,7 @@ struct ISamplingCudaState { gsl::span d_sorted_score; gsl::span d_sorted_softmaxed_score; gsl::span d_softmaxed_score; + gsl::span h_softmaxed_score; gsl::span d_sampled; gsl::span h_sampled_all; gsl::span d_indices; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 3808e5a5c2bb7..083dac9d5334c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -31,6 +31,7 @@ struct SamplingState : public ISamplingCudaState { this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); + this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count)); this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); @@ -54,6 +55,7 @@ struct SamplingState : public ISamplingCudaState { BufferUniquePtr d_sorted_score_buffer_; BufferUniquePtr d_sorted_softmaxed_score_buffer_; BufferUniquePtr d_softmaxed_score_buffer_; + BufferUniquePtr h_softmaxed_score_buffer_; BufferUniquePtr d_sampled_buffer_; BufferUniquePtr h_sampled_all_buffer_; BufferUniquePtr d_indices_buffer_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index a7d86a5dee9d9..9eb3c4dd728bd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -130,6 +130,12 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0])); Tensor* output_sequences = this->context_.Output(0, sequences_shape); + std::cout << "vocab_size: " << parameters->vocab_size << std::endl; + int64_t debug_logits_dims[] = {parameters->batch_size, parameters->vocab_size}; + TensorShape debug_logits_shape(&debug_logits_dims[0], sizeof(debug_logits_dims) / sizeof(debug_logits_dims[0])); + std::cout << "debug_logits_shape: " << debug_logits_shape << std::endl; + Tensor* debug_logits = this->context_.Output(1, debug_logits_shape); + std::vector feeds; std::vector fetches; @@ -246,6 +252,20 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds gsl::copy(sequence_source, batch_output); } + //copy the debug logits to output + if (this->IsCuda() && debug_logits != nullptr) { + gsl::span logits_to_debug = debug_logits->MutableDataAsSpan(); + for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) { + std::cout << "batch_id: " << batch_id << std::endl; + auto batch_output = logits_to_debug.subspan( + static_cast(batch_id) * parameters->vocab_size, + parameters->vocab_size); + gsl::span batch_logits_to_debug = gsl::make_span(sampling_state.h_softmaxed_score.data() + batch_id * parameters->vocab_size, parameters->vocab_size); + + gsl::copy(batch_logits_to_debug, batch_output); + } + } + return status; } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 60f08aaddf87c..3eb2733957227 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -686,6 +686,12 @@ Status GreedySearchProcessLogits( cudaMemcpyDeviceToHost, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state->h_softmaxed_score.data(), + d_softmaxed_score.data(), + sampling_state->h_softmaxed_score.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); } else { // next_tokens = torch.argmax(scores, dim=-1) diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling.cc b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc index 9707a3d03daf2..736f0e9457553 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc @@ -23,6 +23,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 3) // 'repetition_penalty' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 6) // 'custom_attention_mask' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'logits_to_debug' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Sampling); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 6300e749b2acb..df48d6509ce1c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -510,6 +510,13 @@ void GreedySearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { sequences_shape.add_dim()->set_dim_value(batch_size); sequences_shape.add_dim()->set_dim_value(max_length_value); updateOutputShape(ctx, 0, sequences_shape); + + if (ctx.getNumOutputs() > 1) { + ONNX_NAMESPACE::TensorShapeProto logits_to_debug_shape; + logits_to_debug_shape.add_dim()->set_dim_value(batch_size); + logits_to_debug_shape.add_dim(); + updateOutputShape(ctx, 1, logits_to_debug_shape); + } } constexpr const char* Gelu_ver1_doc = @@ -1129,6 +1136,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") + .Output(1, "logits_before_multinomial", "logits_before_multinomial(debug purpose). Shape is (batch_size, vocab_size)", "T") .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { From 05e32d0e5b2dda0e0533cbe2e96e179f7b62cbd3 Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 30 Nov 2022 01:02:58 +0000 Subject: [PATCH 28/51] minor change --- .../contrib_ops/cuda/transformers/generation_cuda_impl.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 134b6118b1590..d0b73f11ae71f 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -440,7 +440,7 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in, int count = vocab_idx; float sum = 0.0f; - while (count != 0) { + while (count >= 0) { sum += d_sorted_logits_in[start_index]; ++start_index; --count; From dda532c99444fdd4178ea3450fd402dbb397cb21 Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 30 Nov 2022 01:59:12 +0000 Subject: [PATCH 29/51] remove some printings --- .../contrib_ops/cpu/transformers/greedy_search_impl_gpt.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 9eb3c4dd728bd..5c4f36f60ebca 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -130,10 +130,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0])); Tensor* output_sequences = this->context_.Output(0, sequences_shape); - std::cout << "vocab_size: " << parameters->vocab_size << std::endl; int64_t debug_logits_dims[] = {parameters->batch_size, parameters->vocab_size}; TensorShape debug_logits_shape(&debug_logits_dims[0], sizeof(debug_logits_dims) / sizeof(debug_logits_dims[0])); - std::cout << "debug_logits_shape: " << debug_logits_shape << std::endl; Tensor* debug_logits = this->context_.Output(1, debug_logits_shape); std::vector feeds; @@ -256,7 +254,6 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds if (this->IsCuda() && debug_logits != nullptr) { gsl::span logits_to_debug = debug_logits->MutableDataAsSpan(); for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) { - std::cout << "batch_id: " << batch_id << std::endl; auto batch_output = logits_to_debug.subspan( static_cast(batch_id) * parameters->vocab_size, parameters->vocab_size); From 653b28b393d7ddde7d02137667099255a91704ea Mon Sep 17 00:00:00 2001 From: wangyems Date: Thu, 1 Dec 2022 21:59:38 +0000 Subject: [PATCH 30/51] optional logits --- onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index df48d6509ce1c..3111f080e04e8 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1136,7 +1136,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") - .Output(1, "logits_before_multinomial", "logits_before_multinomial(debug purpose). Shape is (batch_size, vocab_size)", "T") + .Output(1, "logits_before_multinomial", "logits_before_multinomial(debug purpose). Shape is (batch_size, vocab_size)", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { From ad4e21bc2f3c65069d8a82468a0cc4734c9ccfab Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 6 Dec 2022 01:10:00 +0000 Subject: [PATCH 31/51] refactor cuda impl --- .../transformers/generation_device_helper.cc | 136 +------------ .../cuda/transformers/sampling_cuda_helper.h | 178 ++++++++++++++++++ 2 files changed, 183 insertions(+), 131 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 3eb2733957227..a3916ef09cfe7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -16,6 +16,7 @@ #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" +#include "sampling_cuda_helper.h" #ifdef DEBUG_GENERATION #include @@ -455,7 +456,7 @@ template Status GreedySearchProcessLogits( const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state - transformers::ISamplingCudaState* sampling_state, // buffers + transformers::ISamplingCudaState* sampling_state, // buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) @@ -563,136 +564,8 @@ Status GreedySearchProcessLogits( ORT_UNUSED_PARAMETER(output_scores); if (do_sampling) { - gsl::span& d_index_in = sampling_state->d_index_in; - gsl::span& d_offset = sampling_state->d_offset; - - BufferUniquePtr& storage_buffer = sampling_state->storage_buffer; - size_t& temp_storage_bytes = sampling_state->temp_storage_bytes; - if (step == 1) { - temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), - d_index_in.data(), - d_offset.data(), - parameters->batch_size * parameters->vocab_size, - parameters->batch_size, - cuda_stream); - - cuda::LaunchSetupParamsKernel(d_index_in.data(), - d_offset.data(), - parameters->batch_size, - parameters->vocab_size, - cuda_stream); - -#ifdef DEBUG_GENERATION - dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); -#endif - - void* temp_storage = allocator->Alloc(sampling_state->temp_storage_bytes); - BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); - storage_buffer = std::move(temp_storage_buffer); - } - - gsl::span& d_sorted_score = sampling_state->d_sorted_score; - gsl::span& d_index_out = sampling_state->d_index_out; - -#ifdef DEBUG_GENERATION - dumper->Print("temp_storage_bytes", temp_storage_bytes, true); -#endif - - cuda::LaunchSortPairsDescending(storage_buffer.get(), - temp_storage_bytes, - reinterpret_cast(next_token_scores.data()), - reinterpret_cast(d_sorted_score.data()), - d_index_in.data(), - d_index_out.data(), - parameters->batch_size * parameters->vocab_size, - parameters->batch_size, - d_offset.data(), - cuda_stream); - -#ifdef DEBUG_GENERATION - dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score.data()), batch_size, vocab_size); - dumper->Print("d_index_buffer_in", d_index_in.data(), batch_size, vocab_size); - dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size); -#endif - - gsl::span& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream, - d_sorted_softmaxed_score.data(), - reinterpret_cast(d_sorted_score.data()), - parameters->vocab_size, - parameters->vocab_size, - parameters->batch_size); - -#ifdef DEBUG_GENERATION - dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), batch_size, vocab_size); -#endif - - cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), - d_index_out.data(), - reinterpret_cast(next_token_scores.data()), - parameters->top_p, - parameters->filter_value, - parameters->batch_size, - parameters->vocab_size, - cuda_stream); - -#ifdef DEBUG_GENERATION - dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); -#endif - - // bugbug: actually we can only do softmax at the very beginning and sort the softmaxed scores. - // Not sure if the order change will affect the result. - gsl::span& d_softmaxed_score = sampling_state->d_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream, - d_softmaxed_score.data(), - reinterpret_cast(next_token_scores.data()), - parameters->vocab_size, - parameters->vocab_size, - parameters->batch_size); - -#ifdef DEBUG_GENERATION - dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), batch_size, vocab_size); -#endif - - // multinomial sampling - gsl::span& d_sampled = sampling_state->d_sampled; - gsl::span& h_sampled_all = sampling_state->h_sampled_all; - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(), - h_sampled_all.data() + (step - 1) * parameters->batch_size, - sizeof(float) * parameters->batch_size, - cudaMemcpyHostToDevice, - cuda_stream)); - -#ifdef DEBUG_GENERATION - dumper->Print("d_sampled", d_sampled.data(), batch_size, 1); -#endif - - gsl::span& d_indices = sampling_state->d_indices; - cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(), - d_sampled.data(), - d_indices.data(), - parameters->batch_size, - parameters->vocab_size, - presence_mask.data(), - cuda_stream); - -#ifdef DEBUG_GENERATION - dumper->Print("d_indices", d_indices.data(), batch_size, 1); -#endif - - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), - d_indices.data(), - greedy_state->next_tokens_cpu.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream)); - - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state->h_softmaxed_score.data(), - d_softmaxed_score.data(), - sampling_state->h_softmaxed_score.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream)); - - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + SamplingCudaHelper::TopPSamplingCuda top_p_sampler(allocator, cuda_stream, sampling_state, greedy_state, parameters); + ORT_RETURN_IF_ERROR(top_p_sampler.Sample(step, next_token_scores)); } else { // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; @@ -865,6 +738,7 @@ Status UpdateGptFeeds( // Update attention mask const OrtValue& old_mask = next_inputs[2]; const auto& mask_dims = old_mask.Get().Shape().GetDims(); + if (mask_dims.size() == 2) { const int32_t* old_mask_data = old_mask.Get().Data(); int64_t mask_dims[] = {batch_beam_size, current_length}; diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h new file mode 100644 index 0000000000000..70fedad102a6b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/tensor/utils.h" +#include "contrib_ops/cpu/transformers/generation_shared.h" + +#ifdef DEBUG_GENERATION +#include +#endif + +namespace onnxruntime { +namespace contrib { +namespace SamplingCudaHelper { + +template +class TopPSamplingCuda{ + public: + TopPSamplingCuda(AllocatorPtr& allocator, + cudaStream_t cuda_stream, + transformers::ISamplingCudaState* sampling_state, + transformers::IGreedySearchState* greedy_state, + const transformers::IGenerationParameters* parameters): + allocator_(allocator), + cuda_stream_(cuda_stream), + sampling_state_(sampling_state), + greedy_state_(greedy_state), + parameters_(parameters) {} + + Status Sample(int step, gsl::span& next_token_scores) { + typedef typename ToCudaType::MappedType CudaT; + + gsl::span& d_index_in = sampling_state_->d_index_in; + gsl::span& d_offset = sampling_state_->d_offset; + + BufferUniquePtr& storage_buffer = sampling_state_->storage_buffer; + size_t& temp_storage_bytes = sampling_state_->temp_storage_bytes; + if (step == 1) { + temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), + d_index_in.data(), + d_offset.data(), + parameters_->batch_size * parameters_->vocab_size, + parameters_->batch_size, + cuda_stream_); + + cuda::LaunchSetupParamsKernel(d_index_in.data(), + d_offset.data(), + parameters_->batch_size, + parameters_->vocab_size, + cuda_stream_); + + #ifdef DEBUG_GENERATION + dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); + #endif + + void* temp_storage = allocator_->Alloc(sampling_state_->temp_storage_bytes); + BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator_)); + storage_buffer = std::move(temp_storage_buffer); + } + + gsl::span& d_sorted_score = sampling_state_->d_sorted_score; + gsl::span& d_index_out = sampling_state_->d_index_out; + + #ifdef DEBUG_GENERATION + dumper->Print("temp_storage_bytes", temp_storage_bytes, true); + #endif + + cuda::LaunchSortPairsDescending(storage_buffer.get(), + temp_storage_bytes, + reinterpret_cast(next_token_scores.data()), + reinterpret_cast(d_sorted_score.data()), + d_index_in.data(), + d_index_out.data(), + parameters_->batch_size * parameters_->vocab_size, + parameters_->batch_size, + d_offset.data(), + cuda_stream_); + + #ifdef DEBUG_GENERATION + dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score.data()), batch_size, vocab_size); + dumper->Print("d_index_buffer_in", d_index_in.data(), batch_size, vocab_size); + dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size); + #endif + + gsl::span& d_sorted_softmaxed_score = sampling_state_->d_sorted_softmaxed_score; + dispatch_blockwise_softmax_forward(cuda_stream_, + d_sorted_softmaxed_score.data(), + reinterpret_cast(d_sorted_score.data()), + parameters_->vocab_size, + parameters_->vocab_size, + parameters_->batch_size); + + #ifdef DEBUG_GENERATION + dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), batch_size, vocab_size); + #endif + + cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), + d_index_out.data(), + reinterpret_cast(next_token_scores.data()), + parameters_->top_p, + parameters_->filter_value, + parameters_->batch_size, + parameters_->vocab_size, + cuda_stream_); + + #ifdef DEBUG_GENERATION + dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); + #endif + + // TODO(wy): actually we can only do softmax at the very beginning and sort the softmaxed scores. + gsl::span& d_softmaxed_score = sampling_state_->d_softmaxed_score; + dispatch_blockwise_softmax_forward(cuda_stream_, + d_softmaxed_score.data(), + reinterpret_cast(next_token_scores.data()), + parameters_->vocab_size, + parameters_->vocab_size, + parameters_->batch_size); + + #ifdef DEBUG_GENERATION + dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), batch_size, vocab_size); + #endif + + // multinomial sampling + gsl::span& d_sampled = sampling_state_->d_sampled; + gsl::span& h_sampled_all = sampling_state_->h_sampled_all; + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(), + h_sampled_all.data() + (step - 1) * parameters_->batch_size, + sizeof(float) * parameters_->batch_size, + cudaMemcpyHostToDevice, + cuda_stream_)); + + #ifdef DEBUG_GENERATION + dumper->Print("d_sampled", d_sampled.data(), batch_size, 1); + #endif + + gsl::span& d_indices = sampling_state_->d_indices; + gsl::span& presence_mask = sampling_state_->d_presence_mask; + cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(), + d_sampled.data(), + d_indices.data(), + parameters_->batch_size, + parameters_->vocab_size, + presence_mask.data(), + cuda_stream_); + + #ifdef DEBUG_GENERATION + dumper->Print("d_indices", d_indices.data(), batch_size, 1); + #endif + + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state_->next_tokens_cpu.data(), + sampling_state_->d_indices.data(), + greedy_state_->next_tokens_cpu.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream_)); + + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state_->h_softmaxed_score.data(), + sampling_state_->d_softmaxed_score.data(), + sampling_state_->h_softmaxed_score.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream_)); + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream_)); + + return Status::OK(); + } + + private: + AllocatorPtr& allocator_; + cudaStream_t cuda_stream_; + transformers::ISamplingCudaState* sampling_state_; + transformers::IGreedySearchState* greedy_state_; + const transformers::IGenerationParameters* parameters_; +}; + +} // namespace SamplingCudaHelper +} // namespace contrib +} // namespace onnxruntime From d597492e31fe62075863be0986ef57f71fdf9d3c Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 6 Dec 2022 02:33:55 +0000 Subject: [PATCH 32/51] fix build --- onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index 70fedad102a6b..aba84868f5277 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -5,6 +5,7 @@ #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "core/providers/cpu/tensor/utils.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include "core/providers/cuda/math/softmax.h" #ifdef DEBUG_GENERATION #include From e877315b7957675cb08a5c7ee7608a55b497cfaf Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 7 Dec 2022 23:00:21 +0000 Subject: [PATCH 33/51] refactor cpu kernel --- .../cpu/transformers/beam_search_impl_base.h | 2 +- .../cpu/transformers/generate_impl_base.h | 18 +- .../transformers/generation_device_helper.cc | 168 ++++------- .../transformers/generation_device_helper.h | 4 +- .../cpu/transformers/generation_shared.h | 7 +- .../transformers/greedy_search_impl_base.h | 59 ++-- .../cpu/transformers/logits_processor.cc | 80 +---- .../cpu/transformers/logits_processor.h | 41 +-- .../cpu/transformers/sampling_cpu_helper.h | 143 +++++++++ .../transformers/generation_device_helper.cc | 97 +++--- .../transformers/generation_device_helper.h | 2 +- .../cuda/transformers/sampling_cuda_helper.h | 281 +++++++++--------- 12 files changed, 463 insertions(+), 439 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 0fb22b1578176..d0cbbfa6e4949 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -218,7 +218,7 @@ Status BeamSearchBase::Initialize() { if (!IsCuda()) { // Logits processor is used in CPU only. In CUDA, cuda kernels are used instead. // Initialize processors after CheckInputs so that parameters_->vocab_mask is ready. - logits_processors_.Init(*parameters_, thread_pool_); + logits_processors_.Init(*parameters_); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index b44ba6c75f87a..c27d7802abc65 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -140,19 +140,13 @@ class GenerateBase { if (attention_mask != nullptr) { const auto& dims_attn = attention_mask->Shape().GetDims(); - if (dims_attn.size() == 2) { - if (!SpanEq(dims_attn, dims)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'attention_mask' is expected to have same shape as input_ids"); - } - } else if (dims_attn.size() == 4) { - if (dims_attn[0] != dims[0] || dims_attn[1] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'attention_mask' is expected to shape [batch_size, 1, max_sequence_length, max_sequence_length]"); - } - } else { + if (dims_attn.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size()); + } + if (!SpanEq(dims_attn, dims)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'attention_mask' is expected to have 2 or 4 dimensions, got ", dims_attn.size()); + "Input 'attention_mask' is expected to have same shape as input_ids"); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 0788a84443d27..b902ee57a5467 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -12,6 +12,7 @@ #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_scorer.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" +#include "contrib_ops/cpu/transformers/sampling_cpu_helper.h" #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" @@ -138,8 +139,7 @@ Status CreateGptInputs( OrtValue attention_mask; if (attn_mask_value != nullptr) { const Tensor& attn_mask = attn_mask_value->Get(); - const TensorShape& attn_mask_shape = attn_mask.Shape(); // 2d or 4d - Tensor::InitOrtValue(element_type, attn_mask_shape, const_cast(&attn_mask)->MutableData(), + Tensor::InitOrtValue(element_type, input_ids_shape, const_cast(&attn_mask)->MutableData(), allocator->Info(), attention_mask); } else { auto mask_type = DataTypeImpl::GetType(); @@ -181,13 +181,13 @@ Status CreateGptInputs( expanded_input_ids = std::move(input_ids); expanded_position_ids = std::move(position_ids); expanded_attention_mask = std::move(attention_mask); - } else { - // bugbug: 4d not supported here - ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids); - ExpandInputs(position_ids, num_beams, allocator, expanded_position_ids); - ExpandInputs(attention_mask, num_beams, allocator, expanded_attention_mask); + return Status::OK(); } + ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids); + ExpandInputs(position_ids, num_beams, allocator, expanded_position_ids); + ExpandInputs(attention_mask, num_beams, allocator, expanded_attention_mask); + return Status::OK(); } @@ -408,7 +408,7 @@ template Status GreedySearchProcessLogits( const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state - transformers::ISamplingCudaState* sampling_state, // sampling_state + transformers::ISamplingState* sampling_state, // sampling_state transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) @@ -418,7 +418,6 @@ Status GreedySearchProcessLogits( int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper - ORT_UNUSED_PARAMETER(sampling_state); #ifndef DEBUG_GENERATION ORT_UNUSED_PARAMETER(dumper); #endif @@ -462,94 +461,51 @@ Status GreedySearchProcessLogits( constexpr unsigned top_k = 1; if (do_sampling) { - // bugbug: is this Softmax really needed? - gsl::span& next_token_probs = greedy_state->next_token_probs; - ORT_RETURN_IF_ERROR(SoftmaxCPU(batch_size, - vocab_size, - next_token_scores.data(), - next_token_probs.data(), - false, - thread_pool)); - - // torch.multinomial() - int64_t next_token_probs_dims[] = {static_cast(batch_size), vocab_size}; - TensorShape next_token_probs_shape(&next_token_probs_dims[0], 2); - auto element_type = DataTypeImpl::GetType(); - OrtValue next_token_probs_value; - Tensor::InitOrtValue(element_type, - next_token_probs_shape, - next_token_probs.data(), - allocator->Info(), - next_token_probs_value); - const Tensor& input = next_token_probs_value.Get(); - - std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(parameters->seed)}; - - int64_t sampled_idx_dims[] = {static_cast(batch_size), 1}; - TensorShape sampled_idx_shape(&sampled_idx_dims[0], 2); - - gsl::span& next_token_idx = greedy_state->next_tokens_cpu; - - OrtValue sampled_idx_ov; - Tensor::InitOrtValue(DataTypeImpl::GetType(), - sampled_idx_shape, - next_token_idx.data(), - allocator->Info(), - sampled_idx_ov); - Tensor* sampled_idx = sampled_idx_ov.GetMutable(); - - AllocatorPtr allocator_temp = allocator; - ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocator_temp, - input, - batch_size, - vocab_size, - 1, - generator, - *sampled_idx)); - // TODO: update presense_mask + SamplingCpuHelper::TopPSamplingCpu top_p_sampler(allocator, + thread_pool, + sampling_state, + greedy_state, + parameters); + ORT_RETURN_IF_ERROR(top_p_sampler.Sample(next_token_scores)); -#ifdef DEBUG_GENERATION - dumper->Print("sampled_idx", *sampled_idx); -#endif - } else { - // next_tokens = torch.argmax(scores, dim=-1) - int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; - TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); - auto element_type = DataTypeImpl::GetType(); - OrtValue next_token_scores_value; - Tensor::InitOrtValue(element_type, - next_token_scores_shape, - next_token_scores.data(), - allocator->Info(), - next_token_scores_value); - const Tensor& input = next_token_scores_value.Get(); - - constexpr int axis = 1; - constexpr bool largest = true; - constexpr bool sorted = false; - - Tensor topk_scores; - Tensor topk_indices; - ORT_RETURN_IF_ERROR(TopK(&input, - axis, - top_k, - largest, - sorted, - allocator, - stream, - thread_pool, - topk_scores, - topk_indices)); + return Status::OK(); + } + // next_tokens = torch.argmax(scores, dim=-1) + int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; + TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue next_token_scores_value; + Tensor::InitOrtValue(element_type, + next_token_scores_shape, + next_token_scores.data(), + allocator->Info(), + next_token_scores_value); + const Tensor& input = next_token_scores_value.Get(); + + constexpr int axis = 1; + constexpr bool largest = true; + constexpr bool sorted = false; + + Tensor topk_scores; + Tensor topk_indices; + ORT_RETURN_IF_ERROR(TopK(&input, + axis, + top_k, + largest, + sorted, + allocator, + stream, + thread_pool, + topk_scores, + topk_indices)); #ifdef DEBUG_GENERATION dumper->Print("topk_scores", topk_scores); dumper->Print("topk_indices", topk_indices); #endif - gsl::span next_token_indices = topk_indices.DataAsSpan(); - gsl::copy(next_token_indices, greedy_state->next_tokens_cpu); - } - + gsl::span next_token_indices = topk_indices.DataAsSpan(); + gsl::copy(next_token_indices, greedy_state->next_tokens_cpu); #ifdef DEBUG_GENERATION gsl::span next_tokens(greedy_state->next_tokens_cpu.data(), @@ -624,6 +580,7 @@ Status UpdateGptFeeds( // last_outputs: logits, present_0, present_1, ... // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 ORT_UNUSED_PARAMETER(stream); + // The following updates inputs for subgraph // Update input_ids with next tokens. @@ -639,6 +596,7 @@ Status UpdateGptFeeds( input_ids_data[i] = beam_next_tokens[i]; } next_inputs[0] = input_ids; + if (increase_position) { // Update position IDs int32_t* position_data = position_ids.GetMutable()->MutableData(); @@ -647,24 +605,22 @@ Status UpdateGptFeeds( } } next_inputs[1] = position_ids; + // Update attention mask const OrtValue& old_mask = next_inputs[2]; - const auto& mask_dims = old_mask.Get().Shape().GetDims(); - if (mask_dims.size() == 2) { - const int32_t* old_mask_data = old_mask.Get().Data(); - int64_t mask_dims[] = {batch_beam_size, current_length}; - TensorShape mask_shape(&mask_dims[0], 2); - OrtValue attention_mask; - Tensor::InitOrtValue(int32_type, mask_shape, allocator, attention_mask); - int32_t* mask_data = attention_mask.GetMutable()->MutableData(); - for (int i = 0; i < batch_beam_size; i++) { - for (int j = 0; j < current_length - 1; j++) { - mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j]; - } - mask_data[i * current_length + current_length - 1] = 1; + const int32_t* old_mask_data = old_mask.Get().Data(); + int64_t mask_dims[] = {batch_beam_size, current_length}; + TensorShape mask_shape(&mask_dims[0], 2); + OrtValue attention_mask; + Tensor::InitOrtValue(int32_type, mask_shape, allocator, attention_mask); + int32_t* mask_data = attention_mask.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < current_length - 1; j++) { + mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j]; } - next_inputs[2] = attention_mask; - } // if mask_dims.size() == 4 do nothing + mask_data[i * current_length + current_length - 1] = 1; + } + next_inputs[2] = attention_mask; // Update past state if (num_beams == 1) { @@ -894,7 +850,7 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, - transformers::ISamplingCudaState* sampling_state, + transformers::ISamplingState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 9b1d850623256..4a19dbe9438d6 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -91,7 +91,7 @@ template using GreedySearchProcessLogitsFunc = std::function* greedy_state, // state - transformers::ISamplingCudaState* sampling_state, // sampling buffers + transformers::ISamplingState* sampling_state, // sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) @@ -209,7 +209,7 @@ Status ProcessLogits(const OrtValue& logits, // template Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state - transformers::ISamplingCudaState* sampling_state, // sampling buffers + transformers::ISamplingState* sampling_state, // sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 48aee8e30936d..d6d34f39beb3d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/common/gsl.h" #include "core/framework/allocator.h" #include "core/framework/ort_value.h" @@ -33,7 +34,7 @@ struct IBeamSearchState { gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) gsl::span remaining_scores; // portion of scores that is available for appending next token scores. gsl::span topk_buffer; // temp buffer for topk computation, including: - // 1st stage needs: + // 1st stage needs: // temp score: (batch_size * num_beams * parts_vocab, 2 * num_beams) // temp token: (batch_size * num_beams * parts_vocab, 2 * num_beams) // 2nd stage needs: @@ -62,12 +63,11 @@ struct IGreedySearchState { gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. gsl::span eos_meet; // shape (batch_size) gsl::span next_token_scores; // shape (batch_size, vocab_size) - gsl::span next_token_probs; // shape (batch_size, vocab_size) gsl::span next_tokens; // shape (batch_size) }; template -struct ISamplingCudaState { +struct ISamplingState { gsl::span d_index_in; gsl::span d_index_out; gsl::span d_offset; @@ -82,6 +82,7 @@ struct ISamplingCudaState { BufferUniquePtr storage_buffer; size_t temp_storage_bytes; + std::default_random_engine generator; }; class ISequences { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 083dac9d5334c..16b3d1f6b6ecc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -13,7 +13,7 @@ namespace contrib { namespace transformers { template -struct SamplingState : public ISamplingCudaState { +struct SamplingState : public ISamplingState { void Init(AllocatorPtr allocator, AllocatorPtr cpu_allocator, int batch_size, @@ -21,30 +21,31 @@ struct SamplingState : public ISamplingCudaState { int max_iter, int seed, bool is_cuda) { - if (!is_cuda) { - return; - } int total_count = batch_size * vocab_size; - this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count)); - this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count)); - this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1)); - this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); - this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); + this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); - this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); - this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); - // TODO: do not allocate this buffer if there's no presence_mask - this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); - - this->temp_storage_bytes = 0; - - std::default_random_engine generator = std::default_random_engine{gsl::narrow_cast(seed)}; - std::uniform_real_distribution distribution(0.0, 1.0); - distribution(generator); - for (size_t i = 0; i < this->h_sampled_all.size(); ++i) { - this->h_sampled_all[i] = distribution(generator); + + this->generator = std::default_random_engine{gsl::narrow_cast(seed)}; + + if (is_cuda) { + this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count)); + this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count)); + this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1)); + this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); + this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); + this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); + this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); + this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); + this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); + this->temp_storage_bytes = 0; + // TODO(wy): Do not allocate this buffer if there's no presence_mask + this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); + + std::uniform_real_distribution distribution(0.0, 1.0); + distribution(this->generator); + for (size_t i = 0; i < this->h_sampled_all.size(); ++i) { + this->h_sampled_all[i] = distribution(this->generator); + } } } @@ -92,7 +93,6 @@ struct GreedySearchState : public IGreedySearchState { // below buffers are on cpu or cuda size_t next_token_size = SafeInt(batch_size) * vocab_size; this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_token_probs = AllocateBuffer(allocator, next_token_probs_buffer_, next_token_size); this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size); } @@ -113,7 +113,6 @@ struct GreedySearchState : public IGreedySearchState { BufferUniquePtr sequences_space_buffer_; BufferUniquePtr sequence_lengths_buffer_; BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_token_probs_buffer_; BufferUniquePtr next_tokens_buffer_; BufferUniquePtr next_tokens_cpu_buffer_; BufferUniquePtr next_positions_buffer_; @@ -158,14 +157,14 @@ class GreedySearchBase : public GenerateBase { Status GenerateNextToken(const OrtValue& logits, gsl::span& next_tokens, GreedySearchState& greedy_state, - ISamplingCudaState& sampling_state, + ISamplingState& sampling_state, int counter, int eos_token_id); // Calculate scores from logits, then apply filtering and select next token for each beam. Status ProcessLogits(const OrtValue& logits, // logits output of subgraph GreedySearchState& greedy_state, - ISamplingCudaState& sampling_state, + ISamplingState& sampling_state, AllocatorPtr& allocator, int counter); @@ -205,7 +204,7 @@ Status GreedySearchBase::Initialize() { if (!this->IsCuda()) { // Logits processor is used in CPU only. In CUDA, cuda kernels are used instead. // Initialize processors after CheckInputs so that parameters_->vocab_mask is ready. - this->logits_processors_.Init(*parameters_, thread_pool_); + this->logits_processors_.Init(*parameters_); } return Status::OK(); @@ -215,7 +214,7 @@ template Status GreedySearchBase::ProcessLogits( const OrtValue& logits, GreedySearchState& greedy_state, - ISamplingCudaState& sampling_state, + ISamplingState& sampling_state, AllocatorPtr& allocator, int counter) { bool use_sampling = std::is_same::value; @@ -229,7 +228,7 @@ Status GreedySearchBase::GenerateNextToken( const OrtValue& logits, gsl::span& next_tokens, GreedySearchState& greedy_state, - ISamplingCudaState& sampling_state, + ISamplingState& sampling_state, int counter, int eos_token_id) { // Process logits to get next token scores diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 4efa12e6d5f0f..d0641fedf978e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -213,71 +213,6 @@ void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, #endif } -template -TopPLogitsProcessor::TopPLogitsProcessor(float top_p, float filter_value, - onnxruntime::concurrency::ThreadPool* thread_pool) - : top_p_(top_p), filter_value_(filter_value), thread_pool_(thread_pool) { -} - -template -void TopPLogitsProcessor::Process(const ISequences* /*sequences*/, - NextTokenScores& next_token_scores) { - if (top_p_ == 0.0f) { - return; - } - - const int batch_beam_size = next_token_scores.batch_beam_size; - const int vocab_size = next_token_scores.vocab_size; - - for (int i = 0; i < batch_beam_size; i++) { - gsl::span beam_token_scores = next_token_scores.GetScores(i); - - std::vector sorted_scores(beam_token_scores.begin(), beam_token_scores.end()); - - // decending sort - std::vector sorted_indices(beam_token_scores.size()); - std::iota(sorted_indices.begin(), sorted_indices.end(), 0); - std::sort(sorted_indices.begin(), - sorted_indices.end(), - [&sorted_scores](size_t i1, size_t i2) { - return sorted_scores[i1] > sorted_scores[i2]; - }); - - std::sort(sorted_scores.begin(), sorted_scores.end(), std::greater()); - std::vector cumulative_probs(vocab_size); - // todo: batch - Status status = SoftmaxCPU(1, - vocab_size, - sorted_scores.data(), - cumulative_probs.data(), - false, - thread_pool_); - if (!status.IsOK()) { - ORT_THROW(status.ErrorMessage()); - } - - std::unordered_set sorted_indices_to_remove; - if (cumulative_probs[0] > top_p_) { - sorted_indices_to_remove.insert(1); - } - for (size_t j = 1; j < static_cast(vocab_size) - 1; j++) { - cumulative_probs[j] += cumulative_probs[j - 1]; - if (cumulative_probs[j] > top_p_) { - sorted_indices_to_remove.insert(j + 1); - } - } - - for (auto it = sorted_indices_to_remove.begin(); it != sorted_indices_to_remove.end(); ++it) { - size_t index_to_remove = sorted_indices[*it]; - beam_token_scores[index_to_remove] = filter_value_; - } - } - -#ifdef DEBUG_GENERATION - DumpScores("TopPLogitsProcessor", next_token_scores); -#endif -} - template PresencePenaltyLogitsProcessor::PresencePenaltyLogitsProcessor(const gsl::span& presence_mask, float presence_penalty) @@ -303,19 +238,16 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, #endif } -void LogitsProcessorList::Init(const BeamSearchParameters& parameters, - onnxruntime::concurrency::ThreadPool* thread_pool) { - LogitsProcessorInitImpl(parameters, thread_pool); +void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { + LogitsProcessorInitImpl(parameters); } -void LogitsProcessorList::Init(const GreedySearchParameters& parameters, - onnxruntime::concurrency::ThreadPool* thread_pool) { - LogitsProcessorInitImpl(parameters, thread_pool); +void LogitsProcessorList::Init(const GreedySearchParameters& parameters) { + LogitsProcessorInitImpl(parameters); } -void LogitsProcessorList::Init(const SamplingParameters& parameters, - onnxruntime::concurrency::ThreadPool* thread_pool) { - LogitsProcessorInitImpl(parameters, thread_pool); +void LogitsProcessorList::Init(const SamplingParameters& parameters) { + LogitsProcessorInitImpl(parameters); } void LogitsProcessorList::Process(const ISequences* sequences, diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 8b2bc4a2811be..4a516474c82bd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -109,20 +109,20 @@ class TemperatureLogitsProcessor : public ILogitsProcessor { float temperature_; }; -template -class TopPLogitsProcessor : public ILogitsProcessor { - public: - TopPLogitsProcessor(float top_p, float filter_value, - onnxruntime::concurrency::ThreadPool* thread_pool); +// template +// class TopPLogitsProcessor : public ILogitsProcessor { +// public: +// TopPLogitsProcessor(float top_p, float filter_value, +// onnxruntime::concurrency::ThreadPool* thread_pool); - void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; +// void Process(const ISequences* sequences, +// NextTokenScores& next_token_scores) override; - private: - float top_p_; - float filter_value_; - onnxruntime::concurrency::ThreadPool* thread_pool_; -}; +// private: +// float top_p_; +// float filter_value_; +// onnxruntime::concurrency::ThreadPool* thread_pool_; +// }; template class PresencePenaltyLogitsProcessor : public ILogitsProcessor { @@ -141,15 +141,14 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { class LogitsProcessorList : public ILogitsProcessorList { public: LogitsProcessorList() = default; - void Init(const BeamSearchParameters& parameters, onnxruntime::concurrency::ThreadPool* thread_pool); - void Init(const GreedySearchParameters& parameters, onnxruntime::concurrency::ThreadPool* thread_pool); - void Init(const SamplingParameters& parameters, onnxruntime::concurrency::ThreadPool* thread_pool); + void Init(const BeamSearchParameters& parameters); + void Init(const GreedySearchParameters& parameters); + void Init(const SamplingParameters& parameters); void Process(const ISequences* sequences, gsl::span& next_token_scores, int step); private: template - void LogitsProcessorInitImpl(const GenerationParametersT& parameters, - onnxruntime::concurrency::ThreadPool* thread_pool) { + void LogitsProcessorInitImpl(const GenerationParametersT& parameters) { processor_list_.clear(); if (parameters.repetition_penalty != 1.0f) { // 1.0 means no penalty @@ -189,13 +188,6 @@ class LogitsProcessorList : public ILogitsProcessorList { processor_list_.push_back(temperature_processor_.get()); } - if (parameters.top_p > 0) { - top_p_processor_ = std::make_unique>(parameters.top_p, - parameters.filter_value, - thread_pool); - processor_list_.push_back(top_p_processor_.get()); - } - if (!parameters.presence_mask.empty()) { presence_penalty_processor_ = std::make_unique< PresencePenaltyLogitsProcessor @@ -218,7 +210,6 @@ class LogitsProcessorList : public ILogitsProcessorList { std::unique_ptr> prefix_vocab_mask_processor_; std::unique_ptr> min_length_processor_; std::unique_ptr> temperature_processor_; - std::unique_ptr> top_p_processor_; std::unique_ptr> presence_penalty_processor_; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h new file mode 100644 index 0000000000000..5adcfe15a6922 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace SamplingCpuHelper { + +template +class TopPSamplingCpu{ + public: + TopPSamplingCpu(AllocatorPtr& allocator, + onnxruntime::concurrency::ThreadPool* thread_pool, + transformers::ISamplingState* sampling_state, + transformers::IGreedySearchState* greedy_state, + const transformers::IGenerationParameters* parameters): + allocator_(allocator), + thread_pool_(thread_pool), + sampling_state_(sampling_state), + greedy_state_(greedy_state), + parameters_(parameters) {} + + Status Sample(gsl::span& next_token_scores); + + private: + void filter_scores(std::vector& sorted_indice, gsl::span& next_token_score, size_t index); + + AllocatorPtr& allocator_; + onnxruntime::concurrency::ThreadPool* thread_pool_; + transformers::ISamplingState* sampling_state_; + transformers::IGreedySearchState* greedy_state_; + const transformers::IGenerationParameters* parameters_; +}; + +template +void TopPSamplingCpu::filter_scores(std::vector& sorted_indice, + gsl::span& next_token_score, + size_t index) { + size_t real_index = sorted_indice[index]; + next_token_score[real_index] = parameters_->filter_value; +} + +template +Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { + if (parameters_->top_p == 0.0f) { + ORT_THROW("top_p shall be greater than 0"); + } + + for (int i = 0; i < parameters_->batch_size; i++) { + gsl::span next_token_score = next_token_scores.subspan(i * parameters_->vocab_size, + parameters_->vocab_size); + + // Copy the vector + std::vector sorted_score(next_token_score.begin(), next_token_score.end()); + + // Decending sort + std::vector sorted_indice(parameters_->vocab_size); + std::iota(sorted_indice.begin(), sorted_indice.end(), 0); + std::sort(sorted_indice.begin(), + sorted_indice.end(), + [&sorted_score](size_t i1, size_t i2) { + return sorted_score[i1] > sorted_score[i2]; + }); + + std::sort(sorted_score.begin(), sorted_score.end(), std::greater()); + std::vector cumulative_prob(parameters_->vocab_size); + + // TODO: batch + ORT_RETURN_IF_ERROR(SoftmaxCPU(1, + parameters_->vocab_size, + sorted_score.data(), + cumulative_prob.data(), + false, + thread_pool_)); + + if (cumulative_prob[0] > parameters_->top_p) { + filter_scores(sorted_indice, next_token_score, 1); + } + for (size_t i = 1; i < static_cast(parameters_->vocab_size) - 1; i++) { + cumulative_prob[i] += cumulative_prob[i - 1]; + if (cumulative_prob[i] > parameters_->top_p) { + filter_scores(sorted_indice, next_token_score, i + 1); + } + } + } + + // TODO(wy): This softmax may not be necessary. + gsl::span& next_token_probs = sampling_state_->h_softmaxed_score; + ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters_->batch_size, + parameters_->vocab_size, + next_token_scores.data(), + next_token_probs.data(), + false, + thread_pool_)); + + // torch.multinomial() + int64_t next_token_probs_dims[] = {static_cast(parameters_->batch_size), parameters_->vocab_size}; + TensorShape next_token_probs_shape(&next_token_probs_dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue next_token_probs_value; + Tensor::InitOrtValue(element_type, + next_token_probs_shape, + next_token_probs.data(), + allocator_->Info(), + next_token_probs_value); + const Tensor& input = next_token_probs_value.Get(); + + std::default_random_engine& generator = sampling_state_->generator; + + int64_t sampled_idx_dims[] = {static_cast(parameters_->batch_size), 1}; + TensorShape sampled_idx_shape(&sampled_idx_dims[0], 2); + + gsl::span& next_token_idx = greedy_state_->next_tokens_cpu; + + OrtValue sampled_idx_ov; + Tensor::InitOrtValue(DataTypeImpl::GetType(), + sampled_idx_shape, + next_token_idx.data(), + allocator_->Info(), + sampled_idx_ov); + Tensor* sampled_idx = sampled_idx_ov.GetMutable(); + + // Copy the allocator because MultinomialComputeShared() uses move(allocator) + AllocatorPtr allocator_temp = allocator_; + ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocator_temp, + input, + parameters_->batch_size, + parameters_->vocab_size, + 1, + generator, + *sampled_idx)); + // TODO: update presense_mask() + +#ifdef DEBUG_GENERATION + dumper->Print("sampled_idx", *sampled_idx); +#endif + + return Status::OK(); +} + +} // namespace SamplingCudaHelper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index a3916ef09cfe7..dfca731f8148d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -456,7 +456,7 @@ template Status GreedySearchProcessLogits( const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state - transformers::ISamplingCudaState* sampling_state, // buffers + transformers::ISamplingState* sampling_state, // buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) @@ -564,45 +564,52 @@ Status GreedySearchProcessLogits( ORT_UNUSED_PARAMETER(output_scores); if (do_sampling) { - SamplingCudaHelper::TopPSamplingCuda top_p_sampler(allocator, cuda_stream, sampling_state, greedy_state, parameters); + // TODO: Move the ctor out of the function. + SamplingCudaHelper::TopPSamplingCuda top_p_sampler(allocator, + cuda_stream, + sampling_state, + greedy_state, + parameters); ORT_RETURN_IF_ERROR(top_p_sampler.Sample(step, next_token_scores)); - } else { - // next_tokens = torch.argmax(scores, dim=-1) - int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; - TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); - auto element_type = DataTypeImpl::GetType(); - OrtValue next_token_scores_value; - Tensor::InitOrtValue(element_type, - next_token_scores_shape, - next_token_scores.data(), - allocator->Info(), - next_token_scores_value); - const Tensor& input = next_token_scores_value.Get(); - constexpr int axis = 1; - constexpr unsigned top_k = static_cast(1); - constexpr bool largest = true; - constexpr bool sorted = false; + return Status::OK(); + } - auto topk_scores = Tensor::CreateDefault(); - auto topk_indices = Tensor::CreateDefault(); - ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, - *topk_scores, *topk_indices)); + // next_tokens = torch.argmax(scores, dim=-1) + int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; + TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue next_token_scores_value; + Tensor::InitOrtValue(element_type, + next_token_scores_shape, + next_token_scores.data(), + allocator->Info(), + next_token_scores_value); + const Tensor& input = next_token_scores_value.Get(); + + constexpr int axis = 1; + constexpr unsigned top_k = static_cast(1); + constexpr bool largest = true; + constexpr bool sorted = false; + + auto topk_scores = Tensor::CreateDefault(); + auto topk_indices = Tensor::CreateDefault(); + ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, + *topk_scores, *topk_indices)); #ifdef DEBUG_GENERATION dumper->Print("topk_scores", *(topk_scores.get())); dumper->Print("topk_indices", *(topk_indices.get())); #endif - const int64_t* next_token_indices = topk_indices->Data(); + const int64_t* next_token_indices = topk_indices->Data(); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), - next_token_indices, - greedy_state->next_tokens_cpu.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); - } + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), + next_token_indices, + greedy_state->next_tokens_cpu.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); return Status::OK(); } @@ -737,23 +744,19 @@ Status UpdateGptFeeds( // Update attention mask const OrtValue& old_mask = next_inputs[2]; - const auto& mask_dims = old_mask.Get().Shape().GetDims(); - - if (mask_dims.size() == 2) { - const int32_t* old_mask_data = old_mask.Get().Data(); - int64_t mask_dims[] = {batch_beam_size, current_length}; - TensorShape mask_shape(&mask_dims[0], 2); - OrtValue attention_mask; - auto mask_type = DataTypeImpl::GetType(); - Tensor::InitOrtValue(mask_type, mask_shape, allocator, attention_mask); - int32_t* mask_data = attention_mask.GetMutable()->MutableData(); + const int32_t* old_mask_data = old_mask.Get().Data(); + int64_t mask_dims[] = {batch_beam_size, current_length}; + TensorShape mask_shape(&mask_dims[0], 2); + OrtValue attention_mask; + auto mask_type = DataTypeImpl::GetType(); + Tensor::InitOrtValue(mask_type, mask_shape, allocator, attention_mask); + int32_t* mask_data = attention_mask.GetMutable()->MutableData(); - // Launch kernel to update position_ids and attention_mask for next iteration - cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length, - reinterpret_cast(stream)); + // Launch kernel to update position_ids and attention_mask for next iteration + cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length, + reinterpret_cast(stream)); - next_inputs[2] = attention_mask; - } // do nothing for mask_dims.size() == 4 + next_inputs[2] = attention_mask; // Update past state if (num_beams == 1) { @@ -914,7 +917,7 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, - transformers::ISamplingCudaState* sampling_state, + transformers::ISamplingState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, @@ -981,7 +984,7 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, - transformers::ISamplingCudaState* sampling_state, + transformers::ISamplingState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index 73385450af428..c065c8254f508 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -62,7 +62,7 @@ Status ProcessLogits(const OrtValue& logits, // template Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state - transformers::ISamplingCudaState* sampling_state,// sampling buffers + transformers::ISamplingState* sampling_state,// sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index aba84868f5277..18e3cb5c6da68 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -20,7 +20,7 @@ class TopPSamplingCuda{ public: TopPSamplingCuda(AllocatorPtr& allocator, cudaStream_t cuda_stream, - transformers::ISamplingCudaState* sampling_state, + transformers::ISamplingState* sampling_state, transformers::IGreedySearchState* greedy_state, const transformers::IGenerationParameters* parameters): allocator_(allocator), @@ -29,151 +29,156 @@ class TopPSamplingCuda{ greedy_state_(greedy_state), parameters_(parameters) {} - Status Sample(int step, gsl::span& next_token_scores) { - typedef typename ToCudaType::MappedType CudaT; - - gsl::span& d_index_in = sampling_state_->d_index_in; - gsl::span& d_offset = sampling_state_->d_offset; - - BufferUniquePtr& storage_buffer = sampling_state_->storage_buffer; - size_t& temp_storage_bytes = sampling_state_->temp_storage_bytes; - if (step == 1) { - temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), - d_index_in.data(), - d_offset.data(), - parameters_->batch_size * parameters_->vocab_size, - parameters_->batch_size, - cuda_stream_); - - cuda::LaunchSetupParamsKernel(d_index_in.data(), - d_offset.data(), - parameters_->batch_size, - parameters_->vocab_size, - cuda_stream_); - - #ifdef DEBUG_GENERATION - dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); - #endif - - void* temp_storage = allocator_->Alloc(sampling_state_->temp_storage_bytes); - BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator_)); - storage_buffer = std::move(temp_storage_buffer); - } - - gsl::span& d_sorted_score = sampling_state_->d_sorted_score; - gsl::span& d_index_out = sampling_state_->d_index_out; - - #ifdef DEBUG_GENERATION - dumper->Print("temp_storage_bytes", temp_storage_bytes, true); - #endif - - cuda::LaunchSortPairsDescending(storage_buffer.get(), - temp_storage_bytes, - reinterpret_cast(next_token_scores.data()), - reinterpret_cast(d_sorted_score.data()), - d_index_in.data(), - d_index_out.data(), - parameters_->batch_size * parameters_->vocab_size, - parameters_->batch_size, - d_offset.data(), - cuda_stream_); - - #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score.data()), batch_size, vocab_size); - dumper->Print("d_index_buffer_in", d_index_in.data(), batch_size, vocab_size); - dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size); - #endif - - gsl::span& d_sorted_softmaxed_score = sampling_state_->d_sorted_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream_, - d_sorted_softmaxed_score.data(), - reinterpret_cast(d_sorted_score.data()), - parameters_->vocab_size, - parameters_->vocab_size, - parameters_->batch_size); - - #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), batch_size, vocab_size); - #endif - - cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), - d_index_out.data(), - reinterpret_cast(next_token_scores.data()), - parameters_->top_p, - parameters_->filter_value, - parameters_->batch_size, - parameters_->vocab_size, - cuda_stream_); - - #ifdef DEBUG_GENERATION - dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); - #endif - - // TODO(wy): actually we can only do softmax at the very beginning and sort the softmaxed scores. - gsl::span& d_softmaxed_score = sampling_state_->d_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream_, - d_softmaxed_score.data(), - reinterpret_cast(next_token_scores.data()), - parameters_->vocab_size, - parameters_->vocab_size, - parameters_->batch_size); - - #ifdef DEBUG_GENERATION - dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), batch_size, vocab_size); - #endif - - // multinomial sampling - gsl::span& d_sampled = sampling_state_->d_sampled; - gsl::span& h_sampled_all = sampling_state_->h_sampled_all; - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(), - h_sampled_all.data() + (step - 1) * parameters_->batch_size, - sizeof(float) * parameters_->batch_size, - cudaMemcpyHostToDevice, - cuda_stream_)); - - #ifdef DEBUG_GENERATION - dumper->Print("d_sampled", d_sampled.data(), batch_size, 1); - #endif - - gsl::span& d_indices = sampling_state_->d_indices; - gsl::span& presence_mask = sampling_state_->d_presence_mask; - cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(), - d_sampled.data(), - d_indices.data(), - parameters_->batch_size, - parameters_->vocab_size, - presence_mask.data(), - cuda_stream_); - - #ifdef DEBUG_GENERATION - dumper->Print("d_indices", d_indices.data(), batch_size, 1); - #endif - - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state_->next_tokens_cpu.data(), - sampling_state_->d_indices.data(), - greedy_state_->next_tokens_cpu.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream_)); - - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state_->h_softmaxed_score.data(), - sampling_state_->d_softmaxed_score.data(), - sampling_state_->h_softmaxed_score.size_bytes(), - cudaMemcpyDeviceToHost, - cuda_stream_)); - - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream_)); - - return Status::OK(); - } + Status Sample(int step, gsl::span& next_token_scores); private: AllocatorPtr& allocator_; cudaStream_t cuda_stream_; - transformers::ISamplingCudaState* sampling_state_; + transformers::ISamplingState* sampling_state_; transformers::IGreedySearchState* greedy_state_; const transformers::IGenerationParameters* parameters_; }; +template +Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { + typedef typename ToCudaType::MappedType CudaT; + + gsl::span& d_index_in = sampling_state_->d_index_in; + gsl::span& d_offset = sampling_state_->d_offset; + + BufferUniquePtr& storage_buffer = sampling_state_->storage_buffer; + size_t& temp_storage_bytes = sampling_state_->temp_storage_bytes; + if (step == 1) { + temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), + d_index_in.data(), + d_offset.data(), + parameters_->batch_size * parameters_->vocab_size, + parameters_->batch_size, + cuda_stream_); + + cuda::LaunchSetupParamsKernel(d_index_in.data(), + d_offset.data(), + parameters_->batch_size, + parameters_->vocab_size, + cuda_stream_); + +#ifdef DEBUG_GENERATION + dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); +#endif + + void* temp_storage = allocator_->Alloc(sampling_state_->temp_storage_bytes); + BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator_)); + storage_buffer = std::move(temp_storage_buffer); + } + + gsl::span& d_sorted_score = sampling_state_->d_sorted_score; + gsl::span& d_index_out = sampling_state_->d_index_out; + +#ifdef DEBUG_GENERATION + dumper->Print("temp_storage_bytes", temp_storage_bytes, true); +#endif + + cuda::LaunchSortPairsDescending(storage_buffer.get(), + temp_storage_bytes, + reinterpret_cast(next_token_scores.data()), + reinterpret_cast(d_sorted_score.data()), + d_index_in.data(), + d_index_out.data(), + parameters_->batch_size * parameters_->vocab_size, + parameters_->batch_size, + d_offset.data(), + cuda_stream_); + +#ifdef DEBUG_GENERATION + dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score.data()), batch_size, vocab_size); + dumper->Print("d_index_buffer_in", d_index_in.data(), batch_size, vocab_size); + dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size); +#endif + + gsl::span& d_sorted_softmaxed_score = sampling_state_->d_sorted_softmaxed_score; + dispatch_blockwise_softmax_forward(cuda_stream_, + d_sorted_softmaxed_score.data(), + reinterpret_cast(d_sorted_score.data()), + parameters_->vocab_size, + parameters_->vocab_size, + parameters_->vocab_size, + parameters_->batch_size); + +#ifdef DEBUG_GENERATION + dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), batch_size, vocab_size); +#endif + + cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), + d_index_out.data(), + reinterpret_cast(next_token_scores.data()), + parameters_->top_p, + parameters_->filter_value, + parameters_->batch_size, + parameters_->vocab_size, + cuda_stream_); + +#ifdef DEBUG_GENERATION + dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); +#endif + + // TODO(wy): Can we only do softmax at the very beginning and sort the softmaxed scores. + gsl::span& d_softmaxed_score = sampling_state_->d_softmaxed_score; + dispatch_blockwise_softmax_forward(cuda_stream_, + d_softmaxed_score.data(), + reinterpret_cast(next_token_scores.data()), + parameters_->vocab_size, + parameters_->vocab_size, + parameters_->vocab_size, + parameters_->batch_size); + +#ifdef DEBUG_GENERATION + dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), batch_size, vocab_size); +#endif + + // Multinomial sampling + gsl::span& d_sampled = sampling_state_->d_sampled; + gsl::span& h_sampled_all = sampling_state_->h_sampled_all; + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(), + h_sampled_all.data() + (step - 1) * parameters_->batch_size, + sizeof(float) * parameters_->batch_size, + cudaMemcpyHostToDevice, + cuda_stream_)); + +#ifdef DEBUG_GENERATION + dumper->Print("d_sampled", d_sampled.data(), batch_size, 1); +#endif + + gsl::span& d_indices = sampling_state_->d_indices; + gsl::span& presence_mask = sampling_state_->d_presence_mask; + cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(), + d_sampled.data(), + d_indices.data(), + parameters_->batch_size, + parameters_->vocab_size, + presence_mask.data(), + cuda_stream_); + +#ifdef DEBUG_GENERATION + dumper->Print("d_indices", d_indices.data(), batch_size, 1); +#endif + + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state_->next_tokens_cpu.data(), + sampling_state_->d_indices.data(), + greedy_state_->next_tokens_cpu.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream_)); + + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state_->h_softmaxed_score.data(), + sampling_state_->d_softmaxed_score.data(), + sampling_state_->h_softmaxed_score.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream_)); + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream_)); + + return Status::OK(); +} + } // namespace SamplingCudaHelper } // namespace contrib } // namespace onnxruntime From 6752f7899a49f56f70b54686796122c8f8164825 Mon Sep 17 00:00:00 2001 From: wangyems Date: Thu, 8 Dec 2022 04:15:23 +0000 Subject: [PATCH 34/51] refactor --- .../transformers/generation_device_helper.cc | 10 ++- .../cpu/transformers/generation_shared.h | 4 ++ .../transformers/greedy_search_impl_base.h | 6 +- .../cpu/transformers/sampling_cpu_helper.h | 70 ++++++++++--------- .../transformers/generation_device_helper.cc | 3 +- .../cuda/transformers/sampling_cuda_helper.h | 39 +++++++---- 6 files changed, 80 insertions(+), 52 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index b902ee57a5467..634bfb4e12637 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -418,9 +418,6 @@ Status GreedySearchProcessLogits( int step, // iteration counter void* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper -#ifndef DEBUG_GENERATION - ORT_UNUSED_PARAMETER(dumper); -#endif int batch_size = parameters->batch_size; int vocab_size = parameters->vocab_size; @@ -458,18 +455,18 @@ Status GreedySearchProcessLogits( dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size); #endif - constexpr unsigned top_k = 1; - if (do_sampling) { SamplingCpuHelper::TopPSamplingCpu top_p_sampler(allocator, thread_pool, sampling_state, greedy_state, - parameters); + parameters, + dumper); ORT_RETURN_IF_ERROR(top_p_sampler.Sample(next_token_scores)); return Status::OK(); } + // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); @@ -482,6 +479,7 @@ Status GreedySearchProcessLogits( next_token_scores_value); const Tensor& input = next_token_scores_value.Get(); + constexpr unsigned top_k = 1; constexpr int axis = 1; constexpr bool largest = true; constexpr bool sorted = false; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index d6d34f39beb3d..2aafbcbcda736 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -83,6 +83,10 @@ struct ISamplingState { BufferUniquePtr storage_buffer; size_t temp_storage_bytes; std::default_random_engine generator; + + std::vector sorted_scores; + std::vector sorted_indices; + std::vector cumulative_probs; }; class ISequences { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 16b3d1f6b6ecc..d0acf6fef0ebd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -38,7 +38,7 @@ struct SamplingState : public ISamplingState { this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); this->temp_storage_bytes = 0; - // TODO(wy): Do not allocate this buffer if there's no presence_mask + // TODO: Do not allocate this buffer if there's no presence_mask this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); std::uniform_real_distribution distribution(0.0, 1.0); @@ -46,6 +46,10 @@ struct SamplingState : public ISamplingState { for (size_t i = 0; i < this->h_sampled_all.size(); ++i) { this->h_sampled_all[i] = distribution(this->generator); } + } else { + this->sorted_scores.reserve(total_count); + this->sorted_indices.reserve(total_count); + this->cumulative_probs.reserve(total_count); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h index 5adcfe15a6922..f45f22561c359 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -13,12 +13,14 @@ class TopPSamplingCpu{ onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ISamplingState* sampling_state, transformers::IGreedySearchState* greedy_state, - const transformers::IGenerationParameters* parameters): + const transformers::IGenerationParameters* parameters, + const transformers::IConsoleDumper* dumper): allocator_(allocator), thread_pool_(thread_pool), sampling_state_(sampling_state), greedy_state_(greedy_state), - parameters_(parameters) {} + parameters_(parameters), + dumper_(dumper) {} Status Sample(gsl::span& next_token_scores); @@ -30,6 +32,7 @@ class TopPSamplingCpu{ transformers::ISamplingState* sampling_state_; transformers::IGreedySearchState* greedy_state_; const transformers::IGenerationParameters* parameters_; + const transformers::IConsoleDumper* dumper_; }; template @@ -46,40 +49,43 @@ Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { ORT_THROW("top_p shall be greater than 0"); } - for (int i = 0; i < parameters_->batch_size; i++) { - gsl::span next_token_score = next_token_scores.subspan(i * parameters_->vocab_size, - parameters_->vocab_size); - - // Copy the vector - std::vector sorted_score(next_token_score.begin(), next_token_score.end()); - - // Decending sort - std::vector sorted_indice(parameters_->vocab_size); - std::iota(sorted_indice.begin(), sorted_indice.end(), 0); - std::sort(sorted_indice.begin(), - sorted_indice.end(), - [&sorted_score](size_t i1, size_t i2) { - return sorted_score[i1] > sorted_score[i2]; + std::vector& sorted_scores = sampling_state_->sorted_scores; + sorted_scores.assign(next_token_scores.begin(), next_token_scores.end()); + // Decending sort + std::vector& sorted_indices = sampling_state_->sorted_indices; + + for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { + auto indices_begin = sorted_indices.begin() + i * parameters_->vocab_size; + auto indices_end = sorted_indices.begin() + (i + 1) * parameters_->vocab_size; + std::iota(indices_begin, indices_end, 0); + std::sort(indices_begin, indices_end, + [&next_token_scores](size_t i1, size_t i2) { + return next_token_scores[i1] > next_token_scores[i2]; }); - std::sort(sorted_score.begin(), sorted_score.end(), std::greater()); - std::vector cumulative_prob(parameters_->vocab_size); + std::sort(sorted_scores.begin() + i * parameters_->vocab_size, + sorted_scores.end() + (i + 1) * parameters_->vocab_size, + std::greater()); + } + + std::vector& cumulative_probs = sampling_state_->cumulative_probs; - // TODO: batch - ORT_RETURN_IF_ERROR(SoftmaxCPU(1, - parameters_->vocab_size, - sorted_score.data(), - cumulative_prob.data(), - false, - thread_pool_)); + ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters_->batch_size, + parameters_->vocab_size, + sorted_scores.data(), + cumulative_probs.data(), + false, + thread_pool_)); - if (cumulative_prob[0] > parameters_->top_p) { - filter_scores(sorted_indice, next_token_score, 1); + for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { + size_t offset = i * parameters_->vocab_size; + if (cumulative_probs[offset] > parameters_->top_p) { + filter_scores(sorted_indices, next_token_scores, 1 + offset); } - for (size_t i = 1; i < static_cast(parameters_->vocab_size) - 1; i++) { - cumulative_prob[i] += cumulative_prob[i - 1]; - if (cumulative_prob[i] > parameters_->top_p) { - filter_scores(sorted_indice, next_token_score, i + 1); + for (size_t j = 1; j < static_cast(parameters_->vocab_size) - 1; j++) { + cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; + if (cumulative_probs[j + offset] > parameters_->top_p) { + filter_scores(sorted_indices, next_token_scores, j + offset + 1); } } } @@ -132,7 +138,7 @@ Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { // TODO: update presense_mask() #ifdef DEBUG_GENERATION - dumper->Print("sampled_idx", *sampled_idx); + dumper_->Print("sampled_idx", *sampled_idx); #endif return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index dfca731f8148d..884f80ca3a1e4 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -569,7 +569,8 @@ Status GreedySearchProcessLogits( cuda_stream, sampling_state, greedy_state, - parameters); + parameters, + dumper); ORT_RETURN_IF_ERROR(top_p_sampler.Sample(step, next_token_scores)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index 18e3cb5c6da68..52939efc5d866 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -22,12 +22,14 @@ class TopPSamplingCuda{ cudaStream_t cuda_stream, transformers::ISamplingState* sampling_state, transformers::IGreedySearchState* greedy_state, - const transformers::IGenerationParameters* parameters): + const transformers::IGenerationParameters* parameters, + const transformers::IConsoleDumper* dumper): allocator_(allocator), cuda_stream_(cuda_stream), sampling_state_(sampling_state), greedy_state_(greedy_state), - parameters_(parameters) {} + parameters_(parameters), + dumper_(dumper) {} Status Sample(int step, gsl::span& next_token_scores); @@ -37,6 +39,7 @@ class TopPSamplingCuda{ transformers::ISamplingState* sampling_state_; transformers::IGreedySearchState* greedy_state_; const transformers::IGenerationParameters* parameters_; + const transformers::IConsoleDumper* dumper_; }; template @@ -63,7 +66,7 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_); #ifdef DEBUG_GENERATION - dumper->Print("d_offset_buffer", d_offset.data(), batch_size + 1, 1); + dumper_->Print("d_offset_buffer", d_offset.data(), parameters_->batch_size + 1, 1); #endif void* temp_storage = allocator_->Alloc(sampling_state_->temp_storage_bytes); @@ -75,7 +78,7 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { gsl::span& d_index_out = sampling_state_->d_index_out; #ifdef DEBUG_GENERATION - dumper->Print("temp_storage_bytes", temp_storage_bytes, true); + dumper_->Print("temp_storage_bytes", sampling_state_->temp_storage_bytes, true); #endif cuda::LaunchSortPairsDescending(storage_buffer.get(), @@ -90,9 +93,12 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_); #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score.data()), batch_size, vocab_size); - dumper->Print("d_index_buffer_in", d_index_in.data(), batch_size, vocab_size); - dumper->Print("d_index_buffer_out", d_index_out.data(), batch_size, vocab_size); + dumper_->Print("d_sorted_score_buffer", + reinterpret_cast(d_sorted_score.data()), + parameters_->batch_size, + parameters_->vocab_size); + dumper_->Print("d_index_buffer_in", d_index_in.data(), parameters_->batch_size, parameters_->vocab_size); + dumper_->Print("d_index_buffer_out", d_index_out.data(), parameters_->batch_size, parameters_->vocab_size); #endif gsl::span& d_sorted_softmaxed_score = sampling_state_->d_sorted_softmaxed_score; @@ -105,7 +111,10 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { parameters_->batch_size); #ifdef DEBUG_GENERATION - dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), batch_size, vocab_size); + dumper_->Print("d_sorted_softmaxed_score_buffer", + d_sorted_softmaxed_score.data(), + parameters_->batch_size, + parameters_->vocab_size); #endif cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), @@ -118,7 +127,10 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_); #ifdef DEBUG_GENERATION - dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), batch_size, vocab_size); + dumper_->Print("next_token_scores after filtering logits", + reinterpret_cast(next_token_scores.data()), + parameters_->batch_size, + parameters_->vocab_size); #endif // TODO(wy): Can we only do softmax at the very beginning and sort the softmaxed scores. @@ -132,7 +144,10 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { parameters_->batch_size); #ifdef DEBUG_GENERATION - dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), batch_size, vocab_size); + dumper_->Print("d_softmaxed_score_buffer", + d_softmaxed_score.data(), + parameters_->batch_size, + parameters_->vocab_size); #endif // Multinomial sampling @@ -145,7 +160,7 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_)); #ifdef DEBUG_GENERATION - dumper->Print("d_sampled", d_sampled.data(), batch_size, 1); + dumper_->Print("d_sampled", d_sampled.data(), parameters_->batch_size, 1); #endif gsl::span& d_indices = sampling_state_->d_indices; @@ -159,7 +174,7 @@ Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { cuda_stream_); #ifdef DEBUG_GENERATION - dumper->Print("d_indices", d_indices.data(), batch_size, 1); + dumper_->Print("d_indices", d_indices.data(), parameters_->batch_size, 1); #endif CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state_->next_tokens_cpu.data(), From 0ad2a3562662d74498579cbd7ede3f1600d8689c Mon Sep 17 00:00:00 2001 From: wangyems Date: Thu, 8 Dec 2022 23:18:16 +0000 Subject: [PATCH 35/51] add huggingface logic cpu --- .../transformers/beam_search_parameters.cc | 3 +- .../transformers/generation_device_helper.cc | 13 +-- .../cpu/transformers/generation_shared.h | 6 +- .../transformers/greedy_search_impl_base.h | 10 +- .../cpu/transformers/greedy_search_impl_gpt.h | 25 +++-- .../cpu/transformers/sampling_cpu_helper.h | 91 +++++++++++++------ .../cpu/transformers/sampling_parameters.cc | 1 + .../core/graph/contrib_ops/contrib_defs.cc | 3 +- 8 files changed, 96 insertions(+), 56 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 260cfc7ebf654..bd3a72e989af0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -71,10 +71,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) - if (vocab_size == -1) { + if (vocab_size == -1 || vocab_size == 0) { vocab_size = vocabulary_size; } - num_heads = heads; head_size = hidden_size_per_head; num_layers = layers; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 634bfb4e12637..d7b8f60e3f9a6 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -456,12 +456,13 @@ Status GreedySearchProcessLogits( #endif if (do_sampling) { - SamplingCpuHelper::TopPSamplingCpu top_p_sampler(allocator, - thread_pool, - sampling_state, - greedy_state, - parameters, - dumper); + // std::smaller for huggingface version and std::greater for custom version. + SamplingCpuHelper::TopPSamplingCpu> top_p_sampler(allocator, + thread_pool, + sampling_state, + greedy_state, + parameters, + dumper); ORT_RETURN_IF_ERROR(top_p_sampler.Sample(next_token_scores)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 2aafbcbcda736..73d83877d6e72 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -84,9 +84,8 @@ struct ISamplingState { size_t temp_storage_bytes; std::default_random_engine generator; - std::vector sorted_scores; - std::vector sorted_indices; - std::vector cumulative_probs; + gsl::span sorted_scores; + gsl::span cumulative_probs; }; class ISequences { @@ -136,6 +135,7 @@ struct IGenerationParameters { float top_p = 0.0f; float filter_value; int seed = 0; + int min_tokens_to_keep = 1; // Parameters from inputs int min_length; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index d0acf6fef0ebd..dbeb18e939aa1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -42,14 +42,14 @@ struct SamplingState : public ISamplingState { this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); std::uniform_real_distribution distribution(0.0, 1.0); - distribution(this->generator); + static_cast(distribution(this->generator)); for (size_t i = 0; i < this->h_sampled_all.size(); ++i) { this->h_sampled_all[i] = distribution(this->generator); } } else { - this->sorted_scores.reserve(total_count); - this->sorted_indices.reserve(total_count); - this->cumulative_probs.reserve(total_count); + // TODO: Some buffer can be reused for CPU + this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count)); + this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count)); } } @@ -65,6 +65,8 @@ struct SamplingState : public ISamplingState { BufferUniquePtr h_sampled_all_buffer_; BufferUniquePtr d_indices_buffer_; BufferUniquePtr d_presence_mask_buffer_; + BufferUniquePtr sorted_scores_buffer_; + BufferUniquePtr cumulative_probs_buffer_; }; template diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 5c4f36f60ebca..3b4bab4bc02c3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -130,10 +130,6 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0])); Tensor* output_sequences = this->context_.Output(0, sequences_shape); - int64_t debug_logits_dims[] = {parameters->batch_size, parameters->vocab_size}; - TensorShape debug_logits_shape(&debug_logits_dims[0], sizeof(debug_logits_dims) / sizeof(debug_logits_dims[0])); - Tensor* debug_logits = this->context_.Output(1, debug_logits_shape); - std::vector feeds; std::vector fetches; @@ -240,6 +236,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds } fetches.clear(); } + // Copy the sequences to output gsl::span output = output_sequences->MutableDataAsSpan(); for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) { @@ -250,18 +247,26 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds gsl::copy(sequence_source, batch_output); } - //copy the debug logits to output - if (this->IsCuda() && debug_logits != nullptr) { - gsl::span logits_to_debug = debug_logits->MutableDataAsSpan(); +#ifdef DEBUG_GENERATION + // Debug the one step filtered logits for sampling + int64_t filtered_logits_dims[] = {parameters->batch_size, parameters->vocab_size}; + TensorShape filtered_logits_shape(&filtered_logits_dims[0], + sizeof(filtered_logits_dims) / sizeof(filtered_logits_dims[0])); + Tensor* filtered_logits = this->context_.Output(1, filtered_logits_shape); + if (filtered_logits != nullptr) { + gsl::span filtered_logits_span = filtered_logits->MutableDataAsSpan(); for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) { - auto batch_output = logits_to_debug.subspan( + auto batch_output = filtered_logits_span.subspan( static_cast(batch_id) * parameters->vocab_size, parameters->vocab_size); - gsl::span batch_logits_to_debug = gsl::make_span(sampling_state.h_softmaxed_score.data() + batch_id * parameters->vocab_size, parameters->vocab_size); + gsl::span batch_filtered_logits = gsl::make_span(sampling_state.h_softmaxed_score.data() + + batch_id * parameters->vocab_size, + parameters->vocab_size); - gsl::copy(batch_logits_to_debug, batch_output); + gsl::copy(batch_filtered_logits, batch_output); } } +#endif return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h index f45f22561c359..c192606908bc3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -6,7 +6,7 @@ namespace onnxruntime { namespace contrib { namespace SamplingCpuHelper { -template +template class TopPSamplingCpu{ public: TopPSamplingCpu(AllocatorPtr& allocator, @@ -26,6 +26,12 @@ class TopPSamplingCpu{ private: void filter_scores(std::vector& sorted_indice, gsl::span& next_token_score, size_t index); + void cumulate_and_filter(gsl::span& next_token_scores, + gsl::span& cumulative_probs, + std::vector& sorted_indices); + void cumulate_and_filter_custom(gsl::span& next_token_scores, + gsl::span& cumulative_probs, + std::vector& sorted_indices); AllocatorPtr& allocator_; onnxruntime::concurrency::ThreadPool* thread_pool_; @@ -35,40 +41,74 @@ class TopPSamplingCpu{ const transformers::IConsoleDumper* dumper_; }; -template -void TopPSamplingCpu::filter_scores(std::vector& sorted_indice, +template +void TopPSamplingCpu::filter_scores(std::vector& sorted_indice, gsl::span& next_token_score, size_t index) { size_t real_index = sorted_indice[index]; - next_token_score[real_index] = parameters_->filter_value; + next_token_score[real_index] = (T)parameters_->filter_value; } -template -Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { - if (parameters_->top_p == 0.0f) { - ORT_THROW("top_p shall be greater than 0"); +template +void TopPSamplingCpu::cumulate_and_filter_custom(gsl::span& next_token_scores, + gsl::span& cumulative_probs, + std::vector& sorted_indices) { + for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { + size_t offset = i * parameters_->vocab_size; + if (cumulative_probs[offset] > parameters_->top_p) { + filter_scores(sorted_indices, next_token_scores, 1 + offset); + } + for (size_t j = 1; j < static_cast(parameters_->vocab_size) - 1; j++) { + cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; + if (cumulative_probs[j + offset] > parameters_->top_p) { + filter_scores(sorted_indices, next_token_scores, j + offset + 1); + } + } + } +} + +template +void TopPSamplingCpu::cumulate_and_filter(gsl::span& next_token_scores, + gsl::span& cumulative_probs, + std::vector& sorted_indices) { + for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { + size_t offset = i * parameters_->vocab_size; + if (cumulative_probs[offset] <= 1 - parameters_->top_p) { + filter_scores(sorted_indices, next_token_scores, offset); + } + for (size_t j = 1; j < static_cast(parameters_->vocab_size - parameters_->min_tokens_to_keep); j++) { + cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; + if (cumulative_probs[j + offset] <= 1 - parameters_->top_p) { + filter_scores(sorted_indices, next_token_scores, j + offset); + } + } } +} + +template +Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { + gsl::span& sorted_scores = sampling_state_->sorted_scores; + memcpy(sorted_scores.data(), next_token_scores.data(), next_token_scores.size_bytes()); + std::vector sorted_indices(parameters_->batch_size * parameters_->vocab_size); - std::vector& sorted_scores = sampling_state_->sorted_scores; - sorted_scores.assign(next_token_scores.begin(), next_token_scores.end()); - // Decending sort - std::vector& sorted_indices = sampling_state_->sorted_indices; + Predicator predicator; + // TODO: This could be optimized with allocated buffer and handwritten sort algorithm for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { auto indices_begin = sorted_indices.begin() + i * parameters_->vocab_size; auto indices_end = sorted_indices.begin() + (i + 1) * parameters_->vocab_size; std::iota(indices_begin, indices_end, 0); std::sort(indices_begin, indices_end, - [&next_token_scores](size_t i1, size_t i2) { - return next_token_scores[i1] > next_token_scores[i2]; + [&next_token_scores, &predicator](size_t i1, size_t i2) { + return !predicator(next_token_scores[i1], next_token_scores[i2]); }); std::sort(sorted_scores.begin() + i * parameters_->vocab_size, - sorted_scores.end() + (i + 1) * parameters_->vocab_size, - std::greater()); + sorted_scores.begin() + (i + 1) * parameters_->vocab_size, + predicator); } - std::vector& cumulative_probs = sampling_state_->cumulative_probs; + gsl::span& cumulative_probs = sampling_state_->cumulative_probs; ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters_->batch_size, parameters_->vocab_size, @@ -77,20 +117,12 @@ Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { false, thread_pool_)); - for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { - size_t offset = i * parameters_->vocab_size; - if (cumulative_probs[offset] > parameters_->top_p) { - filter_scores(sorted_indices, next_token_scores, 1 + offset); - } - for (size_t j = 1; j < static_cast(parameters_->vocab_size) - 1; j++) { - cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; - if (cumulative_probs[j + offset] > parameters_->top_p) { - filter_scores(sorted_indices, next_token_scores, j + offset + 1); - } - } + if (std::is_same>::value) { + cumulate_and_filter_custom(next_token_scores, cumulative_probs, sorted_indices); + } else { + cumulate_and_filter(next_token_scores, cumulative_probs, sorted_indices); } - // TODO(wy): This softmax may not be necessary. gsl::span& next_token_probs = sampling_state_->h_softmaxed_score; ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters_->batch_size, parameters_->vocab_size, @@ -136,7 +168,6 @@ Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { generator, *sampled_idx)); // TODO: update presense_mask() - #ifdef DEBUG_GENERATION dumper_->Print("sampled_idx", *sampled_idx); #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc index 4da912b1bce24..856cc730da478 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -15,6 +15,7 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { temperature = info.GetAttrOrDefault("temperature", 1.0f); top_p = info.GetAttrOrDefault("top_p", 0.0f); filter_value = info.GetAttrOrDefault("filter_value", -std::numeric_limits::infinity()); + min_tokens_to_keep = static_cast(info.GetAttrOrDefault("min_tokens_to_keep", 1)); presence_penalty = info.GetAttrOrDefault("presence_penalty", 0.0f); } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 3111f080e04e8..09f2f7e69e51d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1123,6 +1123,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Attr("temperature", "temperature for sampling", AttributeProto::FLOAT, 1.0f) .Attr("top_p", "top_p for sampling", AttributeProto::FLOAT, 0.0f) .Attr("filter_value", "filter value for top_p", AttributeProto::FLOAT, -1e20f) + .Attr("min_tokens_to_keep", "min_tokens_to_keep", AttributeProto::INT, static_cast(1)) .Attr("presence_penalty", "presence penalty for sampling", AttributeProto::FLOAT, 0.0f) .Attr("model_type", "model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast(0)) .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) @@ -1136,7 +1137,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") - .Output(1, "logits_before_multinomial", "logits_before_multinomial(debug purpose). Shape is (batch_size, vocab_size)", "T", OpSchema::Optional) + .Output(1, "filtered_logits", "filtered logits as input to the mutinomial function . Shape is (batch_size, vocab_size)", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { From 6d758d9d4664006dd4af7f5faefa16a38b548b1b Mon Sep 17 00:00:00 2001 From: wangyems Date: Sat, 10 Dec 2022 02:41:21 +0000 Subject: [PATCH 36/51] huggingface topp cuda --- .../transformers/generation_device_helper.cc | 15 +- .../cpu/transformers/generation_shared.h | 15 +- .../cpu/transformers/sampling_cpu_helper.h | 169 ++++++------ .../cpu/transformers/sampling_parameters.cc | 3 +- .../cuda/transformers/generation_cuda_impl.cu | 244 ++++++++++++------ .../cuda/transformers/generation_cuda_impl.h | 39 +-- .../transformers/generation_device_helper.cc | 16 +- .../cuda/transformers/sampling_cuda_helper.h | 207 +++++++-------- .../core/graph/contrib_ops/contrib_defs.cc | 3 +- 9 files changed, 387 insertions(+), 324 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index d7b8f60e3f9a6..a5a857e98143d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -456,14 +456,13 @@ Status GreedySearchProcessLogits( #endif if (do_sampling) { - // std::smaller for huggingface version and std::greater for custom version. - SamplingCpuHelper::TopPSamplingCpu> top_p_sampler(allocator, - thread_pool, - sampling_state, - greedy_state, - parameters, - dumper); - ORT_RETURN_IF_ERROR(top_p_sampler.Sample(next_token_scores)); + ORT_RETURN_IF_ERROR(SamplingCpuHelper::Sample(allocator, + thread_pool, + next_token_scores, + sampling_state, + greedy_state, + parameters, + dumper)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 73d83877d6e72..4cc5bf380f2af 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -130,12 +130,6 @@ struct IGenerationParameters { int decoder_start_token_id; int no_repeat_ngram_size; bool early_stopping; - float presence_penalty; - float temperature = 1.0f; - float top_p = 0.0f; - float filter_value; - int seed = 0; - int min_tokens_to_keep = 1; // Parameters from inputs int min_length; @@ -159,6 +153,15 @@ struct IGenerationParameters { int num_heads; int head_size; int num_layers; + + // Parameters for TopK/TopP sampling. + float presence_penalty; + float filter_value; + float temperature = 1.0f; + float top_p = 0.0f; + int seed = 0; + int min_tokens_to_keep = 1; + bool custom_sampling = false; }; class IConsoleDumper { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h index c192606908bc3..105e2a9a588ef 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -6,170 +6,151 @@ namespace onnxruntime { namespace contrib { namespace SamplingCpuHelper { -template -class TopPSamplingCpu{ - public: - TopPSamplingCpu(AllocatorPtr& allocator, - onnxruntime::concurrency::ThreadPool* thread_pool, - transformers::ISamplingState* sampling_state, - transformers::IGreedySearchState* greedy_state, - const transformers::IGenerationParameters* parameters, - const transformers::IConsoleDumper* dumper): - allocator_(allocator), - thread_pool_(thread_pool), - sampling_state_(sampling_state), - greedy_state_(greedy_state), - parameters_(parameters), - dumper_(dumper) {} - - Status Sample(gsl::span& next_token_scores); - - private: - void filter_scores(std::vector& sorted_indice, gsl::span& next_token_score, size_t index); - void cumulate_and_filter(gsl::span& next_token_scores, - gsl::span& cumulative_probs, - std::vector& sorted_indices); - void cumulate_and_filter_custom(gsl::span& next_token_scores, - gsl::span& cumulative_probs, - std::vector& sorted_indices); - - AllocatorPtr& allocator_; - onnxruntime::concurrency::ThreadPool* thread_pool_; - transformers::ISamplingState* sampling_state_; - transformers::IGreedySearchState* greedy_state_; - const transformers::IGenerationParameters* parameters_; - const transformers::IConsoleDumper* dumper_; -}; - -template -void TopPSamplingCpu::filter_scores(std::vector& sorted_indice, - gsl::span& next_token_score, - size_t index) { +template +void filter_scores(std::vector& sorted_indice, + gsl::span& next_token_score, + const transformers::IGenerationParameters* parameters, + size_t index) { size_t real_index = sorted_indice[index]; - next_token_score[real_index] = (T)parameters_->filter_value; + next_token_score[real_index] = (T)parameters->filter_value; } -template -void TopPSamplingCpu::cumulate_and_filter_custom(gsl::span& next_token_scores, - gsl::span& cumulative_probs, - std::vector& sorted_indices) { - for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { - size_t offset = i * parameters_->vocab_size; - if (cumulative_probs[offset] > parameters_->top_p) { - filter_scores(sorted_indices, next_token_scores, 1 + offset); +template +void cumulate_and_filter_custom(gsl::span& next_token_scores, + gsl::span& cumulative_probs, + const transformers::IGenerationParameters* parameters, + std::vector& sorted_indices) { + for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { + size_t offset = i * parameters->vocab_size; + if (cumulative_probs[offset] > parameters->top_p) { + filter_scores(sorted_indices, next_token_scores, parameters, 1 + offset); } - for (size_t j = 1; j < static_cast(parameters_->vocab_size) - 1; j++) { + for (size_t j = 1; j < static_cast(parameters->vocab_size) - 1; j++) { cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; - if (cumulative_probs[j + offset] > parameters_->top_p) { - filter_scores(sorted_indices, next_token_scores, j + offset + 1); + if (cumulative_probs[j + offset] > parameters->top_p) { + filter_scores(sorted_indices, next_token_scores, parameters, j + offset + 1); } } } } -template -void TopPSamplingCpu::cumulate_and_filter(gsl::span& next_token_scores, - gsl::span& cumulative_probs, - std::vector& sorted_indices) { - for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { - size_t offset = i * parameters_->vocab_size; - if (cumulative_probs[offset] <= 1 - parameters_->top_p) { - filter_scores(sorted_indices, next_token_scores, offset); +template +void cumulate_and_filter(gsl::span& next_token_scores, + gsl::span& cumulative_probs, + const transformers::IGenerationParameters* parameters, + std::vector& sorted_indices) { + for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { + size_t offset = i * parameters->vocab_size; + if (cumulative_probs[offset] <= 1 - parameters->top_p) { + filter_scores(sorted_indices, next_token_scores, parameters, offset); } - for (size_t j = 1; j < static_cast(parameters_->vocab_size - parameters_->min_tokens_to_keep); j++) { + for (size_t j = 1; j < static_cast(parameters->vocab_size - parameters->min_tokens_to_keep); j++) { cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; - if (cumulative_probs[j + offset] <= 1 - parameters_->top_p) { - filter_scores(sorted_indices, next_token_scores, j + offset); + if (cumulative_probs[j + offset] <= 1 - parameters->top_p) { + filter_scores(sorted_indices, next_token_scores, parameters, j + offset); } } } } -template -Status TopPSamplingCpu::Sample(gsl::span& next_token_scores) { - gsl::span& sorted_scores = sampling_state_->sorted_scores; +template +Status Sample(AllocatorPtr& allocator, + onnxruntime::concurrency::ThreadPool* thread_pool, + gsl::span& next_token_scores, + transformers::ISamplingState* sampling_state, + transformers::IGreedySearchState* greedy_state, + const transformers::IGenerationParameters* parameters, + const transformers::IConsoleDumper* dumper) { + ORT_UNUSED_PARAMETER(dumper); + + gsl::span& sorted_scores = sampling_state->sorted_scores; memcpy(sorted_scores.data(), next_token_scores.data(), next_token_scores.size_bytes()); - std::vector sorted_indices(parameters_->batch_size * parameters_->vocab_size); + std::vector sorted_indices(parameters->batch_size * parameters->vocab_size); - Predicator predicator; + std::function predicator; + if (parameters->custom_sampling) { + predicator = std::greater(); + } else { + predicator = std::less(); + } // TODO: This could be optimized with allocated buffer and handwritten sort algorithm - for (size_t i = 0; i < static_cast(parameters_->batch_size); i++) { - auto indices_begin = sorted_indices.begin() + i * parameters_->vocab_size; - auto indices_end = sorted_indices.begin() + (i + 1) * parameters_->vocab_size; + for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { + auto indices_begin = sorted_indices.begin() + i * parameters->vocab_size; + auto indices_end = sorted_indices.begin() + (i + 1) * parameters->vocab_size; std::iota(indices_begin, indices_end, 0); std::sort(indices_begin, indices_end, [&next_token_scores, &predicator](size_t i1, size_t i2) { return !predicator(next_token_scores[i1], next_token_scores[i2]); }); - std::sort(sorted_scores.begin() + i * parameters_->vocab_size, - sorted_scores.begin() + (i + 1) * parameters_->vocab_size, + std::sort(sorted_scores.begin() + i * parameters->vocab_size, + sorted_scores.begin() + (i + 1) * parameters->vocab_size, predicator); } - gsl::span& cumulative_probs = sampling_state_->cumulative_probs; + gsl::span& cumulative_probs = sampling_state->cumulative_probs; - ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters_->batch_size, - parameters_->vocab_size, + ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters->batch_size, + parameters->vocab_size, sorted_scores.data(), cumulative_probs.data(), false, - thread_pool_)); + thread_pool)); - if (std::is_same>::value) { - cumulate_and_filter_custom(next_token_scores, cumulative_probs, sorted_indices); + if (parameters->custom_sampling) { + cumulate_and_filter_custom(next_token_scores, cumulative_probs, parameters, sorted_indices); } else { - cumulate_and_filter(next_token_scores, cumulative_probs, sorted_indices); + cumulate_and_filter(next_token_scores, cumulative_probs, parameters, sorted_indices); } - gsl::span& next_token_probs = sampling_state_->h_softmaxed_score; - ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters_->batch_size, - parameters_->vocab_size, + gsl::span& next_token_probs = sampling_state->h_softmaxed_score; + ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters->batch_size, + parameters->vocab_size, next_token_scores.data(), next_token_probs.data(), false, - thread_pool_)); + thread_pool)); // torch.multinomial() - int64_t next_token_probs_dims[] = {static_cast(parameters_->batch_size), parameters_->vocab_size}; + int64_t next_token_probs_dims[] = {static_cast(parameters->batch_size), parameters->vocab_size}; TensorShape next_token_probs_shape(&next_token_probs_dims[0], 2); auto element_type = DataTypeImpl::GetType(); OrtValue next_token_probs_value; Tensor::InitOrtValue(element_type, next_token_probs_shape, next_token_probs.data(), - allocator_->Info(), + allocator->Info(), next_token_probs_value); const Tensor& input = next_token_probs_value.Get(); - std::default_random_engine& generator = sampling_state_->generator; + std::default_random_engine& generator = sampling_state->generator; - int64_t sampled_idx_dims[] = {static_cast(parameters_->batch_size), 1}; + int64_t sampled_idx_dims[] = {static_cast(parameters->batch_size), 1}; TensorShape sampled_idx_shape(&sampled_idx_dims[0], 2); - gsl::span& next_token_idx = greedy_state_->next_tokens_cpu; + gsl::span& next_token_idx = greedy_state->next_tokens_cpu; OrtValue sampled_idx_ov; Tensor::InitOrtValue(DataTypeImpl::GetType(), sampled_idx_shape, next_token_idx.data(), - allocator_->Info(), + allocator->Info(), sampled_idx_ov); Tensor* sampled_idx = sampled_idx_ov.GetMutable(); // Copy the allocator because MultinomialComputeShared() uses move(allocator) - AllocatorPtr allocator_temp = allocator_; - ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocator_temp, + AllocatorPtr allocatortemp = allocator; + ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocatortemp, input, - parameters_->batch_size, - parameters_->vocab_size, + parameters->batch_size, + parameters->vocab_size, 1, generator, *sampled_idx)); // TODO: update presense_mask() #ifdef DEBUG_GENERATION - dumper_->Print("sampled_idx", *sampled_idx); + dumper->Print("sampled_idx", *sampled_idx); #endif return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc index 856cc730da478..b85854f32c04a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -15,8 +15,9 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { temperature = info.GetAttrOrDefault("temperature", 1.0f); top_p = info.GetAttrOrDefault("top_p", 0.0f); filter_value = info.GetAttrOrDefault("filter_value", -std::numeric_limits::infinity()); - min_tokens_to_keep = static_cast(info.GetAttrOrDefault("min_tokens_to_keep", 1)); + min_tokens_to_keep = static_cast(info.GetAttrOrDefault("min_tokens_to_keep", 0)); presence_penalty = info.GetAttrOrDefault("presence_penalty", 0.0f); + custom_sampling = static_cast(info.GetAttrOrDefault("custom", 0)); } } // namespace transformers diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index d0b73f11ae71f..890958fa6d6c7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -301,15 +301,16 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data, old_mask_data, mask_data, next_positions, batch_beam_size, current_length); } -// TODO: merge those kernels into one template -size_t GetTempStorageSize(const T *d_keys_in, - const int* d_values_in, - int* d_offsets, - int num_items, - int num_segments, - cudaStream_t stream) { - size_t temp_storage_bytes = 0; +void GetTempStorageSize(const T *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes) { + if (is_descending) { cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, d_keys_in, @@ -323,24 +324,42 @@ size_t GetTempStorageSize(const T *d_keys_in, 0, sizeof(T) * 8, stream); - return temp_storage_bytes; + } else { + cub::DeviceSegmentedRadixSort::SortPairs(nullptr, + temp_storage_bytes, + d_keys_in, + (T*)nullptr, + d_values_in, + (int*)nullptr, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); + } } -template size_t GetTempStorageSize( +template void GetTempStorageSize( const float *d_keys_in, const int* d_values_in, int* d_offsets, int num_items, int num_segments, - cudaStream_t stream); + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes); -template size_t GetTempStorageSize( +template void GetTempStorageSize( const half *d_keys_in, const int* d_values_in, int* d_offsets, int num_items, int num_segments, - cudaStream_t stream); + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes); // TODO: merge to one kernel __global__ void SetupParamsKernel(int* d_values_in, @@ -372,7 +391,7 @@ void LaunchSetupParamsKernel(int* d_values_in, } template -void LaunchSortPairsDescending(void *d_temp_storage, +void LaunchSortPairs(void *d_temp_storage, size_t temp_storage_bytes, const T *d_keys_in, T *d_keys_out, @@ -381,51 +400,106 @@ void LaunchSortPairsDescending(void *d_temp_storage, int num_items, int num_segments, int *d_offsets, - cudaStream_t stream) { - cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_keys_out, - d_values_in, - d_values_out, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + cudaStream_t stream, + bool is_descending) { + if (is_descending) { + cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); + } else { + cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); + } +} + +template void LaunchSortPairs(void *d_temp_storage, + size_t temp_storage_bytes, + const float *d_keys_in, + float *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream, + bool is_descending); + +template void LaunchSortPairs(void *d_temp_storage, + size_t temp_storage_bytes, + const half *d_keys_in, + half *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream, + bool is_descending); + +template +__global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in, + const int* d_sorted_indices, + T* d_logits_in_out, + float top_p_threshold, + float filter_value, + int batch_size, + int vocab_size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= batch_size * vocab_size) { + return; + } + + int vocab_idx = index % vocab_size; + int batch_id = index / vocab_size; + int start_index = batch_id * vocab_size; + + int count = vocab_idx; + float sum = 0.0f; + while (count >= 0) { + sum += d_sorted_logits_in[start_index]; + ++start_index; + --count; + } + + if (sum > top_p_threshold) { + // Shift the indices to the right by one according to the custom implementation. + int shifted_index = index + 1; + if (shifted_index % vocab_size != 0) { + int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index]; + d_logits_in_out[original_index] = (T)filter_value; + } + } } -template void LaunchSortPairsDescending(void *d_temp_storage, - size_t temp_storage_bytes, - const float *d_keys_in, - float *d_keys_out, - const int *d_values_in, - int *d_values_out, - int num_items, - int num_segments, - int *d_offsets, - cudaStream_t stream); - -template void LaunchSortPairsDescending(void *d_temp_storage, - size_t temp_storage_bytes, - const half *d_keys_in, - half *d_keys_out, - const int *d_values_in, - int *d_values_out, - int num_items, - int num_segments, - int *d_offsets, - cudaStream_t stream); - -// A trick here: cumuliative sum of the sorted logits is a temporarily variable in the kernel. template __global__ void FilterLogitsKernel(float* d_sorted_logits_in, const int* d_sorted_indices, T* d_logits_in_out, - float top_p, + float top_p_threshold, float filter_value, + int min_tokens_to_keep, int batch_size, int vocab_size) { int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -446,11 +520,9 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in, --count; } - if (sum > top_p) { - // Shift the indices to the right by one according to the Turing implementation. - int shifted_index = index + 1; - if (shifted_index % vocab_size != 0) { - int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index]; + if (sum <= top_p_threshold) { + if (index % vocab_size + min_tokens_to_keep < vocab_size) { + int original_index = batch_id * vocab_size + d_sorted_indices[index]; d_logits_in_out[original_index] = (T)filter_value; } } @@ -462,19 +534,32 @@ void LaunchFilterLogitsKernel(float* d_sorted_logits_in, T* d_logits_in_out, float top_p, float filter_value, + int min_tokens_to_keep, int batch_size, int vocab_size, - cudaStream_t stream) { + cudaStream_t stream, + bool is_descending) { int total_elements = batch_size * vocab_size; constexpr int blockSize = 256; const int gridSize = (total_elements + blockSize - 1) / blockSize; - FilterLogitsKernel<<>>(d_sorted_logits_in, - d_sorted_indices, - d_logits_in_out, - top_p, - filter_value, - batch_size, - vocab_size); + if (is_descending) { + FilterLogitsKernelCustom<<>>(d_sorted_logits_in, + d_sorted_indices, + d_logits_in_out, + top_p, + filter_value, + batch_size, + vocab_size); + } else { + FilterLogitsKernel<<>>(d_sorted_logits_in, + d_sorted_indices, + d_logits_in_out, + 1 - top_p, + filter_value, + min_tokens_to_keep, + batch_size, + vocab_size); + } } template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, @@ -482,32 +567,34 @@ template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, float* d_logits_in_out, float top_p, float filter_value, + int min_tokens_to_keep, int batch_size, int vocab_size, - cudaStream_t stream); + cudaStream_t stream, + bool is_descending); template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, const int* d_sorted_indices, half* d_logits_in_out, float top_p, float filter_value, + int min_tokens_to_keep, int batch_size, int vocab_size, - cudaStream_t stream); + cudaStream_t stream, + bool is_descending); // Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu template -__global__ void sampleMultinomialOnce( - int64_t* dest, - int distributions, - int categories, - scalar_t* sampled, - scalar_t* dist, - int stride_dist, // dist->stride(0) - int stride_categories, // dist->stride(1) - int* d_presence_mask -) { +__global__ void sampleMultinomialOnce(int64_t* dest, + int distributions, + int categories, + scalar_t* sampled, + scalar_t* dist, + int stride_dist, // dist->stride(0) + int stride_categories, // dist->stride(1) + int* d_presence_mask) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -576,8 +663,7 @@ __global__ void sampleMultinomialOnce( } __syncthreads(); } - // Each thread will check to see if the sample falls in its - // bucket + // Each thread will check to see if the sample falls in its bucket scalar_t curBucket = static_cast(smem[threadIdx.x] + prevHighProb); scalar_t prevBucket = static_cast( diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 193f25b7908c7..01a49b8e2ce9d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -60,12 +60,14 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data, cudaStream_t stream); template -size_t GetTempStorageSize(const T *d_keys_in, - const int* d_values_in, - int* d_offsets, - int num_items, - int num_segments, - cudaStream_t stream); +void GetTempStorageSize(const T *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes); void LaunchSetupParamsKernel(int* d_values_in, int* d_offsets, @@ -74,16 +76,17 @@ void LaunchSetupParamsKernel(int* d_values_in, cudaStream_t stream); template -void LaunchSortPairsDescending(void *d_temp_storage, - size_t temp_storage_bytes, - const T *d_keys_in, - T *d_keys_out, - const int *d_values_in, - int *d_values_out, - int num_items, - int num_segments, - int *d_offsets, - cudaStream_t stream); +void LaunchSortPairs(void *d_temp_storage, + size_t temp_storage_bytes, + const T *d_keys_in, + T *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream, + bool is_descending); template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, @@ -91,9 +94,11 @@ void LaunchFilterLogitsKernel(float* d_sorted_logits_in, T* d_logits_in_out, float top_p, float filter_value, + int min_tokens_to_keep, int batch_size, int vocab_size, - cudaStream_t stream); + cudaStream_t stream, + bool is_descending); void TorchMultinomialKernelLauncher(float* d_input, float* d_sampled, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 884f80ca3a1e4..de428131371d0 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -564,14 +564,14 @@ Status GreedySearchProcessLogits( ORT_UNUSED_PARAMETER(output_scores); if (do_sampling) { - // TODO: Move the ctor out of the function. - SamplingCudaHelper::TopPSamplingCuda top_p_sampler(allocator, - cuda_stream, - sampling_state, - greedy_state, - parameters, - dumper); - ORT_RETURN_IF_ERROR(top_p_sampler.Sample(step, next_token_scores)); + ORT_RETURN_IF_ERROR(SamplingCudaHelper::Sample(allocator, + cuda_stream, + next_token_scores, + sampling_state, + greedy_state, + parameters, + step, + dumper)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index 52939efc5d866..092612e674967 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -16,180 +16,167 @@ namespace contrib { namespace SamplingCudaHelper { template -class TopPSamplingCuda{ - public: - TopPSamplingCuda(AllocatorPtr& allocator, - cudaStream_t cuda_stream, - transformers::ISamplingState* sampling_state, - transformers::IGreedySearchState* greedy_state, - const transformers::IGenerationParameters* parameters, - const transformers::IConsoleDumper* dumper): - allocator_(allocator), - cuda_stream_(cuda_stream), - sampling_state_(sampling_state), - greedy_state_(greedy_state), - parameters_(parameters), - dumper_(dumper) {} - - Status Sample(int step, gsl::span& next_token_scores); - - private: - AllocatorPtr& allocator_; - cudaStream_t cuda_stream_; - transformers::ISamplingState* sampling_state_; - transformers::IGreedySearchState* greedy_state_; - const transformers::IGenerationParameters* parameters_; - const transformers::IConsoleDumper* dumper_; -}; - -template -Status TopPSamplingCuda::Sample(int step, gsl::span& next_token_scores) { +Status Sample(AllocatorPtr& allocator, + cudaStream_t cuda_stream, + gsl::span& next_token_scores, + transformers::ISamplingState* sampling_state, + transformers::IGreedySearchState* greedy_state, + const transformers::IGenerationParameters* parameters, + int step, + const transformers::IConsoleDumper* dumper) { + ORT_UNUSED_PARAMETER(dumper); typedef typename ToCudaType::MappedType CudaT; - gsl::span& d_index_in = sampling_state_->d_index_in; - gsl::span& d_offset = sampling_state_->d_offset; + gsl::span& d_index_in = sampling_state->d_index_in; + gsl::span& d_offset = sampling_state->d_offset; + + BufferUniquePtr& storage_buffer = sampling_state->storage_buffer; + size_t& temp_storage_bytes = sampling_state->temp_storage_bytes; - BufferUniquePtr& storage_buffer = sampling_state_->storage_buffer; - size_t& temp_storage_bytes = sampling_state_->temp_storage_bytes; + bool is_descending = parameters->custom_sampling; if (step == 1) { - temp_storage_bytes = cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), - d_index_in.data(), - d_offset.data(), - parameters_->batch_size * parameters_->vocab_size, - parameters_->batch_size, - cuda_stream_); + cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), + d_index_in.data(), + d_offset.data(), + parameters->batch_size * parameters->vocab_size, + parameters->batch_size, + cuda_stream, + is_descending, + temp_storage_bytes); cuda::LaunchSetupParamsKernel(d_index_in.data(), d_offset.data(), - parameters_->batch_size, - parameters_->vocab_size, - cuda_stream_); + parameters->batch_size, + parameters->vocab_size, + cuda_stream); #ifdef DEBUG_GENERATION - dumper_->Print("d_offset_buffer", d_offset.data(), parameters_->batch_size + 1, 1); + dumper->Print("d_offset_buffer", d_offset.data(), parameters->batch_size + 1, 1); #endif - void* temp_storage = allocator_->Alloc(sampling_state_->temp_storage_bytes); - BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator_)); + void* temp_storage = allocator->Alloc(sampling_state->temp_storage_bytes); + BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); storage_buffer = std::move(temp_storage_buffer); } - gsl::span& d_sorted_score = sampling_state_->d_sorted_score; - gsl::span& d_index_out = sampling_state_->d_index_out; + gsl::span& d_sorted_score = sampling_state->d_sorted_score; + gsl::span& d_index_out = sampling_state->d_index_out; #ifdef DEBUG_GENERATION - dumper_->Print("temp_storage_bytes", sampling_state_->temp_storage_bytes, true); + dumper->Print("temp_storage_bytes", sampling_state->temp_storage_bytes, true); #endif - cuda::LaunchSortPairsDescending(storage_buffer.get(), - temp_storage_bytes, - reinterpret_cast(next_token_scores.data()), - reinterpret_cast(d_sorted_score.data()), - d_index_in.data(), - d_index_out.data(), - parameters_->batch_size * parameters_->vocab_size, - parameters_->batch_size, - d_offset.data(), - cuda_stream_); + cuda::LaunchSortPairs(storage_buffer.get(), + temp_storage_bytes, + reinterpret_cast(next_token_scores.data()), + reinterpret_cast(d_sorted_score.data()), + d_index_in.data(), + d_index_out.data(), + parameters->batch_size * parameters->vocab_size, + parameters->batch_size, + d_offset.data(), + cuda_stream, + is_descending); #ifdef DEBUG_GENERATION - dumper_->Print("d_sorted_score_buffer", + dumper->Print("d_sorted_score_buffer", reinterpret_cast(d_sorted_score.data()), - parameters_->batch_size, - parameters_->vocab_size); - dumper_->Print("d_index_buffer_in", d_index_in.data(), parameters_->batch_size, parameters_->vocab_size); - dumper_->Print("d_index_buffer_out", d_index_out.data(), parameters_->batch_size, parameters_->vocab_size); + parameters->batch_size, + parameters->vocab_size); + dumper->Print("d_index_buffer_in", d_index_in.data(), parameters->batch_size, parameters->vocab_size); + dumper->Print("d_index_buffer_out", d_index_out.data(), parameters->batch_size, parameters->vocab_size); #endif - gsl::span& d_sorted_softmaxed_score = sampling_state_->d_sorted_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream_, + gsl::span& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; + dispatch_blockwise_softmax_forward(cuda_stream, d_sorted_softmaxed_score.data(), reinterpret_cast(d_sorted_score.data()), - parameters_->vocab_size, - parameters_->vocab_size, - parameters_->vocab_size, - parameters_->batch_size); + parameters->vocab_size, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size); #ifdef DEBUG_GENERATION - dumper_->Print("d_sorted_softmaxed_score_buffer", + dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), - parameters_->batch_size, - parameters_->vocab_size); + parameters->batch_size, + parameters->vocab_size); #endif cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), d_index_out.data(), reinterpret_cast(next_token_scores.data()), - parameters_->top_p, - parameters_->filter_value, - parameters_->batch_size, - parameters_->vocab_size, - cuda_stream_); + parameters->top_p, + parameters->filter_value, + parameters->min_tokens_to_keep, + parameters->batch_size, + parameters->vocab_size, + cuda_stream, + is_descending); #ifdef DEBUG_GENERATION - dumper_->Print("next_token_scores after filtering logits", + dumper->Print("next_token_scores after filtering logits", reinterpret_cast(next_token_scores.data()), - parameters_->batch_size, - parameters_->vocab_size); + parameters->batch_size, + parameters->vocab_size); #endif - // TODO(wy): Can we only do softmax at the very beginning and sort the softmaxed scores. - gsl::span& d_softmaxed_score = sampling_state_->d_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream_, + gsl::span& d_softmaxed_score = sampling_state->d_softmaxed_score; + dispatch_blockwise_softmax_forward(cuda_stream, d_softmaxed_score.data(), reinterpret_cast(next_token_scores.data()), - parameters_->vocab_size, - parameters_->vocab_size, - parameters_->vocab_size, - parameters_->batch_size); + parameters->vocab_size, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size); #ifdef DEBUG_GENERATION - dumper_->Print("d_softmaxed_score_buffer", + dumper->Print("d_softmaxed_score_buffer", d_softmaxed_score.data(), - parameters_->batch_size, - parameters_->vocab_size); + parameters->batch_size, + parameters->vocab_size); #endif // Multinomial sampling - gsl::span& d_sampled = sampling_state_->d_sampled; - gsl::span& h_sampled_all = sampling_state_->h_sampled_all; + gsl::span& d_sampled = sampling_state->d_sampled; + gsl::span& h_sampled_all = sampling_state->h_sampled_all; CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(), - h_sampled_all.data() + (step - 1) * parameters_->batch_size, - sizeof(float) * parameters_->batch_size, + h_sampled_all.data() + (step - 1) * parameters->batch_size, + sizeof(float) * parameters->batch_size, cudaMemcpyHostToDevice, - cuda_stream_)); + cuda_stream)); #ifdef DEBUG_GENERATION - dumper_->Print("d_sampled", d_sampled.data(), parameters_->batch_size, 1); + dumper->Print("d_sampled", d_sampled.data(), parameters->batch_size, 1); #endif - gsl::span& d_indices = sampling_state_->d_indices; - gsl::span& presence_mask = sampling_state_->d_presence_mask; + gsl::span& d_indices = sampling_state->d_indices; + gsl::span& presence_mask = sampling_state->d_presence_mask; cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(), d_sampled.data(), d_indices.data(), - parameters_->batch_size, - parameters_->vocab_size, + parameters->batch_size, + parameters->vocab_size, presence_mask.data(), - cuda_stream_); + cuda_stream); #ifdef DEBUG_GENERATION - dumper_->Print("d_indices", d_indices.data(), parameters_->batch_size, 1); + dumper->Print("d_indices", d_indices.data(), parameters->batch_size, 1); #endif - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state_->next_tokens_cpu.data(), - sampling_state_->d_indices.data(), - greedy_state_->next_tokens_cpu.size_bytes(), + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), + sampling_state->d_indices.data(), + greedy_state->next_tokens_cpu.size_bytes(), cudaMemcpyDeviceToHost, - cuda_stream_)); + cuda_stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state_->h_softmaxed_score.data(), - sampling_state_->d_softmaxed_score.data(), - sampling_state_->h_softmaxed_score.size_bytes(), + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state->h_softmaxed_score.data(), + sampling_state->d_softmaxed_score.data(), + sampling_state->h_softmaxed_score.size_bytes(), cudaMemcpyDeviceToHost, - cuda_stream_)); + cuda_stream)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream_)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); return Status::OK(); } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 09f2f7e69e51d..0e08e8d629b17 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1123,8 +1123,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Attr("temperature", "temperature for sampling", AttributeProto::FLOAT, 1.0f) .Attr("top_p", "top_p for sampling", AttributeProto::FLOAT, 0.0f) .Attr("filter_value", "filter value for top_p", AttributeProto::FLOAT, -1e20f) - .Attr("min_tokens_to_keep", "min_tokens_to_keep", AttributeProto::INT, static_cast(1)) + .Attr("min_tokens_to_keep", "min_tokens_to_keep", AttributeProto::INT, static_cast(0)) .Attr("presence_penalty", "presence penalty for sampling", AttributeProto::FLOAT, 0.0f) + .Attr("custom", "absence penalty for sampling", AttributeProto::INT, static_cast(0)) .Attr("model_type", "model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast(0)) .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) From dc32745144c5bb559ee1e518c04174c42161fbbb Mon Sep 17 00:00:00 2001 From: wangyems Date: Mon, 12 Dec 2022 22:16:20 +0000 Subject: [PATCH 37/51] refactor --- .../cpu/transformers/greedy_search_impl_gpt.h | 2 +- .../cpu/transformers/sampling_parameters.cc | 1 + .../transformers/convert_generation_temp.py | 1590 ----------------- 3 files changed, 2 insertions(+), 1591 deletions(-) delete mode 100644 onnxruntime/python/tools/transformers/convert_generation_temp.py diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 3b4bab4bc02c3..d4c8c1590fe1c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -200,7 +200,6 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds ExecutionMode::ORT_SEQUENTIAL, this->context_.GetTerminateFlag(), this->context_.Logger()); - ORT_RETURN_IF_ERROR(status); const OrtValue& logits = fetches[0]; @@ -211,6 +210,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds sampling_state, iteration_counter, parameters->eos_token_id)); + // When all batches are finished, stop earlier to avoid wasting computation. gsl::span& eos_meet = greedy_state.eos_meet; size_t batch_id = 0; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc index b85854f32c04a..537b0d7538ce7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -18,6 +18,7 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { min_tokens_to_keep = static_cast(info.GetAttrOrDefault("min_tokens_to_keep", 0)); presence_penalty = info.GetAttrOrDefault("presence_penalty", 0.0f); custom_sampling = static_cast(info.GetAttrOrDefault("custom", 0)); + vocab_size = static_cast(info.GetAttrOrDefault("vocab_size", -1)); } } // namespace transformers diff --git a/onnxruntime/python/tools/transformers/convert_generation_temp.py b/onnxruntime/python/tools/transformers/convert_generation_temp.py deleted file mode 100644 index a22e3f67e2e93..0000000000000 --- a/onnxruntime/python/tools/transformers/convert_generation_temp.py +++ /dev/null @@ -1,1590 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# ------------------------------------------------------------------------- -""" -This converts GPT2 or T5 model to onnx with beam search operator. - -Example 1: convert gpt2 model with beam search: - python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx - -Example 2: convert T5 model with beam search in two steps: - cd ./models/t5 - python convert_to_onnx.py -m t5-small - cd ../.. - python convert_generation.py -m t5-small --model_type t5 \ - --decoder_onnx ./models/t5/onnx_models/t5-small_decoder.onnx \ - --encoder_decoder_init_onnx ./models/t5/onnx_models/t5-small_encoder_decoder_init.onnx \ - --output ./models/t5/onnx_models/t5_small_beam_search.onnx - -Example 3: convert T5 model with beam search. All in one step: - python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx - -Example 4: convert MT5 model with external data file like mt5-base-beamsearch.onnx.data in below example. - python convert_generation.py -m google/mt5-base --model_type mt5 --output mt5-base-beamsearch.onnx -e - -Example 5: convert gpt2 model with sampling: - python convert_generation_temp.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 -""" - -import argparse -import logging -import os -import sys -import time -from enum import Enum -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import numpy as np -import onnx -import torch -from benchmark_helper import Precision -from onnx import GraphProto, ModelProto, TensorProto -from transformers import ( - GPT2Config, - GPT2LMHeadModel, - GPT2Tokenizer, - MT5Config, - MT5ForConditionalGeneration, - T5Config, - T5ForConditionalGeneration, - T5Tokenizer, -) - -from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers - -sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2")) -from gpt2_helper import PRETRAINED_GPT2_MODELS # noqa: E402 -from models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx # noqa: E402 - -sys.path.append(os.path.join(os.path.dirname(__file__), "models", "t5")) -from benchmark_helper import setup_logger -from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models # noqa: E402 -from models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS # noqa: E402 -from onnx_model import OnnxModel - -logger = logging.getLogger("") - - -class GenerationType(Enum): - BEAMSEARCH = "beam_search" - GREEDYSEARCH = "greedy_search" - - def __str__(self): - return self.value - - -def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: - """Parse arguments - - Args: - argv (Optional[List[str]], optional): _description_. Defaults to None. - - Returns: - argparse.Namespace: Parsed arguments. - """ - parser = argparse.ArgumentParser() - - input_group = parser.add_argument_group("Input options") - - input_group.add_argument( - "-m", - "--model_name_or_path", - required=True, - type=str, - help="Pytorch model checkpoint path, or pretrained model name in the list: " - + ", ".join(PRETRAINED_GPT2_MODELS + PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS), - ) - - input_group.add_argument( - "--model_type", - required=False, - type=str, - default="gpt2", - choices=["gpt2", "t5", "mt5"], - help="Model type (default is gpt2) in the list: " + ", ".join(["gpt2", "t5", "mt5"]), - ) - - input_group.add_argument( - "--cache_dir", - required=False, - type=str, - default=os.path.join(".", "cache_models"), - help="Directory to cache pre-trained models", - ) - - input_group.add_argument( - "--decoder_onnx", - required=False, - type=str, - default="", - help="Path of onnx model for decoder. Specify it when you have exported the model.", - ) - - input_group.add_argument( - "--encoder_decoder_init_onnx", - required=False, - type=str, - default="", - help="Path of ONNX model for encoder and decoder initialization. Specify it when you have exported the model.", - ) - - parser.add_argument( - "--verbose", - required=False, - action="store_true", - help="Print more information", - ) - parser.set_defaults(verbose=False) - - output_group = parser.add_argument_group("Output options") - - output_group.add_argument( - "--output", - required=True, - type=str, - help="Output path for onnx model with beam search.", - ) - - output_group.add_argument( - "-p", - "--precision", - required=False, - type=Precision, - default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16], - help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision", - ) - - output_group.add_argument( - "-e", - "--use_external_data_format", - required=False, - action="store_true", - help="save external data for model > 2G", - ) - output_group.set_defaults(use_external_data_format=False) - - output_group.add_argument( - "-s", "--run_shape_inference", required=False, action="store_true", help="run shape inference" - ) - output_group.set_defaults(run_shape_inference=False) - - output_group.add_argument( - "-i", - "--disable_shared_initializers", - required=False, - action="store_true", - help="do not share initializers in encoder and decoder. It will increase memory usage of t5/mt5 models.", - ) - output_group.set_defaults(disable_shared_initializers=False) - - model_group = parser.add_argument_group("Beam search parameters that stored in the output model") - - model_group.add_argument( - "--output_sequences_scores", - required=False, - action="store_true", - help="output sequences scores", - ) - model_group.set_defaults(output_sequences_scores=False) - - model_group.add_argument( - "--output_token_scores", - required=False, - action="store_true", - help="output token scores", - ) - model_group.set_defaults(output_token_scores=False) - - model_group.add_argument("--early_stopping", required=False, action="store_true") - model_group.set_defaults(early_stopping=False) - - model_group.add_argument( - "--no_repeat_ngram_size", - type=int, - required=False, - default=0, - help="No repeat ngram size", - ) - - model_group.add_argument( - "--vocab_mask", - required=False, - action="store_true", - help="Enable vocab_mask. This mask applies only to every generated token to filter some bad words.", - ) - model_group.set_defaults(vocab_mask=False) - - model_group.add_argument( - "--prefix_vocab_mask", - required=False, - action="store_true", - help="Enable prefix_vocab_mask. This mask can be used to filter bad words in the first generated token only", - ) - model_group.set_defaults(prefix_vocab_mask=False) - - model_group.add_argument( - "--custom_attention_mask", - required=False, - action="store_true", - help="Enable custom_attention_mask. This mask can be used to replace default encoder attention mask", - ) - model_group.set_defaults(custom_attention_mask=False) - - beam_parameters_group = parser.add_argument_group( - "Beam search parameters not stored in the output model, for testing parity and performance" - ) - - beam_parameters_group.add_argument("--min_length", type=int, required=False, default=1, help="Min sequence length") - - beam_parameters_group.add_argument("--max_length", type=int, required=False, default=50, help="Max sequence length") - - beam_parameters_group.add_argument("--num_beams", type=int, required=False, default=4, help="Beam size") - - beam_parameters_group.add_argument( - "--num_return_sequences", - type=int, - required=False, - default=1, - help="Number of return sequence <= num_beams", - ) - - beam_parameters_group.add_argument( - "--length_penalty", - type=float, - required=False, - default=1, - help="Positive. >1 to penalize and <1 to encourage short sentence.", - ) - - beam_parameters_group.add_argument( - "--repetition_penalty", - type=float, - required=False, - default=1, - help="Positive. >1 to penalize and <1 to encourage.", - ) - - beam_parameters_group.add_argument( - "--vocab_size", - type=int, - required=False, - default=-1, - help="Vocab_size of the underlying model used to decide the shape of vocab mask", - ) - - test_group = parser.add_argument_group("Other options for testing parity and performance") - - test_group.add_argument( - "--use_gpu", required=False, action="store_true", help="use GPU for inference. Required for fp16." - ) - test_group.set_defaults(use_gpu=False) - - test_group.add_argument( - "--disable_parity", - required=False, - action="store_true", - help="do not run parity test", - ) - test_group.set_defaults(disable_parity=False) - - test_group.add_argument( - "--torch_performance", - required=False, - action="store_true", - help="test PyTorch performance", - ) - test_group.set_defaults(torch_performance=False) - - test_group.add_argument( - "--total_runs", - required=False, - type=int, - default=1, - help="Number of times of inference for latency measurement", - ) - - test_group.add_argument( - "--save_test_data", - required=False, - action="store_true", - help="save test data for onnxruntimer_perf_test tool", - ) - test_group.set_defaults(save_test_data=False) - - args = parser.parse_args(argv) - - return args - - -def gpt2_to_onnx(args: argparse.Namespace): - """Convert GPT-2 model to onnx - - Args: - args (argparse.Namespace): arguments parsed from command line - """ - model_name = args.model_name_or_path - - arguments = [ - "--model_name_or_path", - model_name, - "--output", - args.decoder_onnx, - "--optimize_onnx", - "--precision", - "fp32" if args.precision == Precision.FLOAT32 else "fp16", - "--test_runs", - "1", - "--test_cases", - "10", - "--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, position_ids and attention_mask - "--overwrite", # Overwrite onnx file if existed - ] - if args.use_gpu: - arguments.append("--use_gpu") - if args.use_external_data_format: - arguments.append("--use_external_data_format") - - if args.precision == Precision.FLOAT16: - assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu" - # TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision') - # Need change cuda kernel to support a combination of fp32 logits and fp16 past state. - # Currently logits and past state shall be same data type. - arguments.extend(["--op_block_list", "Add", "LayerNormalization", "FastGelu"]) - - if args.verbose: - logger.info(f"arguments for convert_to_onnx:{arguments}") - - convert_gpt2_to_onnx(argv=arguments) - - -def t5_to_onnx(args: argparse.Namespace): - """Convert T5 model to onnx - - Args: - args (argparse.Namespace): arguments parsed from command line - """ - paths = export_t5_onnx_models( - args.model_name_or_path, - args.cache_dir, - Path(args.output).parent, - use_gpu=args.use_gpu, - use_external_data_format=args.use_external_data_format, - optimize_onnx=False, - precision=args.precision, - verbose=False, - use_decoder_start_token=False, - merge_encoder_and_decoder_init=True, - overwrite=True, - disable_auto_mixed_precision=False, - use_int32_inputs=True, - model_type=args.model_type, - ) - - logger.debug(f"onnx model for encoder: {paths[0]}") - logger.debug(f"onnx model for decoder: {paths[1]}") - args.encoder_decoder_init_onnx = paths[0] - args.decoder_onnx = paths[1] - - -def shape_inference(onnx_path: str, use_external_data_format: bool = True): - """Shape inference on an onnx file, which will be overwritten. - - Args: - onnx_path (str): Path of onnx model - use_external_data_format(bool): output tensors to external data or not. - """ - # Run symbolic shape inference to walk around ORT shape inference issue for subgraph. - from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - - model = onnx.load_model(onnx_path, load_external_data=True) - out = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=False) - if out: - OnnxModel.save(out, onnx_path, save_as_external_data=use_external_data_format) - else: - logger.warning("Failed to run symbolic shape inference on the model.") - - -def create_ort_session(model_path: str, use_gpu: bool) -> InferenceSession: - """Create OnnxRuntime session. - - Args: - model_path (str): onnx model path - use_gpu (bool): use GPU or not - - Raises: - RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified. - - Returns: - onnxruntime.InferenceSession: The created session. - """ - sess_options = SessionOptions() - sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL - execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"] - if use_gpu: - if "CUDAExecutionProvider" not in get_available_providers(): - raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!") - else: - logger.info("use CUDAExecutionProvider") - - ort_session = InferenceSession(model_path, sess_options, providers=execution_providers) - return ort_session - - -def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision): - """Verify GPT-2 subgraph - - Args: - graph (onnx.GraphProto): onnx graph of GPT-2 - precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. - - Raises: - ValueError: Number of inputs not expected. - ValueError: Input name is not expected. - ValueError: Input data type is not expected. - ValueError: Number of outputs not expected. - ValueError: Output name is not expected. - ValueError: Output data type is not expected. - """ - is_float16 = Precision.FLOAT16 == precision - - input_count = len(graph.input) - layer_count = input_count - 3 - assert layer_count >= 1 - - expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)] - if len(graph.input) != len(expected_inputs): - raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") - - for i, expected_input in enumerate(expected_inputs): - if graph.input[i].name != expected_input: - raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") - - expected_type = TensorProto.INT32 - if i >= 3: - expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT - - input_type = graph.input[i].type.tensor_type.elem_type - if input_type != expected_type: - raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") - logger.info("Verifying GPT-2 graph inputs: name and data type are good.") - - expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)] - if len(graph.output) != len(expected_outputs): - raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") - - for i, expected_output in enumerate(expected_outputs): - if graph.output[i].name != expected_output: - raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") - - expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT - output_type = graph.output[i].type.tensor_type.elem_type - if output_type != expected_type: - raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}") - logger.info("Verifying GPT-2 graph outputs: name and data type are good.") - - # TODO(tianleiwu): verify shapes of inputs and outputs. - return - - -def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision): - """Verify T5 decoder subgraph - - Args: - graph (onnx.GraphProto): onnx graph of T5 decoder - precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. - - Raises: - ValueError: Number of inputs not expected. - ValueError: Input name is not expected. - ValueError: Input data type is not expected. - ValueError: Number of outputs not expected. - ValueError: Output name is not expected. - ValueError: Output data type is not expected. - """ - is_float16 = Precision.FLOAT16 == precision - float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT - - input_count = len(graph.input) - layer_count = (input_count - 3) // 4 - assert layer_count >= 1 - - # Expect inputs: - # input_ids: int32 (B, 1) - # encoder_attention_mask: int32 (B, encode_sequence_length) - # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) - - # past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) - # past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) - # ... (for each self attention layer) - - # past_key_cross_0: (B, num_heads, encode_sequence_length, head_size) - # past_value_cross_0: (B, num_heads, encode_sequence_length, head_size) - # ... (for each cross attention layer) - - # TODO: encoder_hidden_states is optional - expected_inputs = ["input_ids", "encoder_attention_mask", "encoder_hidden_states"] - for i in range(layer_count): - expected_inputs.append(f"past_key_self_{i}") - expected_inputs.append(f"past_value_self_{i}") - for i in range(layer_count): - expected_inputs.append(f"past_key_cross_{i}") - expected_inputs.append(f"past_value_cross_{i}") - - if len(graph.input) != len(expected_inputs): - raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") - - for i, expected_input in enumerate(expected_inputs): - if graph.input[i].name != expected_input: - raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") - - expected_type = TensorProto.INT32 if i < 2 else float_type - input_type = graph.input[i].type.tensor_type.elem_type - if input_type != expected_type: - raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") - - # Expect outputs: - # logits: (B, 1, vocab_size) - # present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) - # present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) - # ... (for each self attention layer) - expected_outputs = ["logits"] - for i in range(layer_count): - expected_outputs.append(f"present_key_self_{i}") - expected_outputs.append(f"present_value_self_{i}") - - if len(graph.output) != len(expected_outputs): - raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") - - for i, expected_output in enumerate(expected_outputs): - if graph.output[i].name != expected_output: - raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") - output_type = graph.output[i].type.tensor_type.elem_type - if output_type != float_type: - raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}") - - -def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision): - """Verify T5 decoder subgraph - - Args: - graph (onnx.GraphProto): onnx graph of T5 decoder - precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. - - Raises: - ValueError: Number of inputs not expected. - ValueError: Input name is not expected. - ValueError: Input data type is not expected. - ValueError: Number of outputs not expected. - ValueError: Output name is not expected. - ValueError: Output data type is not expected. - """ - is_float16 = Precision.FLOAT16 == precision - layer_count = (len(graph.output) - 2) // 4 - assert layer_count >= 1 - - # Expect 3 inputs: - # encoder_input_ids: int32 (B, encode_sequence_length) - # encoder_attention_mask: int32 (B, encode_sequence_length) - # decoder_input_ids: int32 (B, 1) - expected_inputs = ["encoder_input_ids", "encoder_attention_mask", "decoder_input_ids"] - if len(graph.input) != len(expected_inputs): - raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") - - for i, expected_input in enumerate(expected_inputs): - if graph.input[i].name != expected_input: - raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") - - expected_type = TensorProto.INT32 - input_type = graph.input[i].type.tensor_type.elem_type - if input_type != expected_type: - raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") - - # Expected outputs: - # logits: (B, 1, vocab_size) - # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) - # present_key_self_0: (B, num_heads, 1, head_size) - # present_value_self_0: (B, num_heads, 1, head_size) - # ... (for each self attention layer) - # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) - # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) - # ... (for each cross attention layer) - expected_outputs = ["logits", "encoder_hidden_states"] - for i in range(layer_count): - expected_outputs.append(f"present_key_self_{i}") - expected_outputs.append(f"present_value_self_{i}") - for i in range(layer_count): - expected_outputs.append(f"present_key_cross_{i}") - expected_outputs.append(f"present_value_cross_{i}") - - if len(graph.output) != len(expected_outputs): - raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") - - for i, expected_output in enumerate(expected_outputs): - if graph.output[i].name != expected_output: - raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") - - expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT - output_type = graph.output[i].type.tensor_type.elem_type - if output_type != expected_type: - raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}") - - logger.info("T5 encoder graph verified: name and data type of inputs and outputs are good.") - - -def remove_shared_initializers( - graph1: GraphProto, - graph2: GraphProto, - shared_prefix: str = "shared_", - min_elements: int = 1024, -): - """Remove initializers with same value from two graphs. - - Args: - graph1 (GraphProto): the first graph to process - graph2 (GraphProto): the second graph to process - shared_prefix (str): add prefix to the shared initializers among two graphs - min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024. - """ - - mapping_initializers_1 = {} - mapping_initializers_2 = {} - shared_initializers_1 = [] - shared_initializers_2 = [] - shared_initializers_names = [] - - for initializer1 in graph1.initializer: - if not (initializer1.dims and sum(initializer1.dims) >= min_elements): - continue - - for initializer2 in graph2.initializer: - if not (initializer2.dims and sum(initializer2.dims) >= min_elements): - continue - - if OnnxModel.has_same_value(initializer1, initializer2): - mapping_initializers_1[initializer1.name] = shared_prefix + initializer2.name - shared_initializers_1.append(initializer1) - - if initializer2.name not in mapping_initializers_2: - shared_name = shared_prefix + initializer2.name - mapping_initializers_2[initializer2.name] = shared_name - shared_initializers_2.append(initializer2) - shared_initializers_names.append(shared_name) - break - - logger.debug(f"shared initializers:{shared_initializers_names}") - - # Make sure new name does not exist in graph 1 - for node in graph1.node: - for j in range(len(node.input)): - if node.input[j] in shared_initializers_names: - raise RuntimeError(f"name is found in graph 1: {node.input[j]}") - - # Make sure new name does not exist in graph 2 - for node in graph2.node: - for j in range(len(node.input)): - if node.input[j] in shared_initializers_names: - raise RuntimeError(f"name is found in graph 2: {node.input[j]}") - - # Remove shared initializers from graph 2 - for initializer in shared_initializers_2: - graph2.initializer.remove(initializer) - - # Rename value info for old names in graph 2 - for value_info in graph2.value_info: - if value_info.name in mapping_initializers_2: - value_info.name = mapping_initializers_2[value_info.name] - - # Rename nodes inputs in graph 2: - for node in graph2.node: - for j in range(len(node.input)): - if node.input[j] in mapping_initializers_2: - new_name = mapping_initializers_2[node.input[j]] - logger.debug(f"graph 2 rename node {node.name} input {j} from {node.input[j]} to {new_name}") - node.input[j] = new_name - - # Remove shared initializers from graph 1 - for initializer in shared_initializers_1: - graph1.initializer.remove(initializer) - - # Rename value info for old names in graph 1 - for value_info in graph1.value_info: - if value_info.name in mapping_initializers_1: - value_info.name = mapping_initializers_1[value_info.name] - - # Rename nodes inputs in graph 1: - for node in graph1.node: - for j in range(len(node.input)): - if node.input[j] in mapping_initializers_1: - new_name = mapping_initializers_1[node.input[j]] - logger.debug(f"graph 1 rename node {node.name} input {j} from {node.input[j]} to {new_name}") - node.input[j] = new_name - - # Rename shared initializers in graph 2 - for initializer in shared_initializers_2: - initializer.name = mapping_initializers_2[initializer.name] - - for initializer in shared_initializers_2: - shape = onnx.numpy_helper.to_array(initializer).shape - value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape) - # Need add value_info for initializers moved to parent graph. Otherwise, ORT will fail. - graph1.value_info.append(value_info) - graph2.value_info.append(value_info) - - return shared_initializers_2 - - -def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto): - encoder = OnnxModel(encoder_model) - decoder = OnnxModel(decoder_model) - encoder.add_prefix_to_names("e_") - decoder.add_prefix_to_names("d_") - encoder.remove_duplicated_initializer() - decoder.remove_duplicated_initializer() - initializers = remove_shared_initializers(encoder.model.graph, decoder.model.graph, "s_") - return initializers - - -def move_initializers( - graph: GraphProto, - min_elements: int = 1024, -) -> List[TensorProto]: - """Remove initializers of a graph, when they have number of elements larger than a threshold. - - Args: - graph (GraphProto): the graph. - min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024. - - Returns: - List[TensorProto]: initializers that are removed from the graph. - """ - moved_initializers = [] - for tensor in graph.initializer: - if not (tensor.dims and sum(tensor.dims) >= min_elements): - continue - moved_initializers.append(tensor) - - for initializer in moved_initializers: - graph.initializer.remove(initializer) - - # Add type info, otherwise ORT will raise error: "input arg (*) does not have type information set by parent node." - for initializer in moved_initializers: - shape = onnx.numpy_helper.to_array(initializer).shape - value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape) - graph.value_info.append(value_info) - - return moved_initializers - - -def convert_generation_model(args: argparse.Namespace, generation_type: GenerationType = GenerationType.BEAMSEARCH): - """Convert model according to command line arguments. - - Args: - args (argparse.Namespace): arguments parsed from command line - """ - is_gpt2: bool = args.model_type == "gpt2" - is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH - - if is_greedysearch: - if not is_gpt2: - raise NotImplementedError("Currently only gpt2 with greedy search is supported") - if args.output_sequences_scores: - raise NotImplementedError("output_sequences_scores currently is not supported in greedy search") - if args.output_token_scores: - raise NotImplementedError("output_token_scores currently is not supported in greedy search") - - if is_gpt2: - if args.decoder_onnx and os.path.exists(args.decoder_onnx): - logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}") - else: - if not args.decoder_onnx: - onnx_filename = "gpt2_past_{}.onnx".format("fp16" if args.precision == Precision.FLOAT16 else "fp32") - args.decoder_onnx = Path(Path(args.output).parent, onnx_filename).as_posix() - - logger.info(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...") - gpt2_to_onnx(args) - else: # t5 or mt5 - if args.decoder_onnx and args.encoder_decoder_init_onnx: - logger.info( - f"skip convert_to_onnx since paths specified: {args.decoder_onnx} and {args.encoder_decoder_init_onnx}" - ) - else: - logger.info(f"Convert model {args.model_name_or_path} to onnx ...") - t5_to_onnx(args) - - if args.run_shape_inference: - logger.info(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.") - shape_inference(args.decoder_onnx, args.use_external_data_format) - - if is_gpt2: - config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) - elif args.model_type == "t5": - config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) - else: - config = MT5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) - - if args.verbose: - logger.info(f"Config={config}") - - eos_token_id = config.eos_token_id - pad_token_id = config.eos_token_id if is_gpt2 else config.pad_token_id - vocab_size = config.vocab_size - - # if vocab_size is given in parameters use that. - if args.vocab_size != -1: - vocab_size = args.vocab_size - - decoder_model = onnx.load_model(args.decoder_onnx, load_external_data=True) - decoder_model.graph.name = f"{args.model_type} decoder" - - if args.model_type == "gpt2": - verify_gpt2_subgraph(decoder_model.graph, args.precision) - else: - verify_t5_decoder_subgraph(decoder_model.graph, args.precision) - - inputs = ( - [ - "input_ids", - "max_length", - "min_length", - "num_beams", - "num_return_sequences", - "length_penalty", - "repetition_penalty", - ] - if not is_greedysearch - else [ - "input_ids", - "max_length", - "min_length", - "repetition_penalty", - ] - ) - - outputs = ["sequences"] - if args.output_sequences_scores: - outputs.append("sequences_scores") - - if args.output_token_scores: - assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" - outputs.append("scores") - - node = ( - onnx.helper.make_node( - "BeamSearch", - inputs=inputs, - outputs=outputs, - name=f"BeamSearch_{args.model_type}", - ) - if not is_greedysearch - else onnx.helper.make_node( - "Sampling", - inputs=inputs, - outputs=outputs, - name=f"GreedySearch_{args.model_type}", - ) - ) - - node.domain = "com.microsoft" - - attr_to_extend = ( - [ - onnx.helper.make_attribute("eos_token_id", eos_token_id), - onnx.helper.make_attribute("pad_token_id", pad_token_id), - onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), - onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), - ] - if not is_greedysearch - else [ - onnx.helper.make_attribute("eos_token_id", eos_token_id), - onnx.helper.make_attribute("pad_token_id", pad_token_id), - onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), - onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - onnx.helper.make_attribute("temperature", 1.1), - onnx.helper.make_attribute("top_p", 0.6), - onnx.helper.make_attribute("filter_value", 0.0), - onnx.helper.make_attribute("presence_penalty", 0.0), - ] - ) - node.attribute.extend(attr_to_extend) - - initializers = [] - if args.model_type in ["t5", "mt5"]: - if args.run_shape_inference: - logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.") - shape_inference(args.encoder_decoder_init_onnx, args.use_external_data_format) - encoder_model = onnx.load_model(args.encoder_decoder_init_onnx, load_external_data=True) - encoder_model.graph.name = f"{args.model_type} encoder and decoder init" - verify_t5_encoder_decoder_init_subgraph(encoder_model.graph, args.precision) - - if not args.disable_shared_initializers: - # Unique shared initializers from the decoder and decoder_init could reduce memory usage in inference. - initializers = get_shared_initializers(encoder_model, decoder_model) - logger.info( - f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in subgraphs are moved to the main graph" - ) - - # TODO(tianleiwu): investigate the following which causes error in inference - # Move initializer from subgraph to main graph could reduce memory usage in inference. - # moved_initializers = move_initializers(encoder_model.graph) - # logger.info( - # f"{len(moved_initializers)} initializers ({[i.name for i in moved_initializers]}) from the encoder are moved to the main graph" - # ) - # initializers.extend(moved_initializers) - - node.attribute.extend( - [ - onnx.helper.make_attribute("encoder", encoder_model.graph), - onnx.helper.make_attribute("decoder", decoder_model.graph), - onnx.helper.make_attribute( - "decoder_start_token_id", - config.decoder_start_token_id if len(encoder_model.graph.input) == 3 else -1, - ), - ] - ) - else: - # Move initializer from subgraph to main graph could reduce memory usage in inference. - initializers = move_initializers(decoder_model.graph) - logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph") - - node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph)) - - # graph inputs - input_ids = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"]) - max_length = onnx.helper.make_tensor_value_info("max_length", TensorProto.INT32, [1]) - min_length = onnx.helper.make_tensor_value_info("min_length", TensorProto.INT32, [1]) - num_beams = onnx.helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1]) - num_return_sequences = onnx.helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) - length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) - repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) - - graph_inputs = ( - [ - input_ids, - max_length, - min_length, - num_beams, - num_return_sequences, - length_penalty, - repetition_penalty, - ] - if not is_greedysearch - else [ - input_ids, - max_length, - min_length, - repetition_penalty, - ] - ) - - if args.vocab_mask: - vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size]) - graph_inputs.append(vocab_mask) - - if args.prefix_vocab_mask: - prefix_vocab_mask = onnx.helper.make_tensor_value_info( - "prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size] - ) - graph_inputs.append(prefix_vocab_mask) - - if args.custom_attention_mask: - attention_mask = onnx.helper.make_tensor_value_info( - "attention_mask", TensorProto.INT32, ["batch_size", "sequence_length"] - ) - graph_inputs.append(attention_mask) - - # graph outputs - sequences = ( - onnx.helper.make_tensor_value_info( - "sequences", - TensorProto.INT32, - ["batch_size", "num_return_sequences", "max_length"], - ) - if not is_greedysearch - else onnx.helper.make_tensor_value_info( - "sequences", - TensorProto.INT32, - ["batch_size", "max_length"], - ) - ) - - sequences_scores = onnx.helper.make_tensor_value_info( - "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] - ) - - scores = onnx.helper.make_tensor_value_info( - "scores", - TensorProto.FLOAT, - ["max_length - sequence_length", "batch_size", "num_beams", vocab_size], - ) - - graph_outputs = [sequences] - - if args.output_sequences_scores: - graph_outputs.append(sequences_scores) - - if args.output_token_scores: - graph_outputs.append(scores) - - new_graph = onnx.helper.make_graph( - [node], - f"{args.model_type} beam search" if not is_greedysearch else f"{args.model_type} greedy search", - graph_inputs, - graph_outputs, - initializers, - ) - - # Create the model - new_model = onnx.helper.make_model( - new_graph, - producer_name="onnxruntime.transformers", - opset_imports=decoder_model.opset_import, - ) - - # TODO(tianleiwu): move shared initializers from T5 encoder and decoder subgraphs to parent graph to save memory. - if args.use_external_data_format: - from packaging import version - - if version.parse(onnx.__version__) < version.parse("1.12.0"): - logger.warning("Require onnx >= 1.12 to save large (>2GB) model!") - - OnnxModel.save( - new_model, - args.output, - save_as_external_data=True, - all_tensors_to_one_file=True, - ) - else: - onnx.save(new_model, args.output) - logger.info(f"model save to {args.output}") - - -def test_torch_performance( - args: argparse.Namespace, - model: Union[GPT2LMHeadModel, T5ForConditionalGeneration], - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - eos_token_id: int, - pad_token_id: int, - bad_words_ids: List[List[int]], -) -> Dict[str, Any]: - """Test PyTorch performance of text generation. - - Args: - args (argparse.Namespace): arguments parsed from command line - model (Union[GPT2LMHeadModel, T5ForConditionalGeneration]): PyTorch model - input_ids (torch.Tensor): input_ids - attention_mask (torch.Tensor): Attention mask - eos_token_id (int): EOS token ID - pad_token_id (int): Padding token ID - bad_words_ids (List[List[int]]): Words shall not be generated. - - Raises: - RuntimeError: PyTorch with CUDA is not available for --use_gpu - - Returns: - Dict[str, Any]: A dictionary with string with metric name, and value can be integer or string. - """ - if args.use_gpu and not torch.cuda.is_available(): - raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.") - - if args.precision == Precision.FLOAT16: - model.half() - - device = torch.device("cuda:0" if args.use_gpu else "cpu") - model.to(device) - - torch.set_grad_enabled(False) - input_ids = input_ids.to(device) - attention_mask = attention_mask.to(device) - - torch_latency = [] - for _ in range(args.total_runs): - start = time.time() - _ = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_length=args.max_length, - min_length=args.min_length, - num_beams=args.num_beams, - early_stopping=args.early_stopping, - no_repeat_ngram_size=args.no_repeat_ngram_size, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - num_return_sequences=args.num_return_sequences, - length_penalty=args.length_penalty, - repetition_penalty=args.repetition_penalty, - bad_words_ids=bad_words_ids, - return_dict_in_generate=True, - output_scores=args.output_sequences_scores or args.output_token_scores, - ) - torch_latency.append(time.time() - start) - batch_size = input_ids.shape[0] - from benchmark_helper import get_latency_result - - return get_latency_result(torch_latency, batch_size) - - -def create_attention_mask(input_ids, pad_token_id): - attention_mask = np.ones(input_ids.shape, dtype=np.int32) - for i in range(input_ids.shape[0]): - abs_pos = 0 - for j in range(input_ids.shape[1]): - if input_ids[i][j] == pad_token_id and abs_pos == 0: - attention_mask[i][j] = 0 - else: - abs_pos += 1 - return attention_mask - - -def test_gpt_model(args: argparse.Namespace, sentences: Optional[List[str]] = None, is_greedy: bool = False): - """Test GPT-2 model - - Args: - args (argparse.Namespace): arguments parsed from command line - sentences (Optional[List[str]], optional): input text. Defaults to None. - - Returns: - Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. - """ - assert args.model_type == "gpt2" - - tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) - tokenizer.padding_side = "left" - tokenizer.pad_token = tokenizer.eos_token - - model = GPT2LMHeadModel.from_pretrained( - args.model_name_or_path, - cache_dir=args.cache_dir, - pad_token_id=tokenizer.eos_token_id, - ) - - # Use different length sentences to test batching - if sentences is None: - sentences = [ - "The product is released", - "I enjoy walking in the park", - "Test best way to invest", - ] - - inputs = tokenizer(sentences, return_tensors="pt", padding=True) - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - - bad_words = "walk in park" - bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True) - bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list - if args.vocab_mask: - logger.debug("bad_words_ids", bad_words_ids) - else: - bad_words_ids = [] - - config = model.config - eos_token_id = config.eos_token_id - pad_token_id = config.eos_token_id - vocab_size = config.vocab_size - - torch_decoded_sequences = [] - beam_outputs = None - if not args.disable_parity: - print("-" * 50) - print("Test PyTorch model and beam search with huggingface transformers...") - beam_outputs = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_length=args.max_length, - min_length=args.min_length, - num_beams=args.num_beams, - early_stopping=args.early_stopping, - no_repeat_ngram_size=args.no_repeat_ngram_size, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - num_return_sequences=args.num_return_sequences, - length_penalty=args.length_penalty, - repetition_penalty=args.repetition_penalty, - bad_words_ids=bad_words_ids if bad_words_ids else None, - return_dict_in_generate=True, - output_scores=args.output_sequences_scores or args.output_token_scores, - ) - print("input_ids", input_ids) - print("huggingface transformers outputs:") - print("sequences", beam_outputs.sequences) - if args.output_sequences_scores: - print("sequences_scores", beam_outputs.sequences_scores) - if args.output_token_scores: - print("scores", beam_outputs.scores) - for i, sequence in enumerate(beam_outputs.sequences): - decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True) - torch_decoded_sequences.append(decoded_sequence) - print(f"{i}: {decoded_sequence}") - - print("-" * 50) - print("Testing beam search with onnxruntime...") - - ort_session = create_ort_session(args.output, args.use_gpu) - print("ort session created") - if is_greedy: - print("is_greedy") - inputs = { - "input_ids": input_ids.cpu().numpy().astype(np.int32), - "max_length": np.array([args.max_length], dtype=np.int32), - "min_length": np.array([args.min_length], dtype=np.int32), - "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), - } - print(inputs) - else: - inputs = { - "input_ids": input_ids.cpu().numpy().astype(np.int32), - "max_length": np.array([args.max_length], dtype=np.int32), - "min_length": np.array([args.min_length], dtype=np.int32), - "num_beams": np.array([args.num_beams], dtype=np.int32), - "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32), - "length_penalty": np.array([args.length_penalty], dtype=np.float32), - "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), - } - - if args.vocab_mask: - vocab_mask = np.ones((vocab_size), dtype=np.int32) - if args.vocab_mask: - for bad_word_id in bad_words_ids: - vocab_mask[bad_word_id] = 0 - inputs["vocab_mask"] = vocab_mask - - if args.custom_attention_mask: - inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id) - - batch_size = input_ids.shape[0] - if args.prefix_vocab_mask: - logger.info("Use prefix vocab mask with all ones in ORT, but no corresponding setting for Torch model.") - prefix_vocab_mask = np.ones((batch_size, vocab_size), dtype=np.int32) - inputs["prefix_vocab_mask"] = prefix_vocab_mask - - logger.debug("ORT inputs", inputs) - result = ort_session.run(None, inputs) - - if args.save_test_data: - test_data_dir = Path(args.output).parent.as_posix() - logger.debug("test_data_dir", test_data_dir) - from bert_test_data import output_test_data - - all_inputs = [inputs] - for i, inputs in enumerate(all_inputs): - dir = os.path.join(test_data_dir, "test_data_set_" + str(i)) - output_test_data(dir, inputs) - - # Test performance - latency = [] - for _ in range(args.total_runs): - start = time.time() - _ = ort_session.run(None, inputs) - latency.append(time.time() - start) - - from benchmark_helper import get_latency_result - - batch_size = input_ids.shape[0] - output = get_latency_result(latency, batch_size) - - print("ORT outputs:") - sequences = result[0] - print("sequences", sequences) - if args.output_sequences_scores: - print("sequences_scores", result[1]) - if args.output_token_scores: - print("scores", result[2]) - - if is_greedy: - (batch_size, max_length) = sequences.shape - ort_decoded_sequences = [] - for i in range(batch_size): - decoded_sequence = tokenizer.decode(sequences[i], skip_special_tokens=True) - ort_decoded_sequences.append(decoded_sequence) - print(f"batch {i} sequence: {decoded_sequence}") - else: - (batch_size, num_sequences, max_length) = sequences.shape - ort_decoded_sequences = [] - for i in range(batch_size): - for j in range(num_sequences): - decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) - ort_decoded_sequences.append(decoded_sequence) - print(f"batch {i} sequence {j}: {decoded_sequence}") - - if beam_outputs: - torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1) - ort_sequences = torch.LongTensor(sequences) - print("-" * 50) - print("Torch Sequences:") - print(torch_sequences) - print(torch_decoded_sequences) - print("-" * 50) - print("ORT Sequences:") - print(ort_sequences) - print(ort_decoded_sequences) - print("-" * 50) - # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. - is_same = torch_decoded_sequences == ort_decoded_sequences - print("Torch and ORT result is ", "same" if is_same else "different") - output["parity"] = is_same - - if args.torch_performance: - torch_latency_output = test_torch_performance( - args, - model, - input_ids, - attention_mask, - eos_token_id, - pad_token_id, - bad_words_ids, - ) - print("Torch Latency", torch_latency_output) - - print("ORT", output) - - return output - - -def test_t5_model(args: argparse.Namespace, sentences: Optional[List[str]] = None): - """Test T5 or MT5 model - - Args: - args (argparse.Namespace): arguments parsed from command line - sentences (Optional[List[str]], optional): input text. Defaults to None. - - Returns: - Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. - """ - assert args.model_type in ["t5", "mt5"] - - if args.prefix_vocab_mask: - logger.debug("Skipping parity test as prefix vocab mask is not implemented by Hugging Face") - return None - - tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) - tokenizer.padding_side = "left" - - if args.model_type == "t5": - model = T5ForConditionalGeneration.from_pretrained( - args.model_name_or_path, - cache_dir=args.cache_dir, - ) - else: - model = MT5ForConditionalGeneration.from_pretrained( - args.model_name_or_path, - cache_dir=args.cache_dir, - ) - - # Use different length sentences to test batching - if sentences is None: - sentences = [ - "translate English to French: The product is released", - "summarize: research continues to show that pets bring real health benefits to their owners." - + "Having a dog around can lead to lower levels of stress for both adults and kids.", - # "summarize: I enjoy walking in the park. It makes my mind feel calm and refreshed. " - # + "I enjoy looking at the trees, flowers, and wildlife around me, and listening to sound from natural.", - ] - - inputs = tokenizer(sentences, return_tensors="pt", padding=True) - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - - bad_words = "walk in park" - bad_words_ids = tokenizer.encode(bad_words)[:-1] # exclude the last token (EOS) - bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list - if args.vocab_mask: - logger.debug("bad_words_ids", bad_words_ids) - else: - bad_words_ids = [] - - config = model.config - eos_token_id = config.eos_token_id - pad_token_id = config.pad_token_id - vocab_size = config.vocab_size - logger.debug(f"eos_token_id:{eos_token_id}, pad_token_id:{pad_token_id}, vocab_size:{vocab_size}") - - torch_decoded_sequences = [] - if not args.disable_parity: - print("-" * 50) - print("Test PyTorch model and beam search with huggingface transformers...") - beam_outputs = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_length=args.max_length, - min_length=args.min_length, - num_beams=args.num_beams, - early_stopping=args.early_stopping, - no_repeat_ngram_size=args.no_repeat_ngram_size, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - num_return_sequences=args.num_return_sequences, - length_penalty=args.length_penalty, - repetition_penalty=args.repetition_penalty, - bad_words_ids=bad_words_ids if bad_words_ids else None, - return_dict_in_generate=True, - output_scores=args.output_sequences_scores or args.output_token_scores, - ) - - print("input_ids", input_ids) - print("huggingface transformers outputs:") - print("sequences", beam_outputs.sequences) - if args.output_sequences_scores: - print("sequences_scores", beam_outputs.sequences_scores) - if args.output_token_scores: - print("scores", beam_outputs.scores) - for i, sequence in enumerate(beam_outputs.sequences): - decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True) - torch_decoded_sequences.append(decoded_sequence) - print("{}: {}".format(i, decoded_sequence)) - - print("-" * 50) - print("Testing beam search with onnxruntime...") - - ort_session = create_ort_session(args.output, args.use_gpu) - - vocab_mask = np.ones((vocab_size), dtype=np.int32) - if args.vocab_mask: - for bad_word_id in bad_words_ids: - vocab_mask[bad_word_id] = 0 - - inputs = { - "input_ids": input_ids.cpu().numpy().astype(np.int32), - "max_length": np.array([args.max_length], dtype=np.int32), - "min_length": np.array([args.min_length], dtype=np.int32), - "num_beams": np.array([args.num_beams], dtype=np.int32), - "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32), - "length_penalty": np.array([args.length_penalty], dtype=np.float32), - "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), - } - - if args.vocab_mask: - inputs["vocab_mask"] = vocab_mask - - if args.custom_attention_mask: - inputs["attention_mask"] = create_attention_mask(input_ids, pad_token_id) - - if args.save_test_data: - test_data_dir = Path(args.output).parent.as_posix() - logger.debug("test_data_dir", test_data_dir) - from bert_test_data import output_test_data - - all_inputs = [inputs] - for i, inputs in enumerate(all_inputs): - dir = os.path.join(test_data_dir, "test_data_set_" + str(i)) - output_test_data(dir, inputs) - - logger.debug("ORT inputs", inputs) - - # Test performance - latency = [] - for _ in range(args.total_runs): - start = time.time() - result = ort_session.run(None, inputs) - latency.append(time.time() - start) - batch_size = input_ids.shape[0] - from benchmark_helper import get_latency_result - - output = get_latency_result(latency, batch_size) - - print("ORT outputs:") - sequences = result[0] - print("sequences", sequences) - if args.output_sequences_scores: - print("sequences_scores", result[1]) - if args.output_token_scores: - print("scores", result[2]) - - (batch_size, num_sequences, max_length) = sequences.shape - ort_decoded_sequences = [] - for i in range(batch_size): - for j in range(num_sequences): - decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) - ort_decoded_sequences.append(decoded_sequence) - print(f"batch {i} sequence {j}: {decoded_sequence}") - - if not args.disable_parity: - torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1) - ort_sequences = torch.LongTensor(sequences) - print("-" * 50) - print("Torch Sequences:") - print(torch_sequences) - print(torch_decoded_sequences) - print("-" * 50) - print("ORT Sequences:") - print(ort_sequences) - print(ort_decoded_sequences) - print("-" * 50) - # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. - is_same = torch_decoded_sequences == ort_decoded_sequences - print("Torch and ORT result is ", "same" if is_same else "different") - output["parity"] = is_same - - if args.torch_performance: - torch_latency_output = test_torch_performance( - args, - model, - input_ids, - attention_mask, - eos_token_id, - pad_token_id, - bad_words_ids, - ) - print("Torch Latency", torch_latency_output) - - print("ORT", output) - return output - - -def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None): - """Main entry function - - Args: - argv (Optional[List[str]], optional): _description_. Defaults to None. - sentences (Optional[List[str]], optional): input text. Defaults to None. - - Raises: - ValueError: Path does not exist: --encoder_decoder_init_onnx - ValueError: Path does not exist: --decoder_onnx - ValueError: --decoder_onnx and --encoder_decoder_init_onnx are not used together for T5 - - Returns: - Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. - """ - - args = parse_arguments(argv) - setup_logger(args.verbose) - - if args.model_type in ["t5", "mt5"]: - if args.encoder_decoder_init_onnx and not os.path.exists(args.encoder_decoder_init_onnx): - raise ValueError(f"Path does not exist: --encoder_decoder_init_onnx {args.encoder_decoder_init_onnx}") - if args.decoder_onnx and not os.path.exists(args.decoder_onnx): - raise ValueError(f"Path does not exist: --decoder_onnx {args.decoder_onnx}") - if (args.encoder_decoder_init_onnx and not args.decoder_onnx) or ( - args.decoder_onnx and not args.encoder_decoder_init_onnx - ): - raise ValueError("--decoder_onnx shall use together with --encoder_decoder_init_onnx") - - is_greedy = args.num_beams == 1 and args.num_return_sequences == 1 - - if args.model_type == "gpt2" and is_greedy: - convert_generation_model(args, GenerationType.GREEDYSEARCH) - else: - convert_generation_model(args) - - logger.info("start testing model...") - if args.model_type in ["t5", "mt5"]: - result = test_t5_model(args, sentences=sentences) - else: - result = test_gpt_model(args, sentences=sentences, is_greedy=is_greedy) - - if result: - if args.use_external_data_format: - logger.info(f"Output files: {args.output}, {args.output}.data") - else: - logger.info(f"Output file: {args.output}") - - return result - - -if __name__ == "__main__": - main() From e7dd78a1e2169539edecee7c06b66f63fe64188c Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Mon, 12 Dec 2022 22:50:07 -0800 Subject: [PATCH 38/51] Update OperatorKernels.md --- docs/OperatorKernels.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9c345ee5f8091..cd50bc192c99e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -438,6 +438,7 @@ Do not modify directly.* |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| +|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| @@ -797,6 +798,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| From a1c8ea2184493ce3bd8db0adaf4c796c8b665476 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 13 Dec 2022 19:41:34 +0000 Subject: [PATCH 39/51] prefast warning --- docs/ContribOperators.md | 80 +++++++++++++++++++ .../cpu/transformers/sampling_cpu_helper.h | 4 +- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 0d782a96a4ca0..fba0c7ded7fa2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -74,6 +74,7 @@ Do not modify directly.* * com.microsoft.RestorePadding * com.microsoft.Rfft * com.microsoft.SampleOp + * com.microsoft.Sampling * com.microsoft.SkipLayerNormalization * com.microsoft.Snpe * com.microsoft.SparseToDenseMatMul @@ -3802,6 +3803,85 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.Sampling** + + Greedy Sampling for text generation. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
custom : int
+
absence penalty for sampling
+
decoder : graph (required)
+
Decoder subgraph to execute in a loop.
+
decoder_start_token_id : int
+
The id of the token that indicates decoding starts.
+
encoder : graph
+
The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
+
eos_token_id : int (required)
+
The id of the end-of-sequence token
+
filter_value : float
+
filter value for top_p
+
min_tokens_to_keep : int
+
min_tokens_to_keep
+
model_type : int
+
model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart
+
no_repeat_ngram_size : int
+
no repeat ngrams size
+
pad_token_id : int (required)
+
The id of the padding token
+
presence_penalty : float
+
presence penalty for sampling
+
temperature : float
+
temperature for sampling
+
top_p : float
+
top_p for sampling
+
+ +#### Inputs (2 - 8) + +
+
input_ids : I
+
The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)
+
max_length : I
+
The maximum length of the sequence to be generated. Shape is (1)
+
min_length (optional) : I
+
The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)
+
repetition_penalty (optional) : T
+
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
+
vocab_mask (optional) : I
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
prefix_vocab_mask (optional) : I
+
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
+
attention_mask (optional) : I
+
Custom attention mask. Shape is (batch_size, sequence_length)
+
presence_mask (optional) : I
+
presence penalty mask. Shape is (batch_size, vocab_size)
+
+ +#### Outputs (1 - 2) + +
+
sequences : I
+
Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)
+
filtered_logits (optional) : T
+
filtered logits as input to the mutinomial function . Shape is (batch_size, vocab_size)
+
+ +#### Type Constraints + +
+
T : tensor(float)
+
Constrain input and output types to float tensors.
+
I : tensor(int32)
+
Constrain to integer types
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h index 105e2a9a588ef..1e3c7035ff336 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -44,7 +44,7 @@ void cumulate_and_filter(gsl::span& next_token_scores, if (cumulative_probs[offset] <= 1 - parameters->top_p) { filter_scores(sorted_indices, next_token_scores, parameters, offset); } - for (size_t j = 1; j < static_cast(parameters->vocab_size - parameters->min_tokens_to_keep); j++) { + for (size_t j = 1; j < static_cast(parameters->vocab_size) - static_cast(parameters->min_tokens_to_keep); j++) { cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; if (cumulative_probs[j + offset] <= 1 - parameters->top_p) { filter_scores(sorted_indices, next_token_scores, parameters, j + offset); @@ -65,7 +65,7 @@ Status Sample(AllocatorPtr& allocator, gsl::span& sorted_scores = sampling_state->sorted_scores; memcpy(sorted_scores.data(), next_token_scores.data(), next_token_scores.size_bytes()); - std::vector sorted_indices(parameters->batch_size * parameters->vocab_size); + std::vector sorted_indices(static_cast(parameters->batch_size) * static_cast(parameters->vocab_size)); std::function predicator; if (parameters->custom_sampling) { From 88575f50bf9774f841abb5bdfa4997756b6213f6 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 13 Dec 2022 22:20:23 +0000 Subject: [PATCH 40/51] exclude amd build --- cmake/onnxruntime_rocm_hipify.cmake | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 76d6e3d220c60..6c38b0782fae6 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -75,10 +75,13 @@ set(contrib_ops_excluded_files "transformers/beam_search.h" "transformers/generation_device_helper.cc" "transformers/generation_device_helper.h" - "transformers/beam_search_impl.cu" - "transformers/beam_search_impl.h" + "transformers/generation_cuda_impl.cu" + "transformers/generation_cuda_impl.h" "transformers/greedy_search.cc" "transformers/greedy_search.h" + "transformers/sampling.cc" + "transformers/sampling.h" + "transformers/sampling_cuda_helper.h" "transformers/dump_cuda_tensor.cc" "transformers/dump_cuda_tensor.h" "conv_transpose_with_dynamic_pads.cc" From 6ff8f23552f9ed5b841e04e295fbc303da54221f Mon Sep 17 00:00:00 2001 From: wangyems Date: Wed, 14 Dec 2022 23:18:14 +0000 Subject: [PATCH 41/51] ffix build issue due to rebase --- .../contrib_ops/cpu/transformers/beam_search.cc | 6 +++--- .../contrib_ops/cpu/transformers/greedy_search.cc | 10 +++++----- .../cpu/transformers/greedy_search_impl_gpt.h | 4 ++-- onnxruntime/contrib_ops/cpu/transformers/sampling.cc | 8 ++++++-- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 4e2d9c398fc36..6ad28a07a0e0a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -66,12 +66,12 @@ void BeamSearch::Init(const OpKernelInfo& info) { ONNX_NAMESPACE::GraphProto proto; - if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { // Make sure the encoder sub-graph attribute is present for the T5 model. ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); } - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // Check if the init_decoder sub-graph attribute is present for the GPT2 model. if (info.GetAttr("init_decoder", &proto).IsOK()) { has_init_decoder_ = true; @@ -114,7 +114,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, init_run_gpt_subgraph_ = std::move(res.second); init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager(); } - } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { + } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { if (attribute_name == "encoder") { ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc index 93202798226ba..4a6b3b58ff848 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc @@ -80,16 +80,16 @@ void GreedySearch::Init(const OpKernelInfo& info) { parameters_.ParseFromAttributes(info); // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5) - ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt || - parameters_.model_type == IBeamSearchParameters::kModelTypeT5); + ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt || + parameters_.model_type == IGenerationParameters::kModelTypeT5); ONNX_NAMESPACE::GraphProto proto; - if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { // Make sure the encoder sub-graph attribute is present for the T5 model. ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); } - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // Check if the init_decoder sub-graph attribute is present for the GPT2 model. if (info.GetAttr("init_decoder", &proto).IsOK()) { has_init_decoder_ = true; @@ -130,7 +130,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat init_run_gpt_subgraph_ = std::move(res.second); init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager(); } - } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { // encoder-decoder like T5 + } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5 ORT_THROW("Not Implemented"); // if (attribute_name == "encoder") { // ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 4e970d8b0c1fa..aa0206f4cdc68 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -153,8 +153,8 @@ Status GreedySearchGpt::UpdateFeeds( } template -Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager, - const FeedsFetchesManager& feeds_fetches_manager) { +Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager, + const FeedsFetchesManager& feeds_fetches_manager) { auto status = Status::OK(); const ParametersT* parameters = this->parameters_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index e6d1b5e656750..4d9ed8fde8cdb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -90,6 +90,8 @@ Status Sampling::Compute(OpKernelContext* ctx) const { if (!gpt_subgraph_->IsOutputFloat16()) { GreedySearchGpt impl{ *ctx_internal, + nullptr, // init decoder + nullptr, *decoder_session_state, *gpt_subgraph_, thread_pool, @@ -105,10 +107,12 @@ Status Sampling::Compute(OpKernelContext* ctx) const { update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; ORT_RETURN_IF_ERROR(impl.Initialize()); - return impl.Execute(*decoder_feeds_fetches_manager_); + return impl.Execute(nullptr, *decoder_feeds_fetches_manager_); } else { GreedySearchGpt impl{ *ctx_internal, + nullptr, // init decoder + nullptr, *decoder_session_state, *gpt_subgraph_, thread_pool, @@ -124,7 +128,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { update_gpt_feeds_fp16_func_}; ORT_RETURN_IF_ERROR(impl.Initialize()); - return impl.Execute(*decoder_feeds_fetches_manager_); + return impl.Execute(nullptr, *decoder_feeds_fetches_manager_); } } From d74cc1da01e699ed77b198957c147d4c00a5f77b Mon Sep 17 00:00:00 2001 From: wangyems Date: Thu, 15 Dec 2022 21:42:53 +0000 Subject: [PATCH 42/51] fix build issue due to rebase --- onnxruntime/contrib_ops/cpu/transformers/sampling.cc | 4 ++-- onnxruntime/contrib_ops/cuda/transformers/sampling.cc | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index 4d9ed8fde8cdb..85696d93cf25b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -95,7 +95,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { *decoder_session_state, *gpt_subgraph_, thread_pool, - cuda_stream_, + ctx->GetComputeStream(), dumper_, parameters, GenerationCpuDeviceHelper::CreateGptInputs, @@ -116,7 +116,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { *decoder_session_state, *gpt_subgraph_, thread_pool, - cuda_stream_, + ctx->GetComputeStream(), dumper_, parameters, GenerationCpuDeviceHelper::CreateGptInputs, diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling.cc b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc index 736f0e9457553..a758112f6f5e7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc @@ -32,8 +32,6 @@ transformers::CudaTensorConsoleDumper g_cuda_dumper_sampling; Sampling::Sampling(const OpKernelInfo& info) : onnxruntime::contrib::transformers::Sampling(info) { - SetComputeStream(static_cast(info.GetExecutionProvider()->GetComputeStream())); - SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds, GenerationCudaDeviceHelper::TopK, GenerationCudaDeviceHelper::DeviceCopy, From 64bd22c7565b9e5307d7ca4abd085c7fe3fadf5e Mon Sep 17 00:00:00 2001 From: wangyems Date: Mon, 19 Dec 2022 21:11:39 +0000 Subject: [PATCH 43/51] add sampling conversion script --- .../tools/transformers/convert_generation.py | 171 ++++++++++++++---- 1 file changed, 133 insertions(+), 38 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 8163b0cacee01..c589487196512 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -25,6 +25,9 @@ Example 5: convert gpt2 model with greedy search: python convert_generation.py -m gpt2 --output gpt2_greedy_search.onnx --num_beams 1 --num_return_sequences 1 + +Example 6: convert gpt2 model with sampling: + python convert_generation.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 --top_p 0.6 """ import argparse @@ -72,6 +75,7 @@ class GenerationType(Enum): BEAMSEARCH = "beam_search" GREEDYSEARCH = "greedy_search" + SAMPLING = "sampling" def __str__(self): return self.value @@ -253,6 +257,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: ) model_group.set_defaults(custom_attention_mask=False) + model_group.add_argument( + "--presence_mask", + required=False, + action="store_true", + help="Presence mask for custom sampling", + ) + model_group.set_defaults(presence_mask=False) + beam_parameters_group = parser.add_argument_group( "Beam search parameters not stored in the output model, for testing parity and performance" ) @@ -287,6 +299,54 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: help="Positive. >1 to penalize and <1 to encourage.", ) + beam_parameters_group.add_argument( + "--temperature", + type=float, + required=False, + default=1.0, + help="The value used to module the next token probabilities.", + ) + + beam_parameters_group.add_argument( + "--top_p", + type=float, + required=False, + default=1.0, + help="Top P for sampling", + ) + + beam_parameters_group.add_argument( + "--filter_value", + type=float, + required=False, + default=-float("Inf"), + help="Filter value for Top P sampling", + ) + + beam_parameters_group.add_argument( + "--min_tokens_to_keep", + type=int, + required=False, + default=1, + help="Minimumber of tokens we keep per batch example in the output.", + ) + + beam_parameters_group.add_argument( + "--presence_penalty", + type=float, + required=False, + default=0.0, + help="presence penalty for custom sampling.", + ) + + beam_parameters_group.add_argument( + "--custom", + type=int, + required=False, + default=0, + help="If 1 customized top P logic is applied", + ) + beam_parameters_group.add_argument( "--vocab_size", type=int, @@ -1095,15 +1155,17 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati args (argparse.Namespace): arguments parsed from command line """ is_gpt2: bool = args.model_type == "gpt2" + is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH + is_sampling: bool = generation_type == GenerationType.SAMPLING - if is_greedysearch: + if is_greedysearch or is_sampling: if not is_gpt2: - raise NotImplementedError("Currently only gpt2 with greedy search is supported") + raise NotImplementedError("Currently only gpt2 with greedy search/sampling is supported") if args.output_sequences_scores: - raise NotImplementedError("output_sequences_scores currently is not supported in greedy search") + raise NotImplementedError("output_sequences_scores currently is not supported in greedy search/sampling") if args.output_token_scores: - raise NotImplementedError("output_token_scores currently is not supported in greedy search") + raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling") if is_gpt2: if args.decoder_onnx and os.path.exists(args.decoder_onnx): @@ -1222,8 +1284,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati else: verify_t5_decoder_subgraph(decoder_model.graph, args.precision) - inputs = ( - [ + inputs = None + if is_beamsearch: + inputs = [ "input_ids", "max_length", "min_length", @@ -1232,14 +1295,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati "length_penalty", "repetition_penalty", ] - if not is_greedysearch - else [ + elif is_greedysearch or is_sampling: + inputs = [ "input_ids", "max_length", "min_length", "repetition_penalty", ] - ) if args.vocab_mask: inputs.append("vocab_mask") @@ -1256,6 +1318,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati else: inputs.append("") + if is_sampling and args.custom and args.presence_mask: + inputs.append("presence_mask") + outputs = ["sequences"] if args.output_sequences_scores: outputs.append("sequences_scores") @@ -1264,40 +1329,60 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" outputs.append("scores") - node = ( - onnx.helper.make_node( + node = None + if is_beamsearch: + node = onnx.helper.make_node( "BeamSearch", inputs=inputs, outputs=outputs, name=f"BeamSearch_{args.model_type}", ) - if not is_greedysearch - else onnx.helper.make_node( + elif is_greedysearch: + node = onnx.helper.make_node( "GreedySearch", inputs=inputs, outputs=outputs, name=f"GreedySearch_{args.model_type}", ) - ) + elif is_sampling: + node = onnx.helper.make_node( + "Sampling", + inputs=inputs, + outputs=outputs, + name=f"Sampling_{args.model_type}", + ) node.domain = "com.microsoft" - attr_to_extend = ( - [ + attr_to_extend = None + if is_beamsearch: + attr_to_extend = [ onnx.helper.make_attribute("eos_token_id", eos_token_id), onnx.helper.make_attribute("pad_token_id", pad_token_id), onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), ] - if not is_greedysearch - else [ + elif is_greedysearch: + attr_to_extend = [ onnx.helper.make_attribute("eos_token_id", eos_token_id), onnx.helper.make_attribute("pad_token_id", pad_token_id), onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), ] - ) + elif is_sampling: + attr_to_extend = [ + onnx.helper.make_attribute("eos_token_id", eos_token_id), + onnx.helper.make_attribute("pad_token_id", pad_token_id), + onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), + onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + onnx.helper.make_attribute("temperature", args.temperature), + onnx.helper.make_attribute("top_p", args.top_p), + onnx.helper.make_attribute("filter_value", args.filter_value), + onnx.helper.make_attribute("min_tokens_to_keep", args.min_tokens_to_keep), + onnx.helper.make_attribute("custom", args.custom), + onnx.helper.make_attribute("presence_penalty", args.presence_penalty), + ] # Explicitly pass in the vocab size via an attribute if logits_matmul_weight_padded: @@ -1367,8 +1452,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) - graph_inputs = ( - [ + graph_inputs = None + if is_beamsearch: + graph_inputs = [ input_ids, max_length, min_length, @@ -1377,14 +1463,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati length_penalty, repetition_penalty, ] - if not is_greedysearch - else [ + elif is_greedysearch or is_sampling: + graph_inputs = [ input_ids, max_length, min_length, repetition_penalty, ] - ) if args.vocab_mask: vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size]) @@ -1402,37 +1487,43 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati ) graph_inputs.append(attention_mask) + if args.custom and args.presence_mask: + presence_mask = onnx.helper.make_tensor_value_info( + "presence_mask", TensorProto.INT32, ["batch_size", vocab_size] + ) + graph_inputs.append(presence_mask) + # graph outputs - sequences = ( - onnx.helper.make_tensor_value_info( + sequences = None + if is_beamsearch: + sequences = onnx.helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"], ) - if not is_greedysearch - else onnx.helper.make_tensor_value_info( + elif is_greedysearch or is_sampling: + sequences = ( + onnx.helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "max_length"], ) ) - sequences_scores = onnx.helper.make_tensor_value_info( - "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] - ) - - scores = onnx.helper.make_tensor_value_info( - "scores", - TensorProto.FLOAT, - ["max_length - sequence_length", "batch_size", "num_beams", vocab_size], - ) - graph_outputs = [sequences] if args.output_sequences_scores: + sequences_scores = onnx.helper.make_tensor_value_info( + "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] + ) graph_outputs.append(sequences_scores) if args.output_token_scores: + scores = onnx.helper.make_tensor_value_info( + "scores", + TensorProto.FLOAT, + ["max_length - sequence_length", "batch_size", "num_beams", vocab_size], + ) graph_outputs.append(scores) new_graph = onnx.helper.make_graph( @@ -1971,6 +2062,10 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None is_greedy = args.num_beams == 1 and args.num_return_sequences == 1 if args.model_type == "gpt2" and is_greedy: + if args.top_p > 0.0 and args.top_p < 1.0: + convert_generation_model(args, GenerationType.SAMPLING) + logger.info("test sampling model is not implemented yet") + return convert_generation_model(args, GenerationType.GREEDYSEARCH) else: convert_generation_model(args) From a74b76b84af13d9366c0bedb4f1633b9580d4897 Mon Sep 17 00:00:00 2001 From: wangyems Date: Mon, 19 Dec 2022 23:04:06 +0000 Subject: [PATCH 44/51] remove dup files --- .../cuda/transformers/beam_search_impl.cu | 286 ------------------ .../cuda/transformers/beam_search_impl.h | 61 ---- 2 files changed, 347 deletions(-) delete mode 100644 onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu delete mode 100644 onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu deleted file mode 100644 index 3b530cef6be8b..0000000000000 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cu_inc/common.cuh" -#include "cub/util_type.cuh" -#include "contrib_ops/cuda/transformers/beam_search_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { -__global__ void InitKernel(float* beam_scores, - int num_beams, - int total_elements) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < total_elements) { - int beam_index = index % num_beams; - beam_scores[index] = beam_index > 0 ? static_cast(-1e9) : 0.0f; - } -} - -void LaunchInitKernel( - float* beam_scores, - int batch_size, - int num_beams, - cudaStream_t stream) { - int total_elements = batch_size * num_beams; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - InitKernel<<>>(beam_scores, num_beams, total_elements); -} - -__global__ void NextTokenKernel(const int64_t* next_token_indices, - int32_t* next_indices, - int32_t* next_tokens, - int vocab_size, - int total_elements) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < total_elements) { - next_indices[index] = next_token_indices[index] / vocab_size; - next_tokens[index] = next_token_indices[index] % vocab_size; - } -} - -void LaunchNextTokenKernel(const int64_t* next_token_indices, - int32_t* next_indices, - int32_t* next_tokens, - int batch_size, - int top_k, - int vocab_size, - cudaStream_t stream) { - int total_elements = batch_size * top_k; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - NextTokenKernel<<>>(next_token_indices, - next_indices, - next_tokens, - vocab_size, - total_elements); -} - -template -__global__ void LogitsProcessKernel( - T* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int num_beams, - int vocab_size, - int padded_vocab_size, - int total_elements, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < total_elements) { - int batch_beam_index = index / padded_vocab_size; - int word_id = index % padded_vocab_size; - - if (word_id >= vocab_size) { - // Set any value within the padding region to the lowest value so that it isn't picked - next_token_scores[index] = cub::FpLimits::Lowest(); - } else { - // RepetitionPenaltyLogitsProcessor - if (repetition_penalty != 1.0f) { - int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length; - bool found = false; - for (int i = 0; i < current_sequence_length; i++) { - if (current_sequence[i] == word_id) { - found = true; - break; - } - } - if (found) { - float score = (float)next_token_scores[index]; - next_token_scores[index] = (T)(score < 0 ? score * repetition_penalty : score / repetition_penalty); - } - } - - // NoRepeatNGramLogitsProcessor - if (no_repeat_ngram_size > 0 && current_sequence_length >= no_repeat_ngram_size) { - int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length; - bool found = false; - for (int i = no_repeat_ngram_size - 1; i < current_sequence_length; i++) { - if (current_sequence[i] == word_id) { // last token of n-gram matched - found = true; - for (int j = 0; j < no_repeat_ngram_size - 1; j++) { // match the remaining N-1 tokens - if (current_sequence[i - j - 1] != current_sequence[current_sequence_length - 1 - j]) { - found = false; - break; - } - } - if (found) { - break; - } - } - } - - if (found) { - next_token_scores[index] = cub::FpLimits::Lowest(); - return; - } - } - - // VocabMaskLogitsProcessor - if (vocab_mask != nullptr && vocab_mask[word_id] == 0) { - next_token_scores[index] = cub::FpLimits::Lowest(); - return; - } - - // PrefixVocabMaskLogitsProcessor - int batch_id = batch_beam_index / num_beams; - if (prefix_vocab_mask != nullptr && prefix_vocab_mask[batch_id * vocab_size + word_id] == 0) { - next_token_scores[index] = cub::FpLimits::Lowest(); - return; - } - - // MinLengthLogitsProcessor - if (word_id == demote_token_id) { - next_token_scores[index] = cub::FpLimits::Lowest(); - } - } - } -} - -template -void LaunchLogitsProcessKernel( - T* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int batch_size, - int num_beams, - int vocab_size, - int padded_vocab_size, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size, - cudaStream_t stream) { - int total_elements = batch_size * num_beams * padded_vocab_size; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - LogitsProcessKernel<<>>( - next_token_scores, - vocab_mask, - prefix_vocab_mask, - num_beams, - vocab_size, - padded_vocab_size, - total_elements, - demote_token_id, - sequences, - max_sequence_length, - current_sequence_length, - repetition_penalty, - no_repeat_ngram_size); -} - -// Instantiation -template void LaunchLogitsProcessKernel( - float* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int batch_size, - int num_beams, - int vocab_size, - int padded_vocab_size, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size, - cudaStream_t stream); - -template void LaunchLogitsProcessKernel( - half* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int batch_size, - int num_beams, - int vocab_size, - int padded_vocab_size, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size, - cudaStream_t stream); - -__global__ void AddProbsKernel(float* log_probs, - float* cum_log_probs, - const int vocab_size, - const int total_elements) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int batch_beam_index = index / vocab_size; - - if (index < total_elements) - log_probs[index] += cum_log_probs[batch_beam_index]; -} - -template -void LaunchAddProbsKernel(T* log_probs, - T* cum_log_probs, - const int batch_size, - const int num_beams, - const int vocab_size, - cudaStream_t stream) { - int total_elements = batch_size * num_beams * vocab_size; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - AddProbsKernel<<>>(log_probs, cum_log_probs, vocab_size, total_elements); -} - -template void LaunchAddProbsKernel( - float* log_probs, - float* cum_log_probs, - const int batch_size, - const int num_beams, - const int vocab_size, - cudaStream_t stream); - -template -__global__ void UpdateGptInputsKernel(const T* old_mask_data, - T* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < batch_beam_size * current_length) { - // Update attention mask. - int i = index / current_length; - int j = index % current_length; - mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast(1); - - if (next_positions != nullptr) { - // Update sequence length (or next positions). - if (index < batch_beam_size) { - next_positions[index]++; - } - } - } -} - -void LaunchUpdateGptKernel(const int32_t* old_mask_data, - int32_t* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length, - cudaStream_t stream) { - assert(current_length > 0); - int total_elements = batch_beam_size * current_length; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - UpdateGptInputsKernel<<>>( - old_mask_data, mask_data, next_positions, batch_beam_size, current_length); -} - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h deleted file mode 100644 index 9b122ca797695..0000000000000 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -void LaunchInitKernel( - float* beam_scores, - int batch_size, - int num_beams, - cudaStream_t stream); - -template -void LaunchAddProbsKernel(T* log_probs, - T* cum_log_probs, - const int batch_size, - const int num_beams, - const int vocab_size, - cudaStream_t stream); - -template -void LaunchLogitsProcessKernel( - T* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int batch_size, - int num_beams, - int vocab_size, - int padded_vocab_size, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size, - cudaStream_t stream); - -void LaunchNextTokenKernel(const int64_t* next_token_indices, - int32_t* next_indices, - int32_t* next_tokens, - int batch_size, - int top_k, - int vocab_size, - cudaStream_t stream); - -void LaunchUpdateGptKernel(const int32_t* old_mask_data, - int32_t* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length, - cudaStream_t stream); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime From 647c9041331074b7f8c4e1b6d9f9ee15df0fc4d0 Mon Sep 17 00:00:00 2001 From: wangyems Date: Mon, 19 Dec 2022 23:14:44 +0000 Subject: [PATCH 45/51] format python --- onnxruntime/python/tools/transformers/convert_generation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 563eab645d6d4..ce830d27cbd69 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1617,13 +1617,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati ["batch_size", "num_return_sequences", "max_length"], ) elif is_greedysearch or is_sampling: - sequences = ( - onnx.helper.make_tensor_value_info( + sequences = onnx.helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "max_length"], ) - ) graph_outputs = [sequences] From 693c4ebb43bbb080e7b6a41af10a2a61716579ab Mon Sep 17 00:00:00 2001 From: wangyems Date: Mon, 19 Dec 2022 23:37:02 +0000 Subject: [PATCH 46/51] fix build --- .../contrib_ops/cuda/transformers/generation_cuda_impl.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 533f70966a961..a8b21f59f5174 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -174,6 +174,7 @@ void LaunchLogitsProcessKernel( int batch_size, int num_beams, int vocab_size, + int padded_vocab_size, int demote_token_id, int32_t* sequences, int max_sequence_length, @@ -193,6 +194,7 @@ void LaunchLogitsProcessKernel( temperature, num_beams, vocab_size, + padded_vocab_size, total_elements, demote_token_id, sequences, @@ -213,6 +215,7 @@ template void LaunchLogitsProcessKernel( int batch_size, int num_beams, int vocab_size, + int padded_vocab_size, int demote_token_id, int32_t* sequences, int max_sequence_length, @@ -231,6 +234,7 @@ template void LaunchLogitsProcessKernel( int batch_size, int num_beams, int vocab_size, + int padded_vocab_size, int demote_token_id, int32_t* sequences, int max_sequence_length, From bc71be8752a07a44896693f5fccdb6ce2a22198f Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 20 Dec 2022 03:41:16 +0000 Subject: [PATCH 47/51] support decoder init --- .../cpu/transformers/greedy_search.cc | 3 +- .../contrib_ops/cpu/transformers/sampling.cc | 65 ++++++++++++++----- .../contrib_ops/cpu/transformers/sampling.h | 8 +-- .../core/graph/contrib_ops/contrib_defs.cc | 4 ++ 4 files changed, 59 insertions(+), 21 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc index 330d88c312d7e..a33d03738e6ec 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc @@ -81,8 +81,7 @@ void GreedySearch::Init(const OpKernelInfo& info) { parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size); // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5) - ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt || - parameters_.model_type == IGenerationParameters::kModelTypeT5); + ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt); ONNX_NAMESPACE::GraphProto proto; if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index 85696d93cf25b..5ff3fa53974ee 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -40,14 +40,26 @@ namespace transformers { void Sampling::Init(const OpKernelInfo& info) { parameters_.ParseFromAttributes(info); + parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size); - // Check model_type 0 (GPT-2) - ORT_ENFORCE(parameters_.model_type == 0); + // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5) + ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt); ONNX_NAMESPACE::GraphProto proto; + if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { + // Make sure the encoder sub-graph attribute is present for the T5 model. + ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); + } + + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { + // Check if the init_decoder sub-graph attribute is present for the GPT2 model. + if (info.GetAttr("init_decoder", &proto).IsOK()) { + has_init_decoder_ = true; + } + } + // Make sure the decoder sub-graph attribute is present for all model types. ORT_ENFORCE(info.GetAttr("decoder", &proto).IsOK()); - ORT_IGNORE_RETURN_VALUE(proto); } Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, @@ -56,15 +68,29 @@ Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, const auto& node = Node(); if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2 if (attribute_name == "decoder") { - ORT_ENFORCE(gpt_subgraph_ == nullptr, - "SetupSubgraphExecutionInfo should only be called once for each subgraph."); - gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer()); - ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state)); + ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, + subgraph_session_state, parameters_); + + auto status = res.first; + if (!status.IsOK()) { + return status; + } + + gpt_subgraph_ = std::move(res.second); decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); - parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size, - gpt_subgraph_->num_heads, - gpt_subgraph_->head_size, - gpt_subgraph_->num_layers); + } else if (attribute_name == "init_decoder") { + ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, + subgraph_session_state, parameters_); + + auto status = res.first; + if (!status.IsOK()) { + return status; + } + + init_run_gpt_subgraph_ = std::move(res.second); + init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager(); } } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5 ORT_THROW("Not Implemented"); @@ -80,6 +106,15 @@ Status Sampling::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute."); ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); + auto* init_run_decoder_session_state = ctx_internal->SubgraphSessionState("init_decoder"); + if (has_init_decoder_) { + ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute."); + ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); + ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ + && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_, + "past_present_share_buffer mode must be same for init decoder and decoder subgraphes"); + } + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); // make a copy since we will update the parameters based on inputs later @@ -90,8 +125,8 @@ Status Sampling::Compute(OpKernelContext* ctx) const { if (!gpt_subgraph_->IsOutputFloat16()) { GreedySearchGpt impl{ *ctx_internal, - nullptr, // init decoder - nullptr, + has_init_decoder_ ? init_run_decoder_session_state : nullptr, + has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr, *decoder_session_state, *gpt_subgraph_, thread_pool, @@ -111,8 +146,8 @@ Status Sampling::Compute(OpKernelContext* ctx) const { } else { GreedySearchGpt impl{ *ctx_internal, - nullptr, // init decoder - nullptr, + has_init_decoder_ ? init_run_decoder_session_state : nullptr, + has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr, *decoder_session_state, *gpt_subgraph_, thread_pool, diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.h b/onnxruntime/contrib_ops/cpu/transformers/sampling.h index fb12f1dbc3db2..ea57ce15e2136 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.h @@ -24,7 +24,6 @@ class Sampling : public IControlFlowKernel { explicit Sampling(const OpKernelInfo& info) : IControlFlowKernel(info), decoder_feeds_fetches_manager_(nullptr), - cuda_stream_(nullptr), dumper_(nullptr) { Init(info); } @@ -38,7 +37,6 @@ class Sampling : public IControlFlowKernel { const SessionState& subgraph_session_state) override; protected: - void SetComputeStream(void* stream) { cuda_stream_ = stream; } void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; } // device helpers that is same for both GPT and encoder-decoder models. @@ -87,15 +85,17 @@ class Sampling : public IControlFlowKernel { //------------------------------------------------------------ // Subgraph and FeedsFetchesManager re-used for each subgraph execution. //------------------------------------------------------------ + std::unique_ptr init_run_gpt_subgraph_; std::unique_ptr gpt_subgraph_; FeedsFetchesManager* decoder_feeds_fetches_manager_; - - void* cuda_stream_; + FeedsFetchesManager* init_run_decoder_feeds_fetches_manager_; IConsoleDumper* dumper_; SamplingParameters parameters_; + + bool has_init_decoder_ = false; }; } // namespace transformers diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index b2db7ec9abf92..eb103b4af5b00 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1136,6 +1136,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Attr("custom", "absence penalty for sampling", AttributeProto::INT, static_cast(0)) .Attr("model_type", "model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast(0)) .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("init_decoder", + "The subgraph for the first decoding run. It will be called once before `decoder` subgraph. " + "This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs", + AttributeProto::GRAPH, OPTIONAL_VALUE) .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") From 26b88db0668bd9a42405ee4f49a43dca8a6fcce6 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 20 Dec 2022 20:56:13 +0000 Subject: [PATCH 48/51] refine contrib_defs.cc --- .../core/graph/contrib_ops/contrib_defs.cc | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index eb103b4af5b00..15f8599f52534 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1128,19 +1128,25 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) - .Attr("temperature", "temperature for sampling", AttributeProto::FLOAT, 1.0f) - .Attr("top_p", "top_p for sampling", AttributeProto::FLOAT, 0.0f) - .Attr("filter_value", "filter value for top_p", AttributeProto::FLOAT, -1e20f) - .Attr("min_tokens_to_keep", "min_tokens_to_keep", AttributeProto::INT, static_cast(0)) - .Attr("presence_penalty", "presence penalty for sampling", AttributeProto::FLOAT, 0.0f) - .Attr("custom", "absence penalty for sampling", AttributeProto::INT, static_cast(0)) - .Attr("model_type", "model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast(0)) + .Attr("temperature", "The value used to module the next token probabilities.", AttributeProto::FLOAT, 1.0f) + .Attr("top_p", + "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.", + AttributeProto::FLOAT, 0.0f) + .Attr("filter_value", "All filtered values will be set to this float value.", AttributeProto::FLOAT, -1e20f) + .Attr("min_tokens_to_keep", "Minimumber of tokens we keep per batch example in the output.", AttributeProto::INT, static_cast(0)) + .Attr("presence_penalty", "Presence penalty for custom sampling", AttributeProto::FLOAT, 0.0f) + .Attr("custom", "If 1 custom sampling logic", AttributeProto::INT, static_cast(0)) + .Attr("model_type", "Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast(0)) .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) .Attr("init_decoder", "The subgraph for the first decoding run. It will be called once before `decoder` subgraph. " "This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs", AttributeProto::GRAPH, OPTIONAL_VALUE) .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) + .Attr("vocab_size", + "Size of the vocabulary. " + "If not provided, it will be inferred from the decoder subgraph's output shape", + AttributeProto::INT, static_cast(-1)) .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) @@ -1148,9 +1154,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) - .Input(7, "presence_mask", "presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) + .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") - .Output(1, "filtered_logits", "filtered logits as input to the mutinomial function . Shape is (batch_size, vocab_size)", "T", OpSchema::Optional) + .Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { From 78f36f5fc552595a919a41ac499fc56e119a9f03 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 20 Dec 2022 22:56:44 +0000 Subject: [PATCH 49/51] update docs --- docs/ContribOperators.md | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 444d4a60bbc77..601db1695969b 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3823,7 +3823,7 @@ This version of the operator has been available since version 1 of the 'com.micr
custom : int
-
absence penalty for sampling
+
If 1 custom sampling logic
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_start_token_id : int
@@ -3833,21 +3833,25 @@ This version of the operator has been available since version 1 of the 'com.micr
eos_token_id : int (required)
The id of the end-of-sequence token
filter_value : float
-
filter value for top_p
+
All filtered values will be set to this float value.
+
init_decoder : graph
+
The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs
min_tokens_to_keep : int
-
min_tokens_to_keep
+
Minimumber of tokens we keep per batch example in the output.
model_type : int
-
model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart
+
Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart
no_repeat_ngram_size : int
no repeat ngrams size
pad_token_id : int (required)
The id of the padding token
presence_penalty : float
-
presence penalty for sampling
+
Presence penalty for custom sampling
temperature : float
-
temperature for sampling
+
The value used to module the next token probabilities.
top_p : float
-
top_p for sampling
+
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
+
vocab_size : int
+
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
#### Inputs (2 - 8) @@ -3868,7 +3872,7 @@ This version of the operator has been available since version 1 of the 'com.micr
attention_mask (optional) : I
Custom attention mask. Shape is (batch_size, sequence_length)
presence_mask (optional) : I
-
presence penalty mask. Shape is (batch_size, vocab_size)
+
Presence penalty mask. Shape is (batch_size, vocab_size)
#### Outputs (1 - 2) @@ -3877,7 +3881,7 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences : I
Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)
filtered_logits (optional) : T
-
filtered logits as input to the mutinomial function . Shape is (batch_size, vocab_size)
+
Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)
#### Type Constraints From 92231815cdfb3c44ec4e76092c3161aa211e6bc3 Mon Sep 17 00:00:00 2001 From: wangyems Date: Tue, 20 Dec 2022 23:27:50 +0000 Subject: [PATCH 50/51] enable padded_vocab_size and decoder_init for sampling op --- .../python/tools/transformers/convert_generation.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index ce830d27cbd69..59913245a589c 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1306,7 +1306,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati args.pad_vocab_size and args.precision == Precision.FLOAT16 and is_gpt2 - and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH) + and (is_beamsearch or is_greedysearch or is_sampling) ): logger.info( f"Pad logits MatMul weights for optimal MatMul perf in fp16 on {args.decoder_onnx}. " @@ -1320,11 +1320,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati gpt2_init_decoder_generated = False gpt2_init_decoder_onnx_path = None - if ( - args.separate_gpt2_decoder_for_init_run - and is_gpt2 - and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH) - ): + if args.separate_gpt2_decoder_for_init_run and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling): logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ") gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format( @@ -2177,7 +2173,7 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None if args.model_type == "gpt2" and is_greedy: if args.top_p > 0.0 and args.top_p < 1.0: convert_generation_model(args, GenerationType.SAMPLING) - logger.info("test sampling model is not implemented yet") + logger.info("The test for gpt2_sampling onnx model is not implemented yet") return convert_generation_model(args, GenerationType.GREEDYSEARCH) else: From f0383ab3ad27f3b48bb9ae52b5051592c6ebd243 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 22 Dec 2022 00:31:59 +0000 Subject: [PATCH 51/51] review comments --- .../contrib_ops/cpu/transformers/sampling.cc | 4 +-- .../cuda/transformers/generation_cuda_impl.cu | 33 +++++-------------- 2 files changed, 10 insertions(+), 27 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index 5ff3fa53974ee..a9b2db40e97df 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -142,7 +142,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; ORT_RETURN_IF_ERROR(impl.Initialize()); - return impl.Execute(nullptr, *decoder_feeds_fetches_manager_); + return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); } else { GreedySearchGpt impl{ *ctx_internal, @@ -163,7 +163,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { update_gpt_feeds_fp16_func_}; ORT_RETURN_IF_ERROR(impl.Initialize()); - return impl.Execute(nullptr, *decoder_feeds_fetches_manager_); + return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); } } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index a8b21f59f5174..f35ef8a40ba49 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -524,6 +524,7 @@ __global__ void FilterLogitsKernel(float* d_sorted_logits_in, int count = vocab_idx; float sum = 0.0f; + // TODO: Optimization needed. e.g. use CUB::SCAN() for cumulative probabilities. while (count >= 0) { sum += d_sorted_logits_in[start_index]; ++start_index; @@ -596,7 +597,7 @@ template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, // Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu -template +template __global__ void sampleMultinomialOnce(int64_t* dest, int distributions, int categories, @@ -605,9 +606,6 @@ __global__ void sampleMultinomialOnce(int64_t* dest, int stride_dist, // dist->stride(0) int stride_categories, // dist->stride(1) int* d_presence_mask) { - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - extern __shared__ unsigned char my_smem[]; __shared__ bool found; __shared__ unsigned foundPos; @@ -617,19 +615,9 @@ __global__ void sampleMultinomialOnce(int64_t* dest, for (int curDist = blockIdx.x; curDist < distributions; curDist += gridDim.x) { - // Each block handles one distribution - // First pass, find the total sum of the distribution - accscalar_t sum = accZero; - scalar_t val; - for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) { - val = dist[curDist * stride_dist + cat * stride_categories]; - // CUDA_KERNEL_ASSERT(!at::_isnan(val)); - // CUDA_KERNEL_ASSERT(!_isinf(val)); - // CUDA_KERNEL_ASSERT(!(val < zero)); - sum = sum + static_cast(val); - } - // threadIdx.x == 0 has the sum value from this - sum = BlockReduce(tmp_storage).Reduce(sum, cub::Sum()); // sum = cuda_utils::BlockReduceSum(sum, smem); + // Assume sum = 1 in Top P sampling as the input is softmaxed. + accscalar_t sum = 1; + // Broadcast sum and sample value if (threadIdx.x == 0) { // Make sure the sum of our distribution didn't overflow @@ -748,24 +736,19 @@ void TorchMultinomialKernelLauncher(float* d_input, int requiredThreads = std::min(maxThreads, requiredWarps * warp_size); int requiredShared = requiredThreads * sizeof(float); - // bugbug: randomize d_sampled dim3 block(requiredThreads); dim3 grid(std::min(batch_size, numSM * 4)); - if (block.x == 1024) { - const int block_size = 1024; - sampleMultinomialOnce + sampleMultinomialOnce <<>>(d_output, batch_size, vocab_size, d_sampled, d_input, vocab_size, - batch_size, + 1, d_presence_mask); - } else { - printf("Please add more cases for block size"); - } + }