Skip to content

Commit

Permalink
Prefix match in first iteration of beam search OP (microsoft#10231)
Browse files Browse the repository at this point in the history
* Add BeamSearch op schema

* Add ONNX conversion for beams search

* remove attention_mask and change input order

* add option to run baseline

* add check data type NULL

* applies VerifyNodeAndOpMatch to subgraph

* update input_ids shape

* Add node name for Cast node

* expose API for topk

* parse parameters

* Add beam search scorer

* output results

* fix typo

* use c++ template and format python

* fix build pipeline errors

* symbolic shape infer of input onnx

* output scores

* add kernel def hash

* Handle vocab_mask; move CheckSubgraph

* undo insert_cast_transformer.cc and fusion_utils.py

* fix typo

* fix merge

* update doc

* add repetition penalty

* refactoring: add GptSubgraph class

* move BeamSearchState from .h to .cc file

* adjust logits processor order

* add batch generation example

* fix repetition penalty for dup words in sequence

* Add test

* Add no repeat ngram processor

* refactoring: move logits processor to classes

* fix build warning

* show latency

* use allocator in beam state

* use allocator in sequences

* fix build error

* move next_positions to beam state

* Changes for prefix matching

* removing debugs

* removing more debugs

* clean up

* clean up

* cpu doc updated

* Updated docs

* updated prefix_vocab_mask dimension in convert script

* changes to support bxs prefix_vocab_mask in beamsearchop kernel

* doc update

* OperatorKernels.md updated

* matching docs from artifacts

* minor change in logits processor

* Addressing comments

* Updated the prefix vocab mask usage properly

Co-authored-by: Tianlei Wu <[email protected]>
  • Loading branch information
viboga and tianleiwu authored Feb 2, 2022
1 parent 1aa0789 commit ad9d2e2
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 12 deletions.
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>The id of the padding token</dd>
</dl>

#### Inputs (6 - 9)
#### Inputs (6 - 10)

<dl>
<dt><tt>input_ids</tt> : I</dt>
Expand All @@ -382,6 +382,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
</dl>

#### Outputs (1 - 3)
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* temperature:**T**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* temperature:**T**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
Expand Down
44 changes: 37 additions & 7 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ class BeamSearchImpl {
Status GenerateNextToken(const OrtValue& logits,
gsl::span<int64_t>& beam_next_tokens,
gsl::span<int64_t>& beam_indices,
BeamSearchState<T>& beam_state);
BeamSearchState<T>& beam_state,
int counter);

// Calculate scores from logits, then apply filtering and select next token for each beam.
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
BeamSearchState<T>& beam_state,
AllocatorPtr& allocator);
AllocatorPtr& allocator,
int counter);

OpKernelContextInternal& context_;

Expand Down Expand Up @@ -292,6 +294,30 @@ Status BeamSearchImpl<T>::CheckInputs(const OpKernelContextInternal& context) {
parameters_->vocab_mask = vocab_mask->DataAsSpan<int32_t>();
}

const Tensor* prefix_vocab_mask = context.Input<Tensor>(9);
if (prefix_vocab_mask != nullptr) {
// prefix_vocab_mask is optional
const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims();
if (vocab_mask_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ",
vocab_mask_dims.size());
}

// prefix_vocab_mask first dimension should be same as the first dimension of input_ids
if (static_cast<int>(vocab_mask_dims[0]) != static_cast<int>(dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size");
}

// There is dependency on vocab_size parameter, which shall be set before calling this function.
if (static_cast<int>(vocab_mask_dims[1]) != parameters_->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ",
vocab_mask_dims[0]);
}

// store prefix vocab mask in parameters.
parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan<int32_t>();
}

return Status::OK();
}

Expand Down Expand Up @@ -346,7 +372,8 @@ template <typename T>
Status BeamSearchImpl<T>::ProcessLogits(
const OrtValue& logits,
BeamSearchState<T>& beam_state,
AllocatorPtr& allocator) {
AllocatorPtr& allocator,
int counter) {
const int64_t batch_beam_size = static_cast<int64_t>(parameters_->BatchBeamSize());
const int& vocab_size = parameters_->vocab_size;

Expand Down Expand Up @@ -394,7 +421,7 @@ Status BeamSearchImpl<T>::ProcessLogits(
#endif

// Apply all score processors that updates scores
logits_processors_.Process(&(beam_state.sequences), next_token_scores);
logits_processors_.Process(&(beam_state.sequences), next_token_scores, counter);

#ifdef DEBUG_BEAM_SEARCH
DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
Expand Down Expand Up @@ -486,9 +513,10 @@ Status BeamSearchImpl<T>::GenerateNextToken(
const OrtValue& logits,
gsl::span<int64_t>& beam_next_tokens,
gsl::span<int64_t>& beam_indices,
BeamSearchState<T>& beam_state) {
BeamSearchState<T>& beam_state,
int counter) {
// Process logits to get next token scores
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_));
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_, counter));

