Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some changes to Sampling Op #14218

Merged
merged 8 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3945,7 +3945,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (2 - 8)
#### Inputs (2 - 9)

<dl>
<dt><tt>input_ids</tt> : I</dt>
Expand All @@ -3964,6 +3964,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>presence_mask</tt> (optional) : I</dt>
<dd>Presence penalty mask. Shape is (batch_size, vocab_size)</dd>
<dt><tt>seed</tt> (optional) : I</dt>
<dd>Seed for random number generator. Shape is (1)</dd>
</dl>

#### Outputs (1 - 2)
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
Expand Down Expand Up @@ -812,7 +812,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) {
vocab_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("vocab_size", -1));
}

void SamplingParameters::ParseFromInputs(OpKernelContext* context) {
this->GreedySearchParameters::ParseFromInputs(context);

auto* seed_tensor = context->Input<Tensor>(8);
seed = seed_tensor ? static_cast<int>(*seed_tensor->Data<int32_t>()) : 0;
ORT_ENFORCE(seed >= 0, "Seed must be >= 0");
}

} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace transformers {

struct SamplingParameters : public GreedySearchParameters {
void ParseFromAttributes(const OpKernelInfo& info);

void ParseFromInputs(OpKernelContext* context);
};

} // namespace transformers
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
.Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Input(8, "seed", "Seed for random number generator. Shape is (1)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I")
.Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
Expand Down
27 changes: 24 additions & 3 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
)
model_group.set_defaults(presence_mask=False)

model_group.add_argument(
"--seed",
required=False,
action="store_true",
help="Random seed for sampling op",
)
model_group.set_defaults(seed=False)

beam_parameters_group = parser.add_argument_group(
"Beam search parameters not stored in the output model, for testing parity and performance"
)
Expand Down Expand Up @@ -1531,6 +1539,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati

if is_sampling and args.custom and args.presence_mask:
inputs.append("presence_mask")
else:
inputs.append("")

if is_sampling and args.seed:
inputs.append("seed")

outputs = ["sequences"]
if args.output_sequences_scores:
Expand Down Expand Up @@ -1709,6 +1722,10 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
)
graph_inputs.append(presence_mask)

if is_sampling and args.seed:
seed = onnx.helper.make_tensor_value_info("seed", TensorProto.INT32, [1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed

How about to make seed configurable? default set to -1. if >0, use the seed passed in by args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently we only test sampling with default seed and small top_p. so this flag limits to help create the model with that input

graph_inputs.append(seed)

# graph outputs
sequences = None
if is_beamsearch:
Expand Down Expand Up @@ -2278,9 +2295,13 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None
if args.model_type == "gpt2" and is_greedy:
if args.top_p > 0.0 and args.top_p < 1.0:
convert_generation_model(args, GenerationType.SAMPLING)
logger.info("The test for gpt2_sampling onnx model is not implemented yet")
return
convert_generation_model(args, GenerationType.GREEDYSEARCH)
logger.info(
"The test for gpt2_sampling onnx model is limited to non-custom model with small top_p(e.g <=0.01) value. The result should be the same as gpt2 greedy search."
)
if args.top_p > 0.01 or args.custom or args.seed:
return
else:
convert_generation_model(args, GenerationType.GREEDYSEARCH)
else:
convert_generation_model(args)

Expand Down
83 changes: 83 additions & 0 deletions onnxruntime/test/contrib_ops/sampling_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <memory>
#include <vector>
#include "gtest/gtest.h"
#include "core/common/gsl.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test/common/cuda_op_test_utils.h"

extern std::unique_ptr<Ort::Env> ort_env;

namespace onnxruntime {
namespace test {

TEST(SamplingTest, GptSampling) {
std::vector<int64_t> input_ids_shape{3, 12};
std::vector<int32_t> input_ids{
0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620,
41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572,
0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328};

std::vector<int64_t> parameter_shape{1};
std::vector<int32_t> max_length{15};
std::vector<int32_t> min_length{1};
std::vector<float> repetition_penalty{1.0f};

std::vector<int64_t> expected_output_shape{input_ids_shape[0], max_length[0]};

std::vector<int32_t> expected_output{
0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, 125, 543, 668,
41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, 776, 213, 697,
0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328, 450};

Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
auto input_ids_tensor = Ort::Value::CreateTensor(
info, input_ids.data(), input_ids.size(), input_ids_shape.data(), input_ids_shape.size());

auto max_length_tensor = Ort::Value::CreateTensor(
info, max_length.data(), max_length.size(), parameter_shape.data(), parameter_shape.size());

auto min_length_tensor = Ort::Value::CreateTensor(
info, min_length.data(), min_length.size(), parameter_shape.data(), parameter_shape.size());

auto repetition_penalty_tensor = Ort::Value::CreateTensor(
info, repetition_penalty.data(), repetition_penalty.size(), parameter_shape.data(), parameter_shape.size());

std::vector<Ort::Value> ort_inputs;
ort_inputs.push_back(std::move(input_ids_tensor));
ort_inputs.push_back(std::move(max_length_tensor));
ort_inputs.push_back(std::move(min_length_tensor));
ort_inputs.push_back(std::move(repetition_penalty_tensor));
const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"};
const char* const output_names[] = {"sequences"};

constexpr int min_cuda_architecture = 530;
if (HasCudaEnvironment(min_cuda_architecture)) {
Ort::SessionOptions session_options;
#ifdef USE_CUDA
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
#endif

Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_sampling.onnx"), session_options);

auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(),
output_names, 1);

ASSERT_EQ(ort_outputs.size(), 1U);
const auto& sequences = ort_outputs[0];
ASSERT_TRUE(sequences.IsTensor());

auto result_ts = sequences.GetTensorTypeAndShapeInfo();
ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, result_ts.GetElementType());

ASSERT_EQ(expected_output_shape, result_ts.GetShape());
const auto* result_vals = sequences.GetTensorData<int32_t>();
auto result_span = gsl::make_span(result_vals, expected_output.size());
ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
}
}

} // namespace test
} // namespace onnxruntime
Binary file not shown.