diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 814a076fae83e..ef402f17d32aa 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3945,7 +3945,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (2 - 8) +#### Inputs (2 - 9)
input_ids : I
@@ -3964,6 +3964,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Custom attention mask. Shape is (batch_size, sequence_length)
presence_mask (optional) : I
Presence penalty mask. Shape is (batch_size, vocab_size)
+
seed (optional) : I
+
Seed for random number generator. Shape is (1)
#### Outputs (1 - 2) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index a1cdb7e30d04b..1d42210731bde 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -450,7 +450,7 @@ Do not modify directly.* |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| -|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| +|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| @@ -812,7 +812,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| +|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc index 6297bcf0b39d8..d389e4ad76016 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc @@ -105,6 +105,12 @@ void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, in DumpCpuTensor(name, tensor, dim0, dim1); } +void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { + if (!is_enabled_) + return; + DumpCpuTensor(name, tensor, dim0, dim1); +} + void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { if (!is_enabled_) return; @@ -180,6 +186,9 @@ void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const { void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { } +void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { +} + void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { } diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h index 27d883008ff58..78b1e6cf9a536 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h @@ -16,6 +16,7 @@ class CpuTensorConsoleDumper : public IConsoleDumper { virtual ~CpuTensorConsoleDumper() {} void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 4cc5bf380f2af..e7dca4d3264de 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -172,6 +172,7 @@ class IConsoleDumper { bool IsEnabled() const { return is_enabled_; } virtual void Print(const char* name, const float* tensor, int dim0, int dim1) const = 0; virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const = 0; + virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0; virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const = 0; virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const = 0; virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const = 0; diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h index 1e3c7035ff336..5caa7c0b27eb8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -10,9 +10,10 @@ template void filter_scores(std::vector& sorted_indice, gsl::span& next_token_score, const transformers::IGenerationParameters* parameters, - size_t index) { - size_t real_index = sorted_indice[index]; - next_token_score[real_index] = (T)parameters->filter_value; + size_t chunk_offset, + size_t offset) { + size_t real_index = sorted_indice[chunk_offset + offset]; + next_token_score[chunk_offset + real_index] = (T)parameters->filter_value; } template @@ -23,12 +24,12 @@ void cumulate_and_filter_custom(gsl::span& next_token_scores, for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { size_t offset = i * parameters->vocab_size; if (cumulative_probs[offset] > parameters->top_p) { - filter_scores(sorted_indices, next_token_scores, parameters, 1 + offset); + filter_scores(sorted_indices, next_token_scores, parameters, offset, 1); } for (size_t j = 1; j < static_cast(parameters->vocab_size) - 1; j++) { cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; if (cumulative_probs[j + offset] > parameters->top_p) { - filter_scores(sorted_indices, next_token_scores, parameters, j + offset + 1); + filter_scores(sorted_indices, next_token_scores, parameters, offset, j + 1); } } } @@ -42,12 +43,12 @@ void cumulate_and_filter(gsl::span& next_token_scores, for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { size_t offset = i * parameters->vocab_size; if (cumulative_probs[offset] <= 1 - parameters->top_p) { - filter_scores(sorted_indices, next_token_scores, parameters, offset); + filter_scores(sorted_indices, next_token_scores, parameters, offset, 0); } for (size_t j = 1; j < static_cast(parameters->vocab_size) - static_cast(parameters->min_tokens_to_keep); j++) { cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; if (cumulative_probs[j + offset] <= 1 - parameters->top_p) { - filter_scores(sorted_indices, next_token_scores, parameters, j + offset); + filter_scores(sorted_indices, next_token_scores, parameters, offset, j); } } } @@ -78,10 +79,11 @@ Status Sample(AllocatorPtr& allocator, for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { auto indices_begin = sorted_indices.begin() + i * parameters->vocab_size; auto indices_end = sorted_indices.begin() + (i + 1) * parameters->vocab_size; + gsl::span next_token_score = next_token_scores.subspan(i * parameters->vocab_size, parameters->vocab_size); std::iota(indices_begin, indices_end, 0); std::sort(indices_begin, indices_end, - [&next_token_scores, &predicator](size_t i1, size_t i2) { - return !predicator(next_token_scores[i1], next_token_scores[i2]); + [&next_token_score, &predicator](size_t i1, size_t i2) { + return predicator(next_token_score[i1], next_token_score[i2]); }); std::sort(sorted_scores.begin() + i * parameters->vocab_size, @@ -89,6 +91,11 @@ Status Sample(AllocatorPtr& allocator, predicator); } +#ifdef DEBUG_GENERATION + dumper->Print("sorted_scores", sorted_scores.data(), parameters->batch_size, parameters->vocab_size); + dumper->Print("sorted_indices", sorted_indices.data(), parameters->batch_size, parameters->vocab_size); +#endif + gsl::span& cumulative_probs = sampling_state->cumulative_probs; ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters->batch_size, @@ -104,13 +111,10 @@ Status Sample(AllocatorPtr& allocator, cumulate_and_filter(next_token_scores, cumulative_probs, parameters, sorted_indices); } - gsl::span& next_token_probs = sampling_state->h_softmaxed_score; - ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters->batch_size, - parameters->vocab_size, - next_token_scores.data(), - next_token_probs.data(), - false, - thread_pool)); +#ifdef DEBUG_GENERATION + dumper->Print("cumulative_probs after filtering", cumulative_probs.data(), parameters->batch_size, parameters->vocab_size); + dumper->Print("next_token_scores after filtering", next_token_scores.data(), parameters->batch_size, parameters->vocab_size); +#endif // torch.multinomial() int64_t next_token_probs_dims[] = {static_cast(parameters->batch_size), parameters->vocab_size}; @@ -119,7 +123,7 @@ Status Sample(AllocatorPtr& allocator, OrtValue next_token_probs_value; Tensor::InitOrtValue(element_type, next_token_probs_shape, - next_token_probs.data(), + next_token_scores.data(), allocator->Info(), next_token_probs_value); const Tensor& input = next_token_probs_value.Get(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc index 537b0d7538ce7..6b3862f04ad72 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -21,6 +21,14 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { vocab_size = static_cast(info.GetAttrOrDefault("vocab_size", -1)); } +void SamplingParameters::ParseFromInputs(OpKernelContext* context) { + this->GreedySearchParameters::ParseFromInputs(context); + + auto* seed_tensor = context->Input(8); + seed = seed_tensor ? static_cast(*seed_tensor->Data()) : 0; + ORT_ENFORCE(seed >= 0, "Seed must be >= 0"); +} + } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h index 6c0f866f09fe2..af203abc15d01 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h @@ -12,6 +12,8 @@ namespace transformers { struct SamplingParameters : public GreedySearchParameters { void ParseFromAttributes(const OpKernelInfo& info); + + void ParseFromInputs(OpKernelContext* context); }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index 55abeaf657950..a83871ac1d09c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -145,6 +145,11 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i DumpGpuTensor(name, tensor, dim0, dim1, true); } +void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); +} + void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { if (is_enabled_) DumpGpuTensor(name, tensor, dim0, dim1, true); @@ -212,6 +217,9 @@ void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { } +void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { +} + void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { } diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h index 1a10ba8910a5e..8641df06d8830 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h @@ -19,6 +19,7 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons virtual ~CudaTensorConsoleDumper() {} void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 15f8599f52534..a8c870d1442cf 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1155,6 +1155,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) + .Input(8, "seed", "Seed for random number generator. Shape is (1)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") .Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 09ee7539c4ee4..82678fa53105e 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -273,6 +273,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: ) model_group.set_defaults(presence_mask=False) + model_group.add_argument( + "--seed", + required=False, + action="store_true", + help="Random seed for sampling op", + ) + model_group.set_defaults(seed=False) + beam_parameters_group = parser.add_argument_group( "Beam search parameters not stored in the output model, for testing parity and performance" ) @@ -1531,6 +1539,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati if is_sampling and args.custom and args.presence_mask: inputs.append("presence_mask") + else: + inputs.append("") + + if is_sampling and args.seed: + inputs.append("seed") outputs = ["sequences"] if args.output_sequences_scores: @@ -1709,6 +1722,10 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati ) graph_inputs.append(presence_mask) + if is_sampling and args.seed: + seed = onnx.helper.make_tensor_value_info("seed", TensorProto.INT32, [1]) + graph_inputs.append(seed) + # graph outputs sequences = None if is_beamsearch: @@ -2278,9 +2295,13 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None if args.model_type == "gpt2" and is_greedy: if args.top_p > 0.0 and args.top_p < 1.0: convert_generation_model(args, GenerationType.SAMPLING) - logger.info("The test for gpt2_sampling onnx model is not implemented yet") - return - convert_generation_model(args, GenerationType.GREEDYSEARCH) + logger.info( + "The test for gpt2_sampling onnx model is limited to non-custom model with small top_p(e.g <=0.01) value. The result should be the same as gpt2 greedy search." + ) + if args.top_p > 0.01 or args.custom or args.seed: + return + else: + convert_generation_model(args, GenerationType.GREEDYSEARCH) else: convert_generation_model(args) diff --git a/onnxruntime/test/contrib_ops/sampling_test.cc b/onnxruntime/test/contrib_ops/sampling_test.cc new file mode 100644 index 0000000000000..48992f24a3234 --- /dev/null +++ b/onnxruntime/test/contrib_ops/sampling_test.cc @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "gtest/gtest.h" +#include "core/common/gsl.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/cuda_op_test_utils.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +#if defined(__linux__) && !defined(__ANDROID__) +#ifdef USE_CUDA +TEST(SamplingTest, Gpt2Sampling_CUDA) { + std::vector input_ids{ + 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, + 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, + 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328}; + + std::vector max_length{15}; + std::vector min_length{1}; + std::vector repetition_penalty{1.0f}; + + + std::vector expected_output{ + 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, 125, 543, 668, + 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, 776, 213, 697, + 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328, 450}; + + const int64_t batch_size = 3; + const int64_t sequence_length = 12; + + std::vector input_ids_shape{batch_size, sequence_length}; + + std::vector parameter_shape{1}; + + std::vector expected_output_shape{input_ids_shape[0], max_length[0]}; + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + auto input_ids_tensor = Ort::Value::CreateTensor( + info, input_ids.data(), input_ids.size(), input_ids_shape.data(), input_ids_shape.size()); + + auto max_length_tensor = Ort::Value::CreateTensor( + info, max_length.data(), max_length.size(), parameter_shape.data(), parameter_shape.size()); + + auto min_length_tensor = Ort::Value::CreateTensor( + info, min_length.data(), min_length.size(), parameter_shape.data(), parameter_shape.size()); + + auto repetition_penalty_tensor = Ort::Value::CreateTensor( + info, repetition_penalty.data(), repetition_penalty.size(), parameter_shape.data(), parameter_shape.size()); + + std::vector ort_inputs; + ort_inputs.push_back(std::move(input_ids_tensor)); + ort_inputs.push_back(std::move(max_length_tensor)); + ort_inputs.push_back(std::move(min_length_tensor)); + ort_inputs.push_back(std::move(repetition_penalty_tensor)); + const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"}; + const char* const output_names[] = {"sequences"}; + + Ort::SessionOptions session_options; + constexpr int min_cuda_architecture = 530; + if (HasCudaEnvironment(min_cuda_architecture)) { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_sampling.onnx"), session_options); + + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + ASSERT_EQ(ort_outputs.size(), 1U); + const auto& sequences = ort_outputs[0]; + ASSERT_TRUE(sequences.IsTensor()); + + auto result_ts = sequences.GetTensorTypeAndShapeInfo(); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, result_ts.GetElementType()); + + ASSERT_EQ(expected_output_shape, result_ts.GetShape()); + const auto* result_vals = sequences.GetTensorData(); + auto result_span = gsl::make_span(result_vals, expected_output.size()); + + ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end())); + } +} +#endif + +TEST(SamplingTest, Gpt2Sampling_CPU) { + std::vector input_ids{ + 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, + 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, + 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328}; + + std::vector max_length{15}; + std::vector min_length{1}; + std::vector repetition_penalty{1.0f}; + + std::vector expected_output{ + 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, 125, 669, 28, + 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, 475, 944, 527, + 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328, 210}; + + const int64_t batch_size = 3; + const int64_t sequence_length = 12; + std::vector input_ids_shape{batch_size, sequence_length}; + + std::vector parameter_shape{1}; + + std::vector expected_output_shape{input_ids_shape[0], max_length[0]}; + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + auto input_ids_tensor = Ort::Value::CreateTensor( + info, input_ids.data(), input_ids.size(), input_ids_shape.data(), input_ids_shape.size()); + + auto max_length_tensor = Ort::Value::CreateTensor( + info, max_length.data(), max_length.size(), parameter_shape.data(), parameter_shape.size()); + + auto min_length_tensor = Ort::Value::CreateTensor( + info, min_length.data(), min_length.size(), parameter_shape.data(), parameter_shape.size()); + + auto repetition_penalty_tensor = Ort::Value::CreateTensor( + info, repetition_penalty.data(), repetition_penalty.size(), parameter_shape.data(), parameter_shape.size()); + + std::vector ort_inputs; + ort_inputs.push_back(std::move(input_ids_tensor)); + ort_inputs.push_back(std::move(max_length_tensor)); + ort_inputs.push_back(std::move(min_length_tensor)); + ort_inputs.push_back(std::move(repetition_penalty_tensor)); + const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"}; + const char* const output_names[] = {"sequences"}; + + Ort::SessionOptions session_options; + Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_sampling.onnx"), session_options); + + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), + output_names, 1); + + ASSERT_EQ(ort_outputs.size(), 1U); + const auto& sequences = ort_outputs[0]; + ASSERT_TRUE(sequences.IsTensor()); + + auto result_ts = sequences.GetTensorTypeAndShapeInfo(); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, result_ts.GetElementType()); + + ASSERT_EQ(expected_output_shape, result_ts.GetShape()); + const auto* result_vals = sequences.GetTensorData(); + auto result_span = gsl::make_span(result_vals, expected_output.size()); + + ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end())); +} +#endif +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/transformers/tiny_gpt2_sampling.onnx b/onnxruntime/test/testdata/transformers/tiny_gpt2_sampling.onnx new file mode 100644 index 0000000000000..c9ecb269f72e6 Binary files /dev/null and b/onnxruntime/test/testdata/transformers/tiny_gpt2_sampling.onnx differ