diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 145caf258..a445fe9c1 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -879,6 +879,8 @@ def get_model_spec(self, model): spec = whisper_spec.WhisperSpec( model.config.encoder_layers, model.config.encoder_attention_heads, + model.config.decoder_layers, + model.config.decoder_attention_heads, ) self.set_encoder(spec.encoder, model.model.encoder) diff --git a/python/ctranslate2/specs/whisper_spec.py b/python/ctranslate2/specs/whisper_spec.py index fffd5c0f8..e32453e1c 100644 --- a/python/ctranslate2/specs/whisper_spec.py +++ b/python/ctranslate2/specs/whisper_spec.py @@ -26,17 +26,27 @@ def __init__( class WhisperSpec(model_spec.LanguageModelSpec): """Describes a Whisper model.""" - def __init__(self, num_layers, num_heads): + def __init__( + self, + num_encoder_layers, + num_encoder_heads, + num_decoder_layers, + num_decoder_heads, + ): """Initializes the model specification. Args: - num_layers: The number of encoder and decoder layers. - num_heads: The number of attention heads. + num_encoder_layers: The number of encoder layers. + num_encoder_heads: The number of encoder attention heads. + num_decoder_layers: The number of decoder layers. + num_decoder_heads: The number of decoder attention heads. """ super().__init__() - self.encoder = WhisperEncoderSpec(num_layers, num_heads) + self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads) self.decoder = transformer_spec.TransformerDecoderSpec( - num_layers, num_heads, activation=common_spec.Activation.GELU + num_decoder_layers, + num_decoder_heads, + activation=common_spec.Activation.GELU, ) self.decoder.scale_embeddings = False