From dd1939d7189cf4122d56b83b1e0262fa92f16a3f Mon Sep 17 00:00:00 2001 From: Light Date: Thu, 11 Apr 2024 21:47:22 +0800 Subject: [PATCH 1/2] Simple quarot proof of concept. --- exllamav2/attn.py | 30 ++++++++++++++++++++++++++++-- exllamav2/config.py | 5 +++++ exllamav2/mlp.py | 20 ++++++++++++++++++-- exllamav2/model.py | 4 ++-- requirements.txt | 3 ++- 5 files changed, 55 insertions(+), 7 deletions(-) diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 088c10cb..112a0741 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -15,6 +15,9 @@ # from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak # import torch.nn.functional as F +from auto_quarot import hadamard_utils +import fast_hadamard_transform + from typing import TYPE_CHECKING if TYPE_CHECKING: from exllamav2.model import ExLlamaV2 @@ -61,6 +64,10 @@ class ExLlamaV2Attention(ExLlamaV2Module): has_norm: bool has_residual: bool + quarot: bool + had_K: torch.Tensor + K: int + had_dim: int class Params: @@ -182,7 +189,8 @@ def __init__(self, key: str, layer_idx: int, has_norm: bool = True, - has_residual: bool = True): + has_residual: bool = True, + quarot: bool = False): super().__init__(model, key) @@ -191,6 +199,11 @@ def __init__(self, self.layer_idx = layer_idx self.has_norm = has_norm self.has_residual = has_residual + self.quarot = quarot + + if self.quarot: + self.had_K, self.K = hadamard_utils.get_hadK(model.config.num_attention_heads) + self.had_dim = model.config.hidden_size // model.config.num_attention_heads self.q_handle = None self.temp_lora_size = 0 @@ -315,6 +328,9 @@ def load(self): self.model.config.arch.rope_neox_style, q_norm, k_norm) + + if self.quarot and self.had_K is not None: + self.had_K = self.had_K.to(self.device_idx) def unload(self): @@ -447,7 +463,7 @@ def forward(self, global has_flash_attn - if self.q_handle is None or intermediates: + if self.quarot or self.q_handle is None or intermediates: return self.forward_torch(hidden_states, cache, attn_params, @@ -800,6 +816,16 @@ def forward_torch(self, if cache is not None: cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) + if self.quarot: + init_shape = attn_output.shape + if self.K == 1: + attn_output = fast_hadamard_transform.hadamard_transform(attn_output.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim).transpose(1, 2), + scale=1/math.sqrt(init_shape[-1]//self.had_dim)).transpose(1, 2) + else: + attn_output = (self.had_K.to(attn_output.dtype) @ attn_output.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim)) / math.sqrt(init_shape[-1]//self.had_dim) + + attn_output = attn_output.reshape(init_shape) + # Output projection attn_proj = self.o_proj.forward(attn_output, loras = loras) diff --git a/exllamav2/config.py b/exllamav2/config.py index 41987af5..bd33594f 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -93,6 +93,7 @@ class ExLlamaV2Config: checkpoint_fused_mlp: bool + quarot: bool = False def __init__(self, model_dir: str | None = None): @@ -200,6 +201,10 @@ def prepare(self, no_tensors: bool = False): "model_max_length", "max_position_embeddings", "max_seq_len"], 2048) + + quarot_config = read(read_config, dict, "quarot_config", None) + if quarot_config is not None: + self.quarot = quarot_config["rotated"] rs = read(read_config, dict, "rope_scaling", None) if rs and "factor" in rs: diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index 6b7fff42..ed404f52 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -13,6 +13,8 @@ if TYPE_CHECKING: from exllamav2.model import ExLlamaV2 +from auto_quarot import hadamard_utils + class ExLlamaV2MLP(ExLlamaV2Module): @@ -30,19 +32,24 @@ class ExLlamaV2MLP(ExLlamaV2Module): has_norm: bool has_residual: bool + quarot: bool + had_K: torch.Tensor + K: int def __init__(self, model: ExLlamaV2, key: str, layer_idx: int, has_norm: bool = True, - has_residual: bool = True): + has_residual: bool = True, + quarot: bool = False): super().__init__(model, key) self.layer_idx = layer_idx self.has_norm = has_norm self.has_residual = has_residual + self.quarot = quarot self.q_handle = None self.temp_lora_size = 0 @@ -50,6 +57,9 @@ def __init__(self, hidden_size = self.model.config.hidden_size intermediate_size = self.model.config.intermediate_size + if self.quarot: + self.had_K, self.K = hadamard_utils.get_hadK(intermediate_size) + if self.has_norm: if self.model.config.arch.norm == "layernorm": self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + self.model.config.arch.norm_key_2) @@ -136,6 +146,9 @@ def load(self): self.model.config.max_input_len * self.model.config.max_batch_size, self.model.config.arch.mlp_act_func == "gelu", self.has_residual) + + if self.quarot: + self.had_K = self.had_K.to(self.device_idx) def unload(self): @@ -224,7 +237,7 @@ def forward(self, loras: list[ExLlamaV2Lora] | None = None, **kwargs) -> torch.Tensor | dict[str: torch.Tensor]: - if self.q_handle is None or intermediates: + if self.quarot or self.q_handle is None or intermediates: return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs) if loras is None or self.temp_lora_size == 0: @@ -271,6 +284,9 @@ def forward_torch(self, elif self.model.config.arch.mlp_act_func == "gelu": y = F.gelu(up) + if self.quarot: + y = hadamard_utils.matmul_hadU_cuda(y, self.had_K, self.K) + down = self.down_proj.forward(y, loras = loras) hidden_states = down + residual if self.has_residual else down diff --git a/exllamav2/model.py b/exllamav2/model.py index 81b21796..7ec07c6b 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -169,9 +169,9 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx) self.modules += [pd] else: - attn = ExLlamaV2Attention(self, layer_key, layer_idx) + attn = ExLlamaV2Attention(self, layer_key, layer_idx, quarot=self.config.quarot) if self.config.arch.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx) - else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx) + else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx, quarot=self.config.quarot) self.modules += [attn, mlp] if self.config.arch.norm == "layernorm": norm = ExLlamaV2LayerNorm(self, "model.norm") diff --git a/requirements.txt b/requirements.txt index a780475c..cbc54fbc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ pygments websockets regex numpy -tokenizers \ No newline at end of file +tokenizers +git+https://github.com/sgsdxzy/AutoQuarot.git \ No newline at end of file From 099f80a77eb72f9c86cbf0eb0e64376224919256 Mon Sep 17 00:00:00 2001 From: Light Date: Sat, 13 Apr 2024 01:05:00 +0800 Subject: [PATCH 2/2] Add quarot for kv. --- exllamav2/attn.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 112a0741..8af106ab 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -65,6 +65,7 @@ class ExLlamaV2Attention(ExLlamaV2Module): has_norm: bool has_residual: bool quarot: bool + kv_quarot: bool had_K: torch.Tensor K: int had_dim: int @@ -200,6 +201,7 @@ def __init__(self, self.has_norm = has_norm self.has_residual = has_residual self.quarot = quarot + self.kv_quarot = True # should be an option if self.quarot: self.had_K, self.K = hadamard_utils.get_hadK(model.config.num_attention_heads) @@ -463,7 +465,7 @@ def forward(self, global has_flash_attn - if self.quarot or self.q_handle is None or intermediates: + if self.quarot or self.kv_quarot or self.q_handle is None or intermediates: return self.forward_torch(hidden_states, cache, attn_params, @@ -766,6 +768,12 @@ def forward_torch(self, ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, position_offsets, self.model.config.arch.rope_neox_style) ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, position_offsets, self.model.config.arch.rope_neox_style) + # Add another rotation for the keys/queries + + if self.kv_quarot: + query_states = fast_hadamard_transform.hadamard_transform(query_states.float(), scale=1/math.sqrt(query_states.shape[-1])).to(query_states.dtype) + key_states = fast_hadamard_transform.hadamard_transform(key_states.float(), scale=1/math.sqrt(key_states.shape[-1])).to(key_states.dtype) + # Add keys and values to cache if cache is not None: @@ -816,6 +824,8 @@ def forward_torch(self, if cache is not None: cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) + # QuaRot before output + if self.quarot: init_shape = attn_output.shape if self.K == 1: