Skip to content

Commit

Permalink
Minimal support for Mistral 7B model. (#1528)
Browse files Browse the repository at this point in the history
* minimal support for Mistral: loader and rotary length (no sliding so far)
  • Loading branch information
vince62s authored Nov 7, 2023
1 parent f92a8a2 commit 50e9ba4
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The project implements a custom runtime that applies many performance optimizati
The following model types are currently supported:

* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, CodeGen, GPTBigCode, Falcon
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, CodeGen, GPTBigCode, Falcon
* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa

Compatible models should be first converted into an optimized model format. The library includes converters for multiple frameworks:
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace ctranslate2 {
const dim_t _num_heads_kv;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
const dim_t _sliding_window;
};

enum class RotaryScalingType {
Expand Down
102 changes: 102 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,108 @@ def set_decoder(self, spec, module):
gc.collect()


@register_loader("MistralConfig")
class MistralLoader(ModelLoader):
@property
def architecture_name(self):
return "MistralForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_hidden_layers

num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

sliding_window = getattr(model.config, "sliding_window", 0)

rope_scaling = getattr(model.config, "rope_scaling", None)
if rope_scaling:
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
rotary_scaling_factor = rope_scaling["factor"]

if rotary_scaling_type is None:
raise NotImplementedError(
"RoPE scaling type '%s' is not yet implemented. "
"The following RoPE scaling types are currently supported: %s"
% (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
)
else:
rotary_scaling_type = None
rotary_scaling_factor = 1

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=common_spec.Activation.SWISH,
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_scaling_type=rotary_scaling_type,
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
)

self.set_decoder(spec.decoder, model.model)
self.set_linear(spec.decoder.projection, model.lm_head)
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token
config.layer_norm_epsilon = model.config.rms_norm_eps

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight

def set_decoder(self, spec, module):
spec.scale_embeddings = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)

for layer_spec, layer in zip(spec.layer, module.layers):
self.set_layer_norm(
layer_spec.self_attention.layer_norm, layer.input_layernorm
)
self.set_layer_norm(
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
)

wq = layer.self_attn.q_proj.weight
wk = layer.self_attn.k_proj.weight
wv = layer.self_attn.v_proj.weight
wo = layer.self_attn.o_proj.weight

layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
layer_spec.self_attention.linear[1].weight = wo

self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("MixFormerSequentialConfig")
class MixFormerSequentialLoader(ModelLoader):
@property
Expand Down
4 changes: 4 additions & 0 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
rotary_scaling_factor=1,
rotary_base=10000,
num_heads_kv=None,
sliding_window=None,
):
self.queries_scale = model_spec.OPTIONAL

Expand Down Expand Up @@ -54,3 +55,6 @@ def __init__(

if num_heads_kv is not None:
self.num_heads_kv = np.dtype("int32").type(num_heads_kv)

if sliding_window is not None:
self.sliding_window = np.dtype("int32").type(sliding_window)
10 changes: 10 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
shared_layer_norm: bool = False,
multi_query_attention: bool = False,
num_heads_kv: Optional[int] = None,
sliding_window: Optional[int] = None,
):
"""Initializes a Transformer decoder specification.
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
attention layer norms.
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
sliding_window: Max sequence length to retain in KV Cache.
"""
if parallel_residual:
if not pre_norm:
Expand Down Expand Up @@ -196,6 +198,7 @@ def __init__(
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
)
for _ in range(num_layers)
]
Expand All @@ -214,6 +217,7 @@ def __init__(
ffn_glu=False,
rms_norm=False,
num_heads_kv=None,
sliding_window=None,
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
Expand Down Expand Up @@ -241,6 +245,7 @@ def __init__(
parallel_residual=False,
shared_layer_norm=False,
num_heads_kv=None,
sliding_window=None,
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
Expand All @@ -253,12 +258,14 @@ def __init__(
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=rotary_base,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
)

if with_encoder_attention:
self.attention = attention_spec.MultiHeadAttentionSpec(
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
)

self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
Expand Down Expand Up @@ -479,6 +486,7 @@ def from_config(
shared_layer_norm: bool = False,
multi_query_attention: bool = False,
num_heads_kv: Optional[int] = None,
sliding_window: Optional[int] = None,
):
"""Creates a Transformer decoder model specification.
Expand Down Expand Up @@ -511,6 +519,7 @@ def from_config(
attention layer norms.
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
sliding_window: max sequence length to retain KV cache
"""
decoder = TransformerDecoderSpec(
num_layers,
Expand All @@ -536,6 +545,7 @@ def from_config(
shared_layer_norm=shared_layer_norm,
multi_query_attention=multi_query_attention,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
)

return cls(decoder)
Expand Down
4 changes: 3 additions & 1 deletion src/layers/attention.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ctranslate2/layers/attention.h"
#include "ctranslate2/ops/split.h"

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -400,6 +401,7 @@ namespace ctranslate2 {
&& !_relative_position_keys
&& !_relative_position_values)
, _cache_time_dim(_merge_time_and_head_dims ? 1 : 2)
, _sliding_window(model.get_attribute_with_default<int32_t>(scope + "/sliding_window", 0))
{
if (_relative_position_keys)
_maximum_relative_position = (_relative_position_keys->dim(0) - 1) / 2;
Expand Down Expand Up @@ -629,7 +631,7 @@ namespace ctranslate2 {

if (!_sin || offset + max_time > _sin.dim(0)) {
const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0;
const dim_t new_num_positions = cur_num_positions + _num_initial_positions;
const dim_t new_num_positions = std::max(offset + max_time, cur_num_positions + _num_initial_positions);
initialize(new_num_positions, dim, device, dtype);
}

Expand Down

0 comments on commit 50e9ba4

Please sign in to comment.