Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling op #13426

Merged
merged 56 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d4664e0
logits wrapper cpu
wangyems Oct 24, 2022
c5fb195
add sampling parameters
wangyems Oct 24, 2022
d664ac6
register sampling cpu
wangyems Oct 24, 2022
3331c8f
add multinomial cpu
wangyems Oct 25, 2022
e7e4d97
fix build
wangyems Oct 25, 2022
bc42d29
logits wrapper cuda
wangyems Oct 25, 2022
713d932
fix a bug
wangyems Oct 26, 2022
400614e
add cub::radixsort
wangyems Oct 27, 2022
3dfde56
add filterlogits cuda
wangyems Oct 27, 2022
e1c0046
provider
wangyems Oct 28, 2022
7c64df9
multinomial cuda
wangyems Nov 1, 2022
25f58fe
use curand and add seed
wangyems Nov 1, 2022
6a60328
fix a bug
wangyems Nov 2, 2022
0363d1b
fix a few crash
wangyems Nov 3, 2022
3b5392f
try to fix win build
wangyems Nov 3, 2022
7b8cc2a
no debug
wangyems Nov 3, 2022
b985964
update
wangyems Nov 3, 2022
7830bec
update
wangyems Nov 3, 2022
33d232e
support 4d mask
wangyems Nov 4, 2022
e9d2215
update
wangyems Nov 8, 2022
e48867c
check presence mask
wangyems Nov 8, 2022
daebfaa
fix crash
wangyems Nov 9, 2022
0fd09a8
reuse buffer
wangyems Nov 9, 2022
61ee848
update presence_mask
wangyems Nov 9, 2022
f7cc0a9
generate random numbers at once
wangyems Nov 15, 2022
09825ce
fix bugs
wangyems Nov 16, 2022
cd12a2e
for debug purpose
wangyems Nov 29, 2022
05e32d0
minor change
wangyems Nov 30, 2022
dda532c
remove some printings
wangyems Nov 30, 2022
653b28b
optional logits
wangyems Dec 1, 2022
ad4e21b
refactor cuda impl
wangyems Dec 6, 2022
d597492
fix build
wangyems Dec 6, 2022
e877315
refactor cpu kernel
wangyems Dec 7, 2022
6752f78
refactor
wangyems Dec 8, 2022
0ad2a35
add huggingface logic cpu
wangyems Dec 8, 2022
6d758d9
huggingface topp cuda
wangyems Dec 10, 2022
dc32745
refactor
wangyems Dec 12, 2022
e7dd78a
Update OperatorKernels.md
wangyems Dec 13, 2022
a1c8ea2
prefast warning
wangyems Dec 13, 2022
88575f5
exclude amd build
wangyems Dec 13, 2022
8893a04
Merge branch 'main' into wangye/beamsample
wangyems Dec 14, 2022
6ff8f23
ffix build issue due to rebase
wangyems Dec 14, 2022
bc71d1a
Merge branch 'main' into wangye/beamsample
wangyems Dec 15, 2022
d74cc1d
fix build issue due to rebase
wangyems Dec 15, 2022
64bd22c
add sampling conversion script
wangyems Dec 19, 2022
b519968
rebase with master
wangyems Dec 19, 2022
a74b76b
remove dup files
wangyems Dec 19, 2022
647c904
format python
wangyems Dec 19, 2022
693c4eb
fix build
wangyems Dec 19, 2022
bc71be8
support decoder init
wangyems Dec 20, 2022
26b88db
refine contrib_defs.cc
wangyems Dec 20, 2022
78f36f5
update docs
wangyems Dec 20, 2022
9223181
enable padded_vocab_size and decoder_init for sampling op
wangyems Dec 20, 2022
f0383ab
review comments
Dec 22, 2022
589fc54
Merge remote-tracking branch 'origin/main' into wangye/beamsample
Dec 22, 2022
fedc839
Merge branch 'main' into wangye/beamsample
wangyems Dec 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,13 @@ set(contrib_ops_excluded_files
"transformers/beam_search.h"
"transformers/generation_device_helper.cc"
"transformers/generation_device_helper.h"
"transformers/beam_search_impl.cu"
"transformers/beam_search_impl.h"
"transformers/generation_cuda_impl.cu"
"transformers/generation_cuda_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
"transformers/sampling.cc"
"transformers/sampling.h"
"transformers/sampling_cuda_helper.h"
"transformers/dump_cuda_tensor.cc"
"transformers/dump_cuda_tensor.h"
"conv_transpose_with_dynamic_pads.cc"
Expand Down
84 changes: 84 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Do not modify directly.*
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
* <a href="#com.microsoft.Snpe">com.microsoft.Snpe</a>
* <a href="#com.microsoft.SparseToDenseMatMul">com.microsoft.SparseToDenseMatMul</a>
Expand Down Expand Up @@ -3810,6 +3811,89 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.Sampling"></a><a name="com.microsoft.sampling">**com.microsoft.Sampling**</a>

Greedy Sampling for text generation.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>custom</tt> : int</dt>
<dd>If 1 custom sampling logic</dd>
<dt><tt>decoder</tt> : graph (required)</dt>
<dd>Decoder subgraph to execute in a loop.</dd>
<dt><tt>decoder_start_token_id</tt> : int</dt>
<dd>The id of the token that indicates decoding starts.</dd>
<dt><tt>encoder</tt> : graph</dt>
<dd>The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
<dt><tt>eos_token_id</tt> : int (required)</dt>
<dd>The id of the end-of-sequence token</dd>
<dt><tt>filter_value</tt> : float</dt>
<dd>All filtered values will be set to this float value.</dd>
<dt><tt>init_decoder</tt> : graph</dt>
<dd>The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs</dd>
<dt><tt>min_tokens_to_keep</tt> : int</dt>
<dd>Minimumber of tokens we keep per batch example in the output.</dd>
<dt><tt>model_type</tt> : int</dt>
<dd>Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart</dd>
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
<dd>no repeat ngrams size</dd>
<dt><tt>pad_token_id</tt> : int (required)</dt>
<dd>The id of the padding token</dd>
<dt><tt>presence_penalty</tt> : float</dt>
<dd>Presence penalty for custom sampling</dd>
<dt><tt>temperature</tt> : float</dt>
<dd>The value used to module the next token probabilities.</dd>
<dt><tt>top_p</tt> : float</dt>
<dd>If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.</dd>
<dt><tt>vocab_size</tt> : int</dt>
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (2 - 8)

<dl>
<dt><tt>input_ids</tt> : I</dt>
<dd>The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)</dd>
<dt><tt>max_length</tt> : I</dt>
<dd>The maximum length of the sequence to be generated. Shape is (1)</dd>
<dt><tt>min_length</tt> (optional) : I</dt>
<dd>The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)</dd>
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>presence_mask</tt> (optional) : I</dt>
<dd>Presence penalty mask. Shape is (batch_size, vocab_size)</dd>
</dl>

