Skip to content

Commit

Permalink
Sampling op (#13426)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

Sampling op for cpu and cuda
support huggingface case and custom case
            


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
wangyems and Ubuntu authored Dec 23, 2022
1 parent 4d2dc8b commit 68518a1
Show file tree
Hide file tree
Showing 41 changed files with 2,404 additions and 509 deletions.
7 changes: 5 additions & 2 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,13 @@ set(contrib_ops_excluded_files
"transformers/beam_search.h"
"transformers/generation_device_helper.cc"
"transformers/generation_device_helper.h"
"transformers/beam_search_impl.cu"
"transformers/beam_search_impl.h"
"transformers/generation_cuda_impl.cu"
"transformers/generation_cuda_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
"transformers/sampling.cc"
"transformers/sampling.h"
"transformers/sampling_cuda_helper.h"
"transformers/dump_cuda_tensor.cc"
"transformers/dump_cuda_tensor.h"
"conv_transpose_with_dynamic_pads.cc"
Expand Down
84 changes: 84 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Do not modify directly.*
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
* <a href="#com.microsoft.Snpe">com.microsoft.Snpe</a>
* <a href="#com.microsoft.SparseToDenseMatMul">com.microsoft.SparseToDenseMatMul</a>
Expand Down Expand Up @@ -3810,6 +3811,89 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.Sampling"></a><a name="com.microsoft.sampling">**com.microsoft.Sampling**</a>

Greedy Sampling for text generation.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>custom</tt> : int</dt>
<dd>If 1 custom sampling logic</dd>
<dt><tt>decoder</tt> : graph (required)</dt>
<dd>Decoder subgraph to execute in a loop.</dd>
<dt><tt>decoder_start_token_id</tt> : int</dt>
<dd>The id of the token that indicates decoding starts.</dd>
<dt><tt>encoder</tt> : graph</dt>
<dd>The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
<dt><tt>eos_token_id</tt> : int (required)</dt>
<dd>The id of the end-of-sequence token</dd>
<dt><tt>filter_value</tt> : float</dt>
<dd>All filtered values will be set to this float value.</dd>
<dt><tt>init_decoder</tt> : graph</dt>
<dd>The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs</dd>
<dt><tt>min_tokens_to_keep</tt> : int</dt>
<dd>Minimumber of tokens we keep per batch example in the output.</dd>
<dt><tt>model_type</tt> : int</dt>
<dd>Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart</dd>
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
<dd>no repeat ngrams size</dd>
<dt><tt>pad_token_id</tt> : int (required)</dt>
<dd>The id of the padding token</dd>
<dt><tt>presence_penalty</tt> : float</dt>
<dd>Presence penalty for custom sampling</dd>
<dt><tt>temperature</tt> : float</dt>
<dd>The value used to module the next token probabilities.</dd>
<dt><tt>top_p</tt> : float</dt>
<dd>If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.</dd>
<dt><tt>vocab_size</tt> : int</dt>
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (2 - 8)

<dl>
<dt><tt>input_ids</tt> : I</dt>
<dd>The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)</dd>
<dt><tt>max_length</tt> : I</dt>
<dd>The maximum length of the sequence to be generated. Shape is (1)</dd>
<dt><tt>min_length</tt> (optional) : I</dt>
<dd>The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)</dd>
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>presence_mask</tt> (optional) : I</dt>
<dd>Presence penalty mask. Shape is (batch_size, vocab_size)</dd>
</dl>

#### Outputs (1 - 2)

<dl>
<dt><tt>sequences</tt> : I</dt>
<dd>Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)</dd>
<dt><tt>filtered_logits</tt> (optional) : T</dt>
<dd>Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>I</tt> : tensor(int32)</dt>
<dd>Constrain to integer types</dd>
</dl>


### <a name="com.microsoft.SkipLayerNormalization"></a><a name="com.microsoft.skiplayernormalization">**com.microsoft.SkipLayerNormalization**</a>

