diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake
index 30d8cbf78fb1a..8b4be045c8674 100644
--- a/cmake/adjust_global_compile_flags.cmake
+++ b/cmake/adjust_global_compile_flags.cmake
@@ -92,8 +92,13 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()
-# enable stream for all the non-minimal build
-if (NOT onnxruntime_MINIMAL_BUILD)
+# Enable stream for all the non-minimal build, except for DML. There's currently a bug
+# in the allocation planner when reusing buffers and more than one streams are used that
+# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
+# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
+# safest option for now.
+# https://github.com/microsoft/onnxruntime/issues/19480
+if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
add_compile_definitions(ORT_ENABLE_STREAM)
endif()
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index fd26b09b09531..f91d66c22ea24 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
+- beginning_timestamp_token_id : int
+- The id of the first timestamp
- decoder : graph (required)
- Decoder subgraph to execute in a loop.
- decoder_output_cross_qk : int
- If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
- decoder_start_token_id : int
-- The id of the token that indicates decoding starts.
+- The id of the token that indicates decoding starts (i.e. the start of transcription token id)
- early_stopping : int
- early stop or not
- encoder : graph
@@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
- Must be 2 for whisper
- no_repeat_ngram_size : int
- no repeat ngrams size
-- no_speech_token : int
+- no_speech_token_id : int
- The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+- no_timestamps_token_id : int
+- The id of the token that indicates no timestamps
- pad_token_id : int (required)
- The id of the padding token
+- start_of_lm_token_id : int
+- The id of the token that indicates LM starts
+- transcribe_token_id : int
+- The id of the transcribe task
+- translate_token_id : int
+- The id of the translate task
- vocab_size : int
- Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
@@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
@@ -5810,11 +5820,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints
diff --git a/js/web/package.json b/js/web/package.json
index 047de382943e6..d306390fac594 100644
--- a/js/web/package.json
+++ b/js/web/package.json
@@ -69,11 +69,14 @@
"exports": {
".": {
"node": "./dist/ort.node.min.js",
+ "types": "./types.d.ts",
"default": {
"import": "./dist/esm/ort.min.js",
"require": "./dist/cjs/ort.min.js",
+ "types": "./types.d.ts",
"default": {
"development": "./dist/ort.js",
+ "types": "./types.d.ts",
"default": "./dist/ort.min.js"
}
}
@@ -81,34 +84,41 @@
"./experimental": {
"import": "./dist/esm/ort.all.min.js",
"require": "./dist/cjs/ort.all.min.js",
+ "types": "./types.d.ts",
"default": {
"development": "./dist/ort.all.js",
+ "types": "./types.d.ts",
"default": "./dist/ort.all.min.js"
}
},
"./wasm": {
"import": "./dist/esm/ort.wasm.min.js",
"require": "./dist/cjs/ort.wasm.min.js",
+ "types": "./types.d.ts",
"default": "./dist/ort.wasm.min.js"
},
"./wasm-core": {
"import": "./dist/esm/ort.wasm-core.min.js",
"require": "./dist/cjs/ort.wasm-core.min.js",
+ "types": "./types.d.ts",
"default": "./dist/ort.wasm-core.min.js"
},
"./webgl": {
"import": "./dist/esm/ort.webgl.min.js",
"require": "./dist/cjs/ort.webgl.min.js",
+ "types": "./types.d.ts",
"default": "./dist/ort.webgl.min.js"
},
"./webgpu": {
"import": "./dist/esm/ort.webgpu.min.js",
"require": "./dist/cjs/ort.webgpu.min.js",
+ "types": "./types.d.ts",
"default": "./dist/ort.webgpu.min.js"
},
"./training": {
"import": "./dist/esm/ort.training.wasm.min.js",
"require": "./dist/cjs/ort.training.wasm.min.js",
+ "types": "./types.d.ts",
"default": "./dist/ort.training.wasm.min.js"
}
},
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
index 56d950ca2f41e..d65ff9c5fb4f8 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
@@ -258,7 +258,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch
cpu_state.sequences.InitDevice(beam_state.sequences_device);
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
- nullptr,
+ this->ort_stream_,
DeviceCopyDirection::hostToDevice));
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
index 94547887d3a90..3dbdd7b0fcd70 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
@@ -214,7 +214,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches
cpu_state.sequences.InitDevice(beam_state.sequences_device);
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
- nullptr,
+ this->ort_stream_,
DeviceCopyDirection::hostToDevice));
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
index 91b93a125ad7a..97dc513d4b54f 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
@@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe
TensorShape no_speech_probs_shape{parameters->batch_size};
Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape);
if (no_speech_probs && no_speech_probs->MutableData()) {
- ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size,
- "no_speech_token id out of range, it is ", parameters->no_speech_token,
+ ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size,
+ "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id,
", vocab_size is ", parameters->vocab_size);
this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData();
}
@@ -226,7 +226,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe
cpu_state.sequences.InitDevice(beam_state.sequences_device);
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
- nullptr,
+ this->ort_stream_,
DeviceCopyDirection::hostToDevice));
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index 3962486d5b5eb..8a466dd9d9c18 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -141,7 +141,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info)
model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper));
ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper);
- no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL));
+ // Token ids are defined below in the order that they appear in the tokenizer
+ translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL));
+ transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL));
+ start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL));
+ no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL));
+ no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL));
+ beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL));
cross_qk_layer_head_input_id = 12;
extra_decoding_ids_input_id = 13;
cross_qk_output_id = 3;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index f6faf2e325f8f..34510902cf309 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -180,7 +180,14 @@ struct IGenerationParameters {
// Parameters for whisper model
bool decoder_output_cross_qk = false;
gsl::span extra_decoding_ids;
- int32_t no_speech_token = -1;
+
+ // Token ids are defined below in the order that they appear in the tokenizer
+ int32_t translate_token_id = -1;
+ int32_t transcribe_token_id = -1;
+ int32_t start_of_lm_token_id = -1;
+ int32_t no_speech_token_id = -1;
+ int32_t no_timestamps_token_id = -1;
+ int32_t beginning_timestamp_token_id = -1;
void* no_speech_probs = nullptr;
int cross_qk_layer_head_input_id = -1;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
index 4688ff272cee9..3c213ee944119 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
@@ -10,6 +10,7 @@
#include "contrib_ops/cpu/transformers/greedy_search_parameters.h"
#include "contrib_ops/cpu/transformers/sampling_parameters.h"
#include "contrib_ops/cpu/transformers/generation_shared.h"
+#include
namespace onnxruntime {
namespace contrib {
@@ -34,6 +35,14 @@ struct NextTokenScores {
}
};
+#ifdef DEBUG_GENERATION
+template
+void DumpScores(const char* name, const NextTokenScores& next_token_scores) {
+ std::cout << name << std::endl;
+ ORT_UNUSED_PARAMETER(next_token_scores);
+}
+#endif
+
// Interface for all scorers for beam search or beam sample.
template
class ILogitsProcessor {
@@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor {
template
class TimestampLogitsProcessor : public ILogitsProcessor {
public:
- TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index)
- : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {}
+ TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|>
+ int start_of_transcript_token_id, // <|startoftranscript|>
+ int translate_token_id, // <|translate|>
+ int transcribe_token_id, // <|transcribe|>
+ int start_of_lm_token_id, // <|startoflm|>
+ int no_timestamps_token_id, // <|notimestamps|>
+ int beginning_timestamp_token_id, // <|0.00|>
+ int max_initial_timestamp_index)
+ : end_of_text_token_id_(end_of_text_token_id),
+ start_of_transcript_token_id_(start_of_transcript_token_id),
+ translate_token_id_(translate_token_id),
+ transcribe_token_id_(transcribe_token_id),
+ start_of_lm_token_id_(start_of_lm_token_id),
+ no_timestamps_token_id_(no_timestamps_token_id),
+ beginning_timestamp_token_id_(beginning_timestamp_token_id),
+ max_initial_timestamp_index_(max_initial_timestamp_index) {}
void Process(const ISequences* sequences,
NextTokenScores& next_token_scores) override {
- // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models.
- const int beg_token_id_ = eos_token_id_ + 107;
- const int not_token_id_ = eos_token_id_ + 106;
- const int solm_token_id_ = eos_token_id_ + 105;
- const int sot_token_id_ = eos_token_id_ + 1;
- constexpr int translate_token_id_ = 50358;
- constexpr int transcribe_token_id_ = 50359;
-
const int batch_beam_size = next_token_scores.batch_beam_size;
const int vocab_size = next_token_scores.vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
@@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor {
size_t sample_begin = 0;
for (size_t j = 0; j < seq_length; j++) {
sample_begin++;
- if (sequence[j] >= beg_token_id_) {
+ if (sequence[j] >= beginning_timestamp_token_id_) {
break;
}
}
@@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor {
// Suppress tokens
for (int j = 0; j < vocab_size; j++) {
// Suppress notimestamps and solm tokens
- if (j == not_token_id_ || j == solm_token_id_) {
+ if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) {
beam_token_scores[j] = std::numeric_limits::lowest();
}
// Suppress sot, translate and transcribe tokens
if (seq_length > sample_begin) {
- if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) {
+ if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) {
beam_token_scores[j] = std::numeric_limits::lowest();
}
}
}
// Timestamps should be in pair except the first one
- const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_;
- const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_;
+ const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_;
+ const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_;
if (last_was_timestamp) {
if (penultimate_was_timestamp) {
// If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated
- for (int j = beg_token_id_; j < vocab_size; j++) {
+ for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) {
beam_token_scores[j] = std::numeric_limits::lowest();
}
} else {
// If timestamp doesn't show up in pair, generate timestamp
- for (int j = 0; j < eos_token_id_; j++) {
+ for (int j = 0; j < end_of_text_token_id_; j++) {
beam_token_scores[j] = std::numeric_limits::lowest();
}
}
@@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor {
// Find timestamp tokens
std::vector timestamps;
for (const auto& word_id : sequence) {
- if (word_id >= beg_token_id_) {
+ if (word_id >= beginning_timestamp_token_id_) {
timestamps.push_back(word_id);
}
}
@@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor {
timestamp_last = timestamps.back() + 1;
}
- for (int j = beg_token_id_; j < timestamp_last; j++) {
+ for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) {
beam_token_scores[j] = std::numeric_limits::lowest();
}
}
if (seq_length == sample_begin) {
- const int last_allowed = beg_token_id_ + max_initial_timestamp_index_;
+ const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_;
for (int j = last_allowed + 1; j < vocab_size; j++) {
beam_token_scores[j] = std::numeric_limits::lowest();
}
@@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor {
float timestamp_logprob = std::numeric_limits::lowest();
{
float logsumexp = 0.0f;
- const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end());
- for (int j = beg_token_id_; j < vocab_size; ++j) {
+ const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end());
+ for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) {
if (beam_token_scores[j] > std::numeric_limits::lowest()) {
logsumexp += expf(beam_token_scores[j] - logprob_max);
}
@@ -258,9 +273,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor {
}
}
- const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_);
+ const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_);
if (timestamp_logprob > max_text_token_logprob) {
- for (int j = 0; j < beg_token_id_; ++j) {
+ for (int j = 0; j < beginning_timestamp_token_id_; ++j) {
beam_token_scores[j] = std::numeric_limits::lowest();
}
}
@@ -272,7 +287,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor {
}
private:
- int eos_token_id_;
+ int end_of_text_token_id_;
+ int start_of_transcript_token_id_;
+ int translate_token_id_;
+ int transcribe_token_id_;
+ int start_of_lm_token_id_;
+ int no_timestamps_token_id_;
+ int beginning_timestamp_token_id_;
int max_initial_timestamp_index_;
};
@@ -334,7 +355,15 @@ class LogitsProcessorList : public ILogitsProcessorList {
// Add timestamp processor for whisper model
if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) {
constexpr int max_initial_timestamp_index = 50;
- timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index);
+ // Token ids are passed below in the order that they appear in the tokenizer
+ timestamp_processor_ = std::make_unique>(parameters.eos_token_id,
+ parameters.decoder_start_token_id,
+ parameters.translate_token_id,
+ parameters.transcribe_token_id,
+ parameters.start_of_lm_token_id,
+ parameters.no_timestamps_token_id,
+ parameters.beginning_timestamp_token_id,
+ max_initial_timestamp_index);
processor_list_.push_back(timestamp_processor_.get());
}
diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
index 380d561bbb23c..b8f8d7691a9b6 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
@@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, //
const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper);
if (step == 1 && is_whisper_model && parameters->no_speech_probs) {
cuda::LaunchSaveNoSpeechProbs(
- (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream);
+ (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream);
}
// NOTE: currently we treat extra decoding ids are same
@@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits, //
cudaMemcpyDeviceToHost,
cuda_stream));
constexpr int max_initial_timestamp_index = 50;
- onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index);
+ // Token ids are passed below in the order that they appear in the tokenizer
+ onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id,
+ parameters->decoder_start_token_id,
+ parameters->translate_token_id,
+ parameters->transcribe_token_id,
+ parameters->start_of_lm_token_id,
+ parameters->no_timestamps_token_id,
+ parameters->beginning_timestamp_token_id,
+ max_initial_timestamp_index);
onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size});
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h
index 61147e4367876..dc45cad692b6e 100644
--- a/onnxruntime/core/framework/execution_providers.h
+++ b/onnxruntime/core/framework/execution_providers.h
@@ -3,7 +3,6 @@
#pragma once
-// #include