diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index ac2130f36c62e..efbdbb52a2507 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -361,7 +361,7 @@ This version of the operator has been available since version 1 of the 'com.micr
The id of the padding token
-#### Inputs (6 - 9) +#### Inputs (6 - 10)
input_ids : I
@@ -382,6 +382,8 @@ This version of the operator has been available since version 1 of the 'com.micr
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
prefix_vocab_mask (optional) : M
+
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
#### Outputs (1 - 3) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 4d885fdb15530..35ac6e0e5b154 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -377,7 +377,7 @@ Do not modify directly.* |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| -|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* temperature:**T**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| +|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* temperature:**T**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| |BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 9925501b1de62..ee33ac7850c39 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -153,12 +153,14 @@ class BeamSearchImpl { Status GenerateNextToken(const OrtValue& logits, gsl::span& beam_next_tokens, gsl::span& beam_indices, - BeamSearchState& beam_state); + BeamSearchState& beam_state, + int counter); // Calculate scores from logits, then apply filtering and select next token for each beam. Status ProcessLogits(const OrtValue& logits, // logits output of subgraph BeamSearchState& beam_state, - AllocatorPtr& allocator); + AllocatorPtr& allocator, + int counter); OpKernelContextInternal& context_; @@ -292,6 +294,30 @@ Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) { parameters_->vocab_mask = vocab_mask->DataAsSpan(); } + const Tensor* prefix_vocab_mask = context.Input(9); + if (prefix_vocab_mask != nullptr) { + // prefix_vocab_mask is optional + const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims(); + if (vocab_mask_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ", + vocab_mask_dims.size()); + } + + // prefix_vocab_mask first dimension should be same as the first dimension of input_ids + if (static_cast(vocab_mask_dims[0]) != static_cast(dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size"); + } + + // There is dependency on vocab_size parameter, which shall be set before calling this function. + if (static_cast(vocab_mask_dims[1]) != parameters_->vocab_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ", + vocab_mask_dims[0]); + } + + // store prefix vocab mask in parameters. + parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan(); + } + return Status::OK(); } @@ -346,7 +372,8 @@ template Status BeamSearchImpl::ProcessLogits( const OrtValue& logits, BeamSearchState& beam_state, - AllocatorPtr& allocator) { + AllocatorPtr& allocator, + int counter) { const int64_t batch_beam_size = static_cast(parameters_->BatchBeamSize()); const int& vocab_size = parameters_->vocab_size; @@ -394,7 +421,7 @@ Status BeamSearchImpl::ProcessLogits( #endif // Apply all score processors that updates scores - logits_processors_.Process(&(beam_state.sequences), next_token_scores); + logits_processors_.Process(&(beam_state.sequences), next_token_scores, counter); #ifdef DEBUG_BEAM_SEARCH DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size); @@ -486,9 +513,10 @@ Status BeamSearchImpl::GenerateNextToken( const OrtValue& logits, gsl::span& beam_next_tokens, gsl::span& beam_indices, - BeamSearchState& beam_state) { + BeamSearchState& beam_state, + int counter) { // Process logits to get next token scores - ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_)); + ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_, counter)); gsl::span& beam_scores = beam_scorer_->GetNextScores(); // It is optional to clone beam_scores. Change it to use same buffer also works: @@ -587,7 +615,9 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { #endif int current_length = parameters_->sequence_length; + int iteration_counter = 0; while (current_length < parameters_->max_length) { + iteration_counter++; #ifdef DEBUG_BEAM_SEARCH DumpString("***CurrentLength", std::to_string(current_length), true); #endif @@ -600,7 +630,7 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) { const OrtValue& logits = fetches[0]; gsl::span beam_next_tokens; gsl::span beam_indices; - ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state)); + ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, iteration_counter)); // When all batches are finished, stop earlier to avoid wasting computation. if (beam_scorer_->IsDone()) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 26de2a98408eb..389f24ecdeb66 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -28,7 +28,8 @@ struct BeamSearchParameters { int sequence_length; // deduce from second dimension of input_ids gsl::span vocab_mask; - + gsl::span prefix_vocab_mask; + // Parameters from outputs. bool output_scores; // whether scores existed in output diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index ca8e722bf8c93..6473ad5d14b5d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -9,6 +9,10 @@ namespace onnxruntime { namespace contrib { namespace transformers { +// beam_search_iteration represents the current iteration counter of beam search +// This value is used to apply processors as needed in specific iteration. +static int beam_search_iteration; + template gsl::span NextTokenScores::GetScores(int batch_beam_index) { assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); @@ -146,6 +150,41 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, #endif } +template +PrefixVocabMaskLogitsProcessor::PrefixVocabMaskLogitsProcessor(const gsl::span& prefix_vocab_mask, int batch_size) : prefix_vocab_mask_(prefix_vocab_mask), batch_size_(batch_size) { +} + +template +void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, + NextTokenScores& next_token_scores) { + assert(!prefix_vocab_mask_.empty()); + + if (beam_search_iteration > 1) { + return; + } + // next_token_scores shape (batch_size * num_beams, vocab_size) + int num_beams = next_token_scores.batch_beam_size / batch_size_; + assert(num_beams * batch_size_ == next_token_scores.batch_beam_size); + + // Process prefix vocabulary mask and set tokens with mask value 0 to -inf. + // prefix_vocab_mask shape (batch_szie, vocab_size). + T* p = next_token_scores.scores.data(); + for (int i = 0; i < batch_size_; i++) { + int prefix_vocab_mask_offset = i * next_token_scores.vocab_size; + for (int j = 0; j < num_beams; j++) { + for (int k = 0; k < next_token_scores.vocab_size; k++, p++) { + if (prefix_vocab_mask_[prefix_vocab_mask_offset + k] == 0) { + *p = std::numeric_limits::lowest(); + } + } + } + } + +#ifdef DEBUG_BEAM_SEARCH + DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores.scores); +#endif +} + template void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { processor_list_.clear(); @@ -165,6 +204,11 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { processor_list_.push_back(vocab_mask_processor_.get()); } + if (!parameters.prefix_vocab_mask.empty()) { + prefix_vocab_mask_processor_ = std::make_unique>(parameters.prefix_vocab_mask, parameters.batch_size); + processor_list_.push_back(prefix_vocab_mask_processor_.get()); + } + if (parameters.min_length > 0) { min_length_processor_ = std::make_unique>(parameters.min_length, parameters.eos_token_id); processor_list_.push_back(min_length_processor_.get()); @@ -176,8 +220,10 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { template void LogitsProcessorList::Process(const ISequences* sequences, - gsl::span& next_token_scores) { + gsl::span& next_token_scores, + int counter) { NextTokenScores input_scores = {next_token_scores, batch_beam_size_, vocab_size_}; + beam_search_iteration = counter; for (size_t i = 0; i < processor_list_.size(); i++) { processor_list_[i]->Process(sequences, input_scores); } @@ -188,6 +234,7 @@ template class MinLengthLogitsProcessor; template class RepetitionPenaltyLogitsProcessor; template class NoRepeatNGramLogitsProcessor; template class VocabMaskLogitsProcessor; +template class PrefixVocabMaskLogitsProcessor; template class LogitsProcessorList; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 78fe9acf63bcb..f5985b966ba05 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -76,12 +76,25 @@ class VocabMaskLogitsProcessor : public ILogitsProcessor { gsl::span vocab_mask_; }; +template +class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor { + public: + PrefixVocabMaskLogitsProcessor(const gsl::span& vocab_mask, int batch_size); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + gsl::span prefix_vocab_mask_; + const int batch_size_; +}; + template class LogitsProcessorList { public: LogitsProcessorList() = default ; void Init(const BeamSearchParameters& parameters); - void Process(const ISequences* sequences, gsl::span& next_token_scores); + void Process(const ISequences* sequences, gsl::span& next_token_scores, int counter); private: int batch_beam_size_; @@ -91,6 +104,7 @@ class LogitsProcessorList { std::unique_ptr> repetition_penalty_processor_; std::unique_ptr> no_repeat_ngram_processor_; std::unique_ptr> vocab_mask_processor_; + std::unique_ptr> prefix_vocab_mask_processor_; std::unique_ptr> min_length_processor_; }; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index fcc219799b46a..950dd1760eded 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -695,6 +695,7 @@ void RegisterTextGenerationSchemas() { "T", OpSchema::Optional) .Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) .Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(9, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 3f251fb9ad507..8bb25fd609b15 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -128,6 +128,18 @@ def parse_arguments(argv=None): default=1, help='Positive. >1 to penalize and <1 to encorage.') + beam_search_group.add_argument('--vocab_size', + type=int, + required=False, + default=-1, + help="Vocab_size of the underlying model") + + beam_search_group.add_argument('--prefix_vocab_mask', + required=False, + action='store_true', + help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete") + beam_search_group.set_defaults(prefix_vocab_mask=False) + mixed_precision_option_group = parser.add_argument_group( "mixed precision conversion parameters that works when \"--precision fp16\" is specified") @@ -230,12 +242,18 @@ def convert_model(args): pad_token_id = config.eos_token_id vocab_size = config.vocab_size + # if vocab_size is given in parameters use that. + if args.vocab_size != -1: + vocab_size = args.vocab_size + model = onnx.load(args.gpt2_onnx) model.graph.name = "gpt2 subgraph" inputs = [ "input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty", "repetition_penalty", "vocab_mask" ] + if args.prefix_vocab_mask: + inputs.append("prefix_vocab_mask") outputs = ["sequences"] if args.output_sequences_scores: @@ -273,6 +291,10 @@ def convert_model(args): repetition_penalty, vocab_mask ] + if args.prefix_vocab_mask: + prefix_vocab_mask = helper.make_tensor_value_info('prefix_vocab_mask', TensorProto.INT32, ['batch_size', vocab_size]) + graph_inputs.append(prefix_vocab_mask) + # graph outputs sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32, ['batch_size', 'num_return_sequences', 'max_length']) @@ -301,6 +323,11 @@ def convert_model(args): def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): + + if args.prefix_vocab_mask: + print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face") + return + from transformers import GPT2Tokenizer, GPT2LMHeadModel tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)