Skip to content

Commit

Permalink
Support conversion for distil-whisper (#1529)
Browse files Browse the repository at this point in the history
Previous code was assuming same number of encoder and decoder layer.
Removed this assumptions and obtain the number of layer separately.
  • Loading branch information
chiiyeh authored Nov 7, 2023
1 parent 50e9ba4 commit d0a9227
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
2 changes: 2 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,8 @@ def get_model_spec(self, model):
spec = whisper_spec.WhisperSpec(
model.config.encoder_layers,
model.config.encoder_attention_heads,
model.config.decoder_layers,
model.config.decoder_attention_heads,
)

self.set_encoder(spec.encoder, model.model.encoder)
Expand Down
20 changes: 15 additions & 5 deletions python/ctranslate2/specs/whisper_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,27 @@ def __init__(
class WhisperSpec(model_spec.LanguageModelSpec):
"""Describes a Whisper model."""

def __init__(self, num_layers, num_heads):
def __init__(
self,
num_encoder_layers,
num_encoder_heads,
num_decoder_layers,
num_decoder_heads,
):
"""Initializes the model specification.
Args:
num_layers: The number of encoder and decoder layers.
num_heads: The number of attention heads.
num_encoder_layers: The number of encoder layers.
num_encoder_heads: The number of encoder attention heads.
num_decoder_layers: The number of decoder layers.
num_decoder_heads: The number of decoder attention heads.
"""
super().__init__()
self.encoder = WhisperEncoderSpec(num_layers, num_heads)
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads)
self.decoder = transformer_spec.TransformerDecoderSpec(
num_layers, num_heads, activation=common_spec.Activation.GELU
num_decoder_layers,
num_decoder_heads,
activation=common_spec.Activation.GELU,
)
self.decoder.scale_embeddings = False

Expand Down

0 comments on commit d0a9227

Please sign in to comment.