gsl::span<T>& beam_scores = beam_scorer_->GetNextScores();
// It is optional to clone beam_scores. Change it to use same buffer also works:
Expand Down Expand Up @@ -587,7 +615,9 @@ Status BeamSearchImpl<T>::Execute(const FeedsFetchesManager& ffm) {
#endif

int current_length = parameters_->sequence_length;
int iteration_counter = 0;
while (current_length < parameters_->max_length) {
iteration_counter++;
#ifdef DEBUG_BEAM_SEARCH
DumpString("***CurrentLength", std::to_string(current_length), true);
#endif
Expand All @@ -600,7 +630,7 @@ Status BeamSearchImpl<T>::Execute(const FeedsFetchesManager& ffm) {
const OrtValue& logits = fetches[0];
gsl::span<int64_t> beam_next_tokens;
gsl::span<int64_t> beam_indices;
ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state));
ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, iteration_counter));

// When all batches are finished, stop earlier to avoid wasting computation.
if (beam_scorer_->IsDone()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ struct BeamSearchParameters {
int sequence_length; // deduce from second dimension of input_ids

gsl::span<const int32_t> vocab_mask;

gsl::span<const int32_t> prefix_vocab_mask;

// Parameters from outputs.
bool output_scores; // whether scores existed in output

Expand Down
49 changes: 48 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ namespace onnxruntime {
namespace contrib {
namespace transformers {

// beam_search_iteration represents the current iteration counter of beam search
// This value is used to apply processors as needed in specific iteration.
static int beam_search_iteration;

template <typename T>
gsl::span<T> NextTokenScores<T>::GetScores(int batch_beam_index) {
assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size);
Expand Down Expand Up @@ -146,6 +150,41 @@ void VocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
#endif
}

template <typename T>
PrefixVocabMaskLogitsProcessor<T>::PrefixVocabMaskLogitsProcessor(const gsl::span<const int32_t>& prefix_vocab_mask, int batch_size) : prefix_vocab_mask_(prefix_vocab_mask), batch_size_(batch_size) {
}

template <typename T>
void PrefixVocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
NextTokenScores<T>& next_token_scores) {
assert(!prefix_vocab_mask_.empty());

if (beam_search_iteration > 1) {
return;
}
// next_token_scores shape (batch_size * num_beams, vocab_size)
int num_beams = next_token_scores.batch_beam_size / batch_size_;
assert(num_beams * batch_size_ == next_token_scores.batch_beam_size);

// Process prefix vocabulary mask and set tokens with mask value 0 to -inf.
// prefix_vocab_mask shape (batch_szie, vocab_size).
T* p = next_token_scores.scores.data();
for (int i = 0; i < batch_size_; i++) {
int prefix_vocab_mask_offset = i * next_token_scores.vocab_size;
for (int j = 0; j < num_beams; j++) {
for (int k = 0; k < next_token_scores.vocab_size; k++, p++) {
if (prefix_vocab_mask_[prefix_vocab_mask_offset + k] == 0) {
*p = std::numeric_limits<T>::lowest();
}
}
}
}

#ifdef DEBUG_BEAM_SEARCH
DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores.scores);
#endif
}

template <typename T>
void LogitsProcessorList<T>::Init(const BeamSearchParameters& parameters) {
processor_list_.clear();
Expand All @@ -165,6 +204,11 @@ void LogitsProcessorList<T>::Init(const BeamSearchParameters& parameters) {
processor_list_.push_back(vocab_mask_processor_.get());
}

if (!parameters.prefix_vocab_mask.empty()) {
prefix_vocab_mask_processor_ = std::make_unique<PrefixVocabMaskLogitsProcessor<T>>(parameters.prefix_vocab_mask, parameters.batch_size);
processor_list_.push_back(prefix_vocab_mask_processor_.get());
}

if (parameters.min_length > 0) {
min_length_processor_ = std::make_unique<MinLengthLogitsProcessor<T>>(parameters.min_length, parameters.eos_token_id);
processor_list_.push_back(min_length_processor_.get());
Expand All @@ -176,8 +220,10 @@ void LogitsProcessorList<T>::Init(const BeamSearchParameters& parameters) {

template <typename T>
void LogitsProcessorList<T>::Process(const ISequences* sequences,
gsl::span<T>& next_token_scores) {
gsl::span<T>& next_token_scores,
int counter) {
NextTokenScores<T> input_scores = {next_token_scores, batch_beam_size_, vocab_size_};
beam_search_iteration = counter;
for (size_t i = 0; i < processor_list_.size(); i++) {
processor_list_[i]->Process(sequences, input_scores);
}
Expand All @@ -188,6 +234,7 @@ template class MinLengthLogitsProcessor<float>;
template class RepetitionPenaltyLogitsProcessor<float>;
template class NoRepeatNGramLogitsProcessor<float>;
template class VocabMaskLogitsProcessor<float>;
template class PrefixVocabMaskLogitsProcessor<float>;
template class LogitsProcessorList<float>;

} // namespace transformers
Expand Down
16 changes: 15 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,25 @@ class VocabMaskLogitsProcessor : public ILogitsProcessor<T> {
gsl::span<const int32_t> vocab_mask_;
};

template <typename T>
class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor<T> {
public:
PrefixVocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask, int batch_size);

void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;

private:
gsl::span<const int32_t> prefix_vocab_mask_;
const int batch_size_;
};

template <typename T>
class LogitsProcessorList {
public:
LogitsProcessorList() = default ;
void Init(const BeamSearchParameters& parameters);
void Process(const ISequences* sequences, gsl::span<T>& next_token_scores);
void Process(const ISequences* sequences, gsl::span<T>& next_token_scores, int counter);

private:
int batch_beam_size_;
Expand All @@ -91,6 +104,7 @@ class LogitsProcessorList {
std::unique_ptr<RepetitionPenaltyLogitsProcessor<T>> repetition_penalty_processor_;
std::unique_ptr<NoRepeatNGramLogitsProcessor<T>> no_repeat_ngram_processor_;
std::unique_ptr<VocabMaskLogitsProcessor<T>> vocab_mask_processor_;
std::unique_ptr<PrefixVocabMaskLogitsProcessor<T>> prefix_vocab_mask_processor_;
std::unique_ptr<MinLengthLogitsProcessor<T>> min_length_processor_;
};

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ void RegisterTextGenerationSchemas() {
"T", OpSchema::Optional)
.Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
.Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
.Input(9, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/python/tools/transformers/convert_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ def parse_arguments(argv=None):
default=1,
help='Positive. >1 to penalize and <1 to encorage.')

beam_search_group.add_argument('--vocab_size',
type=int,
required=False,
default=-1,
help="Vocab_size of the underlying model")

beam_search_group.add_argument('--prefix_vocab_mask',
required=False,
action='store_true',
help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete")
beam_search_group.set_defaults(prefix_vocab_mask=False)

mixed_precision_option_group = parser.add_argument_group(
"mixed precision conversion parameters that works when \"--precision fp16\" is specified")

Expand Down Expand Up @@ -230,12 +242,18 @@ def convert_model(args):
pad_token_id = config.eos_token_id
vocab_size = config.vocab_size

# if vocab_size is given in parameters use that.
if args.vocab_size != -1:
vocab_size = args.vocab_size

model = onnx.load(args.gpt2_onnx)
model.graph.name = "gpt2 subgraph"
inputs = [
"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty",
"repetition_penalty", "vocab_mask"
]
if args.prefix_vocab_mask:
inputs.append("prefix_vocab_mask")

outputs = ["sequences"]
if args.output_sequences_scores:
Expand Down Expand Up @@ -273,6 +291,10 @@ def convert_model(args):
repetition_penalty, vocab_mask
]

if args.prefix_vocab_mask:
prefix_vocab_mask = helper.make_tensor_value_info('prefix_vocab_mask', TensorProto.INT32, ['batch_size', vocab_size])
graph_inputs.append(prefix_vocab_mask)

# graph outputs
sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32,
['batch_size', 'num_return_sequences', 'max_length'])
Expand Down Expand Up @@ -301,6 +323,11 @@ def convert_model(args):


def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):

if args.prefix_vocab_mask:
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
return

from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
Expand Down

0 comments on commit ad9d2e2

Please sign in to comment.