diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f59ee87fd..38fe4bb40 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,4 +38,4 @@ repos: hooks: - id: mypy args: [--ignore-missing-imports] - additional_dependencies: [wandb] + additional_dependencies: [wandb, types-PyYAML] diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index a598ef78e..c6ebdf4e7 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -20,7 +20,7 @@ import safetensors import safetensors.numpy from huggingface_hub import hf_hub_download, snapshot_download -from huggingface_hub.utils import EntryNotFoundError, HFValidationError +from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError from jax.experimental.multihost_utils import sync_global_devices from jax.random import PRNGKey from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer @@ -302,8 +302,13 @@ def default_config(self) -> LevConfig: return self.config_from_hf_config(self.default_hf_config) def HFAutoModelClass(self, auto_class: Type[AutoModel] = AutoModelForCausalLM) -> Type[AutoModel]: - # figure out the - config = self.hf_config_from_hf_checkpoint() + # first, see if it's a built-in model + try: + return auto_class._model_mapping[self.HfConfigClass] + except KeyError: + pass + + config = self.default_hf_config cls_name = auto_class.__name__ if hasattr(config, "auto_map") and cls_name in config.auto_map: class_ref = config.auto_map[cls_name] @@ -434,6 +439,7 @@ def load_pretrained( lm_model_cls: Union[Type[LmWithHfSerializationMixin], LevConfig], ref: Optional[Union[str, RepoRef]] = None, axis_mapping: Optional[ResourceMapping] = None, + resize_vocab_to_match_tokenizer: bool = True, ) -> LmWithHfSerializationMixin: """ Loads a levanter model from a huggingface checkpoint. @@ -472,10 +478,15 @@ def load_pretrained( # Vocab: next, we resize the desired actual size if Vocab.size != tokenizer_Vocab.size: - logger.info( - f"Resizing model from {Vocab.size} to {tokenizer_Vocab.size} to match tokenizer vocab size" - ) - lev_model = lev_model.resize_vocab(tokenizer_Vocab.size) + if resize_vocab_to_match_tokenizer: + logger.info( + f"Resizing model from {Vocab.size} to {tokenizer_Vocab.size} to match tokenizer vocab size" + ) + lev_model = lev_model.resize_vocab(tokenizer_Vocab.size) + else: + logger.warning( + f"Model vocab size ({Vocab.size}) does not match tokenizer vocab size ({tokenizer_Vocab.size})" + ) if axis_mapping is not None: lev_model = haliax.shard_with_axis_mapping(lev_model, axis_mapping) @@ -520,10 +531,26 @@ def _save_pretrained_local( dict_config = config.to_dict() # copy over the default keys - for k in KEYS_TO_COPY_FROM_BASE_CONFIG: - attr = getattr(self.default_hf_config, k, None) - if attr is not None: - dict_config[k] = attr + try: + for k in KEYS_TO_COPY_FROM_BASE_CONFIG: + attr = getattr(self.default_hf_config, k, None) + if attr is not None: + dict_config[k] = attr + # except GatedRepoError: + # warnings.warn("Could not copy keys from base config because the repo is gated. Making assumptions.") + except Exception as e: # noqa + if isinstance(e, GatedRepoError) or isinstance(e.__cause__, GatedRepoError): + warnings.warn("Could not copy keys from base config because the repo is gated. Making assumptions.") + + # this is probably llama, but in general we just need to set the auto_map and architectures + dict_config["auto_map"] = { + "AutoModelForCausalLM": self.HFAutoModelClass(AutoModelForCausalLM).__qualname__, + "AutoConfig": self.HfConfigClass.__qualname__, + } + + dict_config["architectures"] = [self.HFAutoModelClass(AutoModelForCausalLM).__name__] + else: + raise if self.config_overrides: dict_config.update(self.config_overrides) @@ -602,6 +629,8 @@ def _save_code_local(self, path): attributes_path = hf_hub_download(repo_id=repo, filename=".gitattributes", revision=revision) except EntryNotFoundError: attributes_path = None + except GatedRepoError: + attributes_path = None if attributes_path is None: warnings.warn("HF Export - No .gitattributes file found, using a heuristic to decide what to save") diff --git a/src/levanter/compat/torch_serialization.py b/src/levanter/compat/torch_serialization.py index b3b30f9df..411f519c4 100644 --- a/src/levanter/compat/torch_serialization.py +++ b/src/levanter/compat/torch_serialization.py @@ -168,6 +168,10 @@ def flatten_linear_layers(prefix: Optional[str], tree: PyTree, out_dims_first_in linear layers can have arbitrary dimensions, grouped into input and output axes. This function flattens the linear layers in a state dict into a 2d weight matrix and a 1d bias vector. + **You should use out_dims_first_in_dict=True if you're using this to convert a PyTorch model to Haliax and the + PyTorch model uses Linear. If the PyTorch model uses Conv1d, use False.** None is probably not what you want, + except in very specific cases. + :param prefix: prefix to apply to the keys in the state dict :param tree: :param out_dims_first_in_dict: if True, the output dimensions will be the first axis in the flattened weight matrix. @@ -216,6 +220,10 @@ def unflatten_linear_layers( linear layers can have arbitrary dimensions, grouped into input and output axes. This function unflattens the linear layers in a state dict into a 2d weight matrix and a 1d bias vector. + **You should use out_dims_first_in_dict=True if you're using this to convert a PyTorch model to Haliax and the + PyTorch model uses Linear. If the PyTorch model uses Conv1d, use False.** None is probably not what you want, + except in very specific cases. + :param prefix: prefix to apply to the keys in the state dict :param statedict: the state dict to source the flattened weights from :param layer: the exemplar layer to use for unflattening diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py new file mode 100644 index 000000000..bd3638ccd --- /dev/null +++ b/src/levanter/models/llama.py @@ -0,0 +1,543 @@ +import dataclasses +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jrandom +from jaxtyping import PRNGKeyArray +from transformers import LlamaConfig as HfLlamaConfig +from transformers import PretrainedConfig as HfConfig + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, AxisSpec, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import Stacked + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + stack_state_dict, + unflatten_linear_layers, + unstack_state_dict, +) +from levanter.models.gpt2 import ACT2FN +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.utils.py_utils import cached_classproperty + + +@LmConfig.register_subclass("llama") +@dataclass(frozen=True) +class LlamaConfig(HFCompatConfig): + """Config for LlamaModel + + Args: + seq_len (int, optional): maximum length of the input sequence. Defaults to 2048. + hidden_dim (int, optional): dimension of the hidden state. Defaults to 4096. + intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008. + num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. + num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. + activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". + rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. + """ + + seq_len: int = 2048 + hidden_dim: int = 4096 + intermediate_dim: int = 11008 + num_layers: int = 32 + num_heads: int = 32 + activation_function: str = "silu" + initializer_range: float = 0.02 + layer_norm_epsilon: float = 1e-5 + + gradient_checkpointing: bool = True + gradient_checkpointing_block_size: int = 5 + + use_bias: bool = False + rope_scaling: Optional[dict] = None + + # Axis + Pos = property(lambda self: Axis(name="position", size=self.seq_len)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) + Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + + @cached_classproperty + def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore + return HFCheckpointConverter( + cls, # type: ignore + "meta-llama/Llama-2-7b-hf", + trust_remote_code=True, + tokenizer="hf-internal-testing/llama-tokenizer", + HfConfigClass=HfLlamaConfig, + ) + + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + return LlamaConfig( + seq_len=hf_config.max_position_embeddings, + hidden_dim=hf_config.hidden_size, + intermediate_dim=hf_config.intermediate_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + activation_function=hf_config.hidden_act, + initializer_range=hf_config.initializer_range, + layer_norm_epsilon=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfLlamaConfig: + """Convert to HuggingFace's LlamaConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfLlamaConfig: HuggingFace's LlamaConfig + """ + if config_overrides is None: + config_overrides = {} + + return HfLlamaConfig( + max_position_embeddings=self.seq_len, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + hidden_act=self.activation_function, + initializer_range=self.initializer_range, + rms_norm_eps=self.layer_norm_epsilon, + rope_scaling=self.rope_scaling, + vocab_size=vocab_size, + **config_overrides, + ) + + @property + def model_type(cls) -> Type["LlamaLMHeadModel"]: + return LlamaLMHeadModel + + +class LlamaMlp(eqx.Module, StateDictSerializationMixin): + """Multi-layer Perceptron + In comparison with GPT2, LlamaMlp adds an up-proj that multiplies with activated gate_proj, + before down-proj. + """ + + gate_proj: hnn.Linear # projection from Embed to Mlp + up_proj: hnn.Linear # projection from Embed to Mlp + down_proj: hnn.Linear # projection from Mlp to Embed + act: Callable = eqx.static_field() + + @staticmethod + def init( + Embed: Axis, Mlp: Axis, activation_fn: Union[str, Callable], *, key, use_bias: bool = False + ) -> "LlamaMlp": + k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3) + gate_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias) + up_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_up_proj, use_bias=use_bias) + down_proj = hnn.Linear.init(Out=Embed, In=Mlp, key=k_down_proj, use_bias=use_bias) + if isinstance(activation_fn, str): + activation_fn = ACT2FN[activation_fn] + act = activation_fn # type: ignore + return LlamaMlp(gate_proj, up_proj, down_proj, act) + + @named_call + def __call__(self, x: NamedArray) -> NamedArray: + hidden_states = self.gate_proj(x) + hidden_states = self.act(hidden_states) + hidden_states = hidden_states * self.up_proj(x) + outputs = self.down_proj(hidden_states) + return outputs + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp + d = {} + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "gate_proj"), state_dict, self.gate_proj, out_dims_first_in_dict=True + ) + ) + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "up_proj"), state_dict, self.up_proj, out_dims_first_in_dict=True + ) + ) + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "down_proj"), state_dict, self.down_proj, out_dims_first_in_dict=True + ) + ) + + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix=prefix) + + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "gate_proj"), self.gate_proj, out_dims_first_in_dict=True) + ) + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "up_proj"), self.up_proj, out_dims_first_in_dict=True) + ) + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "down_proj"), self.down_proj, out_dims_first_in_dict=True) + ) + + state_dict.update(my_dict) + return state_dict + + +class LlamaRotaryEmbedding(eqx.Module, StateDictSerializationMixin): + Pos: Axis = eqx.field(static=True) + cos_cached: NamedArray = eqx.field(static=True) + sin_cached: NamedArray = eqx.field(static=True) + + def __init__(self, HeadSize: Axis, Pos: Axis, base: int = 10000): + self.Pos = Pos + # this must be compile-time b/c we want to store them in a static field + with jax.ensure_compile_time_eval(): + self.cos_cached, self.sin_cached = self._get_cos_sin_cache(Pos=Pos, HeadSize=HeadSize, base=base) + + @staticmethod + def _get_cos_sin_cache(HeadSize: hax.Axis, Pos: hax.Axis, base: float) -> Tuple[NamedArray, NamedArray]: + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + + position_ids: NamedArray = hax.arange(Pos) + + freqs = position_ids * inv_freq.broadcast_axis(Pos) + # This is different from the paper but aligns with HF implementation: + # It uses a different permutation in order to obtain the same calculation + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos_cached = hax.cos(emb) + sin_cached = hax.sin(emb) + # This is different from the paper but aligns with HF implementation: + return cos_cached, sin_cached + + def __call__(self, seq_len: int) -> Tuple[NamedArray, NamedArray]: + return ( + self.cos_cached[self.Pos, :seq_len], + self.sin_cached[self.Pos, :seq_len], + ) + + +class LlamaAttention(StateDictSerializationMixin, eqx.Module): + config: LlamaConfig = eqx.static_field() + q_proj: hnn.Linear # projection from Embed to query + k_proj: hnn.Linear # projection from Embed to key + v_proj: hnn.Linear # projection from Embed to value + o_proj: hnn.Linear # projection from Heads to output + rotary_emb: LlamaRotaryEmbedding # rotary embedding + + @staticmethod + def init(config: LlamaConfig, *, key) -> "LlamaAttention": + use_bias = config.use_bias + Embed = config.Embed + k_q, k_k, k_v, k_o = jrandom.split(key, 4) + q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias) + k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias) + v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias) + o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias) + rotary_emb = LlamaRotaryEmbedding(config.HeadSize, config.Pos) + return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) + + @named_call + def __call__(self, x: NamedArray, mask: Optional[NamedArray]): + q = self.q_proj(x) # TODO: rearrange and possibly rename + k = self.k_proj(x) + v = self.v_proj(x) + + cos, sin = self.rotary_emb(seq_len=self.config.seq_len) + + q, k = _apply_rotary_pos_emb(q, k, cos, sin) + + q = q.astype(jnp.float32) + k = k.astype(jnp.float32).rename({"position": "key_position"}) + v = v.rename({"position": "key_position"}) + + c = self.config + attn_output = hnn.attention.dot_product_attention(c.Pos, c.KeyPos, c.HeadSize, q, k, v, mask) + attn_output = attn_output.astype(x.dtype) + + attn_output = self.o_proj(attn_output) + return attn_output + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + # unflatten the linear layers of HF state_dict to match the shape of LlamaAttention + d = {} + d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, True)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, True)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, True)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "o_proj"), state_dict, self.o_proj, True)) + + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + # flatten the linear layers of LlamaAttention to match the shape of HF state_dict + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix) + + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "q_proj"), self.q_proj, True)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "k_proj"), self.k_proj, True)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "v_proj"), self.v_proj, True)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "o_proj"), self.o_proj, True)) + + state_dict.update(my_dict) + return state_dict + + +class LlamaRMSNorm(hnn.LayerNorm): + """It is a modified version of LayerNorm. + The main changes are: + 1. The variance is defined as the average of square, versus the original + definition as the average of the squared deviations from the mean. + 2. The output is defined as x * inv, without minusing the mean. + 3. The default value of eps is set to 1e-6 and use_bias to False. + """ + + @staticmethod + def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False): + if use_weight: + weight = hax.ones(axis) + else: + weight = None + if use_bias: + bias = hax.zeros(axis) + else: + bias = None + + return LlamaRMSNorm(axis, weight, bias, eps) + + def __call__(self, x: NamedArray) -> NamedArray: + # This gives a different result than jnp.var(), which is + # defined as the average of the squared deviations from the mean + var = hax.mean(hax.square(x), axis=self.axis) + inv = hax.rsqrt(var + self.eps) + out = x * inv + + if self.weight is not None: + out = self.weight * out + if self.bias is not None: + out = out + self.bias + return out + + +class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): + config: LlamaConfig = eqx.static_field() + self_attn: LlamaAttention + mlp: LlamaMlp + input_layernorm: LlamaRMSNorm + post_attention_layernorm: LlamaRMSNorm + + @staticmethod + def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": + k_attn, k_mlp = jrandom.split(key, 2) + + attn = LlamaAttention.init(config, key=k_attn) + mlp = LlamaMlp.init( + config.Embed, + config.Mlp, + config.activation_function, + key=k_mlp, + use_bias=config.use_bias, + ) + ln_1 = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + ln_2 = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + + return LlamaDecoderLayer(config, attn, mlp, ln_1, ln_2) + + @named_call + def __call__(self, x: NamedArray, mask: Optional[NamedArray]): + # self attention and skip connection + residual = x + x = self.input_layernorm(x) + attn_output = self.self_attn(x=x, mask=mask) + x = residual + attn_output + + # MLP and skip connection + residual = x + x = self.post_attention_layernorm(x) + mlp_output = self.mlp(x) + output = residual + mlp_output + return output + + +class LlamaTransformer(StateDictSerializationMixin, eqx.Module): + config: LlamaConfig = eqx.static_field() + layers: Stacked[LlamaDecoderLayer] + norm: LlamaRMSNorm + + @staticmethod + def init(config: LlamaConfig, *, key) -> "LlamaTransformer": + layers = Stacked.init(config.Layers, LlamaDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, + key=shaped_rng_split(key, config.num_layers), + ) + ln_f = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + + return LlamaTransformer(config, layers, ln_f) + + @named_call + def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray]) -> NamedArray: + x = self.layers.fold(x, mask=attn_mask) + x = self.norm(x) + + return x + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + stacked = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layers")) + out = super().from_state_dict(stacked, prefix=prefix) + return out + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_state_dict: StateDict = {} + super().update_state_dict(my_state_dict, prefix=prefix) + + stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layers")) + state_dict.update(stacked_dict) + + return state_dict + + +class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): + """Similar to GPT2 Embedding, except that: + - Llama doesn't have position embedding in the Embedding layer. + - Llama doesn't use dropout. + """ + + Vocab: Axis = eqx.static_field() + config: LlamaConfig = eqx.static_field() + token_embeddings: NamedArray + + @staticmethod + def init(Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaEmbedding": + k_wte = jrandom.split(key, 1) + + token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) + return LlamaEmbedding(Vocab, config, token_embeddings) + + @named_call + def embed(self, input_ids, *args): + input_embeds = self.token_embeddings.take("vocab", input_ids) + x = input_embeds + return x + + def unembed(self, x: NamedArray): + return hax.dot("embed", x, self.token_embeddings) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"token_embeddings": "model.embed_tokens.weight"} + + def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): + new_weights = hax.tree_util.resize_axis(self.token_embeddings, self.Vocab, new_size, key=key) + return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) + + +class LlamaLMHeadModel(eqx.Module, LmHeadModel[LlamaConfig], StateDictSerializationMixin): + transformer: LlamaTransformer + embeddings: LlamaEmbedding + lm_head: hnn.Linear + + @property + def config(self): + return self.transformer.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + @classmethod + def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": + k_t, k_emb = jrandom.split(key, 2) + transformer = LlamaTransformer.init(config, key=k_t) + embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + return LlamaLMHeadModel(transformer, embeddings, lm_head) + + def __call__( + self, + input_ids: NamedArray, + attn_mask: Optional[NamedArray] = None, + *, + inference: bool = False, + key=None, + ) -> NamedArray: + """ + Args: + input_ids (NamedArray): [batch, position] + Indices of input sequence tokens in the vocabulary. + attn_mask (NamedArray, optional): [batch, position, seq_len] + Mask to avoid performing attention on the padding token indices of the encoder input. + """ + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask) + lm_logits = self.lm_head(x) + return lm_logits + + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": + new_Vocab = self.Vocab.resize(new_size) + k1, k2 = maybe_rng_split(key, 2) + new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"transformer": "model", "embeddings": None} + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp + d = state_dict.copy() + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + ) + ) + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix=prefix) + + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) + ) + + state_dict.update(my_dict) + return state_dict + + +def _rotate_half(x: NamedArray) -> NamedArray: + """Rotates half of the hidden dims of the input and concatenates them.""" + HeadSize = x.axes[-1] + x1 = x[HeadSize, : HeadSize.size // 2] + x2 = x[HeadSize, HeadSize.size // 2 :] + out = hax.concatenate(HeadSize, (-x2, x1)) + return out + + +def _apply_rotary_pos_emb( + q: NamedArray, # [batch, position, heads, head_size] + k: NamedArray, # [batch, position, kv_heads, head_size] + cos: NamedArray, # [position, head_size] + sin: NamedArray, # [position, head_size] +) -> Tuple[NamedArray, NamedArray]: + """Applies rotary position embedding to q and k.""" + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 42dc81ba6..1691768a1 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -60,9 +60,8 @@ def Vocab(self) -> Axis: pass @property - @abc.abstractmethod def Pos(self) -> Axis: - pass + return self.config.Pos @classmethod @abc.abstractmethod diff --git a/tests/test_llama.py b/tests/test_llama.py new file mode 100644 index 000000000..5106ab0f9 --- /dev/null +++ b/tests/test_llama.py @@ -0,0 +1,300 @@ +import tempfile + +import jax +import numpy as np +import transformers +from jax import random + +import haliax as hax + +from levanter.models.llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaLMHeadModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from levanter.models.llama import _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb +from levanter.models.llama import _rotate_half as levanter_rotate_half +from test_utils import skip_if_no_torch + + +@skip_if_no_torch +def test_llama_config(): + # load HF config and convert to levanter config + hf_config = transformers.LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf") + llama_config = LlamaConfig.from_hf_config(hf_config) + + # convert back to HF config + config_overrides = { + "_name_or_path": hf_config._name_or_path, + "architectures": hf_config.architectures, + "torch_dtype": hf_config.torch_dtype, + } + new_hf_config = llama_config.to_hf_config( + vocab_size=hf_config.vocab_size, + config_overrides=config_overrides, + ) + + # assert the content in new_hf_config is the same as hf_config + for k in new_hf_config.__dict__.keys(): + if k in ["_commit_hash", "transformers_version"]: + continue + assert getattr(new_hf_config, k) == getattr( + hf_config, k + ), f"{k} {getattr(new_hf_config, k)} != {getattr(hf_config, k)}" + + +@skip_if_no_torch +def test_llama_rotary_embedding(): + import torch + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding + + llama_config = _get_llama_config() + HeadSize = llama_config.HeadSize + Pos = llama_config.Pos + hidden_dim = HeadSize.size + seq_len = Pos.size + key = random.PRNGKey(0) + device = "cpu" + + x = random.normal(key, (1, seq_len)) + x_torch = torch.from_numpy(np.array(x)) + + levanter_rope = LlamaRotaryEmbedding(HeadSize=HeadSize, Pos=Pos) + levanter_output = levanter_rope(seq_len=seq_len) + hf_rope = HFLlamaRotaryEmbedding(dim=hidden_dim, max_position_embeddings=seq_len, device=device) + hf_output = hf_rope(x_torch, seq_len=seq_len) + + for jax_out, torch_out in zip(levanter_output, hf_output): + torch_out = torch_out.numpy() + assert np.isclose(torch_out, np.array(jax_out.array), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + + +@skip_if_no_torch +def test_apply_rotary_pos_emb(): + import torch + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as hf_apply_rotary_pos_emb + from transformers.models.llama.modeling_llama import rotate_half as hf_rotate_half + + def assert_equal_out(hax_out, torch_out: torch.Tensor): + assert np.isclose( + torch_out.numpy(), np.array(hax_out.array), rtol=1e-2, atol=1e-2 + ).all(), f"{torch_out} != {hax_out}" + + def named_array_to_tensor(named_array): + return torch.from_numpy(np.array(named_array.array)) + + llama_config = _get_llama_config() + + Pos = llama_config.Pos + Heads = llama_config.Heads + HeadSize = llama_config.HeadSize + Batch = hax.Axis("batch", 2) + + # note here we switch Heads and Pos for the shape of the output tensors + q = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Heads, HeadSize)) + k = hax.random.normal(random.PRNGKey(1), (Batch, Pos, Heads, HeadSize)) + + # Check the output of _rotate_half() from levanter and hf + levanter_out_rf_q = levanter_rotate_half(q) + levanter_out_rf_k = levanter_rotate_half(k) + + q_tensor = named_array_to_tensor(q).transpose(1, 2) # needed for HF + k_tensor = named_array_to_tensor(k).transpose(1, 2) + hf_out_rf_q = hf_rotate_half(q_tensor).transpose(1, 2) # re-transpose to match levanter + hf_out_rf_k = hf_rotate_half(k_tensor).transpose(1, 2) + + assert_equal_out(levanter_out_rf_q, hf_out_rf_q) + assert_equal_out(levanter_out_rf_k, hf_out_rf_k) + + # Check the output of _apply_rotary_pos_emb() from levanter and hf + cos = hax.random.normal(random.PRNGKey(2), (Pos, HeadSize)) + sin = hax.random.normal(random.PRNGKey(3), (Pos, HeadSize)) + + levanter_out_rope_q, levanter_out_rope_k = levanter_apply_rotary_pos_emb(q, k, cos, sin) + cos_tensor = named_array_to_tensor(cos) + sin_tensor = named_array_to_tensor(sin) + position_ids = hax.arange(Pos).broadcast_axis(Batch) + position_ids_tensor = named_array_to_tensor(position_ids) + + hf_out_rope_q, hf_out_rope_k = hf_apply_rotary_pos_emb( + q_tensor, k_tensor, cos_tensor, sin_tensor, position_ids_tensor + ) + hf_out_rope_q = hf_out_rope_q.transpose(1, 2) # re-transpose to match levanter + hf_out_rope_k = hf_out_rope_k.transpose(1, 2) + assert_equal_out(levanter_out_rope_q, hf_out_rope_q) + assert_equal_out(levanter_out_rope_k, hf_out_rope_k) + + +@skip_if_no_torch +def test_llama_attention(): + import torch + from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention + + config = _get_llama_config() + + attention = LlamaAttention.init(config=config, key=random.PRNGKey(0)) + + state = attention.to_state_dict() + state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} + hf_attention = HFLlamaAttention(config.to_hf_config(32000)) + hf_attention.load_state_dict(state, strict=True) + + x, mask = _get_random_inputs(config) + x_torch = torch.from_numpy(np.array(x.array)) + batch_size = x_torch.shape[0] + mask_torch = torch.from_numpy(np.array(mask.array)).broadcast_to((batch_size, 1, -1, -1)) + + # the torch mask is really a bias, so we need to invert it and make it a big negative number + mask_torch = (mask_torch == 0).float() * -1e9 + + out = attention(x, mask) + hf_out = hf_attention(x_torch, mask_torch) + + assert np.isclose( + hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-4, atol=1e-4 + ).all(), f"{hf_out[0]} != {out}" + + +@skip_if_no_torch +def test_llama_rms_norm(): + import torch + from transformers.models.llama.modeling_llama import LlamaRMSNorm as HFLlamaRMSNorm + + config = _get_llama_config() + ln = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + hf_ln = HFLlamaRMSNorm(config.Embed.size, eps=config.layer_norm_epsilon) + + x, _ = _get_random_inputs(config) + x_torch = torch.from_numpy(np.array(x.array)) + + out = ln(x) + hf_out = hf_ln(x_torch) + + assert np.isclose( + hf_out.detach().cpu().numpy(), np.array(out.array), rtol=1e-6, atol=1e-6 + ).all(), f"{hf_out} != {out}" + + +@skip_if_no_torch +def test_llama_decoder_layer(): + import torch + from transformers.models.llama.modeling_llama import LlamaDecoderLayer as HFLlamaDecoderLayer + + llama_config = _get_llama_config() + key = random.PRNGKey(0) + llama_decoder_layer = LlamaDecoderLayer.init(config=llama_config, key=key) + + state = llama_decoder_layer.to_state_dict() + state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} + hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000)) + hf_decoder_layer.load_state_dict(state, strict=True) + + x, mask = _get_random_inputs(llama_config) + x_torch = torch.from_numpy(np.array(x.array)) + batch_size = x_torch.shape[0] + mask_torch = torch.from_numpy(np.array(mask.array)).broadcast_to((batch_size, 1, -1, -1)) + mask_torch = (mask_torch == 0).float() * -1e9 + + out = llama_decoder_layer(x, mask) + hf_out = hf_decoder_layer(x_torch, mask_torch) + + assert np.isclose( + hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-4, atol=1e-4 + ).all(), f"{hf_out[0]} != {out}" + + +def test_llama_lm_head_model(): + llama_config = _get_llama_config() + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = llama_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = hax.nn.attention.causal_mask(Pos, llama_config.KeyPos) + + llama_model = LlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + out = llama_model(input_ids, mask) + assert out.array.shape == (Batch.size, Pos.size, Vocab.size) + + +@skip_if_no_torch +def test_llama_roundtrip(): + import torch + from transformers import AutoModelForCausalLM, LlamaForCausalLM + + converter = LlamaConfig.default_hf_checkpoint_converter + + config = LlamaConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + gradient_checkpointing=False, + ) + Vocab = hax.Axis("vocab", 1000) + hf_config = config.to_hf_config(Vocab.size) + + # Make input and attn_mask + input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + attn_mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) + + torch.random.manual_seed(0) + + torch_model = LlamaForCausalLM(hf_config) + torch_model.eval() + + torch_out = torch_model(input_torch) + torch_out = torch_out.logits[0].detach().cpu().numpy() + torch_out = jax.nn.softmax(torch_out, axis=-1) + + with tempfile.TemporaryDirectory() as tmpdir: + torch_model.save_pretrained(f"{tmpdir}/torch_model") + + model = converter.load_pretrained( + LlamaLMHeadModel, f"{tmpdir}/torch_model", resize_vocab_to_match_tokenizer=False + ) + + def compute(input): + model_output = model(input, attn_mask=attn_mask) + return hax.nn.softmax(model_output, axis=model.Vocab) + + compute = jax.jit(compute) + jax_out = compute(input).array + + assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + + converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) + torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") + torch_model2.eval() + + torch_out2 = torch_model2(input_torch) + torch_out2 = torch_out2.logits[0].detach().cpu().numpy() + torch_out2 = jax.nn.softmax(torch_out2, axis=-1) + assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" + assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + + +def _get_llama_config() -> LlamaConfig: + rope_scaling = { + "type": "linear", + "factor": 2.0, + } + return LlamaConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + rope_scaling=rope_scaling, + gradient_checkpointing=False, # disable for tests so debugging is easier + ) + + +def _get_random_inputs(config: LlamaConfig): + Embed = config.Embed + Pos = config.Pos + Batch = hax.Axis("batch", 2) + x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) + mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) + return x, mask