Skip to content

Commit

Permalink
whisper-large-v3 compatibility (#1530)
Browse files Browse the repository at this point in the history
* expose n_mels
openai whisper large-v3 introduces change from 80 to 128 in mel input feature.
exposing n_mels is required to propagate the input size to the audio feature extractor
* fix guessing is_multilingual for large-v3
* alignement heads for large-v3
see #1530 (comment)
* add num_languages property to whisper models
* update comment documentation
---------

Co-authored-by: Valentin Berkes <[email protected]>
  • Loading branch information
funboarder13920 and Valentin Berkes authored Nov 8, 2023
1 parent d0a9227 commit 23f744f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 6 deletions.
12 changes: 12 additions & 0 deletions include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ namespace ctranslate2 {
return _is_multilingual;
}

size_t n_mels() const {
return _n_mels;
}

size_t num_languages() const {
return _num_languages;
}

StorageView encode(StorageView features, const bool to_cpu);

std::vector<WhisperGenerationResult>
Expand Down Expand Up @@ -136,6 +144,8 @@ namespace ctranslate2 {
size_t _eot_id;
size_t _no_timestamps_id;
size_t _no_speech_id;
size_t _n_mels;
size_t _num_languages;
bool _is_multilingual;

StorageView maybe_encode(StorageView features);
Expand All @@ -146,6 +156,8 @@ namespace ctranslate2 {
using ReplicaPool::ReplicaPool;

bool is_multilingual() const;
size_t n_mels() const;
size_t num_languages() const;

std::future<StorageView> encode(const StorageView& features, const bool to_cpu);

Expand Down
22 changes: 18 additions & 4 deletions python/cpp/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ namespace ctranslate2 {
return _pool->is_multilingual();
}

size_t n_mels() const {
return _pool->n_mels();
}

size_t num_languages() const {
return _pool->num_languages();
}

StorageView encode(const StorageView& features, const bool to_cpu) {
return _pool->encode(features, to_cpu).get();
}
Expand Down Expand Up @@ -149,6 +157,12 @@ namespace ctranslate2 {
.def_property_readonly("is_multilingual", &WhisperWrapper::is_multilingual,
"Returns ``True`` if this model is multilingual.")

.def_property_readonly("n_mels", &WhisperWrapper::n_mels,
"Returns dimension of mel input features.")

.def_property_readonly("num_languages", &WhisperWrapper::num_languages,
"Returns the number of languages supported.")

.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
Expand Down Expand Up @@ -201,7 +215,7 @@ namespace ctranslate2 {
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, chunk_length]``.
``[batch_size, n_mels, chunk_length]``.
to_cpu: Copy the encoder output to the CPU before returning the value.
Returns:
Expand Down Expand Up @@ -233,7 +247,7 @@ namespace ctranslate2 {
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, chunk_length]``. This method also accepts the encoded
``[batch_size, n_mels, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
prompts: Batch of initial string tokens or token IDs.
Expand Down Expand Up @@ -271,7 +285,7 @@ namespace ctranslate2 {
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, chunk_length]``. This method also accepts the encoded
``[batch_size, n_mels, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
Expand All @@ -296,7 +310,7 @@ namespace ctranslate2 {
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, chunk_length]``. This method also accepts the encoded
``[batch_size, n_mels, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
start_sequence: The start sequence tokens.
Expand Down
14 changes: 13 additions & 1 deletion python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,7 +1925,7 @@ def main():

# Cross-attention heads that are highly correlated to the word-level timing,
# i.e. the alignment between audio and text tokens.
# Obtained from https://github.com/openai/whisper/blob/v20230306/whisper/__init__.py#L31-L45
# Obtained from https://github.com/openai/whisper/blob/v20231106/whisper/__init__.py#L32-L47
_WHISPER_ALIGNMENT_HEADS = {
"openai/whisper-tiny.en": [
(1, 0),
Expand Down Expand Up @@ -2039,4 +2039,16 @@ def main():
(26, 12),
(27, 15),
],
"openai/whisper-large-v3": [
(7, 0),
(10, 17),
(12, 18),
(13, 12),
(16, 1),
(17, 14),
(19, 11),
(21, 4),
(24, 1),
(25, 6),
],
}
14 changes: 13 additions & 1 deletion src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ namespace ctranslate2 {
_no_speech_id = vocabulary.to_id("<|nospeech|>");
if (_no_speech_id == vocabulary.unk_id())
_no_speech_id = vocabulary.to_id("<|nocaptions|>");
_is_multilingual = vocabulary.size() == 51865;
_is_multilingual = vocabulary.size() >= 51865;
_n_mels = _encoder->input_size();
_num_languages = vocabulary.size() - 51765 - (_is_multilingual ? 1 : 0);
}

StorageView WhisperReplica::encode(StorageView features, const bool to_cpu) {
Expand Down Expand Up @@ -640,6 +642,16 @@ namespace ctranslate2 {
return replica.is_multilingual();
}

size_t Whisper::n_mels() const {
const auto& replica = get_first_replica();
return replica.n_mels();
}

size_t Whisper::num_languages() const {
const auto& replica = get_first_replica();
return replica.num_languages();
}

std::future<StorageView> Whisper::encode(const StorageView& features, const bool to_cpu) {
return post<StorageView>(
[features = features.sync_copy(), to_cpu](WhisperReplica& replica) mutable {
Expand Down

0 comments on commit 23f744f

Please sign in to comment.