diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index ac2130f36c62e..efbdbb52a2507 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -361,7 +361,7 @@ This version of the operator has been available since version 1 of the 'com.micr
The id of the padding token
-#### Inputs (6 - 9)
+#### Inputs (6 - 10)
- input_ids : I
@@ -382,6 +382,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
- vocab_mask (optional) : M
- Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+- prefix_vocab_mask (optional) : M
+- Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
#### Outputs (1 - 3)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 4d885fdb15530..35ac6e0e5b154 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -377,7 +377,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
-|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* temperature:**T**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
+|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* temperature:**T**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
index 9925501b1de62..ee33ac7850c39 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -153,12 +153,14 @@ class BeamSearchImpl {
Status GenerateNextToken(const OrtValue& logits,
gsl::span& beam_next_tokens,
gsl::span& beam_indices,
- BeamSearchState& beam_state);
+ BeamSearchState& beam_state,
+ int counter);
// Calculate scores from logits, then apply filtering and select next token for each beam.
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
BeamSearchState& beam_state,
- AllocatorPtr& allocator);
+ AllocatorPtr& allocator,
+ int counter);
OpKernelContextInternal& context_;
@@ -292,6 +294,30 @@ Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) {
parameters_->vocab_mask = vocab_mask->DataAsSpan();
}
+ const Tensor* prefix_vocab_mask = context.Input(9);
+ if (prefix_vocab_mask != nullptr) {
+ // prefix_vocab_mask is optional
+ const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims();
+ if (vocab_mask_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ",
+ vocab_mask_dims.size());
+ }
+
+ // prefix_vocab_mask first dimension should be same as the first dimension of input_ids
+ if (static_cast(vocab_mask_dims[0]) != static_cast(dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size");
+ }
+
+ // There is dependency on vocab_size parameter, which shall be set before calling this function.
+ if (static_cast(vocab_mask_dims[1]) != parameters_->vocab_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ",
+ vocab_mask_dims[0]);
+ }
+
+ // store prefix vocab mask in parameters.
+ parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan();
+ }
+
return Status::OK();
}
@@ -346,7 +372,8 @@ template
Status BeamSearchImpl::ProcessLogits(
const OrtValue& logits,
BeamSearchState& beam_state,
- AllocatorPtr& allocator) {
+ AllocatorPtr& allocator,
+ int counter) {
const int64_t batch_beam_size = static_cast(parameters_->BatchBeamSize());
const int& vocab_size = parameters_->vocab_size;
@@ -394,7 +421,7 @@ Status BeamSearchImpl::ProcessLogits(
#endif
// Apply all score processors that updates scores
- logits_processors_.Process(&(beam_state.sequences), next_token_scores);
+ logits_processors_.Process(&(beam_state.sequences), next_token_scores, counter);
#ifdef DEBUG_BEAM_SEARCH
DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
@@ -486,9 +513,10 @@ Status BeamSearchImpl::GenerateNextToken(
const OrtValue& logits,
gsl::span& beam_next_tokens,
gsl::span& beam_indices,
- BeamSearchState& beam_state) {
+ BeamSearchState& beam_state,
+ int counter) {
// Process logits to get next token scores
- ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_));
+ ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_, counter));
gsl::span& beam_scores = beam_scorer_->GetNextScores();
// It is optional to clone beam_scores. Change it to use same buffer also works:
@@ -587,7 +615,9 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) {
#endif
int current_length = parameters_->sequence_length;
+ int iteration_counter = 0;
while (current_length < parameters_->max_length) {
+ iteration_counter++;
#ifdef DEBUG_BEAM_SEARCH
DumpString("***CurrentLength", std::to_string(current_length), true);
#endif
@@ -600,7 +630,7 @@ Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) {
const OrtValue& logits = fetches[0];
gsl::span beam_next_tokens;
gsl::span beam_indices;
- ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state));
+ ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, iteration_counter));
// When all batches are finished, stop earlier to avoid wasting computation.
if (beam_scorer_->IsDone()) {
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h
index 26de2a98408eb..389f24ecdeb66 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h
@@ -28,7 +28,8 @@ struct BeamSearchParameters {
int sequence_length; // deduce from second dimension of input_ids
gsl::span vocab_mask;
-
+ gsl::span prefix_vocab_mask;
+
// Parameters from outputs.
bool output_scores; // whether scores existed in output
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
index ca8e722bf8c93..6473ad5d14b5d 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
@@ -9,6 +9,10 @@ namespace onnxruntime {
namespace contrib {
namespace transformers {
+// beam_search_iteration represents the current iteration counter of beam search
+// This value is used to apply processors as needed in specific iteration.
+static int beam_search_iteration;
+
template
gsl::span NextTokenScores::GetScores(int batch_beam_index) {
assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size);
@@ -146,6 +150,41 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/,
#endif
}
+template
+PrefixVocabMaskLogitsProcessor::PrefixVocabMaskLogitsProcessor(const gsl::span& prefix_vocab_mask, int batch_size) : prefix_vocab_mask_(prefix_vocab_mask), batch_size_(batch_size) {
+}
+
+template
+void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/,
+ NextTokenScores& next_token_scores) {
+ assert(!prefix_vocab_mask_.empty());
+
+ if (beam_search_iteration > 1) {
+ return;
+ }
+ // next_token_scores shape (batch_size * num_beams, vocab_size)
+ int num_beams = next_token_scores.batch_beam_size / batch_size_;
+ assert(num_beams * batch_size_ == next_token_scores.batch_beam_size);
+
+ // Process prefix vocabulary mask and set tokens with mask value 0 to -inf.
+ // prefix_vocab_mask shape (batch_szie, vocab_size).
+ T* p = next_token_scores.scores.data();
+ for (int i = 0; i < batch_size_; i++) {
+ int prefix_vocab_mask_offset = i * next_token_scores.vocab_size;
+ for (int j = 0; j < num_beams; j++) {
+ for (int k = 0; k < next_token_scores.vocab_size; k++, p++) {
+ if (prefix_vocab_mask_[prefix_vocab_mask_offset + k] == 0) {
+ *p = std::numeric_limits::lowest();
+ }
+ }
+ }
+ }
+
+#ifdef DEBUG_BEAM_SEARCH
+ DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores.scores);
+#endif
+}
+
template
void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
processor_list_.clear();
@@ -165,6 +204,11 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
processor_list_.push_back(vocab_mask_processor_.get());
}
+ if (!parameters.prefix_vocab_mask.empty()) {
+ prefix_vocab_mask_processor_ = std::make_unique>(parameters.prefix_vocab_mask, parameters.batch_size);
+ processor_list_.push_back(prefix_vocab_mask_processor_.get());
+ }
+
if (parameters.min_length > 0) {
min_length_processor_ = std::make_unique>(parameters.min_length, parameters.eos_token_id);
processor_list_.push_back(min_length_processor_.get());
@@ -176,8 +220,10 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
template
void LogitsProcessorList::Process(const ISequences* sequences,
- gsl::span& next_token_scores) {
+ gsl::span& next_token_scores,
+ int counter) {
NextTokenScores input_scores = {next_token_scores, batch_beam_size_, vocab_size_};
+ beam_search_iteration = counter;
for (size_t i = 0; i < processor_list_.size(); i++) {
processor_list_[i]->Process(sequences, input_scores);
}
@@ -188,6 +234,7 @@ template class MinLengthLogitsProcessor;
template class RepetitionPenaltyLogitsProcessor;
template class NoRepeatNGramLogitsProcessor;
template class VocabMaskLogitsProcessor;
+template class PrefixVocabMaskLogitsProcessor;
template class LogitsProcessorList;
} // namespace transformers
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
index 78fe9acf63bcb..f5985b966ba05 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
@@ -76,12 +76,25 @@ class VocabMaskLogitsProcessor : public ILogitsProcessor {
gsl::span vocab_mask_;
};
+template
+class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor {
+ public:
+ PrefixVocabMaskLogitsProcessor(const gsl::span& vocab_mask, int batch_size);
+
+ void Process(const ISequences* sequences,
+ NextTokenScores& next_token_scores) override;
+
+ private:
+ gsl::span prefix_vocab_mask_;
+ const int batch_size_;
+};
+
template
class LogitsProcessorList {
public:
LogitsProcessorList() = default ;
void Init(const BeamSearchParameters& parameters);
- void Process(const ISequences* sequences, gsl::span& next_token_scores);
+ void Process(const ISequences* sequences, gsl::span& next_token_scores, int counter);
private:
int batch_beam_size_;
@@ -91,6 +104,7 @@ class LogitsProcessorList {
std::unique_ptr> repetition_penalty_processor_;
std::unique_ptr> no_repeat_ngram_processor_;
std::unique_ptr> vocab_mask_processor_;
+ std::unique_ptr> prefix_vocab_mask_processor_;
std::unique_ptr> min_length_processor_;
};
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index fcc219799b46a..950dd1760eded 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -695,6 +695,7 @@ void RegisterTextGenerationSchemas() {
"T", OpSchema::Optional)
.Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
.Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
+ .Input(9, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py
index 3f251fb9ad507..8bb25fd609b15 100644
--- a/onnxruntime/python/tools/transformers/convert_beam_search.py
+++ b/onnxruntime/python/tools/transformers/convert_beam_search.py
@@ -128,6 +128,18 @@ def parse_arguments(argv=None):
default=1,
help='Positive. >1 to penalize and <1 to encorage.')
+ beam_search_group.add_argument('--vocab_size',
+ type=int,
+ required=False,
+ default=-1,
+ help="Vocab_size of the underlying model")
+
+ beam_search_group.add_argument('--prefix_vocab_mask',
+ required=False,
+ action='store_true',
+ help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete")
+ beam_search_group.set_defaults(prefix_vocab_mask=False)
+
mixed_precision_option_group = parser.add_argument_group(
"mixed precision conversion parameters that works when \"--precision fp16\" is specified")
@@ -230,12 +242,18 @@ def convert_model(args):
pad_token_id = config.eos_token_id
vocab_size = config.vocab_size
+ # if vocab_size is given in parameters use that.
+ if args.vocab_size != -1:
+ vocab_size = args.vocab_size
+
model = onnx.load(args.gpt2_onnx)
model.graph.name = "gpt2 subgraph"
inputs = [
"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty",
"repetition_penalty", "vocab_mask"
]
+ if args.prefix_vocab_mask:
+ inputs.append("prefix_vocab_mask")
outputs = ["sequences"]
if args.output_sequences_scores:
@@ -273,6 +291,10 @@ def convert_model(args):
repetition_penalty, vocab_mask
]
+ if args.prefix_vocab_mask:
+ prefix_vocab_mask = helper.make_tensor_value_info('prefix_vocab_mask', TensorProto.INT32, ['batch_size', vocab_size])
+ graph_inputs.append(prefix_vocab_mask)
+
# graph outputs
sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32,
['batch_size', 'num_return_sequences', 'max_length'])
@@ -301,6 +323,11 @@ def convert_model(args):
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
+
+ if args.prefix_vocab_mask:
+ print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
+ return
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)