Skip and Layer Normalization Fusion
Expand Down
2 changes: 2 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
Expand Down Expand Up @@ -797,6 +798,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
Expand Down Expand Up @@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range)>,
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ void BeamSearch::Init(const OpKernelInfo& info) {
parameters_.ParseFromAttributes(info);

// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt ||
parameters_.model_type == IGenerationParameters::kModelTypeT5);

ONNX_NAMESPACE::GraphProto proto;
if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {

if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
// Make sure the encoder sub-graph attribute is present for the T5 model.
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
}

if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
if (info.GetAttr<ONNX_NAMESPACE::GraphProto>("init_decoder", &proto).IsOK()) {
has_init_decoder_ = true;
Expand All @@ -87,7 +88,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
const auto& node = Node();
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
if (attribute_name == "decoder") {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
Expand All @@ -113,8 +114,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
init_run_gpt_subgraph_ = std::move(res.second);
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
}

} else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) {
} else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
if (attribute_name == "encoder") {
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
Expand Down Expand Up @@ -167,7 +167,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
// Make a copy of parameters since we will update it based on inputs later
BeamSearchParameters parameters = parameters_;

if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32
BeamSearchGpt<float> impl{
*ctx_internal,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ Status BeamSearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
context.Input<Tensor>(0), // input_ids
context.Input<Tensor>(7), // vocab_mask
context.Input<Tensor>(8), // prefix_vocab_mask
context.Input<Tensor>(9))); // attention_mask
context.Input<Tensor>(9), // attention_mask
nullptr)); // presence_mask

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Status BeamSearchParameters::Validate() const {
}

void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IBeamSearchParameters::kModelTypeGpt));
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IGenerationParameters::kModelTypeGpt));
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
Expand All @@ -35,7 +35,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
batch_size = static_cast<int>(dims[0]);

// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;
sequence_length = (this->model_type == IGenerationParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;

auto* max_length_tensor = context->Input<Tensor>(1);
max_length = max_length_tensor ? static_cast<int>(*max_length_tensor->Data<int32_t>()) : kMaxSequenceLength;
Expand Down Expand Up @@ -71,10 +71,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
if (vocab_size == -1) {
if (vocab_size == -1 || vocab_size == 0) {
vocab_size = vocabulary_size;
}

num_heads = heads;
head_size = hidden_size_per_head;
num_layers = layers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace onnxruntime {
namespace contrib {
namespace transformers {

struct BeamSearchParameters : public IBeamSearchParameters {
struct BeamSearchParameters : public IGenerationParameters {
Status Validate() const;

int BatchBeamSize() const { return batch_size * num_beams; }
Expand Down
25 changes: 24 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class GenerateBase {
const Tensor* input_ids,
const Tensor* vocab_mask,
const Tensor* prefix_vocab_mask,
const Tensor* attention_mask) const {
const Tensor* attention_mask,
const Tensor* presence_mask) const {
const auto& dims = input_ids->Shape().GetDims();
if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down Expand Up @@ -149,6 +150,28 @@ class GenerateBase {
}
}

if (presence_mask != nullptr) {
const auto& dims_presence = presence_mask->Shape().GetDims();
if (dims_presence.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'presence_mask' is expected to have 2 dimensions, got ", dims_presence.size());
}

// presence_mask first dimension should be same as the first dimension of input_ids
if (static_cast<int>(dims_presence[0]) != static_cast<int>(dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input_ids and presence_mask must have the same batch_size");
}

if (static_cast<int>(dims_presence[1]) != parameters->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'presence_mask' shape[1] shall be vocab_size, got ", dims_presence[1]);
}

// store prefix vocab mask in parameters.
parameters->presence_mask = presence_mask->DataAsSpan<int32_t>();
}

return Status::OK();
}

Expand Down
Loading

0 comments on commit 68518a1

Please sign in to comment.