Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some changes to Sampling Op #14218

Merged
merged 8 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3945,7 +3945,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (2 - 8)
#### Inputs (2 - 9)

<dl>
<dt><tt>input_ids</tt> : I</dt>
Expand All @@ -3964,6 +3964,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>presence_mask</tt> (optional) : I</dt>
<dd>Presence penalty mask. Shape is (batch_size, vocab_size)</dd>
<dt><tt>seed</tt> (optional) : I</dt>
<dd>Seed for random number generator. Shape is (1)</dd>
</dl>

#### Outputs (1 - 2)
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
Expand Down Expand Up @@ -812,7 +812,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, in
DumpCpuTensor<MLFloat16>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const {
if (!is_enabled_)
return;
DumpCpuTensor<size_t>(name, tensor, dim0, dim1);
}

void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
if (!is_enabled_)
return;
Expand Down Expand Up @@ -180,6 +186,9 @@ void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const {
void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const {
}

void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const {
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class CpuTensorConsoleDumper : public IConsoleDumper {
virtual ~CpuTensorConsoleDumper() {}
void Print(const char* name, const float* tensor, int dim0, int dim1) const override;
void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override;
void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class IConsoleDumper {
bool IsEnabled() const { return is_enabled_; }
virtual void Print(const char* name, const float* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const = 0;
Expand Down
38 changes: 21 additions & 17 deletions onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ template <typename T>
void filter_scores(std::vector<size_t>& sorted_indice,
gsl::span<T>& next_token_score,
const transformers::IGenerationParameters* parameters,
size_t index) {
size_t real_index = sorted_indice[index];
next_token_score[real_index] = (T)parameters->filter_value;
size_t chunk_offset,
size_t offset) {
size_t real_index = sorted_indice[chunk_offset + offset];
next_token_score[chunk_offset + real_index] = (T)parameters->filter_value;
}

template <typename T>
Expand All @@ -23,12 +24,12 @@ void cumulate_and_filter_custom(gsl::span<T>& next_token_scores,
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
size_t offset = i * parameters->vocab_size;
if (cumulative_probs[offset] > parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, 1 + offset);
filter_scores(sorted_indices, next_token_scores, parameters, offset, 1);
}
for (size_t j = 1; j < static_cast<size_t>(parameters->vocab_size) - 1; j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] > parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, j + offset + 1);
filter_scores(sorted_indices, next_token_scores, parameters, offset, j + 1);
}
}
}
Expand All @@ -42,12 +43,12 @@ void cumulate_and_filter(gsl::span<T>& next_token_scores,
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
size_t offset = i * parameters->vocab_size;
if (cumulative_probs[offset] <= 1 - parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, offset);
filter_scores(sorted_indices, next_token_scores, parameters, offset, 0);
}
for (size_t j = 1; j < static_cast<size_t>(parameters->vocab_size) - static_cast<size_t>(parameters->min_tokens_to_keep); j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] <= 1 - parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, j + offset);
filter_scores(sorted_indices, next_token_scores, parameters, offset, j);
}
}
}
Expand Down Expand Up @@ -78,17 +79,23 @@ Status Sample(AllocatorPtr& allocator,
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
auto indices_begin = sorted_indices.begin() + i * parameters->vocab_size;
auto indices_end = sorted_indices.begin() + (i + 1) * parameters->vocab_size;
gsl::span<T> next_token_score = next_token_scores.subspan(i * parameters->vocab_size, parameters->vocab_size);
std::iota(indices_begin, indices_end, 0);
std::sort(indices_begin, indices_end,
[&next_token_scores, &predicator](size_t i1, size_t i2) {
return !predicator(next_token_scores[i1], next_token_scores[i2]);
[&next_token_score, &predicator](size_t i1, size_t i2) {
return predicator(next_token_score[i1], next_token_score[i2]);
});

std::sort(sorted_scores.begin() + i * parameters->vocab_size,
sorted_scores.begin() + (i + 1) * parameters->vocab_size,
predicator);
}

#ifdef DEBUG_GENERATION
dumper->Print("sorted_scores", sorted_scores.data(), parameters->batch_size, parameters->vocab_size);
dumper->Print("sorted_indices", sorted_indices.data(), parameters->batch_size, parameters->vocab_size);
#endif

gsl::span<T>& cumulative_probs = sampling_state->cumulative_probs;

ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters->batch_size,
Expand All @@ -104,13 +111,10 @@ Status Sample(AllocatorPtr& allocator,
cumulate_and_filter(next_token_scores, cumulative_probs, parameters, sorted_indices);
}

gsl::span<T>& next_token_probs = sampling_state->h_softmaxed_score;
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters->batch_size,
parameters->vocab_size,
next_token_scores.data(),
next_token_probs.data(),
false,
thread_pool));
#ifdef DEBUG_GENERATION
dumper->Print("cumulative_probs after filtering", cumulative_probs.data(), parameters->batch_size, parameters->vocab_size);
dumper->Print("next_token_scores after filtering", next_token_scores.data(), parameters->batch_size, parameters->vocab_size);
#endif

