diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 814a076fae83e..ef402f17d32aa 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -3945,7 +3945,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (2 - 8)
+#### Inputs (2 - 9)
- input_ids : I
@@ -3964,6 +3964,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- Custom attention mask. Shape is (batch_size, sequence_length)
- presence_mask (optional) : I
- Presence penalty mask. Shape is (batch_size, vocab_size)
+- seed (optional) : I
+- Seed for random number generator. Shape is (1)
#### Outputs (1 - 2)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index a1cdb7e30d04b..1d42210731bde 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -450,7 +450,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
-|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)|
+|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)|
@@ -812,7 +812,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
-|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
+|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc
index 6297bcf0b39d8..d389e4ad76016 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc
@@ -105,6 +105,12 @@ void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, in
DumpCpuTensor(name, tensor, dim0, dim1);
}
+void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const {
+ if (!is_enabled_)
+ return;
+ DumpCpuTensor(name, tensor, dim0, dim1);
+}
+
void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
if (!is_enabled_)
return;
@@ -180,6 +186,9 @@ void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const {
void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const {
}
+void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const {
+}
+
void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const {
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h
index 27d883008ff58..78b1e6cf9a536 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h
@@ -16,6 +16,7 @@ class CpuTensorConsoleDumper : public IConsoleDumper {
virtual ~CpuTensorConsoleDumper() {}
void Print(const char* name, const float* tensor, int dim0, int dim1) const override;
void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override;
+ void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index 4cc5bf380f2af..e7dca4d3264de 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -172,6 +172,7 @@ class IConsoleDumper {
bool IsEnabled() const { return is_enabled_; }
virtual void Print(const char* name, const float* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const = 0;
+ virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const = 0;
virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const = 0;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
index 1e3c7035ff336..5caa7c0b27eb8 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
@@ -10,9 +10,10 @@ template
void filter_scores(std::vector& sorted_indice,
gsl::span& next_token_score,
const transformers::IGenerationParameters* parameters,
- size_t index) {
- size_t real_index = sorted_indice[index];
- next_token_score[real_index] = (T)parameters->filter_value;
+ size_t chunk_offset,
+ size_t offset) {
+ size_t real_index = sorted_indice[chunk_offset + offset];
+ next_token_score[chunk_offset + real_index] = (T)parameters->filter_value;
}
template
@@ -23,12 +24,12 @@ void cumulate_and_filter_custom(gsl::span& next_token_scores,
for (size_t i = 0; i < static_cast(parameters->batch_size); i++) {
size_t offset = i * parameters->vocab_size;
if (cumulative_probs[offset] > parameters->top_p) {
- filter_scores(sorted_indices, next_token_scores, parameters, 1 + offset);
+ filter_scores(sorted_indices, next_token_scores, parameters, offset, 1);
}
for (size_t j = 1; j < static_cast(parameters->vocab_size) - 1; j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] > parameters->top_p) {
- filter_scores(sorted_indices, next_token_scores, parameters, j + offset + 1);
+ filter_scores(sorted_indices, next_token_scores, parameters, offset, j + 1);
}
}
}
@@ -42,12 +43,12 @@ void cumulate_and_filter(gsl::span& next_token_scores,
for (size_t i = 0; i < static_cast(parameters->batch_size); i++) {
size_t offset = i * parameters->vocab_size;
if (cumulative_probs[offset] <= 1 - parameters->top_p) {
- filter_scores(sorted_indices, next_token_scores, parameters, offset);
+ filter_scores(sorted_indices, next_token_scores, parameters, offset, 0);
}
for (size_t j = 1; j < static_cast(parameters->vocab_size) - static_cast(parameters->min_tokens_to_keep); j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] <= 1 - parameters->top_p) {
- filter_scores(sorted_indices, next_token_scores, parameters, j + offset);
+ filter_scores(sorted_indices, next_token_scores, parameters, offset, j);
}
}
}
@@ -78,10 +79,11 @@ Status Sample(AllocatorPtr& allocator,
for (size_t i = 0; i < static_cast(parameters->batch_size); i++) {
auto indices_begin = sorted_indices.begin() + i * parameters->vocab_size;
auto indices_end = sorted_indices.begin() + (i + 1) * parameters->vocab_size;
+ gsl::span next_token_score = next_token_scores.subspan(i * parameters->vocab_size, parameters->vocab_size);
std::iota(indices_begin, indices_end, 0);
std::sort(indices_begin, indices_end,
- [&next_token_scores, &predicator](size_t i1, size_t i2) {
- return !predicator(next_token_scores[i1], next_token_scores[i2]);
+ [&next_token_score, &predicator](size_t i1, size_t i2) {
+ return predicator(next_token_score[i1], next_token_score[i2]);
});
std::sort(sorted_scores.begin() + i * parameters->vocab_size,
@@ -89,6 +91,11 @@ Status Sample(AllocatorPtr& allocator,
predicator);
}
+#ifdef DEBUG_GENERATION
+ dumper->Print("sorted_scores", sorted_scores.data(), parameters->batch_size, parameters->vocab_size);
+ dumper->Print("sorted_indices", sorted_indices.data(), parameters->batch_size, parameters->vocab_size);
+#endif
+
gsl::span& cumulative_probs = sampling_state->cumulative_probs;
ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters->batch_size,
@@ -104,13 +111,10 @@ Status Sample(AllocatorPtr& allocator,
cumulate_and_filter(next_token_scores, cumulative_probs, parameters, sorted_indices);
}
- gsl::span& next_token_probs = sampling_state->h_softmaxed_score;
- ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters->batch_size,
- parameters->vocab_size,
- next_token_scores.data(),
- next_token_probs.data(),
- false,
- thread_pool));
+#ifdef DEBUG_GENERATION
+ dumper->Print("cumulative_probs after filtering", cumulative_probs.data(), parameters->batch_size, parameters->vocab_size);
+ dumper->Print("next_token_scores after filtering", next_token_scores.data(), parameters->batch_size, parameters->vocab_size);
+#endif
// torch.multinomial()
int64_t next_token_probs_dims[] = {static_cast(parameters->batch_size), parameters->vocab_size};
@@ -119,7 +123,7 @@ Status Sample(AllocatorPtr& allocator,
OrtValue next_token_probs_value;
Tensor::InitOrtValue(element_type,
next_token_probs_shape,
- next_token_probs.data(),
+ next_token_scores.data(),
allocator->Info(),
next_token_probs_value);
const Tensor& input = next_token_probs_value.Get();
diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc
index 537b0d7538ce7..6b3862f04ad72 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc
@@ -21,6 +21,14 @@ void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) {
vocab_size = static_cast(info.GetAttrOrDefault("vocab_size", -1));
}
+void SamplingParameters::ParseFromInputs(OpKernelContext* context) {
+ this->GreedySearchParameters::ParseFromInputs(context);
+
+ auto* seed_tensor = context->Input(8);
+ seed = seed_tensor ? static_cast(*seed_tensor->Data()) : 0;
+ ORT_ENFORCE(seed >= 0, "Seed must be >= 0");
+}
+
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h
index 6c0f866f09fe2..af203abc15d01 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h
@@ -12,6 +12,8 @@ namespace transformers {
struct SamplingParameters : public GreedySearchParameters {
void ParseFromAttributes(const OpKernelInfo& info);
+
+ void ParseFromInputs(OpKernelContext* context);
};
} // namespace transformers
diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
index 55abeaf657950..a83871ac1d09c 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
@@ -145,6 +145,11 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i
DumpGpuTensor(name, tensor, dim0, dim1, true);
}
+void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const {
+ if (is_enabled_)
+ DumpGpuTensor(name, tensor, dim0, dim1, true);
+}
+
void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor(name, tensor, dim0, dim1, true);
@@ -212,6 +217,9 @@ void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const {
void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const {
}
+void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const {
+}
+
void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const {
}
diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h
index 1a10ba8910a5e..8641df06d8830 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h
+++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h
@@ -19,6 +19,7 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons
virtual ~CudaTensorConsoleDumper() {}
void Print(const char* name, const float* tensor, int dim0, int dim1) const override;
void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override;
+ void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override;
void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override;
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 15f8599f52534..a8c870d1442cf 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1155,6 +1155,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
.Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
+ .Input(8, "seed", "Seed for random number generator. Shape is (1)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I")
.Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py
index 09ee7539c4ee4..82678fa53105e 100644
--- a/onnxruntime/python/tools/transformers/convert_generation.py
+++ b/onnxruntime/python/tools/transformers/convert_generation.py
@@ -273,6 +273,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
)
model_group.set_defaults(presence_mask=False)
+ model_group.add_argument(
+ "--seed",
+ required=False,
+ action="store_true",
+ help="Random seed for sampling op",
+ )
+ model_group.set_defaults(seed=False)
+
beam_parameters_group = parser.add_argument_group(
"Beam search parameters not stored in the output model, for testing parity and performance"
)
@@ -1531,6 +1539,11 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
if is_sampling and args.custom and args.presence_mask:
inputs.append("presence_mask")
+ else:
+ inputs.append("")
+
+ if is_sampling and args.seed:
+ inputs.append("seed")
outputs = ["sequences"]
if args.output_sequences_scores:
@@ -1709,6 +1722,10 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
)
graph_inputs.append(presence_mask)
+ if is_sampling and args.seed:
+ seed = onnx.helper.make_tensor_value_info("seed", TensorProto.INT32, [1])
+ graph_inputs.append(seed)
+
# graph outputs
sequences = None
if is_beamsearch:
@@ -2278,9 +2295,13 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None
if args.model_type == "gpt2" and is_greedy:
if args.top_p > 0.0 and args.top_p < 1.0:
convert_generation_model(args, GenerationType.SAMPLING)
- logger.info("The test for gpt2_sampling onnx model is not implemented yet")
- return
- convert_generation_model(args, GenerationType.GREEDYSEARCH)
+ logger.info(
+ "The test for gpt2_sampling onnx model is limited to non-custom model with small top_p(e.g <=0.01) value. The result should be the same as gpt2 greedy search."
+ )
+ if args.top_p > 0.01 or args.custom or args.seed:
+ return
+ else:
+ convert_generation_model(args, GenerationType.GREEDYSEARCH)
else:
convert_generation_model(args)
diff --git a/onnxruntime/test/contrib_ops/sampling_test.cc b/onnxruntime/test/contrib_ops/sampling_test.cc
new file mode 100644
index 0000000000000..48992f24a3234
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/sampling_test.cc
@@ -0,0 +1,155 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include "gtest/gtest.h"
+#include "core/common/gsl.h"
+#include "core/session/onnxruntime_cxx_api.h"
+#include "test/common/cuda_op_test_utils.h"
+
+extern std::unique_ptr ort_env;
+
+namespace onnxruntime {
+namespace test {
+
+#if defined(__linux__) && !defined(__ANDROID__)
+#ifdef USE_CUDA
+TEST(SamplingTest, Gpt2Sampling_CUDA) {
+ std::vector input_ids{
+ 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620,
+ 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572,
+ 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328};
+
+ std::vector max_length{15};
+ std::vector min_length{1};
+ std::vector repetition_penalty{1.0f};
+
+
+ std::vector expected_output{
+ 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, 125, 543, 668,
+ 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, 776, 213, 697,
+ 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328, 450};
+
+ const int64_t batch_size = 3;
+ const int64_t sequence_length = 12;
+
+ std::vector input_ids_shape{batch_size, sequence_length};
+
+ std::vector parameter_shape{1};
+
+ std::vector expected_output_shape{input_ids_shape[0], max_length[0]};
+
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
+ auto input_ids_tensor = Ort::Value::CreateTensor(
+ info, input_ids.data(), input_ids.size(), input_ids_shape.data(), input_ids_shape.size());
+
+ auto max_length_tensor = Ort::Value::CreateTensor(
+ info, max_length.data(), max_length.size(), parameter_shape.data(), parameter_shape.size());
+
+ auto min_length_tensor = Ort::Value::CreateTensor(
+ info, min_length.data(), min_length.size(), parameter_shape.data(), parameter_shape.size());
+
+ auto repetition_penalty_tensor = Ort::Value::CreateTensor(
+ info, repetition_penalty.data(), repetition_penalty.size(), parameter_shape.data(), parameter_shape.size());
+
+ std::vector ort_inputs;
+ ort_inputs.push_back(std::move(input_ids_tensor));
+ ort_inputs.push_back(std::move(max_length_tensor));
+ ort_inputs.push_back(std::move(min_length_tensor));
+ ort_inputs.push_back(std::move(repetition_penalty_tensor));
+ const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"};
+ const char* const output_names[] = {"sequences"};
+
+ Ort::SessionOptions session_options;
+ constexpr int min_cuda_architecture = 530;
+ if (HasCudaEnvironment(min_cuda_architecture)) {
+ Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+
+ Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_sampling.onnx"), session_options);
+
+ auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(),
+ output_names, 1);
+
+ ASSERT_EQ(ort_outputs.size(), 1U);
+ const auto& sequences = ort_outputs[0];
+ ASSERT_TRUE(sequences.IsTensor());
+
+ auto result_ts = sequences.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, result_ts.GetElementType());
+
+ ASSERT_EQ(expected_output_shape, result_ts.GetShape());
+ const auto* result_vals = sequences.GetTensorData();
+ auto result_span = gsl::make_span(result_vals, expected_output.size());
+
+ ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
+ }
+}
+#endif
+
+TEST(SamplingTest, Gpt2Sampling_CPU) {
+ std::vector input_ids{
+ 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620,
+ 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572,
+ 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328};
+
+ std::vector max_length{15};
+ std::vector min_length{1};
+ std::vector repetition_penalty{1.0f};
+
+ std::vector expected_output{
+ 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, 125, 669, 28,
+ 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, 475, 944, 527,
+ 0, 0, 0, 52, 328, 219, 328, 206, 288, 227, 896, 328, 210};
+
+ const int64_t batch_size = 3;
+ const int64_t sequence_length = 12;
+ std::vector input_ids_shape{batch_size, sequence_length};
+
+ std::vector parameter_shape{1};
+
+ std::vector expected_output_shape{input_ids_shape[0], max_length[0]};
+
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
+ auto input_ids_tensor = Ort::Value::CreateTensor(
+ info, input_ids.data(), input_ids.size(), input_ids_shape.data(), input_ids_shape.size());
+
+ auto max_length_tensor = Ort::Value::CreateTensor(
+ info, max_length.data(), max_length.size(), parameter_shape.data(), parameter_shape.size());
+
+ auto min_length_tensor = Ort::Value::CreateTensor(
+ info, min_length.data(), min_length.size(), parameter_shape.data(), parameter_shape.size());
+
+ auto repetition_penalty_tensor = Ort::Value::CreateTensor(
+ info, repetition_penalty.data(), repetition_penalty.size(), parameter_shape.data(), parameter_shape.size());
+
+ std::vector ort_inputs;
+ ort_inputs.push_back(std::move(input_ids_tensor));
+ ort_inputs.push_back(std::move(max_length_tensor));
+ ort_inputs.push_back(std::move(min_length_tensor));
+ ort_inputs.push_back(std::move(repetition_penalty_tensor));
+ const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"};
+ const char* const output_names[] = {"sequences"};
+
+ Ort::SessionOptions session_options;
+ Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_sampling.onnx"), session_options);
+
+ auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(),
+ output_names, 1);
+
+ ASSERT_EQ(ort_outputs.size(), 1U);
+ const auto& sequences = ort_outputs[0];
+ ASSERT_TRUE(sequences.IsTensor());
+
+ auto result_ts = sequences.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, result_ts.GetElementType());
+
+ ASSERT_EQ(expected_output_shape, result_ts.GetShape());
+ const auto* result_vals = sequences.GetTensorData();
+ auto result_span = gsl::make_span(result_vals, expected_output.size());
+
+ ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
+}
+#endif
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/testdata/transformers/tiny_gpt2_sampling.onnx b/onnxruntime/test/testdata/transformers/tiny_gpt2_sampling.onnx
new file mode 100644
index 0000000000000..c9ecb269f72e6
Binary files /dev/null and b/onnxruntime/test/testdata/transformers/tiny_gpt2_sampling.onnx differ