#### Outputs (1 - 2)

<dl>
<dt><tt>sequences</tt> : I</dt>
<dd>Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)</dd>
<dt><tt>filtered_logits</tt> (optional) : T</dt>
<dd>Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>I</tt> : tensor(int32)</dt>
<dd>Constrain to integer types</dd>
</dl>


### <a name="com.microsoft.SkipLayerNormalization"></a><a name="com.microsoft.skiplayernormalization">**com.microsoft.SkipLayerNormalization**</a>

Skip and Layer Normalization Fusion
Expand Down
2 changes: 2 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
Expand Down Expand Up @@ -797,6 +798,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
Expand Down Expand Up @@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range)>,
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ void BeamSearch::Init(const OpKernelInfo& info) {
parameters_.ParseFromAttributes(info);

// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt ||
parameters_.model_type == IGenerationParameters::kModelTypeT5);

ONNX_NAMESPACE::GraphProto proto;
if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {

if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
// Make sure the encoder sub-graph attribute is present for the T5 model.
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
}

if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
if (info.GetAttr<ONNX_NAMESPACE::GraphProto>("init_decoder", &proto).IsOK()) {
has_init_decoder_ = true;
Expand All @@ -87,7 +88,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
const auto& node = Node();
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
if (attribute_name == "decoder") {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
Expand All @@ -113,8 +114,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
init_run_gpt_subgraph_ = std::move(res.second);
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
}

} else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) {
} else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
if (attribute_name == "encoder") {
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
Expand Down Expand Up @@ -167,7 +167,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
// Make a copy of parameters since we will update it based on inputs later
BeamSearchParameters parameters = parameters_;

if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32
BeamSearchGpt<float> impl{
*ctx_internal,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ Status BeamSearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
context.Input<Tensor>(0), // input_ids
context.Input<Tensor>(7), // vocab_mask
context.Input<Tensor>(8), // prefix_vocab_mask
context.Input<Tensor>(9))); // attention_mask
context.Input<Tensor>(9), // attention_mask
nullptr)); // presence_mask

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Status BeamSearchParameters::Validate() const {
}

void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IBeamSearchParameters::kModelTypeGpt));
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IGenerationParameters::kModelTypeGpt));
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
Expand All @@ -35,7 +35,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
batch_size = static_cast<int>(dims[0]);

// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;
sequence_length = (this->model_type == IGenerationParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;

auto* max_length_tensor = context->Input<Tensor>(1);
max_length = max_length_tensor ? static_cast<int>(*max_length_tensor->Data<int32_t>()) : kMaxSequenceLength;
Expand Down Expand Up @@ -71,10 +71,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
if (vocab_size == -1) {
if (vocab_size == -1 || vocab_size == 0) {
vocab_size = vocabulary_size;
}

num_heads = heads;
head_size = hidden_size_per_head;
num_layers = layers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace onnxruntime {
namespace contrib {
namespace transformers {

struct BeamSearchParameters : public IBeamSearchParameters {
struct BeamSearchParameters : public IGenerationParameters {
Status Validate() const;

int BatchBeamSize() const { return batch_size * num_beams; }
Expand Down
25 changes: 24 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class GenerateBase {
const Tensor* input_ids,
const Tensor* vocab_mask,
const Tensor* prefix_vocab_mask,
const Tensor* attention_mask) const {
const Tensor* attention_mask,
const Tensor* presence_mask) const {
const auto& dims = input_ids->Shape().GetDims();
if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down Expand Up @@ -149,6 +150,28 @@ class GenerateBase {
}
}

if (presence_mask != nullptr) {
const auto& dims_presence = presence_mask->Shape().GetDims();
if (dims_presence.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'presence_mask' is expected to have 2 dimensions, got ", dims_presence.size());
}

// presence_mask first dimension should be same as the first dimension of input_ids
if (static_cast<int>(dims_presence[0]) != static_cast<int>(dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input_ids and presence_mask must have the same batch_size");
}

if (static_cast<int>(dims_presence[1]) != parameters->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'presence_mask' shape[1] shall be vocab_size, got ", dims_presence[1]);
}

// store prefix vocab mask in parameters.
parameters->presence_mask = presence_mask->DataAsSpan<int32_t>();
}

return Status::OK();
}

Expand Down
Loading