// torch.multinomial()
int64_t next_token_probs_dims[] = {static_cast<int64_t>(parameters->batch_size), parameters->vocab_size};
Expand All @@ -119,7 +123,7 @@ Status Sample(AllocatorPtr& allocator,
OrtValue next_token_probs_value;
Tensor::InitOrtValue(element_type,
next_token_probs_shape,
next_token_probs.data(),
next_token_scores.data(),
allocator->Info(),
next_token_probs_value);
const Tensor& input = next_token_probs_value.Get<Tensor>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) {
vocab_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("vocab_size", -1));
}

void SamplingParameters::ParseFromInputs(OpKernelContext* context) {
this->GreedySearchParameters::ParseFromInputs(context);

auto* seed_tensor = context->Input<Tensor>(8);
seed = seed_tensor ? static_cast<int>(*seed_tensor->Data<int32_t>()) : 0;
ORT_ENFORCE(seed >= 0, "Seed must be >= 0");
}

} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace transformers {

struct SamplingParameters : public GreedySearchParameters {
void ParseFromAttributes(const OpKernelInfo& info);

void ParseFromInputs(OpKernelContext* context);
};

} // namespace transformers
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1, true);
}

void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor<size_t>(name, tensor, dim0, dim1, true);
}

void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1, true);
Expand Down Expand Up @@ -212,6 +217,9 @@ void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const {
void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const {
}

void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const {
}

void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons
virtual ~CudaTensorConsoleDumper() {}
void Print(const char* name, const float* tensor, int dim0, int dim1) const override;
void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override;
void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
.Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Input(8, "seed", "Seed for random number generator. Shape is (1)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I")
.Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
Expand Down
27 changes: 24 additions & 3 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
)
model_group.set_defaults(presence_mask=False)

model_group.add_argument(
"--seed",
required=False,
action="store_true",
help="Random seed for sampling op",
)
model_group.set_defaults(seed=False)

beam_parameters_group = parser.add_argument_group(
"Beam search parameters not stored in the output model, for testing parity and performance"
)
Expand Down Expand Up @@ -1531,6 +1539,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati

if is_sampling and args.custom and args.presence_mask:
inputs.append("presence_mask")
else:
inputs.append("")

if is_sampling and args.seed:
inputs.append("seed")

outputs = ["sequences"]
if args.output_sequences_scores:
Expand Down Expand Up @@ -1709,6 +1722,10 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
)
graph_inputs.append(presence_mask)

if is_sampling and args.seed:
seed = onnx.helper.make_tensor_value_info("seed", TensorProto.INT32, [1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed

How about to make seed configurable? default set to -1. if >0, use the seed passed in by args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently we only test sampling with default seed and small top_p. so this flag limits to help create the model with that input

graph_inputs.append(seed)

# graph outputs
sequences = None
if is_beamsearch:
Expand Down Expand Up @@ -2278,9 +2295,13 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None
if args.model_type == "gpt2" and is_greedy:
if args.top_p > 0.0 and args.top_p < 1.0:
convert_generation_model(args, GenerationType.SAMPLING)
logger.info("The test for gpt2_sampling onnx model is not implemented yet")
return
convert_generation_model(args, GenerationType.GREEDYSEARCH)
logger.info(
"The test for gpt2_sampling onnx model is limited to non-custom model with small top_p(e.g <=0.01) value. The result should be the same as gpt2 greedy search."
)
if args.top_p > 0.01 or args.custom or args.seed:
return
else:
convert_generation_model(args, GenerationType.GREEDYSEARCH)
else:
convert_generation_model(args)

Expand Down
Loading