From ca6ada5bec87129068154ecceb9e2c6e53b08701 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 30 Jul 2023 16:47:42 -0700 Subject: [PATCH 01/57] Add Llama config, Mlp, Attention, and RotaryEmb --- src/levanter/models/llama.py | 152 +++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 src/levanter/models/llama.py diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py new file mode 100644 index 000000000..1dd44a6f9 --- /dev/null +++ b/src/levanter/models/llama.py @@ -0,0 +1,152 @@ +import equinox as eqx + +import haliax.nn as hnn +from haliax import Axis, NamedArray + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin +from levanter.models.lm_model import LmConfig + + +@LmConfig.register_subclass("llama") +@dataclass(frozen=True) +class LlamaConfig(HFCompatConfig): + """Config for LlamaModel + + Args: + vocab_size (int, optional): vocabulary size of the Llama model. Defaults to 32000. + hidden_dim (int, optional): dimension of the hidden state. Defaults to 4096. + intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008. + num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. + num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. + num_kv_heads (int, optional): number of key/value heads needed for Grouped Query Attention. Defaults to 32. + activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". + max_position_embeddings (int, optional): maximum length of the position embedding. Defaults to 2048. + """ + + vocab_size: int = 32000 + hidden_dim: int = 4096 + intermediate_dim: int = 11008 + num_layers: int = 32 + num_heads: int = 32 + num_kv_heads: int = 32 + activation_function: str = "silu" + max_position_embeddings: int = 2048 + initializer_range: float = 0.02 + use_bias: bool = True + + +class LlamaMlp(eqx.Module): + """Multi-layer Perceptron + In comparison with GPT2, LlamaMlp adds an up-proj that multiplies with activated gate_proj, + before down-proj. + """ + + gate_proj: hnn.Linear # projection from Embed to Intermediate + up_proj: hnn.Linear # projection from Embed to Intermediate + down_proj: hnn.Linear # projection from Intermediate to Embed + act: Callable = eqx.static_field() + + @staticmethod + def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = False) -> "LlamaMlp": + k_fc, k_up_proj, k_down_proj = eqx.split_key(key, 3) + gate_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias) + up_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_up_proj, use_bias=use_bias) + down_proj = hnn.Linear.init(Out=Embed, In=Mlp, key=k_down_proj, use_bias=use_bias) + if isinstance(activation_fn, str): + activation_fn = ACT2FN[activation_fn] + act = activation_fn # type: ignore + + @named_call + def __call__(self, x: Tensor, *, key: jnp.ndarray) -> Tensor: + hidden_states = self.gate_proj(x) + hidden_states = self.act(hidden_states) + hidden_states = hidden_states * self.up_proj(x) + outputs = self.down_proj(hidden_states) + return outputs + + +class LlamaAttention(StateDictSerializationMixin, eqx.Module): + config: LlamaConfig = eqx.static_field() + q_proj: hnn.Linear # projection from Embed to query + k_proj: hnn.Linear # projection from Embed to key + v_proj: hnn.Linear # projection from Embed to value + o_proj: hnn.Linear # projection from Heads to output + rotary_emb: hnn.RotaryEmbedding # rotary embedding + + @staticmethod + def init(config: Llama2Config, *, key) -> "Llama2Attention": + use_bias = config.use_bias + Embed = config.Embed + + k_q, k_k, k_v, k_o = eqx.split_key(key, 4) + q_proj = hnn.Linear.init(In=Embed, Out=config.Heads * config.HeadDim, key=k_q, use_bias=use_bias) + k_proj = hnn.Linear.init(In=Embed, Out=config.KVHeads * config.HeadDim, key=k_k, use_bias=use_bias) + v_proj = hnn.Linear.init(In=Embed, Out=config.KVHeads * config.HeadDim, key=k_v, use_bias=use_bias) + o_proj = hnn.Linear.init(In=config.Heads * config.HeadDim, Out=Embed, key=k_o, use_bias=use_bias) + rotary_emb = _get_rotary_emb(config) + return Llama2Attention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) + + named_call + + def __call__(self, x: NamedArray, mask: Optional[NamedArray], layer_idx, inference: bool = True, *, key): + k_q, k_k, k_v = eqx.split_key(key, 3) + + +class LlamaBlock(StateDictSerializationMixin, eqx.Module): + pass + + +class LlamaTransformer(StateDictSerializationMixin, eqx.Module): + pass + + +class LlamaEmbeddings(StateDictSerializationMixin, eqx.Module): + pass + + +class LlamaLMHeadModel(eqx.Module): + pass + + +import jax +import jax.numpy as jnp +from flax import linen as nn + + +class LlamaRotaryEmbedding(eqx.Module): + dim: int + max_position_embeddings: int = 2048 + base: float = 10000 + + def setup(self): + inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) + self.inv_freq = inv_freq + + # Build here to make the embedding. + self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=self.max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + t = jnp.arange(self.max_seq_len_cached) + + freqs = jnp.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = jnp.concatenate((freqs, freqs), axis=-1) + cos_cached = jnp.cos(emb)[None, None, :, :] + sin_cached = jnp.sin(emb)[None, None, :, :] + + return cos_cached, sin_cached + + def __call__(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=seq_len) + + return ( + self.cos_cached[:, :, :seq_len, ...], + self.sin_cached[:, :, :seq_len, ...], + ) + + +def _get_rotary_emb(config: LlamaConfig): + return None From 667a56c65b9b46f84e37f974d41dba533b3455c3 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 30 Jul 2023 17:06:59 -0700 Subject: [PATCH 02/57] address integration test --- src/levanter/models/llama.py | 29 +++++++++++++++++++---------- tests/test_llama.py | 16 ++++++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) create mode 100644 tests/test_llama.py diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 1dd44a6f9..fbead2beb 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -1,9 +1,14 @@ +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Type + import equinox as eqx import haliax.nn as hnn from haliax import Axis, NamedArray +from haliax.jax_utils import named_call from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin +from levanter.compat.torch_serialization import StateDictSerializationMixin from levanter.models.lm_model import LmConfig @@ -57,7 +62,7 @@ def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = False) act = activation_fn # type: ignore @named_call - def __call__(self, x: Tensor, *, key: jnp.ndarray) -> Tensor: + def __call__(self, x: NamedArray) -> NamedArray: hidden_states = self.gate_proj(x) hidden_states = self.act(hidden_states) hidden_states = hidden_states * self.up_proj(x) @@ -71,10 +76,10 @@ class LlamaAttention(StateDictSerializationMixin, eqx.Module): k_proj: hnn.Linear # projection from Embed to key v_proj: hnn.Linear # projection from Embed to value o_proj: hnn.Linear # projection from Heads to output - rotary_emb: hnn.RotaryEmbedding # rotary embedding + # rotary_emb # rotary embedding @staticmethod - def init(config: Llama2Config, *, key) -> "Llama2Attention": + def init(config: LlamaConfig, *, key) -> "LlamaAttention": use_bias = config.use_bias Embed = config.Embed @@ -83,8 +88,8 @@ def init(config: Llama2Config, *, key) -> "Llama2Attention": k_proj = hnn.Linear.init(In=Embed, Out=config.KVHeads * config.HeadDim, key=k_k, use_bias=use_bias) v_proj = hnn.Linear.init(In=Embed, Out=config.KVHeads * config.HeadDim, key=k_v, use_bias=use_bias) o_proj = hnn.Linear.init(In=config.Heads * config.HeadDim, Out=Embed, key=k_o, use_bias=use_bias) - rotary_emb = _get_rotary_emb(config) - return Llama2Attention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) + # rotary_emb = _get_rotary_emb(config) + return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj) named_call @@ -110,17 +115,21 @@ class LlamaLMHeadModel(eqx.Module): import jax import jax.numpy as jnp -from flax import linen as nn -class LlamaRotaryEmbedding(eqx.Module): +class LlamaRotaryEmbedding(): dim: int max_position_embeddings: int = 2048 base: float = 10000 + def __init__(self, dim, max_position_embeddings=2048, base=10000): + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.setup() + def setup(self): - inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) - self.inv_freq = inv_freq + self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) # Build here to make the embedding. self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=self.max_position_embeddings) @@ -137,7 +146,7 @@ def _set_cos_sin_cache(self, seq_len): return cos_cached, sin_cached - def __call__(self, x, seq_len=None): + def __call__(self, x, seq_len: int): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=seq_len) diff --git a/tests/test_llama.py b/tests/test_llama.py new file mode 100644 index 000000000..8aa3465d5 --- /dev/null +++ b/tests/test_llama.py @@ -0,0 +1,16 @@ +from jax import random + +from levanter.models.llama import LlamaRotaryEmbedding + + +def test_llama_rotary_embedding(): + dim = 2048 + max_position_embeddings = 2048 + seq_len = 2048 + base = 10000 + key = random.PRNGKey(0) + rotary_emb = LlamaRotaryEmbedding(dim=dim) + rotary_emb.setup() + x = random.normal(key, (1, 2048)) + x = rotary_emb(x, seq_len=seq_len) + From 0e77701907da857d90d668f4e794319e5d2651ad Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 31 Jul 2023 19:35:50 -0700 Subject: [PATCH 03/57] test compare with hf implementation --- tests/test_llama.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/test_llama.py b/tests/test_llama.py index 8aa3465d5..d07768959 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,16 +1,28 @@ from jax import random +import numpy as np +import torch + +# src/transformers/models/llama/modeling_llama.py +from transformers.models.llama.modeling_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding from levanter.models.llama import LlamaRotaryEmbedding def test_llama_rotary_embedding(): + """Match against HuggingFace's implementation of LlamaRotaryEmbedding.""" dim = 2048 - max_position_embeddings = 2048 seq_len = 2048 - base = 10000 key = random.PRNGKey(0) - rotary_emb = LlamaRotaryEmbedding(dim=dim) - rotary_emb.setup() - x = random.normal(key, (1, 2048)) - x = rotary_emb(x, seq_len=seq_len) + x = random.normal(key, (1, seq_len)) + levanter_rotary_emb = LlamaRotaryEmbedding(dim=dim) + levanter_rotary_emb.setup() + levanter_output = levanter_rotary_emb(x, seq_len=seq_len) + + hf_rotary_emb = HFLlamaRotaryEmbedding(dim=dim, device="cpu") + x_torch = torch.from_numpy(np.array(x)) + hf_output = hf_rotary_emb(x_torch, seq_len=seq_len) + for jax_out, torch_out in zip(levanter_output, hf_output): + torch_out = torch_out.numpy() + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" From ed10ee15339d5c746d55f5c7ffc3e1603929da7e Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 31 Jul 2023 19:40:36 -0700 Subject: [PATCH 04/57] Finish LlamaRotaryEmbedding --- src/levanter/models/llama.py | 20 ++++++++++---------- tests/test_llama.py | 5 ++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index fbead2beb..460860d94 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -2,6 +2,8 @@ from typing import Callable, Dict, Optional, Type import equinox as eqx +import jax +import jax.numpy as jnp import haliax.nn as hnn from haliax import Axis, NamedArray @@ -113,23 +115,20 @@ class LlamaLMHeadModel(eqx.Module): pass -import jax -import jax.numpy as jnp - - -class LlamaRotaryEmbedding(): +class LlamaRotaryEmbedding(eqx.Module): dim: int max_position_embeddings: int = 2048 base: float = 10000 + inv_freq: jnp.ndarray = eqx.static_field() + cos_cached: jnp.ndarray = eqx.static_field() + sin_cached: jnp.ndarray = eqx.static_field() + max_seq_len_cached: int = eqx.static_field() def __init__(self, dim, max_position_embeddings=2048, base=10000): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - self.setup() - - def setup(self): - self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) + self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) # Build here to make the embedding. self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=self.max_position_embeddings) @@ -139,7 +138,8 @@ def _set_cos_sin_cache(self, seq_len): t = jnp.arange(self.max_seq_len_cached) freqs = jnp.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation + # Different from paper but following HF implementation + # It uses a different permutation in order to obtain the same calculation emb = jnp.concatenate((freqs, freqs), axis=-1) cos_cached = jnp.cos(emb)[None, None, :, :] sin_cached = jnp.sin(emb)[None, None, :, :] diff --git a/tests/test_llama.py b/tests/test_llama.py index d07768959..5c515c428 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,11 +1,11 @@ -from jax import random - import numpy as np import torch +from jax import random # src/transformers/models/llama/modeling_llama.py from transformers.models.llama.modeling_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding + from levanter.models.llama import LlamaRotaryEmbedding @@ -16,7 +16,6 @@ def test_llama_rotary_embedding(): key = random.PRNGKey(0) x = random.normal(key, (1, seq_len)) levanter_rotary_emb = LlamaRotaryEmbedding(dim=dim) - levanter_rotary_emb.setup() levanter_output = levanter_rotary_emb(x, seq_len=seq_len) hf_rotary_emb = HFLlamaRotaryEmbedding(dim=dim, device="cpu") From 6034aeedfaa51d68e0ff885df2aeb7bb8f1edb67 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 31 Jul 2023 20:01:46 -0700 Subject: [PATCH 05/57] Add LlamaLinearScalingRotaryEmbedding --- src/levanter/models/llama.py | 28 +++++++++++++++++++++--- tests/test_llama.py | 42 ++++++++++++++++++++++++++---------- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 460860d94..ad4aea1bb 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -128,15 +128,14 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) - - # Build here to make the embedding. + self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=self.max_position_embeddings) def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len t = jnp.arange(self.max_seq_len_cached) + # Evaluates the Einstein summation convention on the operands. freqs = jnp.einsum("i,j->ij", t, self.inv_freq) # Different from paper but following HF implementation # It uses a different permutation in order to obtain the same calculation @@ -157,5 +156,28 @@ def __call__(self, x, seq_len: int): ) +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling""" + + scaling_factor: float = 1.0 + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base) + + def _set_cos_sin_cache(self, seq_len): + """The main difference is that the scaling factor is applied to the time axis""" + self.max_seq_len_cached = seq_len + t = jnp.arange(self.max_seq_len_cached) / self.scaling_factor + + freqs = jnp.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = jnp.concatenate((freqs, freqs), axis=-1) + cos_cached = jnp.cos(emb)[None, None, :, :] + sin_cached = jnp.sin(emb)[None, None, :, :] + + return cos_cached, sin_cached + + def _get_rotary_emb(config: LlamaConfig): return None diff --git a/tests/test_llama.py b/tests/test_llama.py index 5c515c428..132e88159 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -3,10 +3,15 @@ from jax import random # src/transformers/models/llama/modeling_llama.py -from transformers.models.llama.modeling_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding -from levanter.models.llama import LlamaRotaryEmbedding +try: + from transformers.models.llama.modeling_llama import ( + LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, + ) +except ImportError: + HFLlamaLinearScalingRotaryEmbedding = None +from levanter.models.llama import LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding def test_llama_rotary_embedding(): @@ -14,14 +19,29 @@ def test_llama_rotary_embedding(): dim = 2048 seq_len = 2048 key = random.PRNGKey(0) - x = random.normal(key, (1, seq_len)) - levanter_rotary_emb = LlamaRotaryEmbedding(dim=dim) - levanter_output = levanter_rotary_emb(x, seq_len=seq_len) + device = "cpu" - hf_rotary_emb = HFLlamaRotaryEmbedding(dim=dim, device="cpu") - x_torch = torch.from_numpy(np.array(x)) - hf_output = hf_rotary_emb(x_torch, seq_len=seq_len) + def test_levanter_against_hf(levanter_class, hf_class): + x = random.normal(key, (1, seq_len)) + x_torch = torch.from_numpy(np.array(x)) - for jax_out, torch_out in zip(levanter_output, hf_output): - torch_out = torch_out.numpy() - assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + levanter_output = levanter_class(x, seq_len=seq_len) + hf_output = hf_class(x_torch, seq_len=seq_len) + + for jax_out, torch_out in zip(levanter_output, hf_output): + torch_out = torch_out.numpy() + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + + # test LlamaRotaryEmbedding + test_levanter_against_hf( + levanter_class=LlamaRotaryEmbedding(dim=dim), + hf_class=HFLlamaRotaryEmbedding(dim=dim, device=device), + ) + + # test LlamaLinearScalingRotaryEmbedding + if HFLlamaLinearScalingRotaryEmbedding is not None: + scaling_factor = 2.0 + test_levanter_against_hf( + levanter_class=LlamaLinearScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor), + hf_class=HFLlamaLinearScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor, device=device), + ) From e9a7c560533eaf07d78a86a71f2276797fcecf56 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 31 Jul 2023 20:16:07 -0700 Subject: [PATCH 06/57] Add LlamaDynamicNTKScalingRotaryEmbedding --- src/levanter/models/llama.py | 28 ++++++++++++++++++++++++++++ tests/test_llama.py | 34 ++++++++++++++++++---------------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index ad4aea1bb..beca04ba6 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -179,5 +179,33 @@ def _set_cos_sin_cache(self, seq_len): return cos_cached, sin_cached +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. """ + scaling_factor: float = 1.0 + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + self.inv_freq = 1.0 / (base ** (jnp.arange(0, self.dim, 2) / self.dim)) + + t = jnp.arange(self.max_seq_len_cached) + + freqs = jnp.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = jnp.concatenate((freqs, freqs), axis=-1) + cos_cached = jnp.cos(emb)[None, None, :, :] + sin_cached = jnp.sin(emb)[None, None, :, :] + + return cos_cached, sin_cached + + def _get_rotary_emb(config: LlamaConfig): return None diff --git a/tests/test_llama.py b/tests/test_llama.py index 132e88159..1c4de61a8 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -2,22 +2,20 @@ import torch from jax import random -# src/transformers/models/llama/modeling_llama.py -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding - -try: - from transformers.models.llama.modeling_llama import ( - LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, - ) -except ImportError: - HFLlamaLinearScalingRotaryEmbedding = None -from levanter.models.llama import LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding +# The latter 2 classes are only available in HuggingFace's transformers 4.30.0 or later +from transformers.models.llama.modeling_llama import ( + LlamaRotaryEmbedding as HFLlamaRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, + LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, +) +from levanter.models.llama import LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding def test_llama_rotary_embedding(): """Match against HuggingFace's implementation of LlamaRotaryEmbedding.""" dim = 2048 seq_len = 2048 + scaling_factor = 2.0 key = random.PRNGKey(0) device = "cpu" @@ -39,9 +37,13 @@ def test_levanter_against_hf(levanter_class, hf_class): ) # test LlamaLinearScalingRotaryEmbedding - if HFLlamaLinearScalingRotaryEmbedding is not None: - scaling_factor = 2.0 - test_levanter_against_hf( - levanter_class=LlamaLinearScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor), - hf_class=HFLlamaLinearScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor, device=device), - ) + test_levanter_against_hf( + levanter_class=LlamaLinearScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor), + hf_class=HFLlamaLinearScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor, device=device), + ) + + # test LlamaDynamicNTKScalingRotaryEmbedding + test_levanter_against_hf( + levanter_class=LlamaDynamicNTKScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor), + hf_class=HFLlamaDynamicNTKScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor, device=device), + ) From 6e97f5f48976dbd7bfbcfcaa8312a768f7097564 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 31 Jul 2023 20:23:33 -0700 Subject: [PATCH 07/57] Refactor to simplified class differences --- src/levanter/models/llama.py | 45 +++++++++++++++--------------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index beca04ba6..93cee8e97 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -131,9 +131,15 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000): self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=self.max_position_embeddings) + def _get_positional_ids(self): + """A helper function for the convenience of extending to two sub-classes + Here we use a standard positional encoding function, which was described in `Attention is all you need`. + """ + return jnp.arange(self.max_seq_len_cached) + def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len - t = jnp.arange(self.max_seq_len_cached) + t = self._get_positional_ids() # Evaluates the Einstein summation convention on the operands. freqs = jnp.einsum("i,j->ij", t, self.inv_freq) @@ -165,46 +171,31 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base) - def _set_cos_sin_cache(self, seq_len): - """The main difference is that the scaling factor is applied to the time axis""" - self.max_seq_len_cached = seq_len - t = jnp.arange(self.max_seq_len_cached) / self.scaling_factor - - freqs = jnp.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = jnp.concatenate((freqs, freqs), axis=-1) - cos_cached = jnp.cos(emb)[None, None, :, :] - sin_cached = jnp.sin(emb)[None, None, :, :] - - return cos_cached, sin_cached + def _get_positional_ids(self): + """Here we overwrite the function in the base class to implement linear scaling.""" + return jnp.arange(self.max_seq_len_cached) / self.scaling_factor class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. """ + """LlamaRotaryEmbedding extended with Dynamic NTK scaling.""" + scaling_factor: float = 1.0 def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base) - def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: + def _get_positional_ids(self): + """Here we overwrite the function in the base class. + Here it adjusts the frequency base dynamically according to the sequence length. + """ + if self.max_seq_len_cached > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) self.inv_freq = 1.0 / (base ** (jnp.arange(0, self.dim, 2) / self.dim)) - t = jnp.arange(self.max_seq_len_cached) - - freqs = jnp.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = jnp.concatenate((freqs, freqs), axis=-1) - cos_cached = jnp.cos(emb)[None, None, :, :] - sin_cached = jnp.sin(emb)[None, None, :, :] - - return cos_cached, sin_cached + return jnp.arange(self.max_seq_len_cached) def _get_rotary_emb(config: LlamaConfig): From 271be4b291acd4494370ec14773ed3a8f5344ea1 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 31 Jul 2023 20:39:30 -0700 Subject: [PATCH 08/57] Implement _get_rotary_emb --- src/levanter/models/llama.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 93cee8e97..20bb78c3f 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -28,6 +28,7 @@ class LlamaConfig(HFCompatConfig): num_kv_heads (int, optional): number of key/value heads needed for Grouped Query Attention. Defaults to 32. activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". max_position_embeddings (int, optional): maximum length of the position embedding. Defaults to 2048. + rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ vocab_size: int = 32000 @@ -40,6 +41,12 @@ class LlamaConfig(HFCompatConfig): max_position_embeddings: int = 2048 initializer_range: float = 0.02 use_bias: bool = True + rope_scaling: Optional[Dict] = None + + # Axis + Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) + Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) class LlamaMlp(eqx.Module): @@ -74,27 +81,29 @@ def __call__(self, x: NamedArray) -> NamedArray: class LlamaAttention(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() + hidden_dim: int = eqx.static_field() + max_position_embeddings: int = eqx.static_field() q_proj: hnn.Linear # projection from Embed to query k_proj: hnn.Linear # projection from Embed to key v_proj: hnn.Linear # projection from Embed to value o_proj: hnn.Linear # projection from Heads to output - # rotary_emb # rotary embedding + rotary_emb: LlamaRotaryEmbedding # rotary embedding @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaAttention": use_bias = config.use_bias Embed = config.Embed - + head_dim = config.HeadSize.size + max_position_embeddings = config.max_position_embeddings k_q, k_k, k_v, k_o = eqx.split_key(key, 4) q_proj = hnn.Linear.init(In=Embed, Out=config.Heads * config.HeadDim, key=k_q, use_bias=use_bias) k_proj = hnn.Linear.init(In=Embed, Out=config.KVHeads * config.HeadDim, key=k_k, use_bias=use_bias) v_proj = hnn.Linear.init(In=Embed, Out=config.KVHeads * config.HeadDim, key=k_v, use_bias=use_bias) o_proj = hnn.Linear.init(In=config.Heads * config.HeadDim, Out=Embed, key=k_o, use_bias=use_bias) - # rotary_emb = _get_rotary_emb(config) - return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj) - - named_call + rotary_emb = _get_rotary_emb(config, head_dim, max_position_embeddings) + return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) + @named_call def __call__(self, x: NamedArray, mask: Optional[NamedArray], layer_idx, inference: bool = True, *, key): k_q, k_k, k_v = eqx.split_key(key, 3) @@ -198,5 +207,15 @@ def _get_positional_ids(self): return jnp.arange(self.max_seq_len_cached) -def _get_rotary_emb(config: LlamaConfig): - return None +def _get_rotary_emb(config: LlamaConfig, head_dim: int, max_position_embeddings: int) -> LlamaRotaryEmbedding: + if config.rope_scaling is None: + return LlamaRotaryEmbedding(head_dim, max_position_embeddings) + else: + scaling_type = config.rope_scaling["type"] + scaling_factor = config.rope_scaling["factor"] + if scaling_type == "linear": + return LlamaRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) + elif scaling_type == "dynamic": + return LlamaRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") From 669e6e76077455508af868432ac5b905198681d3 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Wed, 2 Aug 2023 09:48:53 -0700 Subject: [PATCH 09/57] work on attention --- src/levanter/models/llama.py | 40 +++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 20bb78c3f..8e7036f1e 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -105,7 +105,26 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": @named_call def __call__(self, x: NamedArray, mask: Optional[NamedArray], layer_idx, inference: bool = True, *, key): - k_q, k_k, k_v = eqx.split_key(key, 3) + q = self.q_proj(x) # TODO: rearrange and possibly rename + k = self.k_proj(x) + v = self.v_proj(x) + + cos, sin = self.rotary_emb(v, seq_len=self.config.KVHeads.size) + + q, k = _apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + scale = jax.lax.rsqrt(float(self.config.HeadSize.size)) + + attn_weights = hax.dot("head_size", q, k) * scale + attn_weights = attn_weights + mask + + # upcast attention to fp32. This is default for Llama Attention + attn_weights = attn_weights.astype(jnp.float32) + + attn_weights = hnn.softmax(attn_weights, axis="key_position").astype(q.dtype) + attn_output = hax.dot("key_position", attn_weights, v) + + # TODO: continue class LlamaBlock(StateDictSerializationMixin, eqx.Module): @@ -158,8 +177,6 @@ def _set_cos_sin_cache(self, seq_len): cos_cached = jnp.cos(emb)[None, None, :, :] sin_cached = jnp.sin(emb)[None, None, :, :] - return cos_cached, sin_cached - def __call__(self, x, seq_len: int): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: @@ -219,3 +236,20 @@ def _get_rotary_emb(config: LlamaConfig, head_dim: int, max_position_embeddings: return LlamaRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + +def _rotate_half(x): + """Rotates half of the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return jnp.concatenate((-x2, x1), axis=-1) + + +def _apply_rotary_pos_emb(q, k, cos, sin, position_ids) -> Tuple[jnp.ndarray, jnp.ndarray]: + cos = jnp.squeeze(cos, axis=1) # from [1, 1, seq_len, dim] to [seq_len, dim] + sin = jnp.squeeze(sin, axis=1) + cos = cos[position_ids, None] # [seq_len, dim] -> [bs, 1, seq_len, dim] + sin = sin[position_ids, None] + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed From e45b39b803c38467751f562cbb500d1ccd56b1ae Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 6 Aug 2023 11:11:40 -0700 Subject: [PATCH 10/57] test initialize attention --- src/levanter/models/llama.py | 120 +++++++++++++++++++---------------- tests/test_llama.py | 56 ++++++++++++---- 2 files changed, 109 insertions(+), 67 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 8e7036f1e..0b57d3c57 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -1,25 +1,27 @@ -from dataclasses import dataclass -from typing import Callable, Dict, Optional, Type +from dataclasses import dataclass, field +from typing import Callable, Dict, Optional, Type, Tuple import equinox as eqx import jax import jax.numpy as jnp +import jax.random as jrandom import haliax.nn as hnn from haliax import Axis, NamedArray from haliax.jax_utils import named_call -from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin +from levanter.compat.hf_checkpoints import HFCheckpointConverter, LmWithHfSerializationMixin from levanter.compat.torch_serialization import StateDictSerializationMixin from levanter.models.lm_model import LmConfig @LmConfig.register_subclass("llama") @dataclass(frozen=True) -class LlamaConfig(HFCompatConfig): +class LlamaConfig: """Config for LlamaModel Args: + seq_len (int, optional): maximum length of the input sequence. Defaults to 2048. vocab_size (int, optional): vocabulary size of the Llama model. Defaults to 32000. hidden_dim (int, optional): dimension of the hidden state. Defaults to 4096. intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008. @@ -31,6 +33,7 @@ class LlamaConfig(HFCompatConfig): rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ + seq_len: int = 2048 vocab_size: int = 32000 hidden_dim: int = 4096 intermediate_dim: int = 11008 @@ -41,11 +44,16 @@ class LlamaConfig(HFCompatConfig): max_position_embeddings: int = 2048 initializer_range: float = 0.02 use_bias: bool = True - rope_scaling: Optional[Dict] = None + rope_scaling: Optional[dict] = None # Axis + Pos = property(lambda self: Axis(name="position", size=self.seq_len)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.hidden_dim * self.mlp_scale)) HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) @@ -62,7 +70,7 @@ class LlamaMlp(eqx.Module): @staticmethod def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = False) -> "LlamaMlp": - k_fc, k_up_proj, k_down_proj = eqx.split_key(key, 3) + k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3) gate_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias) up_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_up_proj, use_bias=use_bias) down_proj = hnn.Linear.init(Out=Embed, In=Mlp, key=k_down_proj, use_bias=use_bias) @@ -79,54 +87,6 @@ def __call__(self, x: NamedArray) -> NamedArray: return outputs -class LlamaAttention(StateDictSerializationMixin, eqx.Module): - config: LlamaConfig = eqx.static_field() - hidden_dim: int = eqx.static_field() - max_position_embeddings: int = eqx.static_field() - q_proj: hnn.Linear # projection from Embed to query - k_proj: hnn.Linear # projection from Embed to key - v_proj: hnn.Linear # projection from Embed to value - o_proj: hnn.Linear # projection from Heads to output - rotary_emb: LlamaRotaryEmbedding # rotary embedding - - @staticmethod - def init(config: LlamaConfig, *, key) -> "LlamaAttention": - use_bias = config.use_bias - Embed = config.Embed - head_dim = config.HeadSize.size - max_position_embeddings = config.max_position_embeddings - k_q, k_k, k_v, k_o = eqx.split_key(key, 4) - q_proj = hnn.Linear.init(In=Embed, Out=config.Heads * config.HeadDim, key=k_q, use_bias=use_bias) - k_proj = hnn.Linear.init(In=Embed, Out=config.KVHeads * config.HeadDim, key=k_k, use_bias=use_bias) - v_proj = hnn.Linear.init(In=Embed, Out=config.KVHeads * config.HeadDim, key=k_v, use_bias=use_bias) - o_proj = hnn.Linear.init(In=config.Heads * config.HeadDim, Out=Embed, key=k_o, use_bias=use_bias) - rotary_emb = _get_rotary_emb(config, head_dim, max_position_embeddings) - return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) - - @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], layer_idx, inference: bool = True, *, key): - q = self.q_proj(x) # TODO: rearrange and possibly rename - k = self.k_proj(x) - v = self.v_proj(x) - - cos, sin = self.rotary_emb(v, seq_len=self.config.KVHeads.size) - - q, k = _apply_rotary_pos_emb(q, k, cos, sin, position_ids) - - scale = jax.lax.rsqrt(float(self.config.HeadSize.size)) - - attn_weights = hax.dot("head_size", q, k) * scale - attn_weights = attn_weights + mask - - # upcast attention to fp32. This is default for Llama Attention - attn_weights = attn_weights.astype(jnp.float32) - - attn_weights = hnn.softmax(attn_weights, axis="key_position").astype(q.dtype) - attn_output = hax.dot("key_position", attn_weights, v) - - # TODO: continue - - class LlamaBlock(StateDictSerializationMixin, eqx.Module): pass @@ -176,6 +136,7 @@ def _set_cos_sin_cache(self, seq_len): emb = jnp.concatenate((freqs, freqs), axis=-1) cos_cached = jnp.cos(emb)[None, None, :, :] sin_cached = jnp.sin(emb)[None, None, :, :] + return cos_cached, sin_cached def __call__(self, x, seq_len: int): # x: [bs, num_attention_heads, seq_len, head_size] @@ -224,6 +185,53 @@ def _get_positional_ids(self): return jnp.arange(self.max_seq_len_cached) +class LlamaAttention(StateDictSerializationMixin, eqx.Module): + config: LlamaConfig = eqx.static_field() + q_proj: hnn.Linear # projection from Embed to query + k_proj: hnn.Linear # projection from Embed to key + v_proj: hnn.Linear # projection from Embed to value + o_proj: hnn.Linear # projection from Heads to output + rotary_emb: LlamaRotaryEmbedding # rotary embedding + + @staticmethod + def init(config: LlamaConfig, *, key) -> "LlamaAttention": + use_bias = config.use_bias + Embed = config.Embed + head_dim = config.HeadSize.size + max_position_embeddings = config.max_position_embeddings + k_q, k_k, k_v, k_o = jrandom.split(key, 4) + q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias) + k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias) + v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias) + o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias) + rotary_emb = _get_rotary_emb(config, head_dim, max_position_embeddings) + return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) + + @named_call + def __call__(self, x: NamedArray, mask: Optional[NamedArray], layer_idx, inference: bool = True, *, key): + q = self.q_proj(x) # TODO: rearrange and possibly rename + k = self.k_proj(x) + v = self.v_proj(x) + + cos, sin = self.rotary_emb(v, seq_len=self.config.KVHeads.size) + + q, k = _apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + scale = jax.lax.rsqrt(float(self.config.HeadSize.size)) + + attn_weights = hax.dot("head_size", q, k) * scale + attn_weights = attn_weights + mask + + # upcast attention to fp32. This is default for Llama Attention + attn_weights = attn_weights.astype(jnp.float32) + + attn_weights = hnn.softmax(attn_weights, axis="key_position").astype(q.dtype) + attn_output = hax.dot("key_position", attn_weights, v) + + attn_output = self.o_proj(attn_output) + return attn_output + + def _get_rotary_emb(config: LlamaConfig, head_dim: int, max_position_embeddings: int) -> LlamaRotaryEmbedding: if config.rope_scaling is None: return LlamaRotaryEmbedding(head_dim, max_position_embeddings) @@ -231,9 +239,9 @@ def _get_rotary_emb(config: LlamaConfig, head_dim: int, max_position_embeddings: scaling_type = config.rope_scaling["type"] scaling_factor = config.rope_scaling["factor"] if scaling_type == "linear": - return LlamaRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) + return LlamaLinearScalingRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) elif scaling_type == "dynamic": - return LlamaRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) + return LlamaDynamicNTKScalingRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/tests/test_llama.py b/tests/test_llama.py index 1c4de61a8..9a7edca4c 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -7,15 +7,23 @@ LlamaRotaryEmbedding as HFLlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, + LlamaAttention as HFLlamaAttention, +) +from levanter.models.llama import ( + LlamaConfig, + LlamaRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaAttention, ) -from levanter.models.llama import LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding +""" def test_llama_rotary_embedding(): - """Match against HuggingFace's implementation of LlamaRotaryEmbedding.""" - dim = 2048 - seq_len = 2048 - scaling_factor = 2.0 + llama_config = _get_llama_config() + hidden_dim = llama_config.hidden_dim + seq_len = llama_config.seq_len + scaling_factor = llama_config.rope_scaling["factor"] key = random.PRNGKey(0) device = "cpu" @@ -32,18 +40,44 @@ def test_levanter_against_hf(levanter_class, hf_class): # test LlamaRotaryEmbedding test_levanter_against_hf( - levanter_class=LlamaRotaryEmbedding(dim=dim), - hf_class=HFLlamaRotaryEmbedding(dim=dim, device=device), + levanter_class=LlamaRotaryEmbedding(dim=hidden_dim), + hf_class=HFLlamaRotaryEmbedding(dim=hidden_dim, device=device), ) # test LlamaLinearScalingRotaryEmbedding test_levanter_against_hf( - levanter_class=LlamaLinearScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor), - hf_class=HFLlamaLinearScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor, device=device), + levanter_class=LlamaLinearScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor), + hf_class=HFLlamaLinearScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor, device=device), ) # test LlamaDynamicNTKScalingRotaryEmbedding test_levanter_against_hf( - levanter_class=LlamaDynamicNTKScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor), - hf_class=HFLlamaDynamicNTKScalingRotaryEmbedding(dim=dim, scaling_factor=scaling_factor, device=device), + levanter_class=LlamaDynamicNTKScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor), + hf_class=HFLlamaDynamicNTKScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor, device=device), + ) +""" + + +def test_llama_attention(): + llama_config = _get_llama_config() + key = random.PRNGKey(4) + hf_llama_att = LlamaAttention.init(config=llama_config, key=key) + + + +def _get_llama_config() -> LlamaConfig: + vocab_size = 32000 + hidden_dim = 2048 + num_heads = 16 + num_kv_heads = 16 + rope_scaling = { + "type": "linear", + "factor": 2.0, + } + return LlamaConfig( + vocab_size=vocab_size, + hidden_dim=hidden_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + rope_scaling=rope_scaling, ) From bbcd0ecdc6aec29854e5529918c1887bd6bba8c1 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 7 Aug 2023 21:14:59 -0700 Subject: [PATCH 11/57] _apply_rotary_pos_emb --- src/levanter/models/llama.py | 49 ++++++++++++++++-------- tests/test_llama.py | 73 +++++++++++++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 21 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 0b57d3c57..bbb61f1f1 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import jax.random as jrandom +import haliax as hax import haliax.nn as hnn from haliax import Axis, NamedArray from haliax.jax_utils import named_call @@ -208,12 +209,12 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], layer_idx, inference: bool = True, *, key): + def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids): q = self.q_proj(x) # TODO: rearrange and possibly rename k = self.k_proj(x) v = self.v_proj(x) - cos, sin = self.rotary_emb(v, seq_len=self.config.KVHeads.size) + cos, sin = self.rotary_emb(v, seq_len=self.config.seq_len) q, k = _apply_rotary_pos_emb(q, k, cos, sin, position_ids) @@ -246,18 +247,34 @@ def _get_rotary_emb(config: LlamaConfig, head_dim: int, max_position_embeddings: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") -def _rotate_half(x): - """Rotates half of the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return jnp.concatenate((-x2, x1), axis=-1) - - -def _apply_rotary_pos_emb(q, k, cos, sin, position_ids) -> Tuple[jnp.ndarray, jnp.ndarray]: - cos = jnp.squeeze(cos, axis=1) # from [1, 1, seq_len, dim] to [seq_len, dim] - sin = jnp.squeeze(sin, axis=1) - cos = cos[position_ids, None] # [seq_len, dim] -> [bs, 1, seq_len, dim] - sin = sin[position_ids, None] - q_embed = q * cos + _rotate_half(q) * sin - k_embed = k * cos + _rotate_half(k) * sin +def _rotate_half(x: NamedArray) -> NamedArray: + """Rotates half of the hidden dims of the input and concatenates them. + This is difficult to Haliax, so we do it in regular array ops. + """ + x_array = x.array + x1 = x_array[..., : x_array.shape[-1] // 2] + x2 = x_array[..., x_array.shape[-1] // 2 :] + output = jnp.concatenate((-x2, x1), axis=-1) + return hax.named(output, x.axes) + + +def _apply_rotary_pos_emb( + q: NamedArray, # [batch, seq_len, heads, head_size] + k: NamedArray, # [batch, seq_len, kv_heads, head_size] + cos: jnp.ndarray, # [1, 1, seq_len, head_size] + sin: jnp.ndarray, # [1, 1, seq_len, head_size] + position_ids: jnp.ndarray # [bs, seq_len] +) -> Tuple[NamedArray, NamedArray]: + """Applies rotary position embedding to q and k. + Note that all the multiplication below are element-wise, so I don't find it + helpful to write in haliax. + """ + cos = jnp.squeeze(jnp.squeeze(cos, axis=1), axis=0) # from [1, 1, seq_len, dim] to [seq_len, dim] + sin = jnp.squeeze(jnp.squeeze(sin, axis=1), axis=0) + cos = jnp.expand_dims(cos[position_ids], axis=2) # [batch, seq_len, 1, head_size] + sin = jnp.expand_dims(sin[position_ids], axis=2) # [batch, seq_len, 1, head_size] + q_embed = (q.array * cos) + (_rotate_half(q).array * sin) # [batch, seq_len, heads, head_size] + k_embed = (k.array * cos) + (_rotate_half(k).array * sin) # [batch, seq_len, kv_heads, head_size] + q_embed = hax.named(q_embed, q.axes) + k_embed = hax.named(k_embed, k.axes) return q_embed, k_embed diff --git a/tests/test_llama.py b/tests/test_llama.py index 9a7edca4c..b763df6cc 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,4 +1,7 @@ import numpy as np + +import haliax as hax +import jax.numpy as jnp import torch from jax import random @@ -8,6 +11,8 @@ LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, LlamaAttention as HFLlamaAttention, + apply_rotary_pos_emb as hf_apply_rotary_pos_emb, + rotate_half as hf_rotate_half, ) from levanter.models.llama import ( LlamaConfig, @@ -15,6 +20,8 @@ LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding, LlamaAttention, + _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb, + _rotate_half as levanter_rotate_half, ) @@ -55,21 +62,77 @@ def test_levanter_against_hf(levanter_class, hf_class): levanter_class=LlamaDynamicNTKScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor), hf_class=HFLlamaDynamicNTKScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor, device=device), ) -""" + def test_llama_attention(): llama_config = _get_llama_config() key = random.PRNGKey(4) - hf_llama_att = LlamaAttention.init(config=llama_config, key=key) + levanter_llama_att = LlamaAttention.init(config=llama_config, key=key) + seq_len = llama_config.seq_len + + input_ids = hax.arange(llama_config.Pos, dtype=jnp.int32) + causal_mask = hax.nn.attention.causal_mask(llama_config.Pos, llama_config.KeyPos) + position_ids = random.randint(random.PRNGKey(0), (1, seq_len), 0, llama_config.Pos.size) + levanter_output = levanter_llama_att(input_ids, mask=causal_mask, position_ids=position_ids) +""" + + +def test_apply_rotary_pos_emb(): + llama_config = _get_llama_config() + + Pos = llama_config.Pos + Heads = llama_config.Heads + KVHeads = llama_config.KVHeads + HeadSize = llama_config.HeadSize + Batch = hax.Axis("batch", 2) + + # note here we switch Heads and Pos for the shape of the output tensors + q = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Heads, HeadSize)) + k = hax.random.normal(random.PRNGKey(1), (Batch, Pos, Heads, HeadSize)) + + # Check the output of _rotate_half() from levanter and hf + levanter_out_rf_q = levanter_rotate_half(q) + levanter_out_rf_k = levanter_rotate_half(k) + + q_tensor = torch.from_numpy(np.array(q.array)).transpose(1, 2) # needed for HF + k_tensor = torch.from_numpy(np.array(k.array)).transpose(1, 2) + hf_out_rf_q = hf_rotate_half(q_tensor).transpose(1, 2) # re-transpose to match levanter + hf_out_rf_k = hf_rotate_half(k_tensor).transpose(1, 2) + + _assert_equal_out(levanter_out_rf_q, hf_out_rf_q) + _assert_equal_out(levanter_out_rf_k, hf_out_rf_k) + + # Check the output of _apply_rotary_pos_emb() from levanter and hf + cos = random.normal(random.PRNGKey(2), (1, 1, Pos.size, HeadSize.size)) + sin = random.normal(random.PRNGKey(3), (1, 1, Pos.size, HeadSize.size)) + position_ids = random.randint(random.PRNGKey(4), (Batch.size, Pos.size), 0, Pos.size) + + levanter_out_rope_q, levanter_out_rope_k = levanter_apply_rotary_pos_emb(q, k, cos, sin, position_ids) + cos_tensor = torch.from_numpy(np.array(cos)) + sin_tensor = torch.from_numpy(np.array(sin)) + position_ids_tensor = torch.from_numpy(np.array(position_ids)) + + hf_out_rope_q, hf_out_rope_k = hf_apply_rotary_pos_emb( + q_tensor, k_tensor, cos_tensor, sin_tensor, position_ids_tensor + ) + hf_out_rope_q = hf_out_rope_q.transpose(1, 2) # re-transpose to match levanter + hf_out_rope_k = hf_out_rope_k.transpose(1, 2) + _assert_equal_out(levanter_out_rope_q, hf_out_rope_q) + _assert_equal_out(levanter_out_rope_k, hf_out_rope_k) + +def _assert_equal_out(hax_out, torch_out: torch.Tensor): + assert np.isclose( + torch_out.numpy(), np.array(hax_out.array), rtol=1e-2, atol=1e-2 + ).all(), f"{torch_out} != {hax_out}" def _get_llama_config() -> LlamaConfig: vocab_size = 32000 - hidden_dim = 2048 - num_heads = 16 - num_kv_heads = 16 + hidden_dim = 48 + num_heads = 8 + num_kv_heads = 8 rope_scaling = { "type": "linear", "factor": 2.0, From 442b93d2575edff63fd48dbd2b9bdd6a1ffe0561 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 7 Aug 2023 21:32:24 -0700 Subject: [PATCH 12/57] update llama and test --- src/levanter/models/llama.py | 29 ++++++++++++------------ tests/test_llama.py | 43 +++++++++++------------------------- 2 files changed, 27 insertions(+), 45 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index bbb61f1f1..2519e5bd6 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass, field -from typing import Callable, Dict, Optional, Type, Tuple +from dataclasses import dataclass +from typing import Callable, Optional, Tuple import equinox as eqx import jax @@ -11,7 +11,6 @@ from haliax import Axis, NamedArray from haliax.jax_utils import named_call -from levanter.compat.hf_checkpoints import HFCheckpointConverter, LmWithHfSerializationMixin from levanter.compat.torch_serialization import StateDictSerializationMixin from levanter.models.lm_model import LmConfig @@ -78,6 +77,7 @@ def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = False) if isinstance(activation_fn, str): activation_fn = ACT2FN[activation_fn] act = activation_fn # type: ignore + return LlamaMlp(gate_proj, up_proj, down_proj, act) @named_call def __call__(self, x: NamedArray) -> NamedArray: @@ -242,7 +242,9 @@ def _get_rotary_emb(config: LlamaConfig, head_dim: int, max_position_embeddings: if scaling_type == "linear": return LlamaLinearScalingRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) elif scaling_type == "dynamic": - return LlamaDynamicNTKScalingRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) + return LlamaDynamicNTKScalingRotaryEmbedding( + head_dim, max_position_embeddings, scaling_factor=scaling_factor + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @@ -263,18 +265,15 @@ def _apply_rotary_pos_emb( k: NamedArray, # [batch, seq_len, kv_heads, head_size] cos: jnp.ndarray, # [1, 1, seq_len, head_size] sin: jnp.ndarray, # [1, 1, seq_len, head_size] - position_ids: jnp.ndarray # [bs, seq_len] + position_ids: jnp.ndarray, # [bs, seq_len] ) -> Tuple[NamedArray, NamedArray]: - """Applies rotary position embedding to q and k. - Note that all the multiplication below are element-wise, so I don't find it - helpful to write in haliax. - """ + """Applies rotary position embedding to q and k.""" cos = jnp.squeeze(jnp.squeeze(cos, axis=1), axis=0) # from [1, 1, seq_len, dim] to [seq_len, dim] sin = jnp.squeeze(jnp.squeeze(sin, axis=1), axis=0) - cos = jnp.expand_dims(cos[position_ids], axis=2) # [batch, seq_len, 1, head_size] - sin = jnp.expand_dims(sin[position_ids], axis=2) # [batch, seq_len, 1, head_size] - q_embed = (q.array * cos) + (_rotate_half(q).array * sin) # [batch, seq_len, heads, head_size] - k_embed = (k.array * cos) + (_rotate_half(k).array * sin) # [batch, seq_len, kv_heads, head_size] - q_embed = hax.named(q_embed, q.axes) - k_embed = hax.named(k_embed, k.axes) + cos = cos[position_ids] # [batch, seq_len, head_size] + sin = sin[position_ids] # [batch, seq_len, head_size] + cos = hax.named(cos, ("batch", "position", "head_size")) + sin = hax.named(sin, ("batch", "position", "head_size")) + q_embed = hax.multiply(q, cos) + hax.multiply(_rotate_half(q), sin) + k_embed = hax.multiply(k, cos) + hax.multiply(_rotate_half(k), sin) return q_embed, k_embed diff --git a/tests/test_llama.py b/tests/test_llama.py index b763df6cc..68299410b 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,31 +1,29 @@ -import numpy as np - -import haliax as hax -import jax.numpy as jnp import torch from jax import random # The latter 2 classes are only available in HuggingFace's transformers 4.30.0 or later from transformers.models.llama.modeling_llama import ( - LlamaRotaryEmbedding as HFLlamaRotaryEmbedding, - LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, - LlamaAttention as HFLlamaAttention, - apply_rotary_pos_emb as hf_apply_rotary_pos_emb, - rotate_half as hf_rotate_half, ) +from transformers.models.llama.modeling_llama import ( + LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, +) +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as hf_apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import rotate_half as hf_rotate_half + +import haliax as hax + from levanter.models.llama import ( LlamaConfig, - LlamaRotaryEmbedding, - LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding, - LlamaAttention, - _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb, - _rotate_half as levanter_rotate_half, + LlamaLinearScalingRotaryEmbedding, + LlamaRotaryEmbedding, ) +from levanter.models.llama import _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb +from levanter.models.llama import _rotate_half as levanter_rotate_half -""" def test_llama_rotary_embedding(): llama_config = _get_llama_config() hidden_dim = llama_config.hidden_dim @@ -64,26 +62,11 @@ def test_levanter_against_hf(levanter_class, hf_class): ) - -def test_llama_attention(): - llama_config = _get_llama_config() - key = random.PRNGKey(4) - levanter_llama_att = LlamaAttention.init(config=llama_config, key=key) - seq_len = llama_config.seq_len - - input_ids = hax.arange(llama_config.Pos, dtype=jnp.int32) - causal_mask = hax.nn.attention.causal_mask(llama_config.Pos, llama_config.KeyPos) - position_ids = random.randint(random.PRNGKey(0), (1, seq_len), 0, llama_config.Pos.size) - levanter_output = levanter_llama_att(input_ids, mask=causal_mask, position_ids=position_ids) -""" - - def test_apply_rotary_pos_emb(): llama_config = _get_llama_config() Pos = llama_config.Pos Heads = llama_config.Heads - KVHeads = llama_config.KVHeads HeadSize = llama_config.HeadSize Batch = hax.Axis("batch", 2) From 40bb7b462a772e6100e81518cdaa6187a3ca8e46 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 13 Aug 2023 14:20:16 -0700 Subject: [PATCH 13/57] update llama and test --- src/levanter/models/llama.py | 56 +++++++++++++++++------------------- tests/test_llama.py | 23 +++++++++------ 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 2519e5bd6..2d354683a 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -105,20 +105,20 @@ class LlamaLMHeadModel(eqx.Module): class LlamaRotaryEmbedding(eqx.Module): - dim: int - max_position_embeddings: int = 2048 + Embed: Axis + Pos: Axis base: float = 10000 inv_freq: jnp.ndarray = eqx.static_field() cos_cached: jnp.ndarray = eqx.static_field() sin_cached: jnp.ndarray = eqx.static_field() max_seq_len_cached: int = eqx.static_field() - def __init__(self, dim, max_position_embeddings=2048, base=10000): - self.dim = dim - self.max_position_embeddings = max_position_embeddings + def __init__(self, Embed: Axis, Pos: Axis, base: int = 10000): + self.Embed = Embed + self.Pos = Pos self.base = base - self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2) / self.dim)) - self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=self.max_position_embeddings) + self.inv_freq = 1.0 / (self.base ** (hax.arange(Embed.resize(Embed.size // 2), step=2) / Embed.size)).array + self.cos_cached, self.sin_cached = self._set_cos_sin_cache(Pos.size) def _get_positional_ids(self): """A helper function for the convenience of extending to two sub-classes @@ -155,9 +155,9 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): scaling_factor: float = 1.0 - def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + def __init__(self, Embed: Axis, Pos: Axis, base=10000, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base) + super().__init__(Embed, Pos, base) def _get_positional_ids(self): """Here we overwrite the function in the base class to implement linear scaling.""" @@ -169,17 +169,17 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): scaling_factor: float = 1.0 - def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + def __init__(self, Embed, Pos, base=10000, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base) + super().__init__(Embed, Pos, base) def _get_positional_ids(self): """Here we overwrite the function in the base class. Here it adjusts the frequency base dynamically according to the sequence length. """ - if self.max_seq_len_cached > self.max_position_embeddings: + if self.max_seq_len_cached > self.Pos.size: base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + (self.scaling_factor * self.max_seq_len_cached / self.Pos.size) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) self.inv_freq = 1.0 / (base ** (jnp.arange(0, self.dim, 2) / self.dim)) @@ -198,14 +198,12 @@ class LlamaAttention(StateDictSerializationMixin, eqx.Module): def init(config: LlamaConfig, *, key) -> "LlamaAttention": use_bias = config.use_bias Embed = config.Embed - head_dim = config.HeadSize.size - max_position_embeddings = config.max_position_embeddings k_q, k_k, k_v, k_o = jrandom.split(key, 4) q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias) k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias) v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias) o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias) - rotary_emb = _get_rotary_emb(config, head_dim, max_position_embeddings) + rotary_emb = _get_rotary_emb(config) return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @named_call @@ -233,31 +231,29 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids): return attn_output -def _get_rotary_emb(config: LlamaConfig, head_dim: int, max_position_embeddings: int) -> LlamaRotaryEmbedding: +def _get_rotary_emb(config: LlamaConfig) -> LlamaRotaryEmbedding: + Embed = config.Embed + Pos = config.Pos if config.rope_scaling is None: - return LlamaRotaryEmbedding(head_dim, max_position_embeddings) + return LlamaRotaryEmbedding(Embed, Pos) else: scaling_type = config.rope_scaling["type"] scaling_factor = config.rope_scaling["factor"] if scaling_type == "linear": - return LlamaLinearScalingRotaryEmbedding(head_dim, max_position_embeddings, scaling_factor=scaling_factor) + return LlamaLinearScalingRotaryEmbedding(Embed, Pos, scaling_factor=scaling_factor) elif scaling_type == "dynamic": - return LlamaDynamicNTKScalingRotaryEmbedding( - head_dim, max_position_embeddings, scaling_factor=scaling_factor - ) + return LlamaDynamicNTKScalingRotaryEmbedding(Embed, Pos, scaling_factor=scaling_factor) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _rotate_half(x: NamedArray) -> NamedArray: - """Rotates half of the hidden dims of the input and concatenates them. - This is difficult to Haliax, so we do it in regular array ops. - """ - x_array = x.array - x1 = x_array[..., : x_array.shape[-1] // 2] - x2 = x_array[..., x_array.shape[-1] // 2 :] - output = jnp.concatenate((-x2, x1), axis=-1) - return hax.named(output, x.axes) + """Rotates half of the hidden dims of the input and concatenates them.""" + HeadSize = x.axes[-1] + x1 = x[HeadSize, : HeadSize.size // 2] + x2 = x[HeadSize, HeadSize.size // 2 :] + out = hax.concatenate(HeadSize, (-x2, x1)) + return out def _apply_rotary_pos_emb( diff --git a/tests/test_llama.py b/tests/test_llama.py index 68299410b..83eaae56d 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,3 +1,4 @@ +import numpy as np import torch from jax import random @@ -26,8 +27,10 @@ def test_llama_rotary_embedding(): llama_config = _get_llama_config() - hidden_dim = llama_config.hidden_dim - seq_len = llama_config.seq_len + Embed = llama_config.Embed + Pos = llama_config.Pos + hidden_dim = Embed.size + seq_len = Pos.size scaling_factor = llama_config.rope_scaling["factor"] key = random.PRNGKey(0) device = "cpu" @@ -45,20 +48,24 @@ def test_levanter_against_hf(levanter_class, hf_class): # test LlamaRotaryEmbedding test_levanter_against_hf( - levanter_class=LlamaRotaryEmbedding(dim=hidden_dim), - hf_class=HFLlamaRotaryEmbedding(dim=hidden_dim, device=device), + levanter_class=LlamaRotaryEmbedding(Embed=Embed, Pos=Pos), + hf_class=HFLlamaRotaryEmbedding(dim=hidden_dim, max_position_embeddings=seq_len, device=device), ) # test LlamaLinearScalingRotaryEmbedding test_levanter_against_hf( - levanter_class=LlamaLinearScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor), - hf_class=HFLlamaLinearScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor, device=device), + levanter_class=LlamaLinearScalingRotaryEmbedding(Embed=Embed, Pos=Pos, scaling_factor=scaling_factor), + hf_class=HFLlamaLinearScalingRotaryEmbedding( + dim=hidden_dim, max_position_embeddings=seq_len, scaling_factor=scaling_factor, device=device + ), ) # test LlamaDynamicNTKScalingRotaryEmbedding test_levanter_against_hf( - levanter_class=LlamaDynamicNTKScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor), - hf_class=HFLlamaDynamicNTKScalingRotaryEmbedding(dim=hidden_dim, scaling_factor=scaling_factor, device=device), + levanter_class=LlamaDynamicNTKScalingRotaryEmbedding(Embed=Embed, Pos=Pos, scaling_factor=scaling_factor), + hf_class=HFLlamaDynamicNTKScalingRotaryEmbedding( + dim=hidden_dim, max_position_embeddings=seq_len, scaling_factor=scaling_factor, device=device + ), ) From 466c1dbedca05d352555cb99a42a7674d9572a88 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 13 Aug 2023 15:48:52 -0700 Subject: [PATCH 14/57] Finish Llama Attention --- src/levanter/models/llama.py | 27 +++++++++++----- tests/test_llama.py | 61 +++++++++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 2d354683a..d3d64588f 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -200,8 +200,9 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": Embed = config.Embed k_q, k_k, k_v, k_o = jrandom.split(key, 4) q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias) - k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias) - v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias) + # TODO: double check if we should use Heads or KV_HEADS here + k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias) + v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias) o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias) rotary_emb = _get_rotary_emb(config) return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @@ -217,14 +218,23 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids): q, k = _apply_rotary_pos_emb(q, k, cos, sin, position_ids) scale = jax.lax.rsqrt(float(self.config.HeadSize.size)) + + # do this first to help keep FP values small + q = q * scale - attn_weights = hax.dot("head_size", q, k) * scale - attn_weights = attn_weights + mask + q = q.astype(jnp.float32) + k = k.astype(jnp.float32) + k = k.rename({"position": "key_position"}) - # upcast attention to fp32. This is default for Llama Attention - attn_weights = attn_weights.astype(jnp.float32) + attn_scores = hax.dot("head_size", q, k) + + if mask is not None: + attn_scores = attn_scores + (1.0 - mask) * -1e9 + + attn_scores = attn_scores.astype(jnp.float32) + attn_weights = hnn.softmax(attn_scores, axis="key_position").astype(x.dtype) + # There's no dropout in llama attention, compared with Gpt2 attention - attn_weights = hnn.softmax(attn_weights, axis="key_position").astype(q.dtype) attn_output = hax.dot("key_position", attn_weights, v) attn_output = self.o_proj(attn_output) @@ -232,7 +242,8 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids): def _get_rotary_emb(config: LlamaConfig) -> LlamaRotaryEmbedding: - Embed = config.Embed + # Note that the embedding here is HeadSize, not the full Embed + Embed = config.HeadSize Pos = config.Pos if config.rope_scaling is None: return LlamaRotaryEmbedding(Embed, Pos) diff --git a/tests/test_llama.py b/tests/test_llama.py index 83eaae56d..c97e97f3c 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -3,6 +3,8 @@ from jax import random # The latter 2 classes are only available in HuggingFace's transformers 4.30.0 or later +from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention +from transformers.models.llama.configuration_llama import LlamaConfig as HFLlamaConfig from transformers.models.llama.modeling_llama import ( LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, ) @@ -16,6 +18,7 @@ import haliax as hax from levanter.models.llama import ( + LlamaAttention, LlamaConfig, LlamaDynamicNTKScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, @@ -112,6 +115,38 @@ def test_apply_rotary_pos_emb(): _assert_equal_out(levanter_out_rope_k, hf_out_rope_k) +def test_llama_attention(): + config = _get_llama_config() + Embed = config.Embed + Pos = config.Pos + Heads = config.Heads + HeadSize = config.HeadSize + Batch = hax.Axis("batch", 2) + x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) + mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) + position_ids = random.randint(random.PRNGKey(2), (Batch.size, Pos.size), 0, Pos.size) + # generate a random key that can be splitted into 4 + key = random.PRNGKey(4) + + levanter_attention = LlamaAttention.init(config=config, key=key) + levanter_out = levanter_attention(x, mask, position_ids) + + hf_config = _levanter_config_to_hf_config(config) + hf_attention = HFLlamaAttention(config=hf_config) # (seq_len, kv_seq_len) + # convert attention_mask's shape from (seq_len, kv_seq_len) to (batch, 1, seq_len, kv_seq_len) + attention_mask = _hax_to_tensor(mask) + attention_mask = attention_mask.reshape(1, 1, config.Pos.size, config.KeyPos.size).repeat(Batch.size, 1, 1, 1) + + hf_out, _, _ = hf_attention( + hidden_states=_hax_to_tensor(x), + attention_mask=attention_mask, + position_ids=torch.from_numpy(np.array(position_ids)), + ) + + # assert the same shape + assert levanter_out.array.shape == hf_out.shape, f"{levanter_out.shape} != {hf_out.shape}" + + def _assert_equal_out(hax_out, torch_out: torch.Tensor): assert np.isclose( torch_out.numpy(), np.array(hax_out.array), rtol=1e-2, atol=1e-2 @@ -119,18 +154,36 @@ def _assert_equal_out(hax_out, torch_out: torch.Tensor): def _get_llama_config() -> LlamaConfig: - vocab_size = 32000 - hidden_dim = 48 - num_heads = 8 - num_kv_heads = 8 + vocab_size = 1000 + seq_len = 128 + hidden_dim = 16 + num_heads = 4 + num_kv_heads = 4 rope_scaling = { "type": "linear", "factor": 2.0, } return LlamaConfig( + seq_len=seq_len, vocab_size=vocab_size, hidden_dim=hidden_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, rope_scaling=rope_scaling, + max_position_embeddings=seq_len, ) + + +def _levanter_config_to_hf_config(levanter_config: LlamaConfig) -> HFLlamaConfig: + return HFLlamaConfig( + vocab_size=levanter_config.vocab_size, + max_position_embeddings=levanter_config.seq_len, + hidden_size=levanter_config.hidden_dim, + num_attention_heads=levanter_config.num_heads, + num_key_value_heads=levanter_config.num_kv_heads, + rope_scaling=levanter_config.rope_scaling, + ) + + +def _hax_to_tensor(x: hax.NamedArray) -> torch.Tensor: + return torch.from_numpy(np.array(x.array)) From f9db04953b0eb93664e4c458b64670342e300484 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 13 Aug 2023 17:50:41 -0700 Subject: [PATCH 15/57] Finish implementing LlamaLMHeadModel --- src/levanter/models/llama.py | 151 +++++++++++++++++++++++++++++++---- 1 file changed, 134 insertions(+), 17 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index d3d64588f..7a69293a3 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -10,6 +10,7 @@ import haliax.nn as hnn from haliax import Axis, NamedArray from haliax.jax_utils import named_call +from haliax.nn.scan import Stacked from levanter.compat.torch_serialization import StateDictSerializationMixin from levanter.models.lm_model import LmConfig @@ -43,6 +44,11 @@ class LlamaConfig: activation_function: str = "silu" max_position_embeddings: int = 2048 initializer_range: float = 0.02 + layer_norm_epsilon: float = 1e-5 + + gradient_checkpointing: bool = True + gradient_checkpointing_block_size: int = 5 + use_bias: bool = True rope_scaling: Optional[dict] = None @@ -88,22 +94,6 @@ def __call__(self, x: NamedArray) -> NamedArray: return outputs -class LlamaBlock(StateDictSerializationMixin, eqx.Module): - pass - - -class LlamaTransformer(StateDictSerializationMixin, eqx.Module): - pass - - -class LlamaEmbeddings(StateDictSerializationMixin, eqx.Module): - pass - - -class LlamaLMHeadModel(eqx.Module): - pass - - class LlamaRotaryEmbedding(eqx.Module): Embed: Axis Pos: Axis @@ -208,7 +198,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids): + def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *): q = self.q_proj(x) # TODO: rearrange and possibly rename k = self.k_proj(x) v = self.v_proj(x) @@ -241,6 +231,133 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids): return attn_output +class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): + config: LlamaConfig = eqx.static_field() + attn: LlamaAttention + mlp: LlamaMLP + ln_1: hnn.LayerNorm # input layernorm + ln_2: hnn.LayerNorm # post attention layernorm + + @staticmethod + def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": + k_attn, k_mlp = jrandom.split(key, 2) + attn = LlamaAttention.init(config, key=k_attn) + mlp = LlamaMLP.init(config, key=key) + ln_1 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=k_attn) + ln_2 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=k_attn) + + return LlamaDecoderLayer(config, attn, mlp, ln_1, ln_2) + + @named_call + def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *): + residual = x + x = self.ln_1(x) + + # self attention and skip connection + attn_output = self.attn(x=x, mask=mask, position_ids=position_ids) + x = residual + attn_output + + # MLP and skip connection + residual = x + x = self.ln_2(x) + mlp_output = self.mlp(x) + output = residual + mlp_output + return output + + +class LlamaTransformer(StateDictSerializationMixin, eqx.Module): + config: LlamaConfig = eqx.static_field() + layers: Stacked[LlamaDecoderLayer] + ln_f: hnn.LayerNorm + + @staticmethod + def init(config: LlamaConfig, *, key) -> "LlamaTransformer": + layers = Stacked.init(config.Layers, LlamaDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, shaped_rng_split(key, config.num_layers), + ) + ln_f = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=key) + + return LlamaTransformer(config, layers, ln_f) + + @named_call + def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray], *) -> NamedArray: + x = self.layers.fold(x, attn_mask=attn_mask, hax.arange(self.config.Layers)) + x = self.ln_f(x) + + return x + + +class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): + """Similar to GPT2 Embedding but without dropout""" + Vocab: Axis = eqx.static_field() + config: LlamaConfig = eqx.static_field() + + token_embeddings: NamedArray + position_embeddings: NamedArray + + @staticmethod + def init(Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaEmbedding": + k_wte, k_wpe = jrandom.split(key, 2) + + token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) + position_embeddings = hax.random.normal(k_wpe, (config.Pos, config.Embed)) * (config.initializer_range / 2) + + return LlamaEmbedding(Vocab, config, token_embeddings, position_embeddings) + + @named_call + def embed(self, input_ids, *): + input_embeds = self.token_embeddings.take("vocab", input_ids) + position_embeds = self.position_embeddings + + x = input_embeds + position_embeds + + return x + + def unembed(self, x: NamedArray): + return hax.dot("embed", x, self.token_embeddings) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"token_embeddings": "wte.weight", "position_embeddings": "wpe.weight"} + + +class LlamaLMHeadModel(StateDictSerializationMixin, eqx.Module): + transformer: LlamaTransformer + embeddings: LlamaEmbedding + + @property + def config(self): + return self.transformer.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + @property + def Pos(self) -> Axis: + return self.config.Pos + + @classmethod + def init(cls, Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2LMHeadModel": + k_t, k_embeddings = jrandom.split(key, 2) + transformer = LlamaTransformer.init(config, key=k_t) + embeddings = LlamaEmbedding.init(Vocab, config, key=k_embeddings) + + return LlamaLMHeadModel(transformer, embeddings) + + def __call__( + self, input_ids: NamedArray, attn_mask: Optional[NamedArray], position_ids, * + ) -> NamedArray: + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, position_ids=position_ids) + lm_logits = self.embeddings.unembed(x) + + return lm_logits + + def _get_rotary_emb(config: LlamaConfig) -> LlamaRotaryEmbedding: # Note that the embedding here is HeadSize, not the full Embed Embed = config.HeadSize From 6781a8ce9d288c6ef35feff6e227b8aedb0af013 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 13 Aug 2023 17:54:07 -0700 Subject: [PATCH 16/57] fix build --- src/levanter/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 7a69293a3..e4518404d 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -14,6 +14,7 @@ from levanter.compat.torch_serialization import StateDictSerializationMixin from levanter.models.lm_model import LmConfig +from levanter.models.gpt2 import ACT2FN @LmConfig.register_subclass("llama") @@ -208,7 +209,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *): q, k = _apply_rotary_pos_emb(q, k, cos, sin, position_ids) scale = jax.lax.rsqrt(float(self.config.HeadSize.size)) - + # do this first to help keep FP values small q = q * scale From abb70fa9da64d9339e87e15470c5dd51fb1e8d83 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 13 Aug 2023 18:06:01 -0700 Subject: [PATCH 17/57] fix build --- src/levanter/models/llama.py | 20 ++++++++++---------- tests/test_llama.py | 6 +----- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index e4518404d..e34d6d585 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -13,8 +13,8 @@ from haliax.nn.scan import Stacked from levanter.compat.torch_serialization import StateDictSerializationMixin -from levanter.models.lm_model import LmConfig from levanter.models.gpt2 import ACT2FN +from levanter.models.lm_model import LmConfig @LmConfig.register_subclass("llama") @@ -199,7 +199,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *): + def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *args): q = self.q_proj(x) # TODO: rearrange and possibly rename k = self.k_proj(x) v = self.v_proj(x) @@ -250,7 +250,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": return LlamaDecoderLayer(config, attn, mlp, ln_1, ln_2) @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *): + def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *args): residual = x x = self.ln_1(x) @@ -274,15 +274,16 @@ class LlamaTransformer(StateDictSerializationMixin, eqx.Module): @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaTransformer": layers = Stacked.init(config.Layers, LlamaDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( - config, shaped_rng_split(key, config.num_layers), + config, + shaped_rng_split(key, config.num_layers), ) ln_f = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=key) return LlamaTransformer(config, layers, ln_f) @named_call - def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray], *) -> NamedArray: - x = self.layers.fold(x, attn_mask=attn_mask, hax.arange(self.config.Layers)) + def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray], *args) -> NamedArray: + x = self.layers.fold(x, attn_mask=attn_mask, position_ids=hax.arange(self.config.Layers)) x = self.ln_f(x) return x @@ -290,6 +291,7 @@ def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray], *) -> NamedAr class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): """Similar to GPT2 Embedding but without dropout""" + Vocab: Axis = eqx.static_field() config: LlamaConfig = eqx.static_field() @@ -306,7 +308,7 @@ def init(Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaEmbedding": return LlamaEmbedding(Vocab, config, token_embeddings, position_embeddings) @named_call - def embed(self, input_ids, *): + def embed(self, input_ids, *args): input_embeds = self.token_embeddings.take("vocab", input_ids) position_embeds = self.position_embeddings @@ -349,9 +351,7 @@ def init(cls, Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2LMHeadModel": return LlamaLMHeadModel(transformer, embeddings) - def __call__( - self, input_ids: NamedArray, attn_mask: Optional[NamedArray], position_ids, * - ) -> NamedArray: + def __call__(self, input_ids: NamedArray, attn_mask: Optional[NamedArray], position_ids, *args) -> NamedArray: x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, position_ids=position_ids) lm_logits = self.embeddings.unembed(x) diff --git a/tests/test_llama.py b/tests/test_llama.py index c97e97f3c..61671991d 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,10 +1,8 @@ import numpy as np import torch from jax import random - -# The latter 2 classes are only available in HuggingFace's transformers 4.30.0 or later -from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention from transformers.models.llama.configuration_llama import LlamaConfig as HFLlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention from transformers.models.llama.modeling_llama import ( LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, ) @@ -119,8 +117,6 @@ def test_llama_attention(): config = _get_llama_config() Embed = config.Embed Pos = config.Pos - Heads = config.Heads - HeadSize = config.HeadSize Batch = hax.Axis("batch", 2) x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) From f21cd1cee13aeb68379adc271037923da0e7fa01 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Tue, 15 Aug 2023 19:54:18 -0700 Subject: [PATCH 18/57] remove max_position_embeddings --- src/levanter/models/llama.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index e34d6d585..475e78a71 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import equinox as eqx import jax @@ -9,7 +9,7 @@ import haliax as hax import haliax.nn as hnn from haliax import Axis, NamedArray -from haliax.jax_utils import named_call +from haliax.jax_utils import named_call, shaped_rng_split from haliax.nn.scan import Stacked from levanter.compat.torch_serialization import StateDictSerializationMixin @@ -31,7 +31,6 @@ class LlamaConfig: num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. num_kv_heads (int, optional): number of key/value heads needed for Grouped Query Attention. Defaults to 32. activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". - max_position_embeddings (int, optional): maximum length of the position embedding. Defaults to 2048. rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ @@ -43,7 +42,6 @@ class LlamaConfig: num_heads: int = 32 num_kv_heads: int = 32 activation_function: str = "silu" - max_position_embeddings: int = 2048 initializer_range: float = 0.02 layer_norm_epsilon: float = 1e-5 @@ -235,7 +233,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *arg class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() attn: LlamaAttention - mlp: LlamaMLP + mlp: LlamaMlp ln_1: hnn.LayerNorm # input layernorm ln_2: hnn.LayerNorm # post attention layernorm @@ -243,7 +241,7 @@ class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": k_attn, k_mlp = jrandom.split(key, 2) attn = LlamaAttention.init(config, key=k_attn) - mlp = LlamaMLP.init(config, key=key) + mlp = LlamaMlp.init(config, key=key) ln_1 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=k_attn) ln_2 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=k_attn) @@ -344,7 +342,7 @@ def Pos(self) -> Axis: return self.config.Pos @classmethod - def init(cls, Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2LMHeadModel": + def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": k_t, k_embeddings = jrandom.split(key, 2) transformer = LlamaTransformer.init(config, key=k_t) embeddings = LlamaEmbedding.init(Vocab, config, key=k_embeddings) From 6e222b7297250126cc17a8ab5c834c12ded592bf Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Tue, 15 Aug 2023 21:00:58 -0700 Subject: [PATCH 19/57] Fix issues found from testing --- src/levanter/models/llama.py | 17 ++++++++------ tests/test_llama.py | 43 +++++++++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 475e78a71..5eaa279ac 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -58,7 +58,7 @@ class LlamaConfig: Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) - Mlp = property(lambda self: Axis(name="mlp", size=self.hidden_dim * self.mlp_scale)) + Mlp = property(lambda self: Axis(name="mlp", size=self.hidden_dim)) # TODO: shall we multiply with mlp_scale? HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) @@ -240,10 +240,11 @@ class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": k_attn, k_mlp = jrandom.split(key, 2) + attn = LlamaAttention.init(config, key=k_attn) - mlp = LlamaMlp.init(config, key=key) - ln_1 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=k_attn) - ln_2 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=k_attn) + mlp = LlamaMlp.init(config.Embed, config.Mlp, config.activation_function, key=k_mlp, use_bias=config.use_bias) + ln_1 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + ln_2 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) return LlamaDecoderLayer(config, attn, mlp, ln_1, ln_2) @@ -271,17 +272,19 @@ class LlamaTransformer(StateDictSerializationMixin, eqx.Module): @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaTransformer": + # TODO: here it reports an error that is related to _get_rotary_emb() in LlamaAttention + # TypeError: Output from batched function Axis(name='head_size', size=4) with type is not a valid JAX type layers = Stacked.init(config.Layers, LlamaDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( config, - shaped_rng_split(key, config.num_layers), + key=shaped_rng_split(key, config.num_layers), ) - ln_f = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias, key=key) + ln_f = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) return LlamaTransformer(config, layers, ln_f) @named_call def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray], *args) -> NamedArray: - x = self.layers.fold(x, attn_mask=attn_mask, position_ids=hax.arange(self.config.Layers)) + x = self.layers.fold(x, mask=attn_mask, position_ids=hax.arange(self.config.Layers)) x = self.ln_f(x) return x diff --git a/tests/test_llama.py b/tests/test_llama.py index 61671991d..3b290a3ab 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -18,8 +18,10 @@ from levanter.models.llama import ( LlamaAttention, LlamaConfig, + LlamaDecoderLayer, LlamaDynamicNTKScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, + LlamaLMHeadModel, LlamaRotaryEmbedding, ) from levanter.models.llama import _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb @@ -115,12 +117,7 @@ def test_apply_rotary_pos_emb(): def test_llama_attention(): config = _get_llama_config() - Embed = config.Embed - Pos = config.Pos - Batch = hax.Axis("batch", 2) - x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) - mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) - position_ids = random.randint(random.PRNGKey(2), (Batch.size, Pos.size), 0, Pos.size) + x, mask, position_ids = _get_random_inputs(config) # generate a random key that can be splitted into 4 key = random.PRNGKey(4) @@ -131,7 +128,7 @@ def test_llama_attention(): hf_attention = HFLlamaAttention(config=hf_config) # (seq_len, kv_seq_len) # convert attention_mask's shape from (seq_len, kv_seq_len) to (batch, 1, seq_len, kv_seq_len) attention_mask = _hax_to_tensor(mask) - attention_mask = attention_mask.reshape(1, 1, config.Pos.size, config.KeyPos.size).repeat(Batch.size, 1, 1, 1) + attention_mask = attention_mask.reshape(1, 1, config.Pos.size, config.KeyPos.size).repeat(x.axes[0].size, 1, 1, 1) hf_out, _, _ = hf_attention( hidden_states=_hax_to_tensor(x), @@ -143,6 +140,27 @@ def test_llama_attention(): assert levanter_out.array.shape == hf_out.shape, f"{levanter_out.shape} != {hf_out.shape}" +def test_llama_decoder_layer(): + llama_config = _get_llama_config() + key = random.PRNGKey(0) + llama_decoder_layer = LlamaDecoderLayer.init(config=llama_config, key=key) + x, mask, position_ids = _get_random_inputs(llama_config) + levanter_out = llama_decoder_layer(x, mask, position_ids) + assert levanter_out.array.shape == (x.axes[0].size, llama_config.seq_len, llama_config.hidden_dim) + + +def test_llama_lm_head_model(): + llama_config = _get_llama_config() + Vocab = hax.Axis("vocab", llama_config.vocab_size) + # generate a key that can be splitted into 2 + llama_model = LlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + # generate a random input + x, mask, position_ids = _get_random_inputs(llama_config) + + levanter_out = llama_model(x, mask, position_ids) + assert levanter_out.array.shape == (Batch.size, Pos.size, llama_config.Vocab.size) + + def _assert_equal_out(hax_out, torch_out: torch.Tensor): assert np.isclose( torch_out.numpy(), np.array(hax_out.array), rtol=1e-2, atol=1e-2 @@ -166,10 +184,19 @@ def _get_llama_config() -> LlamaConfig: num_heads=num_heads, num_kv_heads=num_kv_heads, rope_scaling=rope_scaling, - max_position_embeddings=seq_len, ) +def _get_random_inputs(config: LlamaConfig): + Embed = config.Embed + Pos = config.Pos + Batch = hax.Axis("batch", 2) + x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) + mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) + position_ids = random.randint(random.PRNGKey(2), (Batch.size, Pos.size), 0, Pos.size) + return x, mask, position_ids + + def _levanter_config_to_hf_config(levanter_config: LlamaConfig) -> HFLlamaConfig: return HFLlamaConfig( vocab_size=levanter_config.vocab_size, From 182327e6d49a6ba92ef3ef4482e3aacd4c9ef1d6 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sat, 19 Aug 2023 17:46:17 -0700 Subject: [PATCH 20/57] Fix issues found from end-to-end tests --- src/levanter/models/llama.py | 49 ++++++++++++++++-------- tests/test_llama.py | 73 +++++++++++++++--------------------- 2 files changed, 65 insertions(+), 57 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 5eaa279ac..13d8469bc 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -94,8 +94,8 @@ def __call__(self, x: NamedArray) -> NamedArray: class LlamaRotaryEmbedding(eqx.Module): - Embed: Axis - Pos: Axis + Embed: Axis = eqx.static_field() + Pos: Axis = eqx.static_field() base: float = 10000 inv_freq: jnp.ndarray = eqx.static_field() cos_cached: jnp.ndarray = eqx.static_field() @@ -197,7 +197,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *args): + def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids: NamedArray, *args): q = self.q_proj(x) # TODO: rearrange and possibly rename k = self.k_proj(x) v = self.v_proj(x) @@ -249,7 +249,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": return LlamaDecoderLayer(config, attn, mlp, ln_1, ln_2) @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids, *args): + def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids: NamedArray, *args): residual = x x = self.ln_1(x) @@ -283,8 +283,8 @@ def init(config: LlamaConfig, *, key) -> "LlamaTransformer": return LlamaTransformer(config, layers, ln_f) @named_call - def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray], *args) -> NamedArray: - x = self.layers.fold(x, mask=attn_mask, position_ids=hax.arange(self.config.Layers)) + def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray], position_ids: NamedArray, *args) -> NamedArray: + x = self.layers.fold(x, mask=attn_mask, position_ids=position_ids) x = self.ln_f(x) return x @@ -352,7 +352,24 @@ def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": return LlamaLMHeadModel(transformer, embeddings) - def __call__(self, input_ids: NamedArray, attn_mask: Optional[NamedArray], position_ids, *args) -> NamedArray: + def __call__( + self, + input_ids: NamedArray, + attn_mask: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + *args, + ) -> NamedArray: + """ + Args: + input_ids (NamedArray): [batch, position] + Indices of input sequence tokens in the vocabulary. + attn_mask (NamedArray, optional): [batch, position, seq_len] + Mask to avoid performing attention on the padding token indices of the encoder input. + position_ids (NamedArray, optional): [batch, position] + Indices of positions of each input sequence tokens in the position embeddings. + """ + if position_ids is None: + position_ids = hax.arange(self.Pos).broadcast_axis(input_ids.axes[0]) x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, position_ids=position_ids) lm_logits = self.embeddings.unembed(x) @@ -387,17 +404,19 @@ def _rotate_half(x: NamedArray) -> NamedArray: def _apply_rotary_pos_emb( - q: NamedArray, # [batch, seq_len, heads, head_size] - k: NamedArray, # [batch, seq_len, kv_heads, head_size] - cos: jnp.ndarray, # [1, 1, seq_len, head_size] - sin: jnp.ndarray, # [1, 1, seq_len, head_size] - position_ids: jnp.ndarray, # [bs, seq_len] + q: NamedArray, # [batch, position, heads, head_size] + k: NamedArray, # [batch, position, kv_heads, head_size] + cos: jnp.ndarray, # [1, 1, position, head_size] + sin: jnp.ndarray, # [1, 1, position, head_size] + position_ids: NamedArray, # [bs, position] ) -> Tuple[NamedArray, NamedArray]: """Applies rotary position embedding to q and k.""" - cos = jnp.squeeze(jnp.squeeze(cos, axis=1), axis=0) # from [1, 1, seq_len, dim] to [seq_len, dim] + cos = jnp.squeeze(jnp.squeeze(cos, axis=1), axis=0) # from [1, 1, position, dim] to [position, dim] sin = jnp.squeeze(jnp.squeeze(sin, axis=1), axis=0) - cos = cos[position_ids] # [batch, seq_len, head_size] - sin = sin[position_ids] # [batch, seq_len, head_size] + # TODO: use NamedArray instead of array + position_ids = position_ids.array # [batch, position] + cos = cos[position_ids] # [batch, position, head_size] + sin = sin[position_ids] # [batch, position, head_size] cos = hax.named(cos, ("batch", "position", "head_size")) sin = hax.named(sin, ("batch", "position", "head_size")) q_embed = hax.multiply(q, cos) + hax.multiply(_rotate_half(q), sin) diff --git a/tests/test_llama.py b/tests/test_llama.py index 3b290a3ab..c1d81494f 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,8 +1,6 @@ import numpy as np -import torch from jax import random from transformers.models.llama.configuration_llama import LlamaConfig as HFLlamaConfig -from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention from transformers.models.llama.modeling_llama import ( LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, ) @@ -26,9 +24,13 @@ ) from levanter.models.llama import _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb from levanter.models.llama import _rotate_half as levanter_rotate_half +from test_utils import skip_if_no_torch +@skip_if_no_torch def test_llama_rotary_embedding(): + import torch + llama_config = _get_llama_config() Embed = llama_config.Embed Pos = llama_config.Pos @@ -72,7 +74,15 @@ def test_levanter_against_hf(levanter_class, hf_class): ) +@skip_if_no_torch def test_apply_rotary_pos_emb(): + import torch + + def assert_equal_out(hax_out, torch_out: torch.Tensor): + assert np.isclose( + torch_out.numpy(), np.array(hax_out.array), rtol=1e-2, atol=1e-2 + ).all(), f"{torch_out} != {hax_out}" + llama_config = _get_llama_config() Pos = llama_config.Pos @@ -93,26 +103,26 @@ def test_apply_rotary_pos_emb(): hf_out_rf_q = hf_rotate_half(q_tensor).transpose(1, 2) # re-transpose to match levanter hf_out_rf_k = hf_rotate_half(k_tensor).transpose(1, 2) - _assert_equal_out(levanter_out_rf_q, hf_out_rf_q) - _assert_equal_out(levanter_out_rf_k, hf_out_rf_k) + assert_equal_out(levanter_out_rf_q, hf_out_rf_q) + assert_equal_out(levanter_out_rf_k, hf_out_rf_k) # Check the output of _apply_rotary_pos_emb() from levanter and hf cos = random.normal(random.PRNGKey(2), (1, 1, Pos.size, HeadSize.size)) sin = random.normal(random.PRNGKey(3), (1, 1, Pos.size, HeadSize.size)) - position_ids = random.randint(random.PRNGKey(4), (Batch.size, Pos.size), 0, Pos.size) + position_ids = hax.arange(Pos).broadcast_axis(Batch) levanter_out_rope_q, levanter_out_rope_k = levanter_apply_rotary_pos_emb(q, k, cos, sin, position_ids) cos_tensor = torch.from_numpy(np.array(cos)) sin_tensor = torch.from_numpy(np.array(sin)) - position_ids_tensor = torch.from_numpy(np.array(position_ids)) + position_ids_tensor = torch.from_numpy(np.array(position_ids.array)) hf_out_rope_q, hf_out_rope_k = hf_apply_rotary_pos_emb( q_tensor, k_tensor, cos_tensor, sin_tensor, position_ids_tensor ) hf_out_rope_q = hf_out_rope_q.transpose(1, 2) # re-transpose to match levanter hf_out_rope_k = hf_out_rope_k.transpose(1, 2) - _assert_equal_out(levanter_out_rope_q, hf_out_rope_q) - _assert_equal_out(levanter_out_rope_k, hf_out_rope_k) + assert_equal_out(levanter_out_rope_q, hf_out_rope_q) + assert_equal_out(levanter_out_rope_k, hf_out_rope_k) def test_llama_attention(): @@ -121,23 +131,11 @@ def test_llama_attention(): # generate a random key that can be splitted into 4 key = random.PRNGKey(4) - levanter_attention = LlamaAttention.init(config=config, key=key) - levanter_out = levanter_attention(x, mask, position_ids) - - hf_config = _levanter_config_to_hf_config(config) - hf_attention = HFLlamaAttention(config=hf_config) # (seq_len, kv_seq_len) - # convert attention_mask's shape from (seq_len, kv_seq_len) to (batch, 1, seq_len, kv_seq_len) - attention_mask = _hax_to_tensor(mask) - attention_mask = attention_mask.reshape(1, 1, config.Pos.size, config.KeyPos.size).repeat(x.axes[0].size, 1, 1, 1) - - hf_out, _, _ = hf_attention( - hidden_states=_hax_to_tensor(x), - attention_mask=attention_mask, - position_ids=torch.from_numpy(np.array(position_ids)), - ) + attention = LlamaAttention.init(config=config, key=key) + out = attention(x, mask, position_ids) # assert the same shape - assert levanter_out.array.shape == hf_out.shape, f"{levanter_out.shape} != {hf_out.shape}" + assert out.array.shape == (x.axes[0].size, config.seq_len, config.hidden_dim) def test_llama_decoder_layer(): @@ -145,26 +143,21 @@ def test_llama_decoder_layer(): key = random.PRNGKey(0) llama_decoder_layer = LlamaDecoderLayer.init(config=llama_config, key=key) x, mask, position_ids = _get_random_inputs(llama_config) - levanter_out = llama_decoder_layer(x, mask, position_ids) - assert levanter_out.array.shape == (x.axes[0].size, llama_config.seq_len, llama_config.hidden_dim) + out = llama_decoder_layer(x, mask, position_ids) + assert out.array.shape == (x.axes[0].size, llama_config.seq_len, llama_config.hidden_dim) def test_llama_lm_head_model(): llama_config = _get_llama_config() + Batch = hax.Axis("batch", 2) Vocab = hax.Axis("vocab", llama_config.vocab_size) - # generate a key that can be splitted into 2 - llama_model = LlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) - # generate a random input - x, mask, position_ids = _get_random_inputs(llama_config) - - levanter_out = llama_model(x, mask, position_ids) - assert levanter_out.array.shape == (Batch.size, Pos.size, llama_config.Vocab.size) - + Pos = llama_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, llama_config.vocab_size) + mask = hax.nn.attention.causal_mask(Pos, llama_config.KeyPos) -def _assert_equal_out(hax_out, torch_out: torch.Tensor): - assert np.isclose( - torch_out.numpy(), np.array(hax_out.array), rtol=1e-2, atol=1e-2 - ).all(), f"{torch_out} != {hax_out}" + llama_model = LlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + out = llama_model(input_ids, mask) + assert out.array.shape == (Batch.size, Pos.size, Vocab.size) def _get_llama_config() -> LlamaConfig: @@ -193,7 +186,7 @@ def _get_random_inputs(config: LlamaConfig): Batch = hax.Axis("batch", 2) x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) - position_ids = random.randint(random.PRNGKey(2), (Batch.size, Pos.size), 0, Pos.size) + position_ids = hax.arange(Pos).broadcast_axis(Batch) return x, mask, position_ids @@ -206,7 +199,3 @@ def _levanter_config_to_hf_config(levanter_config: LlamaConfig) -> HFLlamaConfig num_key_value_heads=levanter_config.num_kv_heads, rope_scaling=levanter_config.rope_scaling, ) - - -def _hax_to_tensor(x: hax.NamedArray) -> torch.Tensor: - return torch.from_numpy(np.array(x.array)) From 46939d51c3cb8a108e2c14032ea3b9fc1ef2df89 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sat, 19 Aug 2023 19:49:29 -0700 Subject: [PATCH 21/57] Fix torch import issue --- tests/test_llama.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/tests/test_llama.py b/tests/test_llama.py index c1d81494f..db51ac9e0 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,15 +1,5 @@ import numpy as np from jax import random -from transformers.models.llama.configuration_llama import LlamaConfig as HFLlamaConfig -from transformers.models.llama.modeling_llama import ( - LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, -) -from transformers.models.llama.modeling_llama import ( - LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, -) -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as hf_apply_rotary_pos_emb -from transformers.models.llama.modeling_llama import rotate_half as hf_rotate_half import haliax as hax @@ -30,6 +20,13 @@ @skip_if_no_torch def test_llama_rotary_embedding(): import torch + from transformers.models.llama.modeling_llama import ( + LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, + ) + from transformers.models.llama.modeling_llama import ( + LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, + ) + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding llama_config = _get_llama_config() Embed = llama_config.Embed @@ -77,6 +74,8 @@ def test_levanter_against_hf(levanter_class, hf_class): @skip_if_no_torch def test_apply_rotary_pos_emb(): import torch + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as hf_apply_rotary_pos_emb + from transformers.models.llama.modeling_llama import rotate_half as hf_rotate_half def assert_equal_out(hax_out, torch_out: torch.Tensor): assert np.isclose( @@ -188,14 +187,3 @@ def _get_random_inputs(config: LlamaConfig): mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) position_ids = hax.arange(Pos).broadcast_axis(Batch) return x, mask, position_ids - - -def _levanter_config_to_hf_config(levanter_config: LlamaConfig) -> HFLlamaConfig: - return HFLlamaConfig( - vocab_size=levanter_config.vocab_size, - max_position_embeddings=levanter_config.seq_len, - hidden_size=levanter_config.hidden_dim, - num_attention_heads=levanter_config.num_heads, - num_key_value_heads=levanter_config.num_kv_heads, - rope_scaling=levanter_config.rope_scaling, - ) From 92a7f23b80a103e13151e1e227ca6f6ed8ea899a Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sat, 19 Aug 2023 20:44:39 -0700 Subject: [PATCH 22/57] Refactor RoPE --- src/levanter/models/llama.py | 122 +++++++---------------------------- tests/test_llama.py | 71 +++++++------------- 2 files changed, 47 insertions(+), 146 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 13d8469bc..999a3c327 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -94,87 +94,37 @@ def __call__(self, x: NamedArray) -> NamedArray: class LlamaRotaryEmbedding(eqx.Module): - Embed: Axis = eqx.static_field() Pos: Axis = eqx.static_field() - base: float = 10000 - inv_freq: jnp.ndarray = eqx.static_field() cos_cached: jnp.ndarray = eqx.static_field() sin_cached: jnp.ndarray = eqx.static_field() - max_seq_len_cached: int = eqx.static_field() - def __init__(self, Embed: Axis, Pos: Axis, base: int = 10000): - self.Embed = Embed + def __init__(self, HeadSize: Axis, Pos: Axis, base: int = 10000): self.Pos = Pos - self.base = base - self.inv_freq = 1.0 / (self.base ** (hax.arange(Embed.resize(Embed.size // 2), step=2) / Embed.size)).array - self.cos_cached, self.sin_cached = self._set_cos_sin_cache(Pos.size) + self.cos_cached, self.sin_cached = self._get_cos_sin_cache(Pos=Pos, HeadSize=HeadSize, base=base) - def _get_positional_ids(self): - """A helper function for the convenience of extending to two sub-classes - Here we use a standard positional encoding function, which was described in `Attention is all you need`. - """ - return jnp.arange(self.max_seq_len_cached) + @staticmethod + def _get_cos_sin_cache(HeadSize: NamedArray, Pos: NamedArray, base: float) -> Tuple[jnp.ndarray, jnp.ndarray]: + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) - def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - t = self._get_positional_ids() + position_ids: NamedArray = hax.arange(Pos) # Evaluates the Einstein summation convention on the operands. - freqs = jnp.einsum("i,j->ij", t, self.inv_freq) - # Different from paper but following HF implementation + freqs = position_ids * inv_freq.broadcast_axis(Pos) + # This is different from the paper but alignes with HF implementation: # It uses a different permutation in order to obtain the same calculation - emb = jnp.concatenate((freqs, freqs), axis=-1) - cos_cached = jnp.cos(emb)[None, None, :, :] - sin_cached = jnp.sin(emb)[None, None, :, :] + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos_cached = hax.cos(emb) + sin_cached = hax.sin(emb) return cos_cached, sin_cached - def __call__(self, x, seq_len: int): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self.cos_cached, self.sin_cached = self._set_cos_sin_cache(seq_len=seq_len) - + def __call__(self, seq_len: int) -> Tuple[NamedArray, NamedArray]: return ( - self.cos_cached[:, :, :seq_len, ...], - self.sin_cached[:, :, :seq_len, ...], + self.cos_cached[self.Pos, :seq_len], + self.sin_cached[self.Pos, :seq_len], ) -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling""" - - scaling_factor: float = 1.0 - - def __init__(self, Embed: Axis, Pos: Axis, base=10000, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(Embed, Pos, base) - - def _get_positional_ids(self): - """Here we overwrite the function in the base class to implement linear scaling.""" - return jnp.arange(self.max_seq_len_cached) / self.scaling_factor - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling.""" - - scaling_factor: float = 1.0 - - def __init__(self, Embed, Pos, base=10000, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(Embed, Pos, base) - - def _get_positional_ids(self): - """Here we overwrite the function in the base class. - Here it adjusts the frequency base dynamically according to the sequence length. - """ - if self.max_seq_len_cached > self.Pos.size: - base = self.base * ( - (self.scaling_factor * self.max_seq_len_cached / self.Pos.size) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / (base ** (jnp.arange(0, self.dim, 2) / self.dim)) - - return jnp.arange(self.max_seq_len_cached) - - class LlamaAttention(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() q_proj: hnn.Linear # projection from Embed to query @@ -193,7 +143,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias) v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias) o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias) - rotary_emb = _get_rotary_emb(config) + rotary_emb = LlamaRotaryEmbedding(config.HeadSize, config.Pos) return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @named_call @@ -202,9 +152,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids: Name k = self.k_proj(x) v = self.v_proj(x) - cos, sin = self.rotary_emb(v, seq_len=self.config.seq_len) + cos, sin = self.rotary_emb(seq_len=self.config.seq_len) - q, k = _apply_rotary_pos_emb(q, k, cos, sin, position_ids) + q, k = _apply_rotary_pos_emb(self.config.Pos, q, k, cos, sin, position_ids) scale = jax.lax.rsqrt(float(self.config.HeadSize.size)) @@ -377,23 +327,6 @@ def __call__( return lm_logits -def _get_rotary_emb(config: LlamaConfig) -> LlamaRotaryEmbedding: - # Note that the embedding here is HeadSize, not the full Embed - Embed = config.HeadSize - Pos = config.Pos - if config.rope_scaling is None: - return LlamaRotaryEmbedding(Embed, Pos) - else: - scaling_type = config.rope_scaling["type"] - scaling_factor = config.rope_scaling["factor"] - if scaling_type == "linear": - return LlamaLinearScalingRotaryEmbedding(Embed, Pos, scaling_factor=scaling_factor) - elif scaling_type == "dynamic": - return LlamaDynamicNTKScalingRotaryEmbedding(Embed, Pos, scaling_factor=scaling_factor) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _rotate_half(x: NamedArray) -> NamedArray: """Rotates half of the hidden dims of the input and concatenates them.""" HeadSize = x.axes[-1] @@ -404,21 +337,16 @@ def _rotate_half(x: NamedArray) -> NamedArray: def _apply_rotary_pos_emb( + Pos: Axis, q: NamedArray, # [batch, position, heads, head_size] k: NamedArray, # [batch, position, kv_heads, head_size] - cos: jnp.ndarray, # [1, 1, position, head_size] - sin: jnp.ndarray, # [1, 1, position, head_size] + cos: NamedArray, # [position, head_size] + sin: NamedArray, # [position, head_size] position_ids: NamedArray, # [bs, position] ) -> Tuple[NamedArray, NamedArray]: """Applies rotary position embedding to q and k.""" - cos = jnp.squeeze(jnp.squeeze(cos, axis=1), axis=0) # from [1, 1, position, dim] to [position, dim] - sin = jnp.squeeze(jnp.squeeze(sin, axis=1), axis=0) - # TODO: use NamedArray instead of array - position_ids = position_ids.array # [batch, position] - cos = cos[position_ids] # [batch, position, head_size] - sin = sin[position_ids] # [batch, position, head_size] - cos = hax.named(cos, ("batch", "position", "head_size")) - sin = hax.named(sin, ("batch", "position", "head_size")) - q_embed = hax.multiply(q, cos) + hax.multiply(_rotate_half(q), sin) - k_embed = hax.multiply(k, cos) + hax.multiply(_rotate_half(k), sin) + cos = cos[Pos, position_ids] # [batch, position, head_size] + sin = sin[Pos, position_ids] # [batch, position, head_size] + q_embed = (q * cos + _rotate_half(q) * sin) + k_embed = (k * cos + _rotate_half(k) * sin) return q_embed, k_embed diff --git a/tests/test_llama.py b/tests/test_llama.py index db51ac9e0..9f26263e4 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -7,8 +7,6 @@ LlamaAttention, LlamaConfig, LlamaDecoderLayer, - LlamaDynamicNTKScalingRotaryEmbedding, - LlamaLinearScalingRotaryEmbedding, LlamaLMHeadModel, LlamaRotaryEmbedding, ) @@ -20,55 +18,27 @@ @skip_if_no_torch def test_llama_rotary_embedding(): import torch - from transformers.models.llama.modeling_llama import ( - LlamaDynamicNTKScalingRotaryEmbedding as HFLlamaDynamicNTKScalingRotaryEmbedding, - ) - from transformers.models.llama.modeling_llama import ( - LlamaLinearScalingRotaryEmbedding as HFLlamaLinearScalingRotaryEmbedding, - ) from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as HFLlamaRotaryEmbedding llama_config = _get_llama_config() - Embed = llama_config.Embed + HeadSize = llama_config.HeadSize Pos = llama_config.Pos - hidden_dim = Embed.size + hidden_dim = HeadSize.size seq_len = Pos.size - scaling_factor = llama_config.rope_scaling["factor"] key = random.PRNGKey(0) device = "cpu" - def test_levanter_against_hf(levanter_class, hf_class): - x = random.normal(key, (1, seq_len)) - x_torch = torch.from_numpy(np.array(x)) - - levanter_output = levanter_class(x, seq_len=seq_len) - hf_output = hf_class(x_torch, seq_len=seq_len) - - for jax_out, torch_out in zip(levanter_output, hf_output): - torch_out = torch_out.numpy() - assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + x = random.normal(key, (1, seq_len)) + x_torch = torch.from_numpy(np.array(x)) - # test LlamaRotaryEmbedding - test_levanter_against_hf( - levanter_class=LlamaRotaryEmbedding(Embed=Embed, Pos=Pos), - hf_class=HFLlamaRotaryEmbedding(dim=hidden_dim, max_position_embeddings=seq_len, device=device), - ) + levanter_rope = LlamaRotaryEmbedding(HeadSize=HeadSize, Pos=Pos) + levanter_output = levanter_rope(seq_len=seq_len) + hf_rope = HFLlamaRotaryEmbedding(dim=hidden_dim, max_position_embeddings=seq_len, device=device) + hf_output = hf_rope(x_torch, seq_len=seq_len) - # test LlamaLinearScalingRotaryEmbedding - test_levanter_against_hf( - levanter_class=LlamaLinearScalingRotaryEmbedding(Embed=Embed, Pos=Pos, scaling_factor=scaling_factor), - hf_class=HFLlamaLinearScalingRotaryEmbedding( - dim=hidden_dim, max_position_embeddings=seq_len, scaling_factor=scaling_factor, device=device - ), - ) - - # test LlamaDynamicNTKScalingRotaryEmbedding - test_levanter_against_hf( - levanter_class=LlamaDynamicNTKScalingRotaryEmbedding(Embed=Embed, Pos=Pos, scaling_factor=scaling_factor), - hf_class=HFLlamaDynamicNTKScalingRotaryEmbedding( - dim=hidden_dim, max_position_embeddings=seq_len, scaling_factor=scaling_factor, device=device - ), - ) + for jax_out, torch_out in zip(levanter_output, hf_output): + torch_out = torch_out.numpy() + assert np.isclose(torch_out, np.array(jax_out.array), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" @skip_if_no_torch @@ -82,6 +52,9 @@ def assert_equal_out(hax_out, torch_out: torch.Tensor): torch_out.numpy(), np.array(hax_out.array), rtol=1e-2, atol=1e-2 ).all(), f"{torch_out} != {hax_out}" + def named_array_to_tensor(named_array): + return torch.from_numpy(np.array(named_array.array)) + llama_config = _get_llama_config() Pos = llama_config.Pos @@ -97,8 +70,8 @@ def assert_equal_out(hax_out, torch_out: torch.Tensor): levanter_out_rf_q = levanter_rotate_half(q) levanter_out_rf_k = levanter_rotate_half(k) - q_tensor = torch.from_numpy(np.array(q.array)).transpose(1, 2) # needed for HF - k_tensor = torch.from_numpy(np.array(k.array)).transpose(1, 2) + q_tensor = named_array_to_tensor(q).transpose(1, 2) # needed for HF + k_tensor = named_array_to_tensor(k).transpose(1, 2) hf_out_rf_q = hf_rotate_half(q_tensor).transpose(1, 2) # re-transpose to match levanter hf_out_rf_k = hf_rotate_half(k_tensor).transpose(1, 2) @@ -106,14 +79,14 @@ def assert_equal_out(hax_out, torch_out: torch.Tensor): assert_equal_out(levanter_out_rf_k, hf_out_rf_k) # Check the output of _apply_rotary_pos_emb() from levanter and hf - cos = random.normal(random.PRNGKey(2), (1, 1, Pos.size, HeadSize.size)) - sin = random.normal(random.PRNGKey(3), (1, 1, Pos.size, HeadSize.size)) + cos = hax.random.normal(random.PRNGKey(2), (Pos, HeadSize)) + sin = hax.random.normal(random.PRNGKey(3), (Pos, HeadSize)) position_ids = hax.arange(Pos).broadcast_axis(Batch) - levanter_out_rope_q, levanter_out_rope_k = levanter_apply_rotary_pos_emb(q, k, cos, sin, position_ids) - cos_tensor = torch.from_numpy(np.array(cos)) - sin_tensor = torch.from_numpy(np.array(sin)) - position_ids_tensor = torch.from_numpy(np.array(position_ids.array)) + levanter_out_rope_q, levanter_out_rope_k = levanter_apply_rotary_pos_emb(Pos, q, k, cos, sin, position_ids) + cos_tensor = named_array_to_tensor(cos) + sin_tensor = named_array_to_tensor(sin) + position_ids_tensor = named_array_to_tensor(position_ids) hf_out_rope_q, hf_out_rope_k = hf_apply_rotary_pos_emb( q_tensor, k_tensor, cos_tensor, sin_tensor, position_ids_tensor From 65f5888410d0120644fa824697499654beaa66b8 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sat, 19 Aug 2023 20:45:31 -0700 Subject: [PATCH 23/57] remove () --- src/levanter/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 999a3c327..b190eee31 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -347,6 +347,6 @@ def _apply_rotary_pos_emb( """Applies rotary position embedding to q and k.""" cos = cos[Pos, position_ids] # [batch, position, head_size] sin = sin[Pos, position_ids] # [batch, position, head_size] - q_embed = (q * cos + _rotate_half(q) * sin) - k_embed = (k * cos + _rotate_half(k) * sin) + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin return q_embed, k_embed From 3f028b9c4745838e5d5988ad984cd3b860c88529 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sat, 19 Aug 2023 20:50:19 -0700 Subject: [PATCH 24/57] NamedArray type hint --- src/levanter/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index b190eee31..e54ae6a84 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -105,7 +105,7 @@ def __init__(self, HeadSize: Axis, Pos: Axis, base: int = 10000): @staticmethod def _get_cos_sin_cache(HeadSize: NamedArray, Pos: NamedArray, base: float) -> Tuple[jnp.ndarray, jnp.ndarray]: HeadHalfSize = HeadSize.resize(HeadSize.size // 2) - inv_freq = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) position_ids: NamedArray = hax.arange(Pos) From 2fb1a3ff9de5eabee9ff4bca9c6a11d41aac1541 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sat, 19 Aug 2023 20:58:39 -0700 Subject: [PATCH 25/57] Remove position_ids --- src/levanter/models/llama.py | 23 +++++++---------------- tests/test_llama.py | 15 +++++++-------- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index e54ae6a84..b3b8476eb 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -147,14 +147,14 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids: NamedArray, *args): + def __call__(self, x: NamedArray, mask: Optional[NamedArray]): q = self.q_proj(x) # TODO: rearrange and possibly rename k = self.k_proj(x) v = self.v_proj(x) cos, sin = self.rotary_emb(seq_len=self.config.seq_len) - q, k = _apply_rotary_pos_emb(self.config.Pos, q, k, cos, sin, position_ids) + q, k = _apply_rotary_pos_emb(q, k, cos, sin) scale = jax.lax.rsqrt(float(self.config.HeadSize.size)) @@ -199,12 +199,12 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": return LlamaDecoderLayer(config, attn, mlp, ln_1, ln_2) @named_call - def __call__(self, x: NamedArray, mask: Optional[NamedArray], position_ids: NamedArray, *args): + def __call__(self, x: NamedArray, mask: Optional[NamedArray]): residual = x x = self.ln_1(x) # self attention and skip connection - attn_output = self.attn(x=x, mask=mask, position_ids=position_ids) + attn_output = self.attn(x=x, mask=mask) x = residual + attn_output # MLP and skip connection @@ -233,8 +233,8 @@ def init(config: LlamaConfig, *, key) -> "LlamaTransformer": return LlamaTransformer(config, layers, ln_f) @named_call - def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray], position_ids: NamedArray, *args) -> NamedArray: - x = self.layers.fold(x, mask=attn_mask, position_ids=position_ids) + def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray]) -> NamedArray: + x = self.layers.fold(x, mask=attn_mask) x = self.ln_f(x) return x @@ -306,7 +306,6 @@ def __call__( self, input_ids: NamedArray, attn_mask: Optional[NamedArray] = None, - position_ids: Optional[NamedArray] = None, *args, ) -> NamedArray: """ @@ -315,13 +314,9 @@ def __call__( Indices of input sequence tokens in the vocabulary. attn_mask (NamedArray, optional): [batch, position, seq_len] Mask to avoid performing attention on the padding token indices of the encoder input. - position_ids (NamedArray, optional): [batch, position] - Indices of positions of each input sequence tokens in the position embeddings. """ - if position_ids is None: - position_ids = hax.arange(self.Pos).broadcast_axis(input_ids.axes[0]) x = self.embeddings.embed(input_ids) - x = self.transformer(x, attn_mask=attn_mask, position_ids=position_ids) + x = self.transformer(x, attn_mask=attn_mask) lm_logits = self.embeddings.unembed(x) return lm_logits @@ -337,16 +332,12 @@ def _rotate_half(x: NamedArray) -> NamedArray: def _apply_rotary_pos_emb( - Pos: Axis, q: NamedArray, # [batch, position, heads, head_size] k: NamedArray, # [batch, position, kv_heads, head_size] cos: NamedArray, # [position, head_size] sin: NamedArray, # [position, head_size] - position_ids: NamedArray, # [bs, position] ) -> Tuple[NamedArray, NamedArray]: """Applies rotary position embedding to q and k.""" - cos = cos[Pos, position_ids] # [batch, position, head_size] - sin = sin[Pos, position_ids] # [batch, position, head_size] q_embed = q * cos + _rotate_half(q) * sin k_embed = k * cos + _rotate_half(k) * sin return q_embed, k_embed diff --git a/tests/test_llama.py b/tests/test_llama.py index 9f26263e4..b435282ff 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -81,11 +81,11 @@ def named_array_to_tensor(named_array): # Check the output of _apply_rotary_pos_emb() from levanter and hf cos = hax.random.normal(random.PRNGKey(2), (Pos, HeadSize)) sin = hax.random.normal(random.PRNGKey(3), (Pos, HeadSize)) - position_ids = hax.arange(Pos).broadcast_axis(Batch) - levanter_out_rope_q, levanter_out_rope_k = levanter_apply_rotary_pos_emb(Pos, q, k, cos, sin, position_ids) + levanter_out_rope_q, levanter_out_rope_k = levanter_apply_rotary_pos_emb(q, k, cos, sin) cos_tensor = named_array_to_tensor(cos) sin_tensor = named_array_to_tensor(sin) + position_ids = hax.arange(Pos).broadcast_axis(Batch) position_ids_tensor = named_array_to_tensor(position_ids) hf_out_rope_q, hf_out_rope_k = hf_apply_rotary_pos_emb( @@ -99,12 +99,12 @@ def named_array_to_tensor(named_array): def test_llama_attention(): config = _get_llama_config() - x, mask, position_ids = _get_random_inputs(config) + x, mask = _get_random_inputs(config) # generate a random key that can be splitted into 4 key = random.PRNGKey(4) attention = LlamaAttention.init(config=config, key=key) - out = attention(x, mask, position_ids) + out = attention(x, mask) # assert the same shape assert out.array.shape == (x.axes[0].size, config.seq_len, config.hidden_dim) @@ -114,8 +114,8 @@ def test_llama_decoder_layer(): llama_config = _get_llama_config() key = random.PRNGKey(0) llama_decoder_layer = LlamaDecoderLayer.init(config=llama_config, key=key) - x, mask, position_ids = _get_random_inputs(llama_config) - out = llama_decoder_layer(x, mask, position_ids) + x, mask = _get_random_inputs(llama_config) + out = llama_decoder_layer(x, mask) assert out.array.shape == (x.axes[0].size, llama_config.seq_len, llama_config.hidden_dim) @@ -158,5 +158,4 @@ def _get_random_inputs(config: LlamaConfig): Batch = hax.Axis("batch", 2) x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) - position_ids = hax.arange(Pos).broadcast_axis(Batch) - return x, mask, position_ids + return x, mask From 97dadce93d0a1b50cd72e311be5c19ec1c0db2b7 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 20 Aug 2023 14:05:54 -0700 Subject: [PATCH 26/57] from/to HF config --- src/levanter/models/llama.py | 39 +++++++++++++++++++++++++++++++----- tests/test_llama.py | 35 ++++++++++++++++++++++++++------ 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index b3b8476eb..ec54eb67a 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -5,6 +5,8 @@ import jax import jax.numpy as jnp import jax.random as jrandom +from transformers import LlamaConfig as HfLlamaConfig +from transformers import PretrainedConfig as HfConfig import haliax as hax import haliax.nn as hnn @@ -24,23 +26,19 @@ class LlamaConfig: Args: seq_len (int, optional): maximum length of the input sequence. Defaults to 2048. - vocab_size (int, optional): vocabulary size of the Llama model. Defaults to 32000. hidden_dim (int, optional): dimension of the hidden state. Defaults to 4096. intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008. num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. - num_kv_heads (int, optional): number of key/value heads needed for Grouped Query Attention. Defaults to 32. activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ seq_len: int = 2048 - vocab_size: int = 32000 hidden_dim: int = 4096 intermediate_dim: int = 11008 num_layers: int = 32 num_heads: int = 32 - num_kv_heads: int = 32 activation_function: str = "silu" initializer_range: float = 0.02 layer_norm_epsilon: float = 1e-5 @@ -56,11 +54,42 @@ class LlamaConfig: KeyPos = property(lambda self: self.Pos.alias("key_position")) Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) - KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) Mlp = property(lambda self: Axis(name="mlp", size=self.hidden_dim)) # TODO: shall we multiply with mlp_scale? HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + return LlamaConfig( + seq_len=hf_config.max_position_embeddings, + hidden_dim=hf_config.hidden_size, + intermediate_dim=hf_config.intermediate_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + activation_function=hf_config.hidden_act, + initializer_range=hf_config.initializer_range, + layer_norm_epsilon=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + ) + + def to_hf_config(self, vocab_size: int = 32000, config_overrides=None) -> HfLlamaConfig: + if config_overrides is None: + config_overrides = {} + + return HfLlamaConfig( + max_position_embeddings=self.seq_len, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + hidden_act=self.activation_function, + initializer_range=self.initializer_range, + rms_norm_eps=self.layer_norm_epsilon, + rope_scaling=self.rope_scaling, + vocab_size=vocab_size, + **config_overrides, + ) + class LlamaMlp(eqx.Module): """Multi-layer Perceptron diff --git a/tests/test_llama.py b/tests/test_llama.py index b435282ff..3596418ed 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,4 +1,5 @@ import numpy as np +import transformers from jax import random import haliax as hax @@ -15,6 +16,32 @@ from test_utils import skip_if_no_torch +@skip_if_no_torch +def test_llama_config(): + # load HF config and convert to levanter config + hf_config = transformers.LlamaConfig.from_pretrained("meta-llama/Llama-2-7b-hf") + llama_config = LlamaConfig.from_hf_config(hf_config) + + # convert back to HF config + config_overrides = { + "_name_or_path": hf_config._name_or_path, + "architectures": hf_config.architectures, + "torch_dtype": hf_config.torch_dtype, + } + new_hf_config = llama_config.to_hf_config( + vocab_size=hf_config.vocab_size, + config_overrides=config_overrides, + ) + + # assert the content in new_hf_config is the same as hf_config + for k in new_hf_config.__dict__.keys(): + if k in ["_commit_hash", "transformers_version"]: + continue + assert getattr(new_hf_config, k) == getattr( + hf_config, k + ), f"{k} {getattr(new_hf_config, k)} != {getattr(hf_config, k)}" + + @skip_if_no_torch def test_llama_rotary_embedding(): import torch @@ -122,9 +149,9 @@ def test_llama_decoder_layer(): def test_llama_lm_head_model(): llama_config = _get_llama_config() Batch = hax.Axis("batch", 2) - Vocab = hax.Axis("vocab", llama_config.vocab_size) + Vocab = hax.Axis("vocab", 1000) Pos = llama_config.Pos - input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, llama_config.vocab_size) + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) mask = hax.nn.attention.causal_mask(Pos, llama_config.KeyPos) llama_model = LlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) @@ -133,21 +160,17 @@ def test_llama_lm_head_model(): def _get_llama_config() -> LlamaConfig: - vocab_size = 1000 seq_len = 128 hidden_dim = 16 num_heads = 4 - num_kv_heads = 4 rope_scaling = { "type": "linear", "factor": 2.0, } return LlamaConfig( seq_len=seq_len, - vocab_size=vocab_size, hidden_dim=hidden_dim, num_heads=num_heads, - num_kv_heads=num_kv_heads, rope_scaling=rope_scaling, ) From fefb93a9d13072585c676fa161f0d66afcb780d7 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 21 Aug 2023 20:30:21 -0700 Subject: [PATCH 27/57] Update to state_dict --- src/levanter/models/llama.py | 88 ++++++++++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 9 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index ec54eb67a..bfbab0d85 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -14,9 +14,19 @@ from haliax.jax_utils import named_call, shaped_rng_split from haliax.nn.scan import Stacked -from levanter.compat.torch_serialization import StateDictSerializationMixin +from levanter.compat.hf_checkpoints import HFCheckpointConverter +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + stack_state_dict, + unflatten_linear_layers, + unstack_state_dict, +) from levanter.models.gpt2 import ACT2FN from levanter.models.lm_model import LmConfig +from levanter.utils.py_utils import cached_classproperty @LmConfig.register_subclass("llama") @@ -46,7 +56,7 @@ class LlamaConfig: gradient_checkpointing: bool = True gradient_checkpointing_block_size: int = 5 - use_bias: bool = True + use_bias: bool = False rope_scaling: Optional[dict] = None # Axis @@ -57,6 +67,11 @@ class LlamaConfig: Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) Mlp = property(lambda self: Axis(name="mlp", size=self.hidden_dim)) # TODO: shall we multiply with mlp_scale? HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + Intermediate = property(lambda self: Axis(name="intermediate", size=self.intermediate_dim)) + + @cached_classproperty + def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore + return HFCheckpointConverter(cls, "meta-llama/Llama-2-7b-hf", trust_remote_code=True) @classmethod def from_hf_config(cls, hf_config: HfConfig): @@ -103,11 +118,11 @@ class LlamaMlp(eqx.Module): act: Callable = eqx.static_field() @staticmethod - def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = False) -> "LlamaMlp": + def init(Embed: Axis, Intermediate: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = False) -> "LlamaMlp": k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3) - gate_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias) - up_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_up_proj, use_bias=use_bias) - down_proj = hnn.Linear.init(Out=Embed, In=Mlp, key=k_down_proj, use_bias=use_bias) + gate_proj = hnn.Linear.init(Out=Intermediate, In=Embed, key=k_fc, use_bias=use_bias) + up_proj = hnn.Linear.init(Out=Intermediate, In=Embed, key=k_up_proj, use_bias=use_bias) + down_proj = hnn.Linear.init(Out=Embed, In=Intermediate, key=k_down_proj, use_bias=use_bias) if isinstance(activation_fn, str): activation_fn = ACT2FN[activation_fn] act = activation_fn # type: ignore @@ -168,7 +183,6 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention": Embed = config.Embed k_q, k_k, k_v, k_o = jrandom.split(key, 4) q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias) - # TODO: double check if we should use Heads or KV_HEADS here k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias) v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias) o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias) @@ -208,6 +222,32 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray]): attn_output = self.o_proj(attn_output) return attn_output + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + # unflatten the linear layers of HF state_dict to match the shape of LlamaAttention + d = {} + d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, None)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, None)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, None)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "o_proj"), state_dict, self.o_proj, None)) + + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + # flatten the linear layers of LlamaAttention to match the shape of HF state_dict + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix) + + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "q_proj"), self.q_proj, None)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "k_proj"), self.k_proj, None)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "v_proj"), self.v_proj, None)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "o_proj"), self.o_proj, None)) + + state_dict.update(my_dict) + return state_dict + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"attn": "self_attn"} + class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() @@ -221,7 +261,14 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": k_attn, k_mlp = jrandom.split(key, 2) attn = LlamaAttention.init(config, key=k_attn) - mlp = LlamaMlp.init(config.Embed, config.Mlp, config.activation_function, key=k_mlp, use_bias=config.use_bias) + mlp = LlamaMlp.init( + config.Embed, + config.Intermediate, + config.Mlp, + config.activation_function, + key=k_mlp, + use_bias=config.use_bias, + ) ln_1 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) ln_2 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) @@ -243,6 +290,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray]): output = residual + mlp_output return output + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"ln_1": "input_layernorm", "ln_2": "post_attention_layernorm"} + class LlamaTransformer(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() @@ -268,6 +318,23 @@ def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray]) -> NamedArray return x + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + stacked = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "h")) + out = super().from_state_dict(stacked, prefix=prefix) + return out + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_state_dict: StateDict = {} + super().update_state_dict(my_state_dict, prefix=prefix) + + stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layers")) + state_dict.update(stacked_dict) + + return state_dict + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"ln_f": "norm"} + class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): """Similar to GPT2 Embedding but without dropout""" @@ -300,7 +367,7 @@ def unembed(self, x: NamedArray): return hax.dot("embed", x, self.token_embeddings) def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return {"token_embeddings": "wte.weight", "position_embeddings": "wpe.weight"} + return {"token_embeddings": "model.embed_tokens.weight", "position_embeddings": "wpe.weight"} class LlamaLMHeadModel(StateDictSerializationMixin, eqx.Module): @@ -350,6 +417,9 @@ def __call__( return lm_logits + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"transformer": "model", "embeddings": None} + def _rotate_half(x: NamedArray) -> NamedArray: """Rotates half of the hidden dims of the input and concatenates them.""" From 8239cc3540f15a052bcac4a449c4976ba6c6841f Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 21 Aug 2023 20:35:12 -0700 Subject: [PATCH 28/57] ignore type in default_hf_checkpoint_converter --- src/levanter/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index bfbab0d85..9799b821c 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -71,7 +71,7 @@ class LlamaConfig: @cached_classproperty def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore - return HFCheckpointConverter(cls, "meta-llama/Llama-2-7b-hf", trust_remote_code=True) + return HFCheckpointConverter(cls, "meta-llama/Llama-2-7b-hf", trust_remote_code=True) # type: ignore @classmethod def from_hf_config(cls, hf_config: HfConfig): From 73f1c90c90a3cb7dc16f3aa0e17d4c9cbdaec25b Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 21 Aug 2023 21:14:28 -0700 Subject: [PATCH 29/57] attn to self_attn in LlamaDecoderLayer --- src/levanter/models/llama.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 9799b821c..a743b4150 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -245,9 +245,6 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) state_dict.update(my_dict) return state_dict - def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return {"attn": "self_attn"} - class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() @@ -291,7 +288,11 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray]): return output def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return {"ln_1": "input_layernorm", "ln_2": "post_attention_layernorm"} + return { + "attn": "self_attn", + "ln_1": "input_layernorm", + "ln_2": "post_attention_layernorm", + } class LlamaTransformer(StateDictSerializationMixin, eqx.Module): @@ -367,7 +368,7 @@ def unembed(self, x: NamedArray): return hax.dot("embed", x, self.token_embeddings) def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return {"token_embeddings": "model.embed_tokens.weight", "position_embeddings": "wpe.weight"} + return {"token_embeddings": "model.embed_tokens.weight"} class LlamaLMHeadModel(StateDictSerializationMixin, eqx.Module): From 8f16ec0b08db447e98b49622d90a90d111726fbf Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Tue, 22 Aug 2023 20:18:35 -0700 Subject: [PATCH 30/57] Remove position embed from LlamaEmbedding --- src/levanter/models/llama.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index a743b4150..059ee805b 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -338,30 +338,26 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): - """Similar to GPT2 Embedding but without dropout""" + """Similar to GPT2 Embedding, except that: + - Llama doesn't have position embedding in the Embedding layer. + - Llama doesn't use dropout. + """ Vocab: Axis = eqx.static_field() config: LlamaConfig = eqx.static_field() - token_embeddings: NamedArray - position_embeddings: NamedArray @staticmethod def init(Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaEmbedding": k_wte, k_wpe = jrandom.split(key, 2) token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) - position_embeddings = hax.random.normal(k_wpe, (config.Pos, config.Embed)) * (config.initializer_range / 2) - - return LlamaEmbedding(Vocab, config, token_embeddings, position_embeddings) + return LlamaEmbedding(Vocab, config, token_embeddings) @named_call def embed(self, input_ids, *args): input_embeds = self.token_embeddings.take("vocab", input_ids) - position_embeds = self.position_embeddings - - x = input_embeds + position_embeds - + x = input_embeds return x def unembed(self, x: NamedArray): @@ -393,10 +389,9 @@ def Pos(self) -> Axis: @classmethod def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": - k_t, k_embeddings = jrandom.split(key, 2) + k_t, k_emb = jrandom.split(key, 2) transformer = LlamaTransformer.init(config, key=k_t) - embeddings = LlamaEmbedding.init(Vocab, config, key=k_embeddings) - + embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) return LlamaLMHeadModel(transformer, embeddings) def __call__( From 5578db2d6d23953ce02d2f455e0941cfa1ce222c Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Tue, 22 Aug 2023 21:16:27 -0700 Subject: [PATCH 31/57] Set out_dims_first_in_dict --- src/levanter/models/llama.py | 41 ++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 059ee805b..174dc98eb 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -106,7 +106,7 @@ def to_hf_config(self, vocab_size: int = 32000, config_overrides=None) -> HfLlam ) -class LlamaMlp(eqx.Module): +class LlamaMlp(eqx.Module, StateDictSerializationMixin): """Multi-layer Perceptron In comparison with GPT2, LlamaMlp adds an up-proj that multiplies with activated gate_proj, before down-proj. @@ -136,6 +136,44 @@ def __call__(self, x: NamedArray) -> NamedArray: outputs = self.down_proj(hidden_states) return outputs + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp + d = {} + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "gate_proj"), state_dict, self.gate_proj, out_dims_first_in_dict=True + ) + ) + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "up_proj"), state_dict, self.up_proj, out_dims_first_in_dict=True + ) + ) + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "down_proj"), state_dict, self.down_proj, out_dims_first_in_dict=True + ) + ) + + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix=prefix) + + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "gate_proj"), self.gate_proj, out_dims_first_in_dict=True) + ) + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "up_proj"), self.up_proj, out_dims_first_in_dict=True) + ) + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "down_proj"), self.down_proj, out_dims_first_in_dict=True) + ) + + state_dict.update(my_dict) + return state_dict + class LlamaRotaryEmbedding(eqx.Module): Pos: Axis = eqx.static_field() @@ -398,7 +436,6 @@ def __call__( self, input_ids: NamedArray, attn_mask: Optional[NamedArray] = None, - *args, ) -> NamedArray: """ Args: From d73c6f7ceb636703fcd6ae7c8f3c53c5963cc73f Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Tue, 22 Aug 2023 21:16:42 -0700 Subject: [PATCH 32/57] Add (incomplete) roundtrip test --- tests/test_llama.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_llama.py b/tests/test_llama.py index 3596418ed..6117a04fd 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,3 +1,6 @@ +import tempfile + +import jax import numpy as np import transformers from jax import random @@ -159,6 +162,48 @@ def test_llama_lm_head_model(): assert out.array.shape == (Batch.size, Pos.size, Vocab.size) +@skip_if_no_torch +def test_llama_roundtrip(): + import torch + from transformers import AutoModelForCausalLM + + converter = LlamaConfig.default_hf_checkpoint_converter + + config = _get_llama_config() + Vocab = hax.Axis("vocab", 1000) + + # TODO: load the first torch model with model_id from HF + + # randomly initialize a levanter model + # TODO: use converter.load_pretrained + model = LlamaLMHeadModel.init( + Vocab=Vocab, + config=config, + key=random.PRNGKey(0), + ) + + input = hax.random.randint(random.PRNGKey(0), model.Pos, 0, model.Vocab.size) + attn_mask = hax.nn.attention.causal_mask(model.Pos, model.config.KeyPos) + + def compute(input): + model_output = model(input, attn_mask=attn_mask) + return hax.nn.softmax(model_output, axis=model.Vocab) + + compute = jax.jit(compute) + jax_out = compute(input).array + + with tempfile.TemporaryDirectory() as tmpdir: + converter.save_pretrained(model, tmpdir) + torch_model2 = AutoModelForCausalLM.from_pretrained(tmpdir) + torch_model2.eval() + + torch_out2 = torch_model2(torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0)) + torch_out2 = torch_out2.logits[0].detach().cpu().numpy() + torch_out2 = jax.nn.softmax(torch_out2, axis=-1) + assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" + # assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + + def _get_llama_config() -> LlamaConfig: seq_len = 128 hidden_dim = 16 From 79f4fa9df0f73402efad7222c84e9cefc2d8ab74 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Wed, 23 Aug 2023 20:28:54 -0700 Subject: [PATCH 33/57] remove unused Axis form mlp --- src/levanter/models/llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 174dc98eb..78f8504d1 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Union import equinox as eqx import jax @@ -118,7 +118,9 @@ class LlamaMlp(eqx.Module, StateDictSerializationMixin): act: Callable = eqx.static_field() @staticmethod - def init(Embed: Axis, Intermediate: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = False) -> "LlamaMlp": + def init( + Embed: Axis, Intermediate: Axis, activation_fn: Union[str, Callable], *, key, use_bias: bool = False + ) -> "LlamaMlp": k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3) gate_proj = hnn.Linear.init(Out=Intermediate, In=Embed, key=k_fc, use_bias=use_bias) up_proj = hnn.Linear.init(Out=Intermediate, In=Embed, key=k_up_proj, use_bias=use_bias) @@ -299,7 +301,6 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": mlp = LlamaMlp.init( config.Embed, config.Intermediate, - config.Mlp, config.activation_function, key=k_mlp, use_bias=config.use_bias, From ee6e11c619fa4e44993fdb368f59d0f6f44628c2 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Wed, 23 Aug 2023 21:09:31 -0700 Subject: [PATCH 34/57] Rename v's position --- src/levanter/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 78f8504d1..2fdf12229 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -247,6 +247,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray]): q = q.astype(jnp.float32) k = k.astype(jnp.float32) k = k.rename({"position": "key_position"}) + v = v.rename({"position": "key_position"}) attn_scores = hax.dot("head_size", q, k) From f06e3fa2242e654b5640a3d4f413f58d1f3d526d Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 23 Aug 2023 23:38:55 -0700 Subject: [PATCH 35/57] pretty sure this is the problem: weights weren't being deserialized/reserialized properly --- .pre-commit-config.yaml | 2 +- src/levanter/models/llama.py | 16 ++++++++-------- tests/test_llama.py | 34 +++++++++++++++++++++++++++++++--- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f59ee87fd..38fe4bb40 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,4 +38,4 @@ repos: hooks: - id: mypy args: [--ignore-missing-imports] - additional_dependencies: [wandb] + additional_dependencies: [wandb, types-PyYAML] diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 2fdf12229..7e851d26b 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -266,10 +266,10 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray]): def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): # unflatten the linear layers of HF state_dict to match the shape of LlamaAttention d = {} - d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, None)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, None)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, None)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "o_proj"), state_dict, self.o_proj, None)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, True)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, True)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, True)) + d.update(unflatten_linear_layers(apply_prefix(prefix, "o_proj"), state_dict, self.o_proj, True)) return super().from_state_dict(d, prefix) @@ -278,10 +278,10 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) my_dict: StateDict = {} super().update_state_dict(my_dict, prefix) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "q_proj"), self.q_proj, None)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "k_proj"), self.k_proj, None)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "v_proj"), self.v_proj, None)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "o_proj"), self.o_proj, None)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "q_proj"), self.q_proj, True)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "k_proj"), self.k_proj, True)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "v_proj"), self.v_proj, True)) + my_dict.update(flatten_linear_layers(apply_prefix(prefix, "o_proj"), self.o_proj, True)) state_dict.update(my_dict) return state_dict diff --git a/tests/test_llama.py b/tests/test_llama.py index 6117a04fd..02d4da65f 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -130,8 +130,7 @@ def named_array_to_tensor(named_array): def test_llama_attention(): config = _get_llama_config() x, mask = _get_random_inputs(config) - # generate a random key that can be splitted into 4 - key = random.PRNGKey(4) + key = random.PRNGKey(0) attention = LlamaAttention.init(config=config, key=key) out = attention(x, mask) @@ -140,6 +139,35 @@ def test_llama_attention(): assert out.array.shape == (x.axes[0].size, config.seq_len, config.hidden_dim) +@skip_if_no_torch +def test_llama_attention_vs_hf(): + import torch + from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention + + config = _get_llama_config() + + attention = LlamaAttention.init(config=config, key=random.PRNGKey(0)) + + state = attention.to_state_dict() + state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} + hf_attention = HFLlamaAttention(config.to_hf_config()) + hf_attention.load_state_dict(state, strict=False) + + x, mask = _get_random_inputs(config) + x_torch = torch.from_numpy(np.array(x.array)) + mask_torch = torch.from_numpy(np.array(mask.array)).broadcast_to((2, 1, config.seq_len, config.seq_len)) + + # the torch mask is really a bias, so we need to invert it and make it a big negative number + mask_torch = (mask_torch == 0).float() * -1e9 + + out = attention(x, mask) + hf_out = hf_attention(x_torch, mask_torch) + + assert np.isclose( + hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-2, atol=1e-2 + ).all(), f"{hf_out[0]} != {out}" + + def test_llama_decoder_layer(): llama_config = _get_llama_config() key = random.PRNGKey(0) @@ -201,7 +229,7 @@ def compute(input): torch_out2 = torch_out2.logits[0].detach().cpu().numpy() torch_out2 = jax.nn.softmax(torch_out2, axis=-1) assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" - # assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" def _get_llama_config() -> LlamaConfig: From 41b103c09678184b614661a2897714248123b07a Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 23 Aug 2023 23:43:48 -0700 Subject: [PATCH 36/57] document when to use out_dims_first_in_dict --- src/levanter/compat/torch_serialization.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/levanter/compat/torch_serialization.py b/src/levanter/compat/torch_serialization.py index b3b30f9df..411f519c4 100644 --- a/src/levanter/compat/torch_serialization.py +++ b/src/levanter/compat/torch_serialization.py @@ -168,6 +168,10 @@ def flatten_linear_layers(prefix: Optional[str], tree: PyTree, out_dims_first_in linear layers can have arbitrary dimensions, grouped into input and output axes. This function flattens the linear layers in a state dict into a 2d weight matrix and a 1d bias vector. + **You should use out_dims_first_in_dict=True if you're using this to convert a PyTorch model to Haliax and the + PyTorch model uses Linear. If the PyTorch model uses Conv1d, use False.** None is probably not what you want, + except in very specific cases. + :param prefix: prefix to apply to the keys in the state dict :param tree: :param out_dims_first_in_dict: if True, the output dimensions will be the first axis in the flattened weight matrix. @@ -216,6 +220,10 @@ def unflatten_linear_layers( linear layers can have arbitrary dimensions, grouped into input and output axes. This function unflattens the linear layers in a state dict into a 2d weight matrix and a 1d bias vector. + **You should use out_dims_first_in_dict=True if you're using this to convert a PyTorch model to Haliax and the + PyTorch model uses Linear. If the PyTorch model uses Conv1d, use False.** None is probably not what you want, + except in very specific cases. + :param prefix: prefix to apply to the keys in the state dict :param statedict: the state dict to source the flattened weights from :param layer: the exemplar layer to use for unflattening From 6ed1ee5dd02bdee60d078e35c4f9c8a36755eafb Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 27 Aug 2023 18:11:57 -0700 Subject: [PATCH 37/57] Add LlamaRMSNorm and add more consistency --- src/levanter/models/llama.py | 57 +++++++++++++++++++++++++------ tests/test_llama.py | 66 ++++++++++++++++++++++++++---------- 2 files changed, 95 insertions(+), 28 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 7e851d26b..2e8fb7026 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -10,7 +10,7 @@ import haliax as hax import haliax.nn as hnn -from haliax import Axis, NamedArray +from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import named_call, shaped_rng_split from haliax.nn.scan import Stacked @@ -28,6 +28,8 @@ from levanter.models.lm_model import LmConfig from levanter.utils.py_utils import cached_classproperty +jax.config.update("jax_disable_jit", True) + @LmConfig.register_subclass("llama") @dataclass(frozen=True) @@ -287,12 +289,48 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) return state_dict +class LlamaRMSNorm(hnn.LayerNorm): + """It is a modified version of LayerNorm. + The main changes are: + 1. The variance is defined as the average of square, versus the original + definition as the average of the squared deviations from the mean. + 2. The output is defined as x * inv, without minusing the mean. + 3. The default value of eps is set to 1e-6 and use_bias to False. + """ + + @staticmethod + def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = False): + if use_weight: + weight = hax.ones(axis) + else: + weight = None + if use_bias: + bias = hax.zeros(axis) + else: + bias = None + + return LlamaRMSNorm(axis, weight, bias, eps) + + def __call__(self, x: NamedArray) -> NamedArray: + # This gives a different result than jnp.var(), which is + # defined as the average of the squared deviations from the mean + var = hax.mean(hax.square(x), axis=self.axis) + inv = hax.rsqrt(var + self.eps) + out = x * inv + + if self.weight is not None: + out = self.weight * out + if self.bias is not None: + out = out + self.bias + return out + + class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() attn: LlamaAttention mlp: LlamaMlp - ln_1: hnn.LayerNorm # input layernorm - ln_2: hnn.LayerNorm # post attention layernorm + ln_1: LlamaRMSNorm # input layernorm + ln_2: LlamaRMSNorm # post attention layernorm @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": @@ -306,8 +344,8 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": key=k_mlp, use_bias=config.use_bias, ) - ln_1 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) - ln_2 = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + ln_1 = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + ln_2 = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) return LlamaDecoderLayer(config, attn, mlp, ln_1, ln_2) @@ -338,7 +376,7 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: class LlamaTransformer(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() layers: Stacked[LlamaDecoderLayer] - ln_f: hnn.LayerNorm + ln_f: LlamaRMSNorm @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaTransformer": @@ -348,7 +386,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaTransformer": config, key=shaped_rng_split(key, config.num_layers), ) - ln_f = hnn.LayerNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + ln_f = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) return LlamaTransformer(config, layers, ln_f) @@ -360,7 +398,7 @@ def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray]) -> NamedArray return x def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - stacked = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "h")) + stacked = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layers")) out = super().from_state_dict(stacked, prefix=prefix) return out @@ -389,7 +427,7 @@ class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): @staticmethod def init(Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaEmbedding": - k_wte, k_wpe = jrandom.split(key, 2) + k_wte = jrandom.split(key, 1) token_embeddings = hax.random.normal(k_wte, (Vocab, config.Embed)) return LlamaEmbedding(Vocab, config, token_embeddings) @@ -449,7 +487,6 @@ def __call__( x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask) lm_logits = self.embeddings.unembed(x) - return lm_logits def _state_dict_key_map(self) -> Dict[str, Optional[str]]: diff --git a/tests/test_llama.py b/tests/test_llama.py index 02d4da65f..ef4ec4f73 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -13,6 +13,7 @@ LlamaDecoderLayer, LlamaLMHeadModel, LlamaRotaryEmbedding, + LlamaRMSNorm, ) from levanter.models.llama import _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb from levanter.models.llama import _rotate_half as levanter_rotate_half @@ -127,20 +128,8 @@ def named_array_to_tensor(named_array): assert_equal_out(levanter_out_rope_k, hf_out_rope_k) -def test_llama_attention(): - config = _get_llama_config() - x, mask = _get_random_inputs(config) - key = random.PRNGKey(0) - - attention = LlamaAttention.init(config=config, key=key) - out = attention(x, mask) - - # assert the same shape - assert out.array.shape == (x.axes[0].size, config.seq_len, config.hidden_dim) - - @skip_if_no_torch -def test_llama_attention_vs_hf(): +def test_llama_attention(): import torch from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention @@ -151,11 +140,12 @@ def test_llama_attention_vs_hf(): state = attention.to_state_dict() state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} hf_attention = HFLlamaAttention(config.to_hf_config()) - hf_attention.load_state_dict(state, strict=False) + hf_attention.load_state_dict(state, strict=True) x, mask = _get_random_inputs(config) x_torch = torch.from_numpy(np.array(x.array)) - mask_torch = torch.from_numpy(np.array(mask.array)).broadcast_to((2, 1, config.seq_len, config.seq_len)) + batch_size = x_torch.shape[0] + mask_torch = torch.from_numpy(np.array(mask.array)).broadcast_to((batch_size, 1, -1, -1)) # the torch mask is really a bias, so we need to invert it and make it a big negative number mask_torch = (mask_torch == 0).float() * -1e9 @@ -164,17 +154,56 @@ def test_llama_attention_vs_hf(): hf_out = hf_attention(x_torch, mask_torch) assert np.isclose( - hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-2, atol=1e-2 + hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-4, atol=1e-4 ).all(), f"{hf_out[0]} != {out}" +@skip_if_no_torch +def test_llama_rms_norm(): + import torch + from transformers.models.llama.modeling_llama import LlamaRMSNorm as HFLlamaRMSNorm + + config = _get_llama_config() + ln = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + hf_ln = HFLlamaRMSNorm(config.Embed.size, eps=config.layer_norm_epsilon) + + x, _ = _get_random_inputs(config) + x_torch = torch.from_numpy(np.array(x.array)) + + out = ln(x) + hf_out = hf_ln(x_torch) + + assert np.isclose( + hf_out.detach().cpu().numpy(), np.array(out.array), rtol=1e-6, atol=1e-6 + ).all(), f"{hf_out} != {out}" + + +@skip_if_no_torch def test_llama_decoder_layer(): + import torch + from transformers.models.llama.modeling_llama import LlamaDecoderLayer as HFLlamaDecoderLayer + llama_config = _get_llama_config() key = random.PRNGKey(0) llama_decoder_layer = LlamaDecoderLayer.init(config=llama_config, key=key) + + state = llama_decoder_layer.to_state_dict() + state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} + hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config()) + hf_decoder_layer.load_state_dict(state, strict=True) + x, mask = _get_random_inputs(llama_config) + x_torch = torch.from_numpy(np.array(x.array)) + batch_size = x_torch.shape[0] + mask_torch = torch.from_numpy(np.array(mask.array)).broadcast_to((batch_size, 1, -1, -1)) + mask_torch = (mask_torch == 0).float() * -1e9 + out = llama_decoder_layer(x, mask) - assert out.array.shape == (x.axes[0].size, llama_config.seq_len, llama_config.hidden_dim) + hf_out = hf_decoder_layer(x_torch, mask_torch) + + assert np.isclose( + hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-4, atol=1e-4 + ).all(), f"{hf_out[0]} != {out}" def test_llama_lm_head_model(): @@ -212,6 +241,7 @@ def test_llama_roundtrip(): input = hax.random.randint(random.PRNGKey(0), model.Pos, 0, model.Vocab.size) attn_mask = hax.nn.attention.causal_mask(model.Pos, model.config.KeyPos) + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) def compute(input): model_output = model(input, attn_mask=attn_mask) @@ -225,7 +255,7 @@ def compute(input): torch_model2 = AutoModelForCausalLM.from_pretrained(tmpdir) torch_model2.eval() - torch_out2 = torch_model2(torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0)) + torch_out2 = torch_model2(input_torch) torch_out2 = torch_out2.logits[0].detach().cpu().numpy() torch_out2 = jax.nn.softmax(torch_out2, axis=-1) assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" From 9c669fd8827458347d8ccdaca82134af1ee395d1 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Sun, 27 Aug 2023 18:12:43 -0700 Subject: [PATCH 38/57] Fix issues from pre-commit tests --- src/levanter/models/llama.py | 1 + tests/test_llama.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 2e8fb7026..93444a367 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -28,6 +28,7 @@ from levanter.models.lm_model import LmConfig from levanter.utils.py_utils import cached_classproperty + jax.config.update("jax_disable_jit", True) diff --git a/tests/test_llama.py b/tests/test_llama.py index ef4ec4f73..cc19ee6b3 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -12,8 +12,8 @@ LlamaConfig, LlamaDecoderLayer, LlamaLMHeadModel, - LlamaRotaryEmbedding, LlamaRMSNorm, + LlamaRotaryEmbedding, ) from levanter.models.llama import _apply_rotary_pos_emb as levanter_apply_rotary_pos_emb from levanter.models.llama import _rotate_half as levanter_rotate_half From 6245f9b62a45931b7e787c1141e08ff1e727e5ab Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 28 Aug 2023 15:16:45 -0700 Subject: [PATCH 39/57] tie llama weights by default --- src/levanter/models/llama.py | 10 ++++++---- tests/test_llama.py | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 93444a367..ec691296a 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -29,9 +29,6 @@ from levanter.utils.py_utils import cached_classproperty -jax.config.update("jax_disable_jit", True) - - @LmConfig.register_subclass("llama") @dataclass(frozen=True) class LlamaConfig: @@ -74,7 +71,12 @@ class LlamaConfig: @cached_classproperty def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore - return HFCheckpointConverter(cls, "meta-llama/Llama-2-7b-hf", trust_remote_code=True) # type: ignore + return HFCheckpointConverter( + cls, # type: ignore + "meta-llama/Llama-2-7b-hf", + trust_remote_code=True, + config_overrides={"tie_word_embeddings": True}, + ) @classmethod def from_hf_config(cls, hf_config: HfConfig): diff --git a/tests/test_llama.py b/tests/test_llama.py index cc19ee6b3..17db0edf4 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -247,7 +247,7 @@ def compute(input): model_output = model(input, attn_mask=attn_mask) return hax.nn.softmax(model_output, axis=model.Vocab) - compute = jax.jit(compute) + # compute = jax.jit(compute) jax_out = compute(input).array with tempfile.TemporaryDirectory() as tmpdir: @@ -275,6 +275,8 @@ def _get_llama_config() -> LlamaConfig: hidden_dim=hidden_dim, num_heads=num_heads, rope_scaling=rope_scaling, + # disable for tests so debugging is easier + gradient_checkpointing=False, ) From 190cba30a3d827ddbbcd4633e9924b87d9a5deec Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 28 Aug 2023 16:16:39 -0700 Subject: [PATCH 40/57] add a todo for ivan --- src/levanter/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index ec691296a..f2a69955a 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -44,6 +44,8 @@ class LlamaConfig: rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ + # TODO(ivan): add tying of embeddings, default it to false to match the original model + seq_len: int = 2048 hidden_dim: int = 4096 intermediate_dim: int = 11008 From a022bb1027ff621bce2ac4143239abc9b5fbab07 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 28 Aug 2023 16:36:33 -0700 Subject: [PATCH 41/57] make test pass even without auth token --- src/levanter/compat/hf_checkpoints.py | 37 ++++++++++++++++++++++----- src/levanter/models/llama.py | 2 ++ tests/test_llama.py | 4 +-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index 3f282c325..5b6238f0c 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -20,7 +20,7 @@ import safetensors import safetensors.numpy from huggingface_hub import hf_hub_download, snapshot_download -from huggingface_hub.utils import EntryNotFoundError, HFValidationError +from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError from jax.experimental.multihost_utils import sync_global_devices from jax.random import PRNGKey from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer @@ -302,8 +302,13 @@ def default_config(self) -> LevConfig: return self.config_from_hf_config(self.default_hf_config) def HFAutoModelClass(self, auto_class: Type[AutoModel] = AutoModelForCausalLM) -> Type[AutoModel]: - # figure out the - config = self.hf_config_from_hf_checkpoint() + # first, see if it's a built-in model + try: + return auto_class._model_mapping[self.HfConfigClass] + except KeyError: + pass + + config = self.default_hf_config cls_name = auto_class.__name__ if hasattr(config, "auto_map") and cls_name in config.auto_map: class_ref = config.auto_map[cls_name] @@ -503,10 +508,26 @@ def _save_pretrained_local( dict_config = config.to_dict() # copy over the default keys - for k in KEYS_TO_COPY_FROM_BASE_CONFIG: - attr = getattr(self.default_hf_config, k, None) - if attr is not None: - dict_config[k] = attr + try: + for k in KEYS_TO_COPY_FROM_BASE_CONFIG: + attr = getattr(self.default_hf_config, k, None) + if attr is not None: + dict_config[k] = attr + # except GatedRepoError: + # warnings.warn("Could not copy keys from base config because the repo is gated. Making assumptions.") + except Exception as e: # noqa + if isinstance(e, GatedRepoError) or isinstance(e.__cause__, GatedRepoError): + warnings.warn("Could not copy keys from base config because the repo is gated. Making assumptions.") + + # this is probably llama, but in general we just need to set the auto_map and architectures + dict_config["auto_map"] = { + "AutoModelForCausalLM": self.HFAutoModelClass(AutoModelForCausalLM).__qualname__, + "AutoConfig": self.HfConfigClass.__qualname__, + } + + dict_config["architectures"] = [self.HFAutoModelClass(AutoModelForCausalLM).__name__] + else: + raise if self.config_overrides: dict_config.update(self.config_overrides) @@ -585,6 +606,8 @@ def _save_code_local(self, path): attributes_path = hf_hub_download(repo_id=repo, filename=".gitattributes", revision=revision) except EntryNotFoundError: attributes_path = None + except GatedRepoError: + attributes_path = None if attributes_path is None: warnings.warn("HF Export - No .gitattributes file found, using a heuristic to decide what to save") diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index f2a69955a..8936d3f43 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -78,6 +78,8 @@ def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"] "meta-llama/Llama-2-7b-hf", trust_remote_code=True, config_overrides={"tie_word_embeddings": True}, + tokenizer="hf-internal-testing/llama-tokenizer", + HfConfigClass=HfLlamaConfig, ) @classmethod diff --git a/tests/test_llama.py b/tests/test_llama.py index 17db0edf4..37adf7bfd 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -247,11 +247,11 @@ def compute(input): model_output = model(input, attn_mask=attn_mask) return hax.nn.softmax(model_output, axis=model.Vocab) - # compute = jax.jit(compute) + compute = jax.jit(compute) jax_out = compute(input).array with tempfile.TemporaryDirectory() as tmpdir: - converter.save_pretrained(model, tmpdir) + converter.save_pretrained(model, tmpdir, save_reference_code=False) torch_model2 = AutoModelForCausalLM.from_pretrained(tmpdir) torch_model2.eval() From ec742c453a624959f856719518087987c7c1a532 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 28 Aug 2023 19:26:39 -0700 Subject: [PATCH 42/57] Intermediate -> Mlp --- src/levanter/models/llama.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 8936d3f43..62e1fcabc 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -43,9 +43,6 @@ class LlamaConfig: activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ - - # TODO(ivan): add tying of embeddings, default it to false to match the original model - seq_len: int = 2048 hidden_dim: int = 4096 intermediate_dim: int = 11008 @@ -67,9 +64,8 @@ class LlamaConfig: Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) - Mlp = property(lambda self: Axis(name="mlp", size=self.hidden_dim)) # TODO: shall we multiply with mlp_scale? + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) - Intermediate = property(lambda self: Axis(name="intermediate", size=self.intermediate_dim)) @cached_classproperty def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore @@ -121,19 +117,19 @@ class LlamaMlp(eqx.Module, StateDictSerializationMixin): before down-proj. """ - gate_proj: hnn.Linear # projection from Embed to Intermediate - up_proj: hnn.Linear # projection from Embed to Intermediate - down_proj: hnn.Linear # projection from Intermediate to Embed + gate_proj: hnn.Linear # projection from Embed to Mlp + up_proj: hnn.Linear # projection from Embed to Mlp + down_proj: hnn.Linear # projection from Mlp to Embed act: Callable = eqx.static_field() @staticmethod def init( - Embed: Axis, Intermediate: Axis, activation_fn: Union[str, Callable], *, key, use_bias: bool = False + Embed: Axis, Mlp: Axis, activation_fn: Union[str, Callable], *, key, use_bias: bool = False ) -> "LlamaMlp": k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3) - gate_proj = hnn.Linear.init(Out=Intermediate, In=Embed, key=k_fc, use_bias=use_bias) - up_proj = hnn.Linear.init(Out=Intermediate, In=Embed, key=k_up_proj, use_bias=use_bias) - down_proj = hnn.Linear.init(Out=Embed, In=Intermediate, key=k_down_proj, use_bias=use_bias) + gate_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias) + up_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_up_proj, use_bias=use_bias) + down_proj = hnn.Linear.init(Out=Embed, In=Mlp, key=k_down_proj, use_bias=use_bias) if isinstance(activation_fn, str): activation_fn = ACT2FN[activation_fn] act = activation_fn # type: ignore @@ -346,7 +342,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": attn = LlamaAttention.init(config, key=k_attn) mlp = LlamaMlp.init( config.Embed, - config.Intermediate, + config.Mlp, config.activation_function, key=k_mlp, use_bias=config.use_bias, From a64e63a56a1908f43d449bea5e6267e66c9b3016 Mon Sep 17 00:00:00 2001 From: Ivan Zhou Date: Mon, 28 Aug 2023 19:28:36 -0700 Subject: [PATCH 43/57] Update src/levanter/models/llama.py Co-authored-by: David Hall --- src/levanter/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 8936d3f43..b21dc545d 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -202,7 +202,6 @@ def _get_cos_sin_cache(HeadSize: NamedArray, Pos: NamedArray, base: float) -> Tu position_ids: NamedArray = hax.arange(Pos) - # Evaluates the Einstein summation convention on the operands. freqs = position_ids * inv_freq.broadcast_axis(Pos) # This is different from the paper but alignes with HF implementation: # It uses a different permutation in order to obtain the same calculation From 1bd25ca37557d36ac07a3b2765c83e6aab8a30c2 Mon Sep 17 00:00:00 2001 From: Ivan Zhou Date: Mon, 28 Aug 2023 19:30:09 -0700 Subject: [PATCH 44/57] Update src/levanter/models/llama.py Co-authored-by: David Hall --- src/levanter/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 376032c7d..175410423 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -203,7 +203,7 @@ def _get_cos_sin_cache(HeadSize: NamedArray, Pos: NamedArray, base: float) -> Tu # It uses a different permutation in order to obtain the same calculation emb = hax.concatenate(HeadSize, (freqs, freqs)) cos_cached = hax.cos(emb) - sin_cached = hax.sin(emb) + # This is different from the paper but aligns with HF implementation: return cos_cached, sin_cached def __call__(self, seq_len: int) -> Tuple[NamedArray, NamedArray]: From d7412d4538a550d41f96875d3ed9711efdf5dd76 Mon Sep 17 00:00:00 2001 From: Ivan Zhou Date: Mon, 28 Aug 2023 19:30:19 -0700 Subject: [PATCH 45/57] Update src/levanter/models/llama.py Co-authored-by: David Hall --- src/levanter/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 175410423..bc2235dfe 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -353,10 +353,9 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": @named_call def __call__(self, x: NamedArray, mask: Optional[NamedArray]): + # self attention and skip connection residual = x x = self.ln_1(x) - - # self attention and skip connection attn_output = self.attn(x=x, mask=mask) x = residual + attn_output From 2e3c1fb4db659555bb759366b41b167371f16389 Mon Sep 17 00:00:00 2001 From: Ivan Zhou Date: Mon, 28 Aug 2023 19:30:41 -0700 Subject: [PATCH 46/57] Update src/levanter/models/llama.py Co-authored-by: David Hall --- src/levanter/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index bc2235dfe..edf32a62c 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -381,8 +381,6 @@ class LlamaTransformer(StateDictSerializationMixin, eqx.Module): @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaTransformer": - # TODO: here it reports an error that is related to _get_rotary_emb() in LlamaAttention - # TypeError: Output from batched function Axis(name='head_size', size=4) with type is not a valid JAX type layers = Stacked.init(config.Layers, LlamaDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( config, key=shaped_rng_split(key, config.num_layers), From 05f684fa399e9361d4fc40b814c277204c2b3742 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 28 Aug 2023 19:38:27 -0700 Subject: [PATCH 47/57] Fix issues from pre-commit checks --- src/levanter/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index edf32a62c..14e408c3e 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -43,6 +43,7 @@ class LlamaConfig: activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ + seq_len: int = 2048 hidden_dim: int = 4096 intermediate_dim: int = 11008 @@ -203,6 +204,7 @@ def _get_cos_sin_cache(HeadSize: NamedArray, Pos: NamedArray, base: float) -> Tu # It uses a different permutation in order to obtain the same calculation emb = hax.concatenate(HeadSize, (freqs, freqs)) cos_cached = hax.cos(emb) + sin_cached = hax.sin(emb) # This is different from the paper but aligns with HF implementation: return cos_cached, sin_cached From f6157dfacff160f228dba43e1ce66fdf84fc98f1 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 28 Aug 2023 20:16:10 -0700 Subject: [PATCH 48/57] Start from llama 2 hf in roundtrip --- src/levanter/models/llama.py | 15 +++++++++++- tests/test_llama.py | 44 ++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 14e408c3e..c325ceb1b 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -93,7 +93,19 @@ def from_hf_config(cls, hf_config: HfConfig): rope_scaling=hf_config.rope_scaling, ) - def to_hf_config(self, vocab_size: int = 32000, config_overrides=None) -> HfLlamaConfig: + def to_hf_config( + self, vocab_size: int = 32000, tie_word_embeddings: bool = False, config_overrides: Optional[Dict] = None + ) -> HfLlamaConfig: + """Convert to HuggingFace's LlamaConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + tie_word_embeddings (bool, optional): Whether to tie weight embeddings. HuggingFace's default value is False + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfLlamaConfig: HuggingFace's LlamaConfig + """ if config_overrides is None: config_overrides = {} @@ -108,6 +120,7 @@ def to_hf_config(self, vocab_size: int = 32000, config_overrides=None) -> HfLlam rms_norm_eps=self.layer_norm_epsilon, rope_scaling=self.rope_scaling, vocab_size=vocab_size, + tie_word_embeddings=tie_word_embeddings, **config_overrides, ) diff --git a/tests/test_llama.py b/tests/test_llama.py index 37adf7bfd..ac3d47eb4 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -222,26 +222,27 @@ def test_llama_lm_head_model(): @skip_if_no_torch def test_llama_roundtrip(): import torch - from transformers import AutoModelForCausalLM + from transformers import AutoModelForCausalLM, LlamaForCausalLM converter = LlamaConfig.default_hf_checkpoint_converter - config = _get_llama_config() - Vocab = hax.Axis("vocab", 1000) + config = LlamaConfig() + Vocab = hax.Axis("vocab", 32000) - # TODO: load the first torch model with model_id from HF + # Make input and attn_mask + input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + attn_mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) - # randomly initialize a levanter model - # TODO: use converter.load_pretrained - model = LlamaLMHeadModel.init( - Vocab=Vocab, - config=config, - key=random.PRNGKey(0), - ) + torch_config = config.to_hf_config(vocab_size=Vocab.size) + torch_model = LlamaForCausalLM(torch_config) + torch_model.eval() - input = hax.random.randint(random.PRNGKey(0), model.Pos, 0, model.Vocab.size) - attn_mask = hax.nn.attention.causal_mask(model.Pos, model.config.KeyPos) - input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) + torch_out = torch_model(input_torch) + torch_out = torch_out.logits[0].detach().cpu().numpy() + torch_out = jax.nn.softmax(torch_out, axis=-1) + + model = converter.load_pretrained(LlamaLMHeadModel) def compute(input): model_output = model(input, attn_mask=attn_mask) @@ -250,6 +251,9 @@ def compute(input): compute = jax.jit(compute) jax_out = compute(input).array + assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + with tempfile.TemporaryDirectory() as tmpdir: converter.save_pretrained(model, tmpdir, save_reference_code=False) torch_model2 = AutoModelForCausalLM.from_pretrained(tmpdir) @@ -263,20 +267,16 @@ def compute(input): def _get_llama_config() -> LlamaConfig: - seq_len = 128 - hidden_dim = 16 - num_heads = 4 rope_scaling = { "type": "linear", "factor": 2.0, } return LlamaConfig( - seq_len=seq_len, - hidden_dim=hidden_dim, - num_heads=num_heads, + seq_len=128, + hidden_dim=16, + num_heads=4, rope_scaling=rope_scaling, - # disable for tests so debugging is easier - gradient_checkpointing=False, + gradient_checkpointing=False, # disable for tests so debugging is easier ) From 5109b55877f8213c44260bcbd19bc094bd54437a Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Mon, 28 Aug 2023 20:49:36 -0700 Subject: [PATCH 49/57] Update model_id in the round trip test --- tests/test_llama.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_llama.py b/tests/test_llama.py index ac3d47eb4..4c55b5797 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -7,6 +7,7 @@ import haliax as hax +from levanter.compat.hf_checkpoints import RepoRef from levanter.models.llama import ( LlamaAttention, LlamaConfig, @@ -224,25 +225,30 @@ def test_llama_roundtrip(): import torch from transformers import AutoModelForCausalLM, LlamaForCausalLM + model_id = "stanford-crfm/levanter-llama-test" converter = LlamaConfig.default_hf_checkpoint_converter - config = LlamaConfig() - Vocab = hax.Axis("vocab", 32000) + config = LlamaConfig( + seq_len=128, + hidden_dim=16, + num_heads=4, + gradient_checkpointing=False, + ) + Vocab = hax.Axis("vocab", 1000) # Make input and attn_mask input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) attn_mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) - torch_config = config.to_hf_config(vocab_size=Vocab.size) - torch_model = LlamaForCausalLM(torch_config) + torch_model = LlamaForCausalLM.from_pretrained(model_id) torch_model.eval() torch_out = torch_model(input_torch) torch_out = torch_out.logits[0].detach().cpu().numpy() torch_out = jax.nn.softmax(torch_out, axis=-1) - model = converter.load_pretrained(LlamaLMHeadModel) + model = converter.load_pretrained(LlamaLMHeadModel, RepoRef(model_id)) def compute(input): model_output = model(input, attn_mask=attn_mask) From 0a5d36d7875843c6b7ca4f931e2ce2183d720743 Mon Sep 17 00:00:00 2001 From: Ivan-Zhou Date: Tue, 29 Aug 2023 21:32:22 -0700 Subject: [PATCH 50/57] Untie weight at LMHead Linear Layer --- src/levanter/models/llama.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index c325ceb1b..32ea1b6e3 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -74,7 +74,6 @@ def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"] cls, # type: ignore "meta-llama/Llama-2-7b-hf", trust_remote_code=True, - config_overrides={"tie_word_embeddings": True}, tokenizer="hf-internal-testing/llama-tokenizer", HfConfigClass=HfLlamaConfig, ) @@ -93,14 +92,11 @@ def from_hf_config(cls, hf_config: HfConfig): rope_scaling=hf_config.rope_scaling, ) - def to_hf_config( - self, vocab_size: int = 32000, tie_word_embeddings: bool = False, config_overrides: Optional[Dict] = None - ) -> HfLlamaConfig: + def to_hf_config(self, vocab_size: int = 32000, config_overrides: Optional[Dict] = None) -> HfLlamaConfig: """Convert to HuggingFace's LlamaConfig Args: vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. - tie_word_embeddings (bool, optional): Whether to tie weight embeddings. HuggingFace's default value is False config_overrides (dict, optional): Overrides for the config. Defaults to None. Returns: @@ -120,7 +116,6 @@ def to_hf_config( rms_norm_eps=self.layer_norm_epsilon, rope_scaling=self.rope_scaling, vocab_size=vocab_size, - tie_word_embeddings=tie_word_embeddings, **config_overrides, ) @@ -462,6 +457,7 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: class LlamaLMHeadModel(StateDictSerializationMixin, eqx.Module): transformer: LlamaTransformer embeddings: LlamaEmbedding + lm_head: hnn.Linear @property def config(self): @@ -484,7 +480,8 @@ def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": k_t, k_emb = jrandom.split(key, 2) transformer = LlamaTransformer.init(config, key=k_t) embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) - return LlamaLMHeadModel(transformer, embeddings) + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=config.use_bias) + return LlamaLMHeadModel(transformer, embeddings, lm_head) def __call__( self, @@ -500,12 +497,33 @@ def __call__( """ x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask) - lm_logits = self.embeddings.unembed(x) + lm_logits = self.lm_head(x) return lm_logits def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"transformer": "model", "embeddings": None} + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp + d = {} + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + ) + ) + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_dict: StateDict = {} + super().update_state_dict(my_dict, prefix=prefix) + + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) + ) + + state_dict.update(my_dict) + return state_dict + def _rotate_half(x: NamedArray) -> NamedArray: """Rotates half of the hidden dims of the input and concatenates them.""" From c0569426b29fa3b27e46f19fc5fa3fc0cd0598d7 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 30 Aug 2023 09:12:11 -0700 Subject: [PATCH 51/57] fix round trip test, use compile time eval for the cos/sin cache --- src/levanter/models/llama.py | 30 ++++++++++++++++++------------ tests/test_llama.py | 35 +++++++++++++++++++---------------- 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 32ea1b6e3..a7003c8ec 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Type, Union import equinox as eqx import jax @@ -14,7 +14,7 @@ from haliax.jax_utils import named_call, shaped_rng_split from haliax.nn.scan import Stacked -from levanter.compat.hf_checkpoints import HFCheckpointConverter +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig from levanter.compat.torch_serialization import ( StateDict, StateDictSerializationMixin, @@ -31,7 +31,7 @@ @LmConfig.register_subclass("llama") @dataclass(frozen=True) -class LlamaConfig: +class LlamaConfig(HFCompatConfig): """Config for LlamaModel Args: @@ -92,7 +92,7 @@ def from_hf_config(cls, hf_config: HfConfig): rope_scaling=hf_config.rope_scaling, ) - def to_hf_config(self, vocab_size: int = 32000, config_overrides: Optional[Dict] = None) -> HfLlamaConfig: + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfLlamaConfig: """Convert to HuggingFace's LlamaConfig Args: @@ -119,6 +119,10 @@ def to_hf_config(self, vocab_size: int = 32000, config_overrides: Optional[Dict] **config_overrides, ) + @property + def model_type(cls) -> Type["LlamaLMHeadModel"]: + return LlamaLMHeadModel + class LlamaMlp(eqx.Module, StateDictSerializationMixin): """Multi-layer Perceptron @@ -191,17 +195,19 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) return state_dict -class LlamaRotaryEmbedding(eqx.Module): - Pos: Axis = eqx.static_field() - cos_cached: jnp.ndarray = eqx.static_field() - sin_cached: jnp.ndarray = eqx.static_field() +class LlamaRotaryEmbedding(eqx.Module, StateDictSerializationMixin): + Pos: Axis = eqx.field(static=True) + cos_cached: NamedArray = eqx.field(static=True) + sin_cached: NamedArray = eqx.field(static=True) def __init__(self, HeadSize: Axis, Pos: Axis, base: int = 10000): self.Pos = Pos - self.cos_cached, self.sin_cached = self._get_cos_sin_cache(Pos=Pos, HeadSize=HeadSize, base=base) + # this must be compile-time b/c we want to store them in a static field + with jax.ensure_compile_time_eval(): + self.cos_cached, self.sin_cached = self._get_cos_sin_cache(Pos=Pos, HeadSize=HeadSize, base=base) @staticmethod - def _get_cos_sin_cache(HeadSize: NamedArray, Pos: NamedArray, base: float) -> Tuple[jnp.ndarray, jnp.ndarray]: + def _get_cos_sin_cache(HeadSize: hax.Axis, Pos: hax.Axis, base: float) -> Tuple[NamedArray, NamedArray]: HeadHalfSize = HeadSize.resize(HeadSize.size // 2) inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) @@ -480,7 +486,7 @@ def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": k_t, k_emb = jrandom.split(key, 2) transformer = LlamaTransformer.init(config, key=k_t) embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) - lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=config.use_bias) + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) return LlamaLMHeadModel(transformer, embeddings, lm_head) def __call__( @@ -505,7 +511,7 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp - d = {} + d = state_dict.copy() d.update( unflatten_linear_layers( apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True diff --git a/tests/test_llama.py b/tests/test_llama.py index 4c55b5797..d9de5c42f 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -7,7 +7,6 @@ import haliax as hax -from levanter.compat.hf_checkpoints import RepoRef from levanter.models.llama import ( LlamaAttention, LlamaConfig, @@ -140,7 +139,7 @@ def test_llama_attention(): state = attention.to_state_dict() state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} - hf_attention = HFLlamaAttention(config.to_hf_config()) + hf_attention = HFLlamaAttention(config.to_hf_config(32000)) hf_attention.load_state_dict(state, strict=True) x, mask = _get_random_inputs(config) @@ -190,7 +189,7 @@ def test_llama_decoder_layer(): state = llama_decoder_layer.to_state_dict() state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} - hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config()) + hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000)) hf_decoder_layer.load_state_dict(state, strict=True) x, mask = _get_random_inputs(llama_config) @@ -225,7 +224,6 @@ def test_llama_roundtrip(): import torch from transformers import AutoModelForCausalLM, LlamaForCausalLM - model_id = "stanford-crfm/levanter-llama-test" converter = LlamaConfig.default_hf_checkpoint_converter config = LlamaConfig( @@ -235,34 +233,39 @@ def test_llama_roundtrip(): gradient_checkpointing=False, ) Vocab = hax.Axis("vocab", 1000) + hf_config = config.to_hf_config(Vocab.size) # Make input and attn_mask input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) attn_mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos) input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) - torch_model = LlamaForCausalLM.from_pretrained(model_id) + torch.random.manual_seed(0) + + torch_model = LlamaForCausalLM(hf_config) torch_model.eval() torch_out = torch_model(input_torch) torch_out = torch_out.logits[0].detach().cpu().numpy() torch_out = jax.nn.softmax(torch_out, axis=-1) - model = converter.load_pretrained(LlamaLMHeadModel, RepoRef(model_id)) + with tempfile.TemporaryDirectory() as tmpdir: + torch_model.save_pretrained(f"{tmpdir}/torch_model") + + model = converter.load_pretrained(LlamaLMHeadModel, f"{tmpdir}/torch_model") - def compute(input): - model_output = model(input, attn_mask=attn_mask) - return hax.nn.softmax(model_output, axis=model.Vocab) + def compute(input): + model_output = model(input, attn_mask=attn_mask) + return hax.nn.softmax(model_output, axis=model.Vocab) - compute = jax.jit(compute) - jax_out = compute(input).array + # compute = jax.jit(compute) + jax_out = compute(input).array - assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" - assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" - with tempfile.TemporaryDirectory() as tmpdir: - converter.save_pretrained(model, tmpdir, save_reference_code=False) - torch_model2 = AutoModelForCausalLM.from_pretrained(tmpdir) + converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) + torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") torch_model2.eval() torch_out2 = torch_model2(input_torch) From 83802839818e9f8ea266c155b0804629a4d2a4f3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 30 Aug 2023 09:20:19 -0700 Subject: [PATCH 52/57] Update src/levanter/models/llama.py --- src/levanter/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index a7003c8ec..b1920e092 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -214,7 +214,7 @@ def _get_cos_sin_cache(HeadSize: hax.Axis, Pos: hax.Axis, base: float) -> Tuple[ position_ids: NamedArray = hax.arange(Pos) freqs = position_ids * inv_freq.broadcast_axis(Pos) - # This is different from the paper but alignes with HF implementation: + # This is different from the paper but aligns with HF implementation: # It uses a different permutation in order to obtain the same calculation emb = hax.concatenate(HeadSize, (freqs, freqs)) cos_cached = hax.cos(emb) From 93eeebecbf1de75f3ca4e7df9e809a6b88d36dca Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 30 Aug 2023 09:20:56 -0700 Subject: [PATCH 53/57] Update tests/test_llama.py --- tests/test_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_llama.py b/tests/test_llama.py index d9de5c42f..d542490bc 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -258,7 +258,7 @@ def compute(input): model_output = model(input, attn_mask=attn_mask) return hax.nn.softmax(model_output, axis=model.Vocab) - # compute = jax.jit(compute) + compute = jax.jit(compute) jax_out = compute(input).array assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" From dca91578d391e5ac77580f2b15dd1b19fd587699 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 30 Aug 2023 10:18:46 -0700 Subject: [PATCH 54/57] let's just use llama names where reasonable --- src/levanter/models/llama.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index b1920e092..41df853f9 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -345,10 +345,10 @@ def __call__(self, x: NamedArray) -> NamedArray: class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() - attn: LlamaAttention + self_attn: LlamaAttention mlp: LlamaMlp - ln_1: LlamaRMSNorm # input layernorm - ln_2: LlamaRMSNorm # post attention layernorm + input_layernorm: LlamaRMSNorm + post_attention_layernorm: LlamaRMSNorm @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": @@ -371,29 +371,22 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": def __call__(self, x: NamedArray, mask: Optional[NamedArray]): # self attention and skip connection residual = x - x = self.ln_1(x) - attn_output = self.attn(x=x, mask=mask) + x = self.input_layernorm(x) + attn_output = self.self_attn(x=x, mask=mask) x = residual + attn_output # MLP and skip connection residual = x - x = self.ln_2(x) + x = self.post_attention_layernorm(x) mlp_output = self.mlp(x) output = residual + mlp_output return output - def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return { - "attn": "self_attn", - "ln_1": "input_layernorm", - "ln_2": "post_attention_layernorm", - } - class LlamaTransformer(StateDictSerializationMixin, eqx.Module): config: LlamaConfig = eqx.static_field() layers: Stacked[LlamaDecoderLayer] - ln_f: LlamaRMSNorm + norm: LlamaRMSNorm @staticmethod def init(config: LlamaConfig, *, key) -> "LlamaTransformer": @@ -408,7 +401,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaTransformer": @named_call def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray]) -> NamedArray: x = self.layers.fold(x, mask=attn_mask) - x = self.ln_f(x) + x = self.norm(x) return x @@ -426,9 +419,6 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) return state_dict - def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - return {"ln_f": "norm"} - class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): """Similar to GPT2 Embedding, except that: From 9ccd7952ce42ee1390c7e807d370ce513e435a96 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 30 Aug 2023 10:41:51 -0700 Subject: [PATCH 55/57] implement LmHeadModel in LLama --- src/levanter/models/llama.py | 11 +++++------ src/levanter/models/lm_model.py | 3 +-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 41df853f9..98357d15c 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -25,7 +25,7 @@ unstack_state_dict, ) from levanter.models.gpt2 import ACT2FN -from levanter.models.lm_model import LmConfig +from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.utils.py_utils import cached_classproperty @@ -450,7 +450,7 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"token_embeddings": "model.embed_tokens.weight"} -class LlamaLMHeadModel(StateDictSerializationMixin, eqx.Module): +class LlamaLMHeadModel(eqx.Module, LmHeadModel[LlamaConfig], StateDictSerializationMixin): transformer: LlamaTransformer embeddings: LlamaEmbedding lm_head: hnn.Linear @@ -467,10 +467,6 @@ def vocab_size(self) -> int: def Vocab(self) -> Axis: return self.embeddings.Vocab - @property - def Pos(self) -> Axis: - return self.config.Pos - @classmethod def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": k_t, k_emb = jrandom.split(key, 2) @@ -483,6 +479,9 @@ def __call__( self, input_ids: NamedArray, attn_mask: Optional[NamedArray] = None, + *, + inference: bool = False, + key=None, ) -> NamedArray: """ Args: diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index d384d6799..c8fc03283 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -58,9 +58,8 @@ def Vocab(self) -> Axis: pass @property - @abc.abstractmethod def Pos(self) -> Axis: - pass + return self.config.Pos @classmethod @abc.abstractmethod From 5d8a10268c3bfa6d6640562de37a0e9df5da3e00 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 30 Aug 2023 10:42:11 -0700 Subject: [PATCH 56/57] use haliax's built in attention --- src/levanter/models/llama.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 98357d15c..c3e4bda9b 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -259,26 +259,13 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray]): q, k = _apply_rotary_pos_emb(q, k, cos, sin) - scale = jax.lax.rsqrt(float(self.config.HeadSize.size)) - - # do this first to help keep FP values small - q = q * scale - q = q.astype(jnp.float32) - k = k.astype(jnp.float32) - k = k.rename({"position": "key_position"}) + k = k.astype(jnp.float32).rename({"position": "key_position"}) v = v.rename({"position": "key_position"}) - attn_scores = hax.dot("head_size", q, k) - - if mask is not None: - attn_scores = attn_scores + (1.0 - mask) * -1e9 - - attn_scores = attn_scores.astype(jnp.float32) - attn_weights = hnn.softmax(attn_scores, axis="key_position").astype(x.dtype) - # There's no dropout in llama attention, compared with Gpt2 attention - - attn_output = hax.dot("key_position", attn_weights, v) + c = self.config + attn_output = hnn.attention.dot_product_attention(c.Pos, c.KeyPos, c.HeadSize, q, k, v, mask) + attn_output = attn_output.astype(x.dtype) attn_output = self.o_proj(attn_output) return attn_output From 8c6a20dce9467f93af896d890d9c6e1228704814 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 30 Aug 2023 11:08:45 -0700 Subject: [PATCH 57/57] update for latest main: tokenizer resizing --- src/levanter/compat/hf_checkpoints.py | 14 ++++++++++---- src/levanter/models/llama.py | 17 ++++++++++++++++- tests/test_llama.py | 4 +++- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index 027df1aed..c6ebdf4e7 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -439,6 +439,7 @@ def load_pretrained( lm_model_cls: Union[Type[LmWithHfSerializationMixin], LevConfig], ref: Optional[Union[str, RepoRef]] = None, axis_mapping: Optional[ResourceMapping] = None, + resize_vocab_to_match_tokenizer: bool = True, ) -> LmWithHfSerializationMixin: """ Loads a levanter model from a huggingface checkpoint. @@ -477,10 +478,15 @@ def load_pretrained( # Vocab: next, we resize the desired actual size if Vocab.size != tokenizer_Vocab.size: - logger.info( - f"Resizing model from {Vocab.size} to {tokenizer_Vocab.size} to match tokenizer vocab size" - ) - lev_model = lev_model.resize_vocab(tokenizer_Vocab.size) + if resize_vocab_to_match_tokenizer: + logger.info( + f"Resizing model from {Vocab.size} to {tokenizer_Vocab.size} to match tokenizer vocab size" + ) + lev_model = lev_model.resize_vocab(tokenizer_Vocab.size) + else: + logger.warning( + f"Model vocab size ({Vocab.size}) does not match tokenizer vocab size ({tokenizer_Vocab.size})" + ) if axis_mapping is not None: lev_model = haliax.shard_with_axis_mapping(lev_model, axis_mapping) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index c3e4bda9b..bd3638ccd 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass from typing import Callable, Dict, Optional, Tuple, Type, Union @@ -5,13 +6,14 @@ import jax import jax.numpy as jnp import jax.random as jrandom +from jaxtyping import PRNGKeyArray from transformers import LlamaConfig as HfLlamaConfig from transformers import PretrainedConfig as HfConfig import haliax as hax import haliax.nn as hnn from haliax import Axis, AxisSpec, NamedArray -from haliax.jax_utils import named_call, shaped_rng_split +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split from haliax.nn.scan import Stacked from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig @@ -436,6 +438,10 @@ def unembed(self, x: NamedArray): def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"token_embeddings": "model.embed_tokens.weight"} + def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): + new_weights = hax.tree_util.resize_axis(self.token_embeddings, self.Vocab, new_size, key=key) + return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) + class LlamaLMHeadModel(eqx.Module, LmHeadModel[LlamaConfig], StateDictSerializationMixin): transformer: LlamaTransformer @@ -482,6 +488,15 @@ def __call__( lm_logits = self.lm_head(x) return lm_logits + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": + new_Vocab = self.Vocab.resize(new_size) + k1, k2 = maybe_rng_split(key, 2) + new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"transformer": "model", "embeddings": None} diff --git a/tests/test_llama.py b/tests/test_llama.py index d542490bc..5106ab0f9 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -252,7 +252,9 @@ def test_llama_roundtrip(): with tempfile.TemporaryDirectory() as tmpdir: torch_model.save_pretrained(f"{tmpdir}/torch_model") - model = converter.load_pretrained(LlamaLMHeadModel, f"{tmpdir}/torch_model") + model = converter.load_pretrained( + LlamaLMHeadModel, f"{tmpdir}/torch_model", resize_vocab_to_match_tokenizer=False + ) def compute(input): model_output = model(input, attn_mask=attn